Skip to main content

rust_igraph/algorithms/spanning/
mst.rs

1//! Minimum spanning tree (ALGO-MST-001).
2//!
3//! Counterpart of `igraph_minimum_spanning_tree` from
4//! `references/igraph/src/misc/spanning_trees.c` (~480 lines). Returns
5//! the edge IDs that constitute a spanning tree (or spanning forest if
6//! the graph is disconnected) of the input graph. Three internal
7//! variants:
8//!
9//! * `Unweighted` — arbitrary spanning tree via BFS from each unvisited
10//!   vertex. `O(|V|+|E|)`.
11//! * `Prim` — eager Prim with a hand-rolled binary min-heap keyed by
12//!   weight; ties resolved by edge-ID order (deterministic).
13//!   `O(|E| log |V|)`.
14//! * `Kruskal` — sort edges by weight ascending (`f64::total_cmp` for a
15//!   total order on the IEEE-754 hierarchy, then by edge ID for ties),
16//!   then walk with a path-compressed union-find. `O(|E| log |E|)`.
17//!
18//! [`MstAlgorithm::Automatic`] picks `Unweighted` when `weights` is
19//! `None` and `Kruskal` otherwise — matching the upstream C dispatch
20//! comment at `spanning_trees.c:466`.
21//!
22//! Directed edge directions are ignored, mirroring the upstream
23//! behaviour: every directed edge is treated as undirected. Self-loops
24//! never enter the tree (an MST edge must connect two distinct
25//! components); parallel edges are deduplicated by the algorithm
26//! (Kruskal picks the lightest; Prim's already-added bit short-circuits
27//! the rest).
28//!
29//! The algorithm is fully deterministic for a fixed input — there is no
30//! RNG anywhere in the pipeline.
31
32use std::cmp::Ordering;
33
34use crate::core::graph::EdgeId;
35use crate::core::{Graph, IgraphError, IgraphResult, VertexId};
36
37/// Selector for the minimum-spanning-tree algorithm.
38///
39/// Counterpart of the C `igraph_mst_algorithm_t` enum at
40/// `spanning_trees.c:461-484`.
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
42pub enum MstAlgorithm {
43    /// Let the implementation pick: `Unweighted` when `weights` is
44    /// `None`, otherwise `Kruskal`.
45    Automatic,
46    /// Ignore edge weights and run a BFS spanning forest.
47    Unweighted,
48    /// Prim's algorithm with a binary min-heap (eager variant).
49    Prim,
50    /// Kruskal's algorithm with sort + union-find.
51    Kruskal,
52}
53
54/// Computes a minimum spanning tree (or forest, if disconnected) of
55/// `graph` and returns the IDs of the edges that constitute the tree.
56///
57/// Counterpart of `igraph_minimum_spanning_tree` from
58/// `references/igraph/src/misc/spanning_trees.c:461`.
59///
60/// # Arguments
61///
62/// * `graph` - Input graph. Edge directions are ignored.
63/// * `weights` - Optional edge weights, indexed by edge ID. When
64///   provided, `weights.len()` must equal `graph.ecount()`. NaN
65///   entries are rejected. When `None`, edge weights are treated as
66///   uniform and the spanning forest is an arbitrary BFS tree.
67/// * `method` - Which underlying algorithm to use; see
68///   [`MstAlgorithm`].
69///
70/// # Returns
71///
72/// The edge IDs (in the order the algorithm picked them) that form
73/// the spanning tree / forest. For a graph with `c` connected
74/// components, the result has exactly `vcount − c` edges. The order
75/// reflects the algorithm: BFS layer-by-layer for `Unweighted`,
76/// pop-order for `Prim`, ascending-weight for `Kruskal`.
77///
78/// # Errors
79///
80/// Returns [`IgraphError::InvalidArgument`] when
81/// * `weights.len()` does not match `graph.ecount()`,
82/// * any weight is NaN, or
83/// * `method` is a weighted variant (`Prim` / `Kruskal`) but `weights`
84///   is `None`.
85///
86/// # Examples
87///
88/// ```
89/// use rust_igraph::{Graph, MstAlgorithm, minimum_spanning_tree};
90///
91/// // Square 0-1-2-3-0 with diagonals 0-2 and 1-3, weights chosen so
92/// // the unique MST is the four outer edges (weight 1) and skips both
93/// // diagonals (weight 10).
94/// let mut g = Graph::with_vertices(4);
95/// g.add_edge(0, 1).unwrap(); // eid 0, w=1
96/// g.add_edge(1, 2).unwrap(); // eid 1, w=1
97/// g.add_edge(2, 3).unwrap(); // eid 2, w=1
98/// g.add_edge(3, 0).unwrap(); // eid 3, w=1
99/// g.add_edge(0, 2).unwrap(); // eid 4, w=10
100/// g.add_edge(1, 3).unwrap(); // eid 5, w=10
101///
102/// let w = [1.0, 1.0, 1.0, 1.0, 10.0, 10.0];
103/// let mut tree = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Kruskal).unwrap();
104/// tree.sort_unstable();
105/// // A 4-vertex graph needs vcount-1 = 3 tree edges (one square edge
106/// // closes the cycle).
107/// assert_eq!(tree.len(), 3);
108/// assert!(tree.iter().all(|&e| e < 4));
109/// ```
110///
111/// # References
112///
113/// * R. C. Prim, *Shortest connection networks and some
114///   generalizations*, Bell System Technical Journal **36** (1957),
115///   1389-1401.
116/// * J. B. Kruskal, *On the shortest spanning subtree of a graph and
117///   the traveling salesman problem*, Proc. Amer. Math. Soc. **7**
118///   (1956), 48-50.
119pub fn minimum_spanning_tree(
120    graph: &Graph,
121    weights: Option<&[f64]>,
122    method: MstAlgorithm,
123) -> IgraphResult<Vec<EdgeId>> {
124    let resolved = match method {
125        MstAlgorithm::Automatic => {
126            if weights.is_none() {
127                MstAlgorithm::Unweighted
128            } else {
129                MstAlgorithm::Kruskal
130            }
131        }
132        other => other,
133    };
134
135    match resolved {
136        MstAlgorithm::Unweighted => mst_unweighted(graph),
137        MstAlgorithm::Prim => {
138            let w = require_weights(graph, weights)?;
139            mst_prim(graph, w)
140        }
141        MstAlgorithm::Kruskal => {
142            let w = require_weights(graph, weights)?;
143            mst_kruskal(graph, w)
144        }
145        MstAlgorithm::Automatic => unreachable!("Automatic resolved above"),
146    }
147}
148
149fn require_weights<'a>(graph: &Graph, weights: Option<&'a [f64]>) -> IgraphResult<&'a [f64]> {
150    let w = weights.ok_or_else(|| {
151        IgraphError::InvalidArgument(
152            "weights required for the chosen MST algorithm (Prim/Kruskal); supply Some(&[..])"
153                .to_string(),
154        )
155    })?;
156    let m = graph.ecount();
157    if w.len() != m {
158        return Err(IgraphError::InvalidArgument(format!(
159            "weights length {} does not match edge count {}",
160            w.len(),
161            m
162        )));
163    }
164    if w.iter().any(|x| x.is_nan()) {
165        return Err(IgraphError::InvalidArgument(
166            "weights must not contain NaN values".to_string(),
167        ));
168    }
169    Ok(w)
170}
171
172fn max_tree_edges(vcount: usize, ecount: usize) -> usize {
173    if vcount == 0 {
174        0
175    } else if ecount < vcount {
176        ecount
177    } else {
178        vcount - 1
179    }
180}
181
182// ---------- Unweighted: BFS spanning forest ----------
183
184fn mst_unweighted(graph: &Graph) -> IgraphResult<Vec<EdgeId>> {
185    let n = graph.vcount() as usize;
186    let m = graph.ecount();
187    let mut result: Vec<EdgeId> = Vec::with_capacity(max_tree_edges(n, m));
188    if n == 0 {
189        return Ok(result);
190    }
191    let mut already_added = vec![false; n];
192    let mut added_edges = vec![false; m];
193    let mut queue: std::collections::VecDeque<VertexId> = std::collections::VecDeque::new();
194
195    for start in 0..n {
196        if already_added[start] {
197            continue;
198        }
199        already_added[start] = true;
200        let start_v = u32::try_from(start)
201            .map_err(|_| IgraphError::InvalidArgument("vertex id overflows u32".to_string()))?;
202        queue.push_back(start_v);
203        while let Some(v) = queue.pop_front() {
204            let eids = graph.incident(v)?;
205            for &edge in &eids {
206                let e_idx = edge as usize;
207                if added_edges[e_idx] {
208                    continue;
209                }
210                let to = graph.edge_other(edge, v)?;
211                let to_idx = to as usize;
212                if already_added[to_idx] {
213                    continue;
214                }
215                already_added[to_idx] = true;
216                added_edges[e_idx] = true;
217                result.push(edge);
218                queue.push_back(to);
219            }
220        }
221    }
222    Ok(result)
223}
224
225// ---------- Prim: eager binary min-heap ----------
226
227/// Min-heap entry for Prim. Ordered by `(weight asc, edge asc)` —
228/// edge-ID is the tie-breaker so equal-weight runs are deterministic.
229#[derive(Debug, Clone, Copy)]
230struct HeapEntry {
231    weight: f64,
232    edge: EdgeId,
233    from: VertexId,
234}
235
236impl PartialEq for HeapEntry {
237    fn eq(&self, other: &Self) -> bool {
238        self.cmp(other) == Ordering::Equal
239    }
240}
241impl Eq for HeapEntry {}
242impl PartialOrd for HeapEntry {
243    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
244        Some(self.cmp(other))
245    }
246}
247impl Ord for HeapEntry {
248    fn cmp(&self, other: &Self) -> Ordering {
249        self.weight
250            .total_cmp(&other.weight)
251            .then(self.edge.cmp(&other.edge))
252    }
253}
254
255fn mst_prim(graph: &Graph, weights: &[f64]) -> IgraphResult<Vec<EdgeId>> {
256    let n = graph.vcount() as usize;
257    let m = graph.ecount();
258    let mut result: Vec<EdgeId> = Vec::with_capacity(max_tree_edges(n, m));
259    if n == 0 {
260        return Ok(result);
261    }
262    let mut already_added = vec![false; n];
263    let mut added_edges = vec![false; m];
264    // std::collections::BinaryHeap is a max-heap; we wrap entries in
265    // Reverse so the smallest weight comes out first.
266    let mut heap: std::collections::BinaryHeap<std::cmp::Reverse<HeapEntry>> =
267        std::collections::BinaryHeap::new();
268
269    for start in 0..n {
270        if already_added[start] {
271            continue;
272        }
273        let start_v = u32::try_from(start)
274            .map_err(|_| IgraphError::InvalidArgument("vertex id overflows u32".to_string()))?;
275        already_added[start] = true;
276        push_incident_into_heap(graph, weights, &already_added, start_v, &mut heap)?;
277
278        while let Some(std::cmp::Reverse(entry)) = heap.pop() {
279            let edge = entry.edge;
280            let e_idx = edge as usize;
281            if added_edges[e_idx] {
282                continue;
283            }
284            let to = graph.edge_other(edge, entry.from)?;
285            let to_idx = to as usize;
286            if already_added[to_idx] {
287                // Edge would close a cycle inside the current
288                // partial-tree component; skip.
289                continue;
290            }
291            already_added[to_idx] = true;
292            added_edges[e_idx] = true;
293            result.push(edge);
294            push_incident_into_heap(graph, weights, &already_added, to, &mut heap)?;
295        }
296    }
297    Ok(result)
298}
299
300fn push_incident_into_heap(
301    graph: &Graph,
302    weights: &[f64],
303    already_added: &[bool],
304    v: VertexId,
305    heap: &mut std::collections::BinaryHeap<std::cmp::Reverse<HeapEntry>>,
306) -> IgraphResult<()> {
307    let eids = graph.incident(v)?;
308    for &edge in &eids {
309        let other = graph.edge_other(edge, v)?;
310        if already_added[other as usize] {
311            continue;
312        }
313        let w = weights[edge as usize];
314        heap.push(std::cmp::Reverse(HeapEntry {
315            weight: w,
316            edge,
317            from: v,
318        }));
319    }
320    Ok(())
321}
322
323// ---------- Kruskal: sort + union-find ----------
324
325fn mst_kruskal(graph: &Graph, weights: &[f64]) -> IgraphResult<Vec<EdgeId>> {
326    let n = graph.vcount() as usize;
327    let m = graph.ecount();
328    let mut result: Vec<EdgeId> = Vec::with_capacity(max_tree_edges(n, m));
329    if n == 0 {
330        return Ok(result);
331    }
332
333    // Indices into the edge list, sorted by weight ascending. Ties are
334    // broken by edge ID so the picked tree is deterministic across
335    // platforms regardless of `sort_by`'s internal stability.
336    let mut order: Vec<EdgeId> = (0..u32::try_from(m)
337        .map_err(|_| IgraphError::InvalidArgument("edge count overflows u32".to_string()))?)
338        .collect();
339    order.sort_by(|a, b| {
340        weights[*a as usize]
341            .total_cmp(&weights[*b as usize])
342            .then(a.cmp(b))
343    });
344
345    let mut parent: Vec<u32> = (0..u32::try_from(n)
346        .map_err(|_| IgraphError::InvalidArgument("vertex count overflows u32".to_string()))?)
347        .collect();
348
349    let target = max_tree_edges(n, m);
350    for &edge in &order {
351        if result.len() == target {
352            break;
353        }
354        let (u, v) = graph.edge(edge)?;
355        let ru = uf_find(&mut parent, u as usize);
356        let rv = uf_find(&mut parent, v as usize);
357        if ru != rv {
358            // Path-compressed merge: hang ru under rv (matches C
359            // `merge_comp` which sets parent[ci] = cj).
360            let rv_u32 = u32::try_from(rv)
361                .map_err(|_| IgraphError::InvalidArgument("vertex id overflows u32".to_string()))?;
362            parent[ru] = rv_u32;
363            result.push(edge);
364        }
365    }
366    Ok(result)
367}
368
369fn uf_find(parent: &mut [u32], mut i: usize) -> usize {
370    // Mirror of the C `get_comp` — climb to the root, then collapse
371    // only `i` (no full path compression, but does avoid quadratic
372    // worst-case in practice).
373    let start = i;
374    loop {
375        let next = parent[i] as usize;
376        if next == i {
377            // Cache the root in the original index slot, leaving the
378            // rest of the chain untouched (matches the upstream C).
379            // Falls back gracefully if u32::try_from would fail —
380            // `parent` is sized to vcount which is bounded by u32.
381            if let Ok(root_u32) = u32::try_from(i) {
382                parent[start] = root_u32;
383            }
384            return i;
385        }
386        i = next;
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    fn collect_sorted(mut v: Vec<EdgeId>) -> Vec<EdgeId> {
395        v.sort_unstable();
396        v
397    }
398
399    fn weight_sum(weights: &[f64], edges: &[EdgeId]) -> f64 {
400        edges.iter().map(|&e| weights[e as usize]).sum()
401    }
402
403    fn assert_is_forest(graph: &Graph, edges: &[EdgeId], expected_components: usize) {
404        let n = graph.vcount() as usize;
405        // Forest invariant: |edges| == n - components, plus union-find
406        // detects no cycle on the picked edges.
407        assert_eq!(
408            edges.len(),
409            n - expected_components,
410            "expected {} tree edges for {} vertices in {} components, got {}",
411            n - expected_components,
412            n,
413            expected_components,
414            edges.len()
415        );
416        let mut parent: Vec<u32> = (0..u32::try_from(n).expect("vcount fits u32")).collect();
417        for &eid in edges {
418            let (u, v) = graph.edge(eid).expect("edge exists");
419            let ru = uf_find(&mut parent, u as usize);
420            let rv = uf_find(&mut parent, v as usize);
421            assert_ne!(ru, rv, "edge {eid} closes a cycle in the spanning tree");
422            parent[ru] = u32::try_from(rv).expect("vid fits u32");
423        }
424    }
425
426    // ---------- Empty / singleton ----------
427
428    #[test]
429    fn empty_graph_returns_empty_tree() {
430        let g = Graph::with_vertices(0);
431        for method in [
432            MstAlgorithm::Automatic,
433            MstAlgorithm::Unweighted,
434            MstAlgorithm::Prim,
435            MstAlgorithm::Kruskal,
436        ] {
437            let weights: Option<&[f64]> =
438                if matches!(method, MstAlgorithm::Prim | MstAlgorithm::Kruskal) {
439                    Some(&[])
440                } else {
441                    None
442                };
443            let tree = minimum_spanning_tree(&g, weights, method).unwrap();
444            assert!(tree.is_empty(), "{method:?} on empty graph");
445        }
446    }
447
448    #[test]
449    fn single_vertex_no_edges_is_empty_tree() {
450        let g = Graph::with_vertices(1);
451        let tree = minimum_spanning_tree(&g, None, MstAlgorithm::Automatic).unwrap();
452        assert!(tree.is_empty());
453    }
454
455    #[test]
456    fn isolated_vertices_are_a_forest_with_no_edges() {
457        let g = Graph::with_vertices(5);
458        let tree = minimum_spanning_tree(&g, None, MstAlgorithm::Unweighted).unwrap();
459        assert!(tree.is_empty());
460        assert_is_forest(&g, &tree, 5);
461    }
462
463    // ---------- Unweighted ----------
464
465    #[test]
466    fn unweighted_chain_picks_all_edges() {
467        // 0-1-2-3 is a tree already; MST = all 3 edges.
468        let mut g = Graph::with_vertices(4);
469        g.add_edges(vec![(0u32, 1u32), (1, 2), (2, 3)]).unwrap();
470        let tree = minimum_spanning_tree(&g, None, MstAlgorithm::Unweighted).unwrap();
471        assert_eq!(tree.len(), 3);
472        assert_is_forest(&g, &tree, 1);
473    }
474
475    #[test]
476    fn unweighted_triangle_drops_one_edge() {
477        let mut g = Graph::with_vertices(3);
478        g.add_edges(vec![(0u32, 1u32), (1, 2), (2, 0)]).unwrap();
479        let tree = minimum_spanning_tree(&g, None, MstAlgorithm::Unweighted).unwrap();
480        assert_eq!(tree.len(), 2);
481        assert_is_forest(&g, &tree, 1);
482    }
483
484    #[test]
485    fn unweighted_forest_disjoint_components() {
486        // Two disjoint trees: 0-1-2 and 3-4 ⇒ MST keeps everything.
487        let mut g = Graph::with_vertices(5);
488        g.add_edges(vec![(0u32, 1u32), (1, 2), (3, 4)]).unwrap();
489        let tree = minimum_spanning_tree(&g, None, MstAlgorithm::Unweighted).unwrap();
490        assert_eq!(tree.len(), 3);
491        assert_is_forest(&g, &tree, 2);
492    }
493
494    #[test]
495    fn unweighted_complete_k5_keeps_n_minus_1_edges() {
496        let mut g = Graph::with_vertices(5);
497        for i in 0u32..5 {
498            for j in (i + 1)..5 {
499                g.add_edge(i, j).unwrap();
500            }
501        }
502        let tree = minimum_spanning_tree(&g, None, MstAlgorithm::Unweighted).unwrap();
503        assert_eq!(tree.len(), 4);
504        assert_is_forest(&g, &tree, 1);
505    }
506
507    // ---------- Prim ----------
508
509    #[test]
510    fn prim_picks_unique_mst_when_weights_are_distinct() {
511        // Square + diagonals; outer edges have weight 1, diagonals 10.
512        // The unique MST is the outer 4-edge cycle minus one edge, i.e.
513        // any 3 of the four outer edges.
514        let mut g = Graph::with_vertices(4);
515        g.add_edges(vec![(0u32, 1u32), (1, 2), (2, 3), (3, 0), (0, 2), (1, 3)])
516            .unwrap();
517        let w = [1.0, 1.0, 1.0, 1.0, 10.0, 10.0];
518        let tree = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Prim).unwrap();
519        assert_eq!(tree.len(), 3);
520        // None of the diagonals (eids 4, 5) should appear.
521        assert!(tree.iter().all(|&e| e < 4));
522        assert_is_forest(&g, &tree, 1);
523        let total = weight_sum(&w, &tree);
524        assert!((total - 3.0).abs() < 1e-12);
525    }
526
527    #[test]
528    fn prim_handles_disconnected_graphs_as_forest() {
529        let mut g = Graph::with_vertices(5);
530        g.add_edges(vec![(0u32, 1u32), (1, 2), (3, 4)]).unwrap();
531        let w = [1.0, 2.0, 3.0];
532        let tree = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Prim).unwrap();
533        assert_eq!(tree.len(), 3);
534        assert_is_forest(&g, &tree, 2);
535    }
536
537    #[test]
538    fn prim_and_kruskal_agree_on_distinct_weight_mst() {
539        // Mini road graph: 5 vertices, 7 edges, all distinct weights ⇒
540        // MST is unique.
541        let mut g = Graph::with_vertices(5);
542        g.add_edges(vec![
543            (0u32, 1u32),
544            (0, 2),
545            (1, 2),
546            (1, 3),
547            (2, 3),
548            (3, 4),
549            (2, 4),
550        ])
551        .unwrap();
552        let w = [1.0, 4.0, 2.0, 7.0, 3.0, 5.0, 6.0];
553        let prim = collect_sorted(minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Prim).unwrap());
554        let kruskal =
555            collect_sorted(minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Kruskal).unwrap());
556        assert_eq!(prim, kruskal);
557        assert_eq!(prim.len(), 4);
558        // With these weights the unique MST is {(0,1)=1, (1,2)=2, (2,3)=3, (3,4)=5}
559        // — eids 0, 2, 4, 5.
560        assert_eq!(prim, vec![0u32, 2, 4, 5]);
561    }
562
563    // ---------- Kruskal ----------
564
565    #[test]
566    fn kruskal_breaks_ties_by_edge_id() {
567        // Triangle with equal weights ⇒ Kruskal keeps the two
568        // lowest-id edges (0 and 1), dropping edge 2.
569        let mut g = Graph::with_vertices(3);
570        g.add_edges(vec![(0u32, 1u32), (1, 2), (2, 0)]).unwrap();
571        let w = [1.0, 1.0, 1.0];
572        let tree = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Kruskal).unwrap();
573        assert_eq!(tree, vec![0u32, 1]);
574    }
575
576    #[test]
577    fn kruskal_picks_lightest_parallel_edge() {
578        // Two parallel edges between 0 and 1, weights 5 and 1.
579        let mut g = Graph::with_vertices(2);
580        g.add_edge(0, 1).unwrap();
581        g.add_edge(0, 1).unwrap();
582        let w = [5.0, 1.0];
583        let tree = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Kruskal).unwrap();
584        // Lightest parallel edge wins; the heavier one is dropped.
585        assert_eq!(tree, vec![1u32]);
586    }
587
588    #[test]
589    fn kruskal_ignores_self_loops() {
590        // 0-0 self-loop + 0-1 edge: self-loop never appears in the
591        // spanning tree (a tree edge must connect two distinct
592        // components).
593        let mut g = Graph::with_vertices(2);
594        g.add_edge(0, 0).unwrap(); // eid 0
595        g.add_edge(0, 1).unwrap(); // eid 1
596        let w = [0.5, 2.0];
597        let tree = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Kruskal).unwrap();
598        assert_eq!(tree, vec![1u32]);
599    }
600
601    // ---------- Determinism / convergence ----------
602
603    #[test]
604    fn unweighted_is_deterministic_across_runs() {
605        let mut g = Graph::with_vertices(5);
606        for i in 0u32..5 {
607            for j in (i + 1)..5 {
608                g.add_edge(i, j).unwrap();
609            }
610        }
611        let a = minimum_spanning_tree(&g, None, MstAlgorithm::Unweighted).unwrap();
612        let b = minimum_spanning_tree(&g, None, MstAlgorithm::Unweighted).unwrap();
613        assert_eq!(a, b);
614    }
615
616    #[test]
617    fn prim_and_kruskal_have_equal_total_weight_even_when_trees_differ() {
618        // Square with all equal weights ⇒ many valid MSTs, but total
619        // weight is invariant.
620        let mut g = Graph::with_vertices(4);
621        g.add_edges(vec![(0u32, 1u32), (1, 2), (2, 3), (3, 0)])
622            .unwrap();
623        let w = [1.0, 1.0, 1.0, 1.0];
624        let prim = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Prim).unwrap();
625        let kruskal = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Kruskal).unwrap();
626        assert_eq!(prim.len(), 3);
627        assert_eq!(kruskal.len(), 3);
628        assert!((weight_sum(&w, &prim) - weight_sum(&w, &kruskal)).abs() < 1e-12);
629    }
630
631    // ---------- Directed → undirected ----------
632
633    #[test]
634    fn directed_graph_is_treated_as_undirected() {
635        // Directed 0→1, 1→2, 2→0 — three edges; the MST/forest still
636        // has 2 edges (vcount-1=2 for one connected component).
637        let mut g = Graph::new(3, true).unwrap();
638        g.add_edges(vec![(0u32, 1u32), (1, 2), (2, 0)]).unwrap();
639        let tree = minimum_spanning_tree(&g, None, MstAlgorithm::Unweighted).unwrap();
640        assert_eq!(tree.len(), 2);
641        assert_is_forest(&g, &tree, 1);
642    }
643
644    // ---------- Automatic dispatch ----------
645
646    #[test]
647    fn automatic_picks_unweighted_when_no_weights() {
648        let mut g = Graph::with_vertices(4);
649        g.add_edges(vec![(0u32, 1u32), (1, 2), (2, 3)]).unwrap();
650        let a = minimum_spanning_tree(&g, None, MstAlgorithm::Automatic).unwrap();
651        let b = minimum_spanning_tree(&g, None, MstAlgorithm::Unweighted).unwrap();
652        assert_eq!(a, b);
653    }
654
655    #[test]
656    fn automatic_picks_kruskal_when_weights_given() {
657        let mut g = Graph::with_vertices(4);
658        g.add_edges(vec![(0u32, 1u32), (1, 2), (2, 3), (3, 0)])
659            .unwrap();
660        let w = [1.0, 2.0, 3.0, 4.0];
661        let a = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Automatic).unwrap();
662        let b = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Kruskal).unwrap();
663        assert_eq!(a, b);
664    }
665
666    // ---------- Error paths ----------
667
668    #[test]
669    fn prim_without_weights_errors() {
670        let mut g = Graph::with_vertices(2);
671        g.add_edge(0, 1).unwrap();
672        let err = minimum_spanning_tree(&g, None, MstAlgorithm::Prim).unwrap_err();
673        assert!(matches!(err, IgraphError::InvalidArgument(_)));
674    }
675
676    #[test]
677    fn kruskal_without_weights_errors() {
678        let mut g = Graph::with_vertices(2);
679        g.add_edge(0, 1).unwrap();
680        let err = minimum_spanning_tree(&g, None, MstAlgorithm::Kruskal).unwrap_err();
681        assert!(matches!(err, IgraphError::InvalidArgument(_)));
682    }
683
684    #[test]
685    fn mismatched_weight_length_errors() {
686        let mut g = Graph::with_vertices(3);
687        g.add_edges(vec![(0u32, 1u32), (1, 2)]).unwrap();
688        let w = [1.0, 2.0, 3.0]; // length 3, ecount = 2
689        let err = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Prim).unwrap_err();
690        assert!(matches!(err, IgraphError::InvalidArgument(_)));
691    }
692
693    #[test]
694    fn nan_weights_are_rejected() {
695        let mut g = Graph::with_vertices(2);
696        g.add_edge(0, 1).unwrap();
697        let w = [f64::NAN];
698        let err = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Prim).unwrap_err();
699        assert!(matches!(err, IgraphError::InvalidArgument(_)));
700    }
701
702    // ---------- Stress / weight ordering ----------
703
704    #[test]
705    #[allow(clippy::many_single_char_names)]
706    fn kruskal_total_weight_minimal_on_random_distinct_weights() {
707        // Brute-force enumeration on a small graph: 5 vertices, 8
708        // edges, all distinct weights. Compare Kruskal's total weight
709        // against every spanning tree from a brute-force enumeration.
710        let mut g = Graph::with_vertices(5);
711        let edges = [
712            (0u32, 1u32),
713            (0, 2),
714            (0, 3),
715            (1, 2),
716            (1, 4),
717            (2, 3),
718            (2, 4),
719            (3, 4),
720        ];
721        for (u, v) in edges {
722            g.add_edge(u, v).unwrap();
723        }
724        let w: [f64; 8] = [4.0, 1.0, 7.0, 2.0, 5.0, 6.0, 3.0, 8.0];
725        let kruskal_tree = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Kruskal).unwrap();
726        let kruskal_total = weight_sum(&w, &kruskal_tree);
727
728        // Brute-force: enumerate every 4-edge subset; check tree (no
729        // cycle, connected) — track min total.
730        let n_vertices = 5usize;
731        let n_edges = 8usize;
732        let mut min_total = f64::INFINITY;
733        // C(8,4) = 70; cheap.
734        for mask in 0u32..(1 << n_edges) {
735            if mask.count_ones() != 4 {
736                continue;
737            }
738            let mut parent: Vec<u32> =
739                (0..u32::try_from(n_vertices).expect("vcount fits u32")).collect();
740            let mut total = 0.0f64;
741            let mut cycle = false;
742            for idx in 0..n_edges {
743                if mask & (1 << idx) == 0 {
744                    continue;
745                }
746                let (u, v) = edges[idx];
747                let ru = uf_find(&mut parent, u as usize);
748                let rv = uf_find(&mut parent, v as usize);
749                if ru == rv {
750                    cycle = true;
751                    break;
752                }
753                parent[ru] = u32::try_from(rv).unwrap();
754                total += w[idx];
755            }
756            if cycle {
757                continue;
758            }
759            // Check connectivity ⇒ one component.
760            let roots: std::collections::HashSet<usize> =
761                (0..n_vertices).map(|x| uf_find(&mut parent, x)).collect();
762            if roots.len() != 1 {
763                continue;
764            }
765            if total < min_total {
766                min_total = total;
767            }
768        }
769        assert!((kruskal_total - min_total).abs() < 1e-12);
770    }
771}
772
773#[cfg(all(test, feature = "proptest-harness"))]
774mod proptests {
775    use super::*;
776    use proptest::prelude::*;
777
778    /// Spanning tree invariants that must hold for every MST result:
779    ///   1. The picked edges are pairwise distinct (no duplicates).
780    ///   2. They form an acyclic subgraph (union-find sees no cycle).
781    ///   3. Edge count == vcount - components (forest invariant).
782    fn assert_forest_invariants(graph: &Graph, edges: &[EdgeId]) -> Result<(), TestCaseError> {
783        // Distinctness.
784        let unique: std::collections::HashSet<EdgeId> = edges.iter().copied().collect();
785        prop_assert_eq!(unique.len(), edges.len(), "duplicate edge in MST result");
786
787        let n = graph.vcount() as usize;
788        let mut parent: Vec<u32> = (0..u32::try_from(n).expect("vcount fits u32")).collect();
789        for &eid in edges {
790            let (u, v) = graph
791                .edge(eid)
792                .map_err(|e| TestCaseError::Fail(e.to_string().into()))?;
793            let ru = uf_find(&mut parent, u as usize);
794            let rv = uf_find(&mut parent, v as usize);
795            prop_assert!(ru != rv, "edge {} closes a cycle in the MST", eid);
796            let rv_u32 =
797                u32::try_from(rv).map_err(|_| TestCaseError::Fail("vid overflow".into()))?;
798            parent[ru] = rv_u32;
799        }
800
801        // Compute components on the spanning forest (= components of
802        // the original graph because every component contributes its
803        // own spanning tree).
804        let mut tree_roots = std::collections::HashSet::new();
805        for v in 0..n {
806            tree_roots.insert(uf_find(&mut parent, v));
807        }
808
809        // Also compute components on the *full* graph for the
810        // expected count.
811        let mut full_parent: Vec<u32> = (0..u32::try_from(n).expect("vcount fits u32")).collect();
812        for eid in 0..u32::try_from(graph.ecount()).unwrap_or(u32::MAX) {
813            let (u, v) = match graph.edge(eid) {
814                Ok(pair) => pair,
815                Err(_) => continue,
816            };
817            let ru = uf_find(&mut full_parent, u as usize);
818            let rv = uf_find(&mut full_parent, v as usize);
819            if ru != rv {
820                full_parent[ru] = u32::try_from(rv).unwrap_or(u32::MAX);
821            }
822        }
823        let mut full_roots = std::collections::HashSet::new();
824        for v in 0..n {
825            full_roots.insert(uf_find(&mut full_parent, v));
826        }
827
828        prop_assert_eq!(
829            tree_roots.len(),
830            full_roots.len(),
831            "spanning forest must have the same component count as the original graph"
832        );
833        prop_assert_eq!(
834            edges.len(),
835            n - full_roots.len(),
836            "MST edge count must equal vcount - components"
837        );
838        Ok(())
839    }
840
841    prop_compose! {
842        fn small_undirected_graph()(n in 2u32..=10u32, edges_seed in any::<u64>())
843            -> Graph {
844            let mut g = Graph::with_vertices(n);
845            let mut rng = edges_seed;
846            let target_m = ((n * (n - 1)) / 2).min(n + 6) as usize;
847            let mut added = 0usize;
848            let mut guard = 0usize;
849            while added < target_m && guard < target_m * 8 + 4 {
850                rng = rng
851                    .wrapping_mul(6_364_136_223_846_793_005)
852                    .wrapping_add(1_442_695_040_888_963_407);
853                let u = ((rng >> 33) % u64::from(n)) as u32;
854                rng = rng
855                    .wrapping_mul(6_364_136_223_846_793_005)
856                    .wrapping_add(1_442_695_040_888_963_407);
857                let v = ((rng >> 33) % u64::from(n)) as u32;
858                guard += 1;
859                if u == v {
860                    continue;
861                }
862                if g.add_edge(u, v).is_ok() {
863                    added += 1;
864                }
865            }
866            g
867        }
868    }
869
870    prop_compose! {
871        fn small_weighted_graph()(g in small_undirected_graph(), seed in any::<u64>())
872            -> (Graph, Vec<f64>) {
873            let m = g.ecount();
874            let mut weights = Vec::with_capacity(m);
875            let mut rng = seed.wrapping_add(0x9E37_79B9_7F4A_7C15);
876            for _ in 0..m {
877                rng = rng
878                    .wrapping_mul(6_364_136_223_846_793_005)
879                    .wrapping_add(1_442_695_040_888_963_407);
880                // Weights in [0.5, 100.5) — strictly positive, finite,
881                // never NaN.
882                let bits = (rng >> 11) as f64 / (1u64 << 53) as f64;
883                weights.push(0.5 + 100.0 * bits);
884            }
885            (g, weights)
886        }
887    }
888
889    proptest! {
890        #[test]
891        fn unweighted_is_a_spanning_forest(g in small_undirected_graph()) {
892            let tree = minimum_spanning_tree(&g, None, MstAlgorithm::Unweighted).unwrap();
893            assert_forest_invariants(&g, &tree)?;
894        }
895
896        #[test]
897        fn prim_is_a_spanning_forest((g, w) in small_weighted_graph()) {
898            let tree = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Prim).unwrap();
899            assert_forest_invariants(&g, &tree)?;
900        }
901
902        #[test]
903        fn kruskal_is_a_spanning_forest((g, w) in small_weighted_graph()) {
904            let tree = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Kruskal).unwrap();
905            assert_forest_invariants(&g, &tree)?;
906        }
907
908        #[test]
909        fn prim_and_kruskal_have_equal_total_weight((g, w) in small_weighted_graph()) {
910            let prim = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Prim).unwrap();
911            let kruskal = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Kruskal).unwrap();
912            let sp: f64 = prim.iter().map(|&e| w[e as usize]).sum();
913            let sk: f64 = kruskal.iter().map(|&e| w[e as usize]).sum();
914            // Matroid optimality: every MST has the same total weight.
915            prop_assert!((sp - sk).abs() < 1e-9, "prim={} kruskal={}", sp, sk);
916        }
917
918        #[test]
919        fn automatic_matches_underlying((g, w) in small_weighted_graph()) {
920            // Automatic w/ weights == Kruskal.
921            let auto = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Automatic).unwrap();
922            let kruskal = minimum_spanning_tree(&g, Some(&w), MstAlgorithm::Kruskal).unwrap();
923            prop_assert_eq!(auto, kruskal);
924
925            // Automatic w/o weights == Unweighted.
926            let auto = minimum_spanning_tree(&g, None, MstAlgorithm::Automatic).unwrap();
927            let unweighted =
928                minimum_spanning_tree(&g, None, MstAlgorithm::Unweighted).unwrap();
929            prop_assert_eq!(auto, unweighted);
930        }
931    }
932}