Skip to main content

rust_igraph/algorithms/spanning/
random_spanning_tree.rs

1//! Random spanning tree via loop-erased random walk (ALGO-RST-001).
2//!
3//! Counterpart of `igraph_random_spanning_tree()` from
4//! `references/igraph/src/misc/spanning_trees.c`.
5//!
6//! Uniformly samples spanning trees of a connected graph (or spanning
7//! forests of a disconnected graph) using Wilson's algorithm (loop-erased
8//! random walk). Edge directions are ignored.
9
10use crate::core::rng::SplitMix64;
11use crate::core::{Graph, IgraphError, IgraphResult};
12
13/// Uniformly sample a random spanning tree (or forest) of a graph.
14///
15/// Uses loop-erased random walk (Wilson's algorithm). Edge directions
16/// are ignored. Multi-edges are supported and affect sampling frequency.
17///
18/// If `start_vertex` is `Some(v)`, only the component containing `v`
19/// is spanned; the result has `component_size - 1` edges.
20/// If `start_vertex` is `None`, a random spanning forest of all
21/// components is returned.
22///
23/// Returns a vector of edge IDs forming the spanning tree/forest.
24///
25/// # Errors
26///
27/// - `InvalidArgument` if `start_vertex` is out of range.
28///
29/// # Examples
30///
31/// ```
32/// use rust_igraph::{Graph, random_spanning_tree};
33///
34/// // Triangle: any spanning tree has exactly 2 edges.
35/// let mut g = Graph::with_vertices(3);
36/// g.add_edge(0, 1).unwrap();
37/// g.add_edge(1, 2).unwrap();
38/// g.add_edge(0, 2).unwrap();
39/// let tree = random_spanning_tree(&g, Some(0), 42).unwrap();
40/// assert_eq!(tree.len(), 2);
41/// ```
42pub fn random_spanning_tree(
43    graph: &Graph,
44    start_vertex: Option<u32>,
45    seed: u64,
46) -> IgraphResult<Vec<u32>> {
47    let vcount = graph.vcount();
48
49    if let Some(v) = start_vertex {
50        if v >= vcount {
51            return Err(IgraphError::InvalidArgument(format!(
52                "random_spanning_tree: vertex {v} out of range (vcount={vcount})"
53            )));
54        }
55    }
56
57    if vcount == 0 {
58        return Ok(Vec::new());
59    }
60
61    let adj = build_incidence(graph)?;
62    let mut rng = SplitMix64::new(seed);
63    let mut visited = vec![false; vcount as usize];
64    let mut result: Vec<u32> = Vec::new();
65
66    if let Some(vid) = start_vertex {
67        let comp_size = count_component(graph, vid, &adj)?;
68        lerw(
69            graph,
70            &adj,
71            vid,
72            comp_size,
73            &mut visited,
74            &mut rng,
75            &mut result,
76        )?;
77    } else {
78        let components = find_components(vcount, &adj);
79        for (root, comp_size) in components {
80            lerw(
81                graph,
82                &adj,
83                root,
84                comp_size,
85                &mut visited,
86                &mut rng,
87                &mut result,
88            )?;
89        }
90    }
91
92    Ok(result)
93}
94
95/// For each vertex, store a list of `(edge_id, other_vertex)` pairs,
96/// treating the graph as undirected.
97fn build_incidence(graph: &Graph) -> IgraphResult<Vec<Vec<(u32, u32)>>> {
98    let vcount = graph.vcount();
99    let ecount = graph.ecount();
100    let mut inc: Vec<Vec<(u32, u32)>> = vec![Vec::new(); vcount as usize];
101
102    for eid in 0..ecount {
103        let eid_u32 = u32::try_from(eid).map_err(|_| IgraphError::Internal("edge id overflow"))?;
104        let (from, to) = graph.edge(eid_u32)?;
105        inc[from as usize].push((eid_u32, to));
106        inc[to as usize].push((eid_u32, from));
107    }
108
109    Ok(inc)
110}
111
112/// Count vertices reachable from `start` treating graph as undirected.
113fn count_component(graph: &Graph, start: u32, adj: &[Vec<(u32, u32)>]) -> IgraphResult<u32> {
114    let vcount = graph.vcount();
115    let mut visited = vec![false; vcount as usize];
116    let mut queue = std::collections::VecDeque::new();
117    visited[start as usize] = true;
118    queue.push_back(start);
119    let mut count: u32 = 1;
120
121    while let Some(v) = queue.pop_front() {
122        for &(_, nb) in &adj[v as usize] {
123            if !visited[nb as usize] {
124                visited[nb as usize] = true;
125                count = count
126                    .checked_add(1)
127                    .ok_or(IgraphError::Internal("component size overflow"))?;
128                queue.push_back(nb);
129            }
130        }
131    }
132
133    Ok(count)
134}
135
136/// Find one representative vertex and size for each connected component.
137fn find_components(vcount: u32, adj: &[Vec<(u32, u32)>]) -> Vec<(u32, u32)> {
138    let mut visited = vec![false; vcount as usize];
139    let mut components: Vec<(u32, u32)> = Vec::new();
140
141    for v in 0..vcount {
142        if visited[v as usize] {
143            continue;
144        }
145        let mut queue = std::collections::VecDeque::new();
146        visited[v as usize] = true;
147        queue.push_back(v);
148        let mut size: u32 = 1;
149
150        while let Some(u) = queue.pop_front() {
151            for &(_, nb) in &adj[u as usize] {
152                if !visited[nb as usize] {
153                    visited[nb as usize] = true;
154                    size = size.saturating_add(1);
155                    queue.push_back(nb);
156                }
157            }
158        }
159
160        components.push((v, size));
161    }
162
163    components
164}
165
166/// Loop-erased random walk from `start` until all `comp_size` vertices
167/// in the component are visited.
168fn lerw(
169    graph: &Graph,
170    adj: &[Vec<(u32, u32)>],
171    start: u32,
172    comp_size: u32,
173    visited: &mut [bool],
174    rng: &mut SplitMix64,
175    result: &mut Vec<u32>,
176) -> IgraphResult<()> {
177    let _ = graph;
178    visited[start as usize] = true;
179    let mut visited_count: u32 = 1;
180    let mut current = start;
181
182    while visited_count < comp_size {
183        let edges = &adj[current as usize];
184        if edges.is_empty() {
185            break;
186        }
187
188        let idx = rng.gen_index(edges.len());
189        let (eid, next) = edges[idx];
190
191        if !visited[next as usize] {
192            result.push(eid);
193            visited[next as usize] = true;
194            visited_count = visited_count
195                .checked_add(1)
196                .ok_or(IgraphError::Internal("visited count overflow"))?;
197        }
198
199        current = next;
200    }
201
202    Ok(())
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    fn make_undirected(n: u32, edges: &[(u32, u32)]) -> Graph {
210        let mut g = Graph::with_vertices(n);
211        for &(u, v) in edges {
212            g.add_edge(u, v).unwrap();
213        }
214        g
215    }
216
217    #[test]
218    fn empty_graph() {
219        let g = Graph::with_vertices(0);
220        let t = random_spanning_tree(&g, None, 0).unwrap();
221        assert!(t.is_empty());
222    }
223
224    #[test]
225    fn single_vertex() {
226        let g = Graph::with_vertices(1);
227        let t = random_spanning_tree(&g, Some(0), 0).unwrap();
228        assert!(t.is_empty());
229    }
230
231    #[test]
232    fn single_edge() {
233        let g = make_undirected(2, &[(0, 1)]);
234        let t = random_spanning_tree(&g, Some(0), 0).unwrap();
235        assert_eq!(t.len(), 1);
236        assert_eq!(t[0], 0);
237    }
238
239    #[test]
240    fn triangle() {
241        let g = make_undirected(3, &[(0, 1), (1, 2), (0, 2)]);
242        let t = random_spanning_tree(&g, Some(0), 42).unwrap();
243        assert_eq!(t.len(), 2);
244        // All edges should be valid edge IDs
245        for &eid in &t {
246            assert!(eid < 3);
247        }
248    }
249
250    #[test]
251    fn k4_complete() {
252        let g = make_undirected(4, &[(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)]);
253        let t = random_spanning_tree(&g, Some(0), 123).unwrap();
254        assert_eq!(t.len(), 3); // spanning tree of K4 has 3 edges
255    }
256
257    #[test]
258    fn path_graph() {
259        // Path 0-1-2-3: only one spanning tree (the path itself)
260        let g = make_undirected(4, &[(0, 1), (1, 2), (2, 3)]);
261        let t = random_spanning_tree(&g, Some(0), 0).unwrap();
262        assert_eq!(t.len(), 3);
263        let mut sorted = t.clone();
264        sorted.sort_unstable();
265        assert_eq!(sorted, vec![0, 1, 2]);
266    }
267
268    #[test]
269    fn spanning_forest_disconnected() {
270        // Two triangles: 0-1-2 and 3-4-5
271        let g = make_undirected(6, &[(0, 1), (1, 2), (0, 2), (3, 4), (4, 5), (3, 5)]);
272        let t = random_spanning_tree(&g, None, 42).unwrap();
273        // Forest should have 4 edges (2 per component)
274        assert_eq!(t.len(), 4);
275    }
276
277    #[test]
278    fn start_vertex_component_only() {
279        // Two triangles: 0-1-2 and 3-4-5
280        let g = make_undirected(6, &[(0, 1), (1, 2), (0, 2), (3, 4), (4, 5), (3, 5)]);
281        // Only span the component containing vertex 0
282        let t = random_spanning_tree(&g, Some(0), 42).unwrap();
283        assert_eq!(t.len(), 2);
284        // All edge IDs should be from the first component (0, 1, or 2)
285        for &eid in &t {
286            assert!(eid < 3);
287        }
288    }
289
290    #[test]
291    fn invalid_vertex_error() {
292        let g = Graph::with_vertices(3);
293        let err = random_spanning_tree(&g, Some(5), 0).unwrap_err();
294        assert!(matches!(err, IgraphError::InvalidArgument(_)));
295    }
296
297    #[test]
298    fn deterministic_with_same_seed() {
299        let g = make_undirected(5, &[(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3), (3, 4)]);
300        let t1 = random_spanning_tree(&g, Some(0), 999).unwrap();
301        let t2 = random_spanning_tree(&g, Some(0), 999).unwrap();
302        assert_eq!(t1, t2);
303    }
304
305    #[test]
306    fn different_seeds_may_differ() {
307        let g = make_undirected(
308            5,
309            &[
310                (0, 1),
311                (0, 2),
312                (0, 3),
313                (0, 4),
314                (1, 2),
315                (1, 3),
316                (2, 3),
317                (2, 4),
318                (3, 4),
319            ],
320        );
321        let mut different = false;
322        for s in 0..20 {
323            let t1 = random_spanning_tree(&g, Some(0), s).unwrap();
324            let t2 = random_spanning_tree(&g, Some(0), s + 100).unwrap();
325            if t1 != t2 {
326                different = true;
327                break;
328            }
329        }
330        assert!(
331            different,
332            "with enough seeds, different trees should appear"
333        );
334    }
335
336    #[test]
337    fn result_forms_spanning_tree() {
338        // Verify the result is a valid spanning tree: n-1 edges, connects all vertices
339        let g = make_undirected(6, &[(0, 1), (0, 2), (1, 2), (2, 3), (3, 4), (3, 5), (4, 5)]);
340        let t = random_spanning_tree(&g, Some(0), 77).unwrap();
341        assert_eq!(t.len(), 5); // 6 vertices - 1
342
343        // Build adjacency of tree edges and verify connectivity
344        let mut tree_adj: Vec<Vec<u32>> = vec![Vec::new(); 6];
345        for &eid in &t {
346            let (from, to) = g.edge(eid).unwrap();
347            tree_adj[from as usize].push(to);
348            tree_adj[to as usize].push(from);
349        }
350
351        // BFS to check connectivity
352        let mut vis = [false; 6];
353        let mut queue = std::collections::VecDeque::new();
354        vis[0] = true;
355        queue.push_back(0u32);
356        let mut count = 1;
357        while let Some(v) = queue.pop_front() {
358            for &nb in &tree_adj[v as usize] {
359                if !vis[nb as usize] {
360                    vis[nb as usize] = true;
361                    count += 1;
362                    queue.push_back(nb);
363                }
364            }
365        }
366        assert_eq!(count, 6);
367    }
368
369    #[test]
370    fn directed_graph_works() {
371        // Edge directions are ignored
372        let mut g = Graph::new(3, true).unwrap();
373        g.add_edge(0, 1).unwrap();
374        g.add_edge(1, 2).unwrap();
375        g.add_edge(2, 0).unwrap();
376        let t = random_spanning_tree(&g, Some(0), 42).unwrap();
377        assert_eq!(t.len(), 2);
378    }
379
380    #[test]
381    fn isolated_vertices_forest() {
382        // 5 isolated vertices: forest has 0 edges
383        let g = Graph::with_vertices(5);
384        let t = random_spanning_tree(&g, None, 0).unwrap();
385        assert!(t.is_empty());
386    }
387
388    #[test]
389    fn multi_edge_graph() {
390        // Multi-edges: 0-1 appears twice, 1-2 once
391        let mut g = Graph::with_vertices(3);
392        g.add_edge(0, 1).unwrap(); // eid 0
393        g.add_edge(0, 1).unwrap(); // eid 1
394        g.add_edge(1, 2).unwrap(); // eid 2
395        let t = random_spanning_tree(&g, Some(0), 42).unwrap();
396        assert_eq!(t.len(), 2);
397        // edge 2 (1-2) must always be in the tree
398        assert!(t.contains(&2));
399    }
400}