rust_igraph/algorithms/games/
negative_sampling.rs1#![allow(
8 clippy::cast_possible_truncation,
9 clippy::cast_sign_loss,
10 clippy::many_single_char_names
11)]
12
13use crate::core::rng::SplitMix64;
14use crate::core::{Graph, IgraphResult, VertexId};
15
16pub fn sample_negative_edges(
54 graph: &Graph,
55 count: usize,
56 seed: u64,
57) -> IgraphResult<Vec<(VertexId, VertexId)>> {
58 let nv = graph.vcount();
59 if nv < 2 || count == 0 {
60 return Ok(Vec::new());
61 }
62
63 let directed = graph.is_directed();
64 let nv64 = u64::from(nv);
65 let max_possible = if directed {
66 nv64 * (nv64 - 1)
67 } else {
68 nv64 * (nv64 - 1) / 2
69 };
70
71 let existing_edges = graph.ecount() as u64;
72 let available = max_possible.saturating_sub(existing_edges);
73 if available == 0 {
74 return Ok(Vec::new());
75 }
76
77 let target = count.min(available as usize);
78 let mut rng = SplitMix64::new(seed);
79 let mut result: Vec<(VertexId, VertexId)> = Vec::with_capacity(target);
80
81 let max_attempts = target as u64 * 20 + 1000;
82 let mut attempts: u64 = 0;
83
84 while result.len() < target && attempts < max_attempts {
85 attempts += 1;
86
87 let u = rng.gen_index(nv as usize) as VertexId;
88 let v = rng.gen_index(nv as usize) as VertexId;
89
90 if u == v {
91 continue;
92 }
93
94 let (a, b) = if !directed && u > v { (v, u) } else { (u, v) };
95
96 if graph.has_edge(a, b) {
97 continue;
98 }
99
100 if result.contains(&(a, b)) {
101 continue;
102 }
103
104 result.push((a, b));
105 }
106
107 Ok(result)
108}
109
110pub fn sample_negative_edges_excluding(
135 graph: &Graph,
136 count: usize,
137 exclude: &[(VertexId, VertexId)],
138 seed: u64,
139) -> IgraphResult<Vec<(VertexId, VertexId)>> {
140 let nv = graph.vcount();
141 if nv < 2 || count == 0 {
142 return Ok(Vec::new());
143 }
144
145 let directed = graph.is_directed();
146 let mut rng = SplitMix64::new(seed);
147 let mut result: Vec<(VertexId, VertexId)> = Vec::with_capacity(count);
148
149 let max_attempts = count as u64 * 20 + 1000;
150 let mut attempts: u64 = 0;
151
152 while result.len() < count && attempts < max_attempts {
153 attempts += 1;
154
155 let u = rng.gen_index(nv as usize) as VertexId;
156 let v = rng.gen_index(nv as usize) as VertexId;
157
158 if u == v {
159 continue;
160 }
161
162 let (a, b) = if !directed && u > v { (v, u) } else { (u, v) };
163
164 if graph.has_edge(a, b) {
165 continue;
166 }
167
168 if exclude.contains(&(a, b)) {
169 continue;
170 }
171
172 if result.contains(&(a, b)) {
173 continue;
174 }
175
176 result.push((a, b));
177 }
178
179 Ok(result)
180}
181
182pub fn sample_negative_edges_degree_biased(
205 graph: &Graph,
206 count: usize,
207 seed: u64,
208) -> IgraphResult<Vec<(VertexId, VertexId)>> {
209 let nv = graph.vcount();
210 if nv < 2 || count == 0 {
211 return Ok(Vec::new());
212 }
213
214 let directed = graph.is_directed();
215 let mut rng = SplitMix64::new(seed);
216
217 let mut cumulative_degree: Vec<u64> = Vec::with_capacity(nv as usize);
218 let mut total_degree: u64 = 0;
219 for vid in 0..nv {
220 let deg = graph.degree(vid)?;
221 total_degree += (deg as u64) + 1;
222 cumulative_degree.push(total_degree);
223 }
224
225 let mut result: Vec<(VertexId, VertexId)> = Vec::with_capacity(count);
226 let max_attempts = count as u64 * 20 + 1000;
227 let mut attempts: u64 = 0;
228
229 while result.len() < count && attempts < max_attempts {
230 attempts += 1;
231
232 let u = sample_by_cumulative(&cumulative_degree, total_degree, &mut rng);
233 let v = sample_by_cumulative(&cumulative_degree, total_degree, &mut rng);
234
235 if u == v {
236 continue;
237 }
238
239 let (a, b) = if !directed && u > v { (v, u) } else { (u, v) };
240
241 if graph.has_edge(a, b) {
242 continue;
243 }
244
245 if result.contains(&(a, b)) {
246 continue;
247 }
248
249 result.push((a, b));
250 }
251
252 Ok(result)
253}
254
255fn sample_by_cumulative(cumulative: &[u64], total: u64, rng: &mut SplitMix64) -> VertexId {
258 let r = rng.next_u64() % total;
259 let idx = match cumulative.binary_search(&(r + 1)) {
262 Ok(i) | Err(i) => i,
263 };
264 idx as VertexId
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 fn path4() -> Graph {
272 Graph::from_edges(&[(0, 1), (1, 2), (2, 3)], false, Some(4)).unwrap()
273 }
274
275 fn complete4() -> Graph {
276 Graph::from_edges(
277 &[(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)],
278 false,
279 Some(4),
280 )
281 .unwrap()
282 }
283
284 #[test]
285 fn basic_negative_sampling() {
286 let g = path4();
287 let neg = sample_negative_edges(&g, 3, 42).unwrap();
288 assert_eq!(neg.len(), 3);
289 for &(u, v) in &neg {
290 assert!(!g.has_edge(u, v));
291 assert!(u < v);
292 assert_ne!(u, v);
293 }
294 }
295
296 #[test]
297 fn no_duplicates() {
298 let g = path4();
299 let neg = sample_negative_edges(&g, 3, 42).unwrap();
300 for i in 0..neg.len() {
301 for j in (i + 1)..neg.len() {
302 assert_ne!(neg[i], neg[j]);
303 }
304 }
305 }
306
307 #[test]
308 fn complete_graph_no_negatives() {
309 let g = complete4();
310 let neg = sample_negative_edges(&g, 10, 42).unwrap();
311 assert!(neg.is_empty());
312 }
313
314 #[test]
315 fn deterministic() {
316 let g = path4();
317 let n1 = sample_negative_edges(&g, 3, 99).unwrap();
318 let n2 = sample_negative_edges(&g, 3, 99).unwrap();
319 assert_eq!(n1, n2);
320 }
321
322 #[test]
323 fn empty_graph() {
324 let g = Graph::with_vertices(0);
325 let neg = sample_negative_edges(&g, 5, 42).unwrap();
326 assert!(neg.is_empty());
327 }
328
329 #[test]
330 fn single_vertex() {
331 let g = Graph::with_vertices(1);
332 let neg = sample_negative_edges(&g, 5, 42).unwrap();
333 assert!(neg.is_empty());
334 }
335
336 #[test]
337 fn zero_count() {
338 let g = path4();
339 let neg = sample_negative_edges(&g, 0, 42).unwrap();
340 assert!(neg.is_empty());
341 }
342
343 #[test]
344 fn directed_graph() {
345 let g = Graph::from_edges(&[(0, 1), (1, 2)], true, Some(3)).unwrap();
346 let neg = sample_negative_edges(&g, 4, 42).unwrap();
347 assert_eq!(neg.len(), 4);
348 for &(u, v) in &neg {
349 assert!(!g.has_edge(u, v));
350 assert_ne!(u, v);
351 }
352 }
353
354 #[test]
355 fn excluding_works() {
356 let g = path4(); let exclude = vec![(0, 3)];
359 let neg = sample_negative_edges_excluding(&g, 10, &exclude, 42).unwrap();
360 for &(u, v) in &neg {
361 assert_ne!((u, v), (0, 3));
362 assert!(!g.has_edge(u, v));
363 }
364 }
365
366 #[test]
367 fn degree_biased_valid() {
368 let g = path4();
369 let neg = sample_negative_edges_degree_biased(&g, 3, 42).unwrap();
370 assert_eq!(neg.len(), 3);
371 for &(u, v) in &neg {
372 assert!(!g.has_edge(u, v));
373 assert!(u < v);
374 }
375 }
376
377 #[test]
378 fn degree_biased_prefers_high_degree() {
379 let edges: Vec<(u32, u32)> = (1..5).map(|i| (0, i)).collect();
384 let g = Graph::from_edges(&edges, false, Some(20)).unwrap();
385
386 let neg = sample_negative_edges_degree_biased(&g, 30, 42).unwrap();
387 let v0_count = neg.iter().filter(|&&(u, v)| u == 0 || v == 0).count();
388 assert!(v0_count >= 3);
391 }
392
393 #[test]
394 fn respects_max_available() {
395 let g = Graph::from_edges(&[(0, 1), (1, 2)], false, Some(3)).unwrap();
397 let neg = sample_negative_edges(&g, 100, 42).unwrap();
398 assert_eq!(neg.len(), 1);
399 assert_eq!(neg[0], (0, 2));
400 }
401}