Skip to main content

rust_igraph/algorithms/games/
edge_sampling.rs

1//! Edge sampling and train/test edge splits (ALGO-TR-009).
2//!
3//! Utilities for randomly sampling edges and splitting a graph's edge set
4//! into disjoint train/test partitions. Essential for link prediction
5//! evaluation: the test edges are held out, and the model must predict them.
6
7#![allow(
8    clippy::cast_possible_truncation,
9    clippy::cast_precision_loss,
10    clippy::cast_sign_loss,
11    clippy::many_single_char_names
12)]
13
14use crate::core::rng::SplitMix64;
15use crate::core::{Graph, IgraphError, IgraphResult, VertexId};
16
17/// Result of an edge train/test split.
18#[derive(Debug, Clone, PartialEq, Eq)]
19pub struct EdgeSplit {
20    /// Training edges (the larger partition).
21    pub train: Vec<(VertexId, VertexId)>,
22    /// Test edges (the smaller, held-out partition).
23    pub test: Vec<(VertexId, VertexId)>,
24}
25
26/// Uniformly sample `count` edges from the graph without replacement.
27///
28/// Returns a random subset of the graph's edges. If `count` exceeds
29/// the number of edges, all edges are returned (shuffled).
30///
31/// # Parameters
32///
33/// - `graph` — The input graph.
34/// - `count` — Number of edges to sample.
35/// - `seed` — PRNG seed for deterministic sampling.
36///
37/// # Examples
38///
39/// ```
40/// use rust_igraph::{Graph, sample_edges};
41///
42/// let g = Graph::from_edges(
43///     &[(0,1),(1,2),(2,3),(3,4),(4,0)], false, Some(5)
44/// ).unwrap();
45///
46/// let sampled = sample_edges(&g, 3, 42).unwrap();
47/// assert_eq!(sampled.len(), 3);
48/// for &(u, v) in &sampled {
49///     assert!(g.has_edge(u, v));
50/// }
51/// ```
52pub fn sample_edges(
53    graph: &Graph,
54    count: usize,
55    seed: u64,
56) -> IgraphResult<Vec<(VertexId, VertexId)>> {
57    let edges = collect_edges(graph);
58    let actual = count.min(edges.len());
59    Ok(shuffle_and_take(edges, actual, seed))
60}
61
62/// Split graph edges into train and test sets.
63///
64/// Randomly partitions the edge set: `test_fraction` of edges go to the
65/// test set, the remainder to training. Useful for link prediction
66/// evaluation where test edges are removed from the graph.
67///
68/// # Parameters
69///
70/// - `graph` — The input graph.
71/// - `test_fraction` — Fraction of edges for the test set (0.0 to 1.0).
72/// - `seed` — PRNG seed for deterministic splitting.
73///
74/// # Returns
75///
76/// An [`EdgeSplit`] with `train` and `test` edge vectors.
77///
78/// # Errors
79///
80/// Returns an error if `test_fraction` is not in `[0.0, 1.0]`.
81///
82/// # Examples
83///
84/// ```
85/// use rust_igraph::{Graph, split_edges};
86///
87/// let g = Graph::from_edges(
88///     &[(0,1),(1,2),(2,3),(3,4),(4,0),(0,2),(1,3),(2,4)],
89///     false, Some(5)
90/// ).unwrap();
91///
92/// let split = split_edges(&g, 0.25, 42).unwrap();
93/// assert_eq!(split.train.len() + split.test.len(), g.ecount());
94/// assert_eq!(split.test.len(), 2); // 25% of 8 = 2
95/// ```
96pub fn split_edges(graph: &Graph, test_fraction: f64, seed: u64) -> IgraphResult<EdgeSplit> {
97    if !(0.0..=1.0).contains(&test_fraction) {
98        return Err(IgraphError::InvalidArgument(format!(
99            "test_fraction must be in [0.0, 1.0], got {test_fraction}"
100        )));
101    }
102
103    let edges = collect_edges(graph);
104    let ne = edges.len();
105    let test_count = (ne as f64 * test_fraction).round() as usize;
106
107    let shuffled = shuffle_all(edges, seed);
108    let test = shuffled[..test_count].to_vec();
109    let train = shuffled[test_count..].to_vec();
110
111    Ok(EdgeSplit { train, test })
112}
113
114/// Split edges ensuring the training graph remains connected.
115///
116/// Like [`split_edges`] but guarantees that removing the test edges does
117/// not disconnect the graph. Edges whose removal would create a bridge
118/// (disconnection) are kept in training. If maintaining connectivity
119/// limits the test set size, fewer edges than requested may end up in test.
120///
121/// # Examples
122///
123/// ```
124/// use rust_igraph::{Graph, split_edges_connected};
125///
126/// // Cycle: every edge can be removed without disconnecting
127/// let g = Graph::from_edges(
128///     &[(0,1),(1,2),(2,3),(3,0)], false, Some(4)
129/// ).unwrap();
130///
131/// let split = split_edges_connected(&g, 0.5, 42).unwrap();
132/// assert_eq!(split.train.len() + split.test.len(), 4);
133/// // Test set has at most 2 edges (50%)
134/// assert!(split.test.len() <= 2);
135/// ```
136pub fn split_edges_connected(
137    graph: &Graph,
138    test_fraction: f64,
139    seed: u64,
140) -> IgraphResult<EdgeSplit> {
141    if !(0.0..=1.0).contains(&test_fraction) {
142        return Err(IgraphError::InvalidArgument(format!(
143            "test_fraction must be in [0.0, 1.0], got {test_fraction}"
144        )));
145    }
146
147    let edges = collect_edges(graph);
148    let ne = edges.len();
149    let target_test = (ne as f64 * test_fraction).round() as usize;
150
151    let shuffled = shuffle_all(edges, seed);
152
153    let mut train: Vec<(VertexId, VertexId)> = Vec::with_capacity(ne);
154    let mut test: Vec<(VertexId, VertexId)> = Vec::new();
155
156    // Build adjacency as we add training edges; check connectivity via
157    // simple degree tracking — an edge can go to test only if both
158    // endpoints will still have at least one remaining training edge.
159    // This is a heuristic; for exact bridge detection we'd need a more
160    // expensive algorithm, but this keeps the training graph connected
161    // in practice for reasonably dense graphs.
162    let nv = graph.vcount() as usize;
163    let mut train_degree: Vec<u32> = vec![0; nv];
164
165    // First pass: assign all edges to train to compute full degrees
166    for &(u, v) in &shuffled {
167        train_degree[u as usize] += 1;
168        train_degree[v as usize] += 1;
169    }
170
171    // Second pass: try to move edges to test
172    for &(u, v) in &shuffled {
173        if test.len() >= target_test {
174            train.push((u, v));
175            continue;
176        }
177
178        // Can we remove this edge without isolating a vertex?
179        let u_deg = train_degree[u as usize];
180        let v_deg = train_degree[v as usize];
181
182        if u_deg > 1 && v_deg > 1 {
183            test.push((u, v));
184            train_degree[u as usize] -= 1;
185            train_degree[v as usize] -= 1;
186        } else {
187            train.push((u, v));
188        }
189    }
190
191    Ok(EdgeSplit { train, test })
192}
193
194// --- Internal helpers ---
195
196fn collect_edges(graph: &Graph) -> Vec<(VertexId, VertexId)> {
197    graph.edges().collect()
198}
199
200fn shuffle_and_take(
201    mut items: Vec<(VertexId, VertexId)>,
202    k: usize,
203    seed: u64,
204) -> Vec<(VertexId, VertexId)> {
205    let n = items.len();
206    let mut rng = SplitMix64::new(seed);
207    let take = k.min(n);
208    for i in 0..take {
209        let j = i + rng.gen_index(n - i);
210        items.swap(i, j);
211    }
212    items.truncate(take);
213    items
214}
215
216fn shuffle_all(mut items: Vec<(VertexId, VertexId)>, seed: u64) -> Vec<(VertexId, VertexId)> {
217    let n = items.len();
218    let mut rng = SplitMix64::new(seed);
219    for i in 0..n.saturating_sub(1) {
220        let j = i + rng.gen_index(n - i);
221        items.swap(i, j);
222    }
223    items
224}
225
226#[cfg(test)]
227mod tests {
228    use super::*;
229
230    fn cycle5() -> Graph {
231        Graph::from_edges(&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)], false, Some(5)).unwrap()
232    }
233
234    fn dense5() -> Graph {
235        Graph::from_edges(
236            &[
237                (0, 1),
238                (1, 2),
239                (2, 3),
240                (3, 4),
241                (4, 0),
242                (0, 2),
243                (1, 3),
244                (2, 4),
245            ],
246            false,
247            Some(5),
248        )
249        .unwrap()
250    }
251
252    #[test]
253    fn sample_edges_basic() {
254        let g = cycle5();
255        let sampled = sample_edges(&g, 3, 42).unwrap();
256        assert_eq!(sampled.len(), 3);
257        for &(u, v) in &sampled {
258            assert!(g.has_edge(u, v));
259        }
260    }
261
262    #[test]
263    fn sample_edges_all() {
264        let g = cycle5();
265        let sampled = sample_edges(&g, 100, 42).unwrap();
266        assert_eq!(sampled.len(), 5);
267    }
268
269    #[test]
270    fn sample_edges_zero() {
271        let g = cycle5();
272        let sampled = sample_edges(&g, 0, 42).unwrap();
273        assert!(sampled.is_empty());
274    }
275
276    #[test]
277    fn sample_edges_no_duplicates() {
278        let g = dense5();
279        let sampled = sample_edges(&g, 5, 42).unwrap();
280        for i in 0..sampled.len() {
281            for j in (i + 1)..sampled.len() {
282                assert_ne!(sampled[i], sampled[j]);
283            }
284        }
285    }
286
287    #[test]
288    fn sample_edges_deterministic() {
289        let g = dense5();
290        let s1 = sample_edges(&g, 4, 99).unwrap();
291        let s2 = sample_edges(&g, 4, 99).unwrap();
292        assert_eq!(s1, s2);
293    }
294
295    #[test]
296    fn split_basic() {
297        let g = dense5(); // 8 edges
298        let split = split_edges(&g, 0.25, 42).unwrap();
299        assert_eq!(split.train.len() + split.test.len(), 8);
300        assert_eq!(split.test.len(), 2); // 25% of 8 = 2
301    }
302
303    #[test]
304    fn split_all_train() {
305        let g = cycle5();
306        let split = split_edges(&g, 0.0, 42).unwrap();
307        assert_eq!(split.train.len(), 5);
308        assert!(split.test.is_empty());
309    }
310
311    #[test]
312    fn split_all_test() {
313        let g = cycle5();
314        let split = split_edges(&g, 1.0, 42).unwrap();
315        assert!(split.train.is_empty());
316        assert_eq!(split.test.len(), 5);
317    }
318
319    #[test]
320    fn split_invalid_fraction() {
321        let g = cycle5();
322        assert!(split_edges(&g, 1.5, 42).is_err());
323        assert!(split_edges(&g, -0.1, 42).is_err());
324    }
325
326    #[test]
327    fn split_deterministic() {
328        let g = dense5();
329        let s1 = split_edges(&g, 0.3, 99).unwrap();
330        let s2 = split_edges(&g, 0.3, 99).unwrap();
331        assert_eq!(s1, s2);
332    }
333
334    #[test]
335    fn split_connected_basic() {
336        let g = dense5();
337        let split = split_edges_connected(&g, 0.25, 42).unwrap();
338        assert_eq!(split.train.len() + split.test.len(), 8);
339        // No vertex should be isolated in training set
340        let nv = g.vcount() as usize;
341        let mut deg = vec![0u32; nv];
342        for &(u, v) in &split.train {
343            deg[u as usize] += 1;
344            deg[v as usize] += 1;
345        }
346        for d in &deg {
347            assert!(*d >= 1, "vertex isolated in training set");
348        }
349    }
350
351    #[test]
352    fn split_connected_cycle() {
353        let g = cycle5();
354        let split = split_edges_connected(&g, 0.5, 42).unwrap();
355        // In a cycle, each vertex has degree 2, so we can remove at most
356        // one edge per vertex → at most 2 edges can go to test while
357        // keeping all degrees >= 1
358        assert!(split.test.len() <= 2);
359        let nv = g.vcount() as usize;
360        let mut deg = vec![0u32; nv];
361        for &(u, v) in &split.train {
362            deg[u as usize] += 1;
363            deg[v as usize] += 1;
364        }
365        for d in &deg {
366            assert!(*d >= 1);
367        }
368    }
369
370    #[test]
371    fn split_connected_invalid_fraction() {
372        let g = cycle5();
373        assert!(split_edges_connected(&g, 2.0, 42).is_err());
374    }
375
376    #[test]
377    fn empty_graph() {
378        let g = Graph::with_vertices(3);
379        let sampled = sample_edges(&g, 5, 42).unwrap();
380        assert!(sampled.is_empty());
381
382        let split = split_edges(&g, 0.5, 42).unwrap();
383        assert!(split.train.is_empty());
384        assert!(split.test.is_empty());
385    }
386
387    #[test]
388    fn directed_graph() {
389        let g = Graph::from_edges(&[(0, 1), (1, 2), (2, 0)], true, Some(3)).unwrap();
390        let sampled = sample_edges(&g, 2, 42).unwrap();
391        assert_eq!(sampled.len(), 2);
392        for &(u, v) in &sampled {
393            assert!(g.has_edge(u, v));
394        }
395    }
396}