rust_igraph/algorithms/games/
edge_sampling.rs1#![allow(
8 clippy::cast_possible_truncation,
9 clippy::cast_precision_loss,
10 clippy::cast_sign_loss,
11 clippy::many_single_char_names
12)]
13
14use crate::core::rng::SplitMix64;
15use crate::core::{Graph, IgraphError, IgraphResult, VertexId};
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct EdgeSplit {
20 pub train: Vec<(VertexId, VertexId)>,
22 pub test: Vec<(VertexId, VertexId)>,
24}
25
26pub fn sample_edges(
53 graph: &Graph,
54 count: usize,
55 seed: u64,
56) -> IgraphResult<Vec<(VertexId, VertexId)>> {
57 let edges = collect_edges(graph);
58 let actual = count.min(edges.len());
59 Ok(shuffle_and_take(edges, actual, seed))
60}
61
62pub fn split_edges(graph: &Graph, test_fraction: f64, seed: u64) -> IgraphResult<EdgeSplit> {
97 if !(0.0..=1.0).contains(&test_fraction) {
98 return Err(IgraphError::InvalidArgument(format!(
99 "test_fraction must be in [0.0, 1.0], got {test_fraction}"
100 )));
101 }
102
103 let edges = collect_edges(graph);
104 let ne = edges.len();
105 let test_count = (ne as f64 * test_fraction).round() as usize;
106
107 let shuffled = shuffle_all(edges, seed);
108 let test = shuffled[..test_count].to_vec();
109 let train = shuffled[test_count..].to_vec();
110
111 Ok(EdgeSplit { train, test })
112}
113
114pub fn split_edges_connected(
137 graph: &Graph,
138 test_fraction: f64,
139 seed: u64,
140) -> IgraphResult<EdgeSplit> {
141 if !(0.0..=1.0).contains(&test_fraction) {
142 return Err(IgraphError::InvalidArgument(format!(
143 "test_fraction must be in [0.0, 1.0], got {test_fraction}"
144 )));
145 }
146
147 let edges = collect_edges(graph);
148 let ne = edges.len();
149 let target_test = (ne as f64 * test_fraction).round() as usize;
150
151 let shuffled = shuffle_all(edges, seed);
152
153 let mut train: Vec<(VertexId, VertexId)> = Vec::with_capacity(ne);
154 let mut test: Vec<(VertexId, VertexId)> = Vec::new();
155
156 let nv = graph.vcount() as usize;
163 let mut train_degree: Vec<u32> = vec![0; nv];
164
165 for &(u, v) in &shuffled {
167 train_degree[u as usize] += 1;
168 train_degree[v as usize] += 1;
169 }
170
171 for &(u, v) in &shuffled {
173 if test.len() >= target_test {
174 train.push((u, v));
175 continue;
176 }
177
178 let u_deg = train_degree[u as usize];
180 let v_deg = train_degree[v as usize];
181
182 if u_deg > 1 && v_deg > 1 {
183 test.push((u, v));
184 train_degree[u as usize] -= 1;
185 train_degree[v as usize] -= 1;
186 } else {
187 train.push((u, v));
188 }
189 }
190
191 Ok(EdgeSplit { train, test })
192}
193
194fn collect_edges(graph: &Graph) -> Vec<(VertexId, VertexId)> {
197 graph.edges().collect()
198}
199
200fn shuffle_and_take(
201 mut items: Vec<(VertexId, VertexId)>,
202 k: usize,
203 seed: u64,
204) -> Vec<(VertexId, VertexId)> {
205 let n = items.len();
206 let mut rng = SplitMix64::new(seed);
207 let take = k.min(n);
208 for i in 0..take {
209 let j = i + rng.gen_index(n - i);
210 items.swap(i, j);
211 }
212 items.truncate(take);
213 items
214}
215
216fn shuffle_all(mut items: Vec<(VertexId, VertexId)>, seed: u64) -> Vec<(VertexId, VertexId)> {
217 let n = items.len();
218 let mut rng = SplitMix64::new(seed);
219 for i in 0..n.saturating_sub(1) {
220 let j = i + rng.gen_index(n - i);
221 items.swap(i, j);
222 }
223 items
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229
230 fn cycle5() -> Graph {
231 Graph::from_edges(&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)], false, Some(5)).unwrap()
232 }
233
234 fn dense5() -> Graph {
235 Graph::from_edges(
236 &[
237 (0, 1),
238 (1, 2),
239 (2, 3),
240 (3, 4),
241 (4, 0),
242 (0, 2),
243 (1, 3),
244 (2, 4),
245 ],
246 false,
247 Some(5),
248 )
249 .unwrap()
250 }
251
252 #[test]
253 fn sample_edges_basic() {
254 let g = cycle5();
255 let sampled = sample_edges(&g, 3, 42).unwrap();
256 assert_eq!(sampled.len(), 3);
257 for &(u, v) in &sampled {
258 assert!(g.has_edge(u, v));
259 }
260 }
261
262 #[test]
263 fn sample_edges_all() {
264 let g = cycle5();
265 let sampled = sample_edges(&g, 100, 42).unwrap();
266 assert_eq!(sampled.len(), 5);
267 }
268
269 #[test]
270 fn sample_edges_zero() {
271 let g = cycle5();
272 let sampled = sample_edges(&g, 0, 42).unwrap();
273 assert!(sampled.is_empty());
274 }
275
276 #[test]
277 fn sample_edges_no_duplicates() {
278 let g = dense5();
279 let sampled = sample_edges(&g, 5, 42).unwrap();
280 for i in 0..sampled.len() {
281 for j in (i + 1)..sampled.len() {
282 assert_ne!(sampled[i], sampled[j]);
283 }
284 }
285 }
286
287 #[test]
288 fn sample_edges_deterministic() {
289 let g = dense5();
290 let s1 = sample_edges(&g, 4, 99).unwrap();
291 let s2 = sample_edges(&g, 4, 99).unwrap();
292 assert_eq!(s1, s2);
293 }
294
295 #[test]
296 fn split_basic() {
297 let g = dense5(); let split = split_edges(&g, 0.25, 42).unwrap();
299 assert_eq!(split.train.len() + split.test.len(), 8);
300 assert_eq!(split.test.len(), 2); }
302
303 #[test]
304 fn split_all_train() {
305 let g = cycle5();
306 let split = split_edges(&g, 0.0, 42).unwrap();
307 assert_eq!(split.train.len(), 5);
308 assert!(split.test.is_empty());
309 }
310
311 #[test]
312 fn split_all_test() {
313 let g = cycle5();
314 let split = split_edges(&g, 1.0, 42).unwrap();
315 assert!(split.train.is_empty());
316 assert_eq!(split.test.len(), 5);
317 }
318
319 #[test]
320 fn split_invalid_fraction() {
321 let g = cycle5();
322 assert!(split_edges(&g, 1.5, 42).is_err());
323 assert!(split_edges(&g, -0.1, 42).is_err());
324 }
325
326 #[test]
327 fn split_deterministic() {
328 let g = dense5();
329 let s1 = split_edges(&g, 0.3, 99).unwrap();
330 let s2 = split_edges(&g, 0.3, 99).unwrap();
331 assert_eq!(s1, s2);
332 }
333
334 #[test]
335 fn split_connected_basic() {
336 let g = dense5();
337 let split = split_edges_connected(&g, 0.25, 42).unwrap();
338 assert_eq!(split.train.len() + split.test.len(), 8);
339 let nv = g.vcount() as usize;
341 let mut deg = vec![0u32; nv];
342 for &(u, v) in &split.train {
343 deg[u as usize] += 1;
344 deg[v as usize] += 1;
345 }
346 for d in ° {
347 assert!(*d >= 1, "vertex isolated in training set");
348 }
349 }
350
351 #[test]
352 fn split_connected_cycle() {
353 let g = cycle5();
354 let split = split_edges_connected(&g, 0.5, 42).unwrap();
355 assert!(split.test.len() <= 2);
359 let nv = g.vcount() as usize;
360 let mut deg = vec![0u32; nv];
361 for &(u, v) in &split.train {
362 deg[u as usize] += 1;
363 deg[v as usize] += 1;
364 }
365 for d in ° {
366 assert!(*d >= 1);
367 }
368 }
369
370 #[test]
371 fn split_connected_invalid_fraction() {
372 let g = cycle5();
373 assert!(split_edges_connected(&g, 2.0, 42).is_err());
374 }
375
376 #[test]
377 fn empty_graph() {
378 let g = Graph::with_vertices(3);
379 let sampled = sample_edges(&g, 5, 42).unwrap();
380 assert!(sampled.is_empty());
381
382 let split = split_edges(&g, 0.5, 42).unwrap();
383 assert!(split.train.is_empty());
384 assert!(split.test.is_empty());
385 }
386
387 #[test]
388 fn directed_graph() {
389 let g = Graph::from_edges(&[(0, 1), (1, 2), (2, 0)], true, Some(3)).unwrap();
390 let sampled = sample_edges(&g, 2, 42).unwrap();
391 assert_eq!(sampled.len(), 2);
392 for &(u, v) in &sampled {
393 assert!(g.has_edge(u, v));
394 }
395 }
396}