Skip to main content

rust_igraph/algorithms/hrg/
mcmc.rs

1//! MCMC-based HRG fitting engine.
2//!
3//! Implements the Clauset-Moore-Newman (2008) algorithm for fitting
4//! hierarchical random graph models via Markov Chain Monte Carlo over
5//! the space of dendrograms. The engine supports:
6//! - [`hrg_fit`]: MCMC search for the maximum-likelihood dendrogram
7//! - [`hrg_consensus`]: sample split frequencies from MCMC
8//! - [`hrg_predict`]: predict missing edges based on MCMC ensemble
9
10#![allow(
11    clippy::doc_markdown,
12    clippy::cast_possible_truncation,
13    clippy::cast_possible_wrap,
14    clippy::cast_sign_loss,
15    clippy::cast_precision_loss,
16    clippy::cast_lossless,
17    clippy::needless_range_loop,
18    clippy::many_single_char_names
19)]
20
21use crate::core::rng::SplitMix64;
22use crate::core::{Graph, IgraphError, IgraphResult};
23
24use super::HrgTree;
25
26// ── Internal dendrogram representation ──────────────────────────────────
27
28const LEFT: u8 = 0;
29const RIGHT: u8 = 1;
30
31/// A child pointer: either a leaf vertex or an internal node index.
32#[derive(Debug, Clone, Copy)]
33enum Child {
34    Leaf(u32),
35    Internal(usize),
36}
37
38/// An internal node in the dendrogram.
39#[derive(Debug, Clone)]
40struct DendroNode {
41    left: Child,
42    right: Child,
43    parent: Option<usize>,
44    /// Number of leaves in this subtree.
45    n: u32,
46    /// Number of edges spanning left and right subtrees.
47    e: u32,
48    /// Probability p = e / (nL * nR).
49    p: f64,
50    /// Log-likelihood contribution of this node.
51    log_l: f64,
52    /// Minimum leaf label in subtree (for order property).
53    label: u32,
54}
55
56/// The full dendrogram structure for MCMC.
57struct Dendro {
58    nodes: Vec<DendroNode>,
59    /// Parent of leaf v (index of the internal node).
60    leaf_parent: Vec<usize>,
61    /// Which side leaf v is on relative to its parent (LEFT or RIGHT).
62    leaf_side: Vec<u8>,
63    /// Adjacency list for the input graph (symmetric).
64    adj: Vec<Vec<u32>>,
65    /// Total log-likelihood of current dendrogram.
66    total_log_l: f64,
67    /// Number of leaf vertices.
68    n: usize,
69}
70
71impl Dendro {
72    /// Build a random initial dendrogram from a graph.
73    #[allow(clippy::cast_possible_truncation)]
74    fn from_graph(graph: &Graph, rng: &mut SplitMix64) -> IgraphResult<Self> {
75        let n = graph.vcount() as usize;
76        if n < 3 {
77            return Err(IgraphError::InvalidArgument(
78                "HRG fit requires at least 3 vertices".into(),
79            ));
80        }
81
82        // Build symmetric adjacency list (ignoring self-loops + multi-edges)
83        let adj = build_adjacency(graph)?;
84
85        // Create n-1 internal nodes with random initial structure.
86        // Strategy: random sequential insertion (like the C code).
87        let num_internal = n - 1;
88        let mut nodes: Vec<DendroNode> = Vec::with_capacity(num_internal);
89        for _ in 0..num_internal {
90            nodes.push(DendroNode {
91                left: Child::Leaf(0),
92                right: Child::Leaf(0),
93                parent: None,
94                n: 0,
95                e: 0,
96                p: 0.0,
97                log_l: 0.0,
98                label: 0,
99            });
100        }
101
102        let mut leaf_parent = vec![0usize; n];
103        let mut leaf_side = vec![LEFT; n];
104
105        // Random permutation of leaves
106        let mut perm: Vec<u32> = (0..n as u32).collect();
107        for i in (1..n).rev() {
108            let j = rng.gen_index(i + 1);
109            perm.swap(i, j);
110        }
111
112        // Build a random binary tree by inserting leaves one at a time.
113        // Start with root having first two leaves.
114        nodes[0].left = Child::Leaf(perm[0]);
115        nodes[0].right = Child::Leaf(perm[1]);
116        leaf_parent[perm[0] as usize] = 0;
117        leaf_side[perm[0] as usize] = LEFT;
118        leaf_parent[perm[1] as usize] = 0;
119        leaf_side[perm[1] as usize] = RIGHT;
120
121        // For each remaining leaf, pick a random existing leaf and
122        // replace it with a new internal node that has the old leaf
123        // and new leaf as children.
124        for (active_internal, leaf_idx) in (1usize..).zip(2..n) {
125            let new_leaf = perm[leaf_idx];
126            // Pick a random existing leaf to split
127            let target_leaf = perm[rng.gen_index(leaf_idx)];
128            let target_parent = leaf_parent[target_leaf as usize];
129            let target_side = leaf_side[target_leaf as usize];
130
131            let new_internal = active_internal;
132
133            // New internal node gets target_leaf and new_leaf as children
134            nodes[new_internal].left = Child::Leaf(target_leaf);
135            nodes[new_internal].right = Child::Leaf(new_leaf);
136            nodes[new_internal].parent = Some(target_parent);
137
138            // Replace the target_leaf in its parent with the new internal node
139            if target_side == LEFT {
140                nodes[target_parent].left = Child::Internal(new_internal);
141            } else {
142                nodes[target_parent].right = Child::Internal(new_internal);
143            }
144
145            // Update leaf parents
146            leaf_parent[target_leaf as usize] = new_internal;
147            leaf_side[target_leaf as usize] = LEFT;
148            leaf_parent[new_leaf as usize] = new_internal;
149            leaf_side[new_leaf as usize] = RIGHT;
150        }
151
152        let mut dendro = Dendro {
153            nodes,
154            leaf_parent,
155            leaf_side,
156            adj,
157            total_log_l: 0.0,
158            n,
159        };
160
161        dendro.recompute_all();
162        Ok(dendro)
163    }
164
165    /// Build dendro from an existing HRG tree + graph.
166    #[allow(clippy::needless_range_loop, clippy::cast_possible_truncation)]
167    fn from_hrg(graph: &Graph, hrg: &HrgTree) -> IgraphResult<Self> {
168        let n = graph.vcount() as usize;
169        if n < 3 {
170            return Err(IgraphError::InvalidArgument(
171                "HRG fit requires at least 3 vertices".into(),
172            ));
173        }
174        if hrg.size() as usize != n {
175            return Err(IgraphError::InvalidArgument(
176                "HRG size does not match graph vertex count".into(),
177            ));
178        }
179
180        let adj = build_adjacency(graph)?;
181        let num_internal = n - 1;
182
183        let mut nodes: Vec<DendroNode> = Vec::with_capacity(num_internal);
184        let mut leaf_parent = vec![0usize; n];
185        let mut leaf_side = vec![LEFT; n];
186
187        for i in 0..num_internal {
188            let lc = hrg.left[i];
189            let rc = hrg.right[i];
190
191            let left = if lc < 0 {
192                #[allow(clippy::cast_sign_loss)]
193                let idx = (-lc - 1) as usize;
194                Child::Internal(idx)
195            } else {
196                #[allow(clippy::cast_sign_loss)]
197                let idx = lc as u32;
198                leaf_parent[idx as usize] = i;
199                leaf_side[idx as usize] = LEFT;
200                Child::Leaf(idx)
201            };
202
203            let right = if rc < 0 {
204                #[allow(clippy::cast_sign_loss)]
205                let idx = (-rc - 1) as usize;
206                Child::Internal(idx)
207            } else {
208                #[allow(clippy::cast_sign_loss)]
209                let idx = rc as u32;
210                leaf_parent[idx as usize] = i;
211                leaf_side[idx as usize] = RIGHT;
212                Child::Leaf(idx)
213            };
214
215            nodes.push(DendroNode {
216                left,
217                right,
218                parent: None,
219                n: hrg.vertices[i] as u32,
220                e: hrg.edges[i] as u32,
221                p: hrg.prob[i],
222                log_l: 0.0,
223                label: 0,
224            });
225        }
226
227        // Set parent pointers
228        for i in 0..num_internal {
229            match nodes[i].left {
230                Child::Internal(c) => nodes[c].parent = Some(i),
231                Child::Leaf(_) => {}
232            }
233            match nodes[i].right {
234                Child::Internal(c) => nodes[c].parent = Some(i),
235                Child::Leaf(_) => {}
236            }
237        }
238
239        let mut dendro = Dendro {
240            nodes,
241            leaf_parent,
242            leaf_side,
243            adj,
244            total_log_l: 0.0,
245            n,
246        };
247
248        dendro.recompute_all();
249        Ok(dendro)
250    }
251
252    /// Recompute all statistics from scratch (subtree sizes, edge counts, likelihood).
253    #[allow(clippy::needless_range_loop)]
254    fn recompute_all(&mut self) {
255        let num_internal = self.n - 1;
256
257        // Compute subtree sizes bottom-up
258        for i in (0..num_internal).rev() {
259            let nl = self.subtree_size(self.nodes[i].left);
260            let nr = self.subtree_size(self.nodes[i].right);
261            self.nodes[i].n = nl + nr;
262        }
263
264        // Compute labels (minimum leaf in subtree)
265        for i in (0..num_internal).rev() {
266            let ll = self.min_label(self.nodes[i].left);
267            let rl = self.min_label(self.nodes[i].right);
268            self.nodes[i].label = ll.min(rl);
269        }
270
271        // Compute edge counts using leaf enumeration + adjacency
272        for i in 0..num_internal {
273            let left_leaves = self.collect_leaves(self.nodes[i].left);
274            let right_leaves = self.collect_leaves(self.nodes[i].right);
275            let mut e = 0u32;
276            for &lv in &left_leaves {
277                for &rv in &right_leaves {
278                    if self.adj[lv as usize].contains(&rv) {
279                        e += 1;
280                    }
281                }
282            }
283            self.nodes[i].e = e;
284        }
285
286        // Compute probabilities and log-likelihood
287        self.refresh_likelihood();
288    }
289
290    /// Refresh log-likelihood and probabilities from current e and n values.
291    fn refresh_likelihood(&mut self) {
292        self.total_log_l = 0.0;
293        let num_internal = self.n - 1;
294        for i in 0..num_internal {
295            let nl = self.subtree_size(self.nodes[i].left);
296            let nr = self.subtree_size(self.nodes[i].right);
297            #[allow(clippy::cast_possible_truncation)]
298            let nl_nr = (nl as u64) * (nr as u64);
299            let ei = self.nodes[i].e as u64;
300
301            if nl_nr == 0 {
302                self.nodes[i].p = 0.0;
303                self.nodes[i].log_l = 0.0;
304            } else {
305                #[allow(clippy::cast_precision_loss)]
306                let p = ei as f64 / nl_nr as f64;
307                self.nodes[i].p = p;
308
309                if ei == 0 || ei == nl_nr {
310                    self.nodes[i].log_l = 0.0;
311                } else {
312                    #[allow(clippy::cast_precision_loss)]
313                    let dl = (ei as f64) * p.ln() + ((nl_nr - ei) as f64) * (1.0 - p).ln();
314                    self.nodes[i].log_l = dl;
315                }
316            }
317            self.total_log_l += self.nodes[i].log_l;
318        }
319    }
320
321    fn subtree_size(&self, child: Child) -> u32 {
322        match child {
323            Child::Leaf(_) => 1,
324            Child::Internal(idx) => self.nodes[idx].n,
325        }
326    }
327
328    fn min_label(&self, child: Child) -> u32 {
329        match child {
330            Child::Leaf(v) => v,
331            Child::Internal(idx) => self.nodes[idx].label,
332        }
333    }
334
335    fn collect_leaves(&self, child: Child) -> Vec<u32> {
336        let mut result = Vec::new();
337        let mut stack = vec![child];
338        while let Some(c) = stack.pop() {
339            match c {
340                Child::Leaf(v) => result.push(v),
341                Child::Internal(idx) => {
342                    stack.push(self.nodes[idx].left);
343                    stack.push(self.nodes[idx].right);
344                }
345            }
346        }
347        result
348    }
349
350    /// Count edges between two subtrees.
351    fn compute_edge_count(&self, a: Child, b: Child) -> u32 {
352        let leaves_a = self.collect_leaves(a);
353        let leaves_b = self.collect_leaves(b);
354        let mut count = 0u32;
355        // Use the smaller set for outer loop
356        if leaves_a.len() <= leaves_b.len() {
357            for &va in &leaves_a {
358                for &vb in &leaves_b {
359                    if self.adj[va as usize].contains(&vb) {
360                        count += 1;
361                    }
362                }
363            }
364        } else {
365            for &vb in &leaves_b {
366                for &va in &leaves_a {
367                    if self.adj[va as usize].contains(&vb) {
368                        count += 1;
369                    }
370                }
371            }
372        }
373        count
374    }
375
376    /// Compute log-likelihood for given (e, n_L*n_R) pair.
377    #[allow(clippy::cast_precision_loss)]
378    fn node_log_likelihood(e: u32, nl_nr: u64) -> f64 {
379        let ei = e as u64;
380        if ei == 0 || ei == nl_nr || nl_nr == 0 {
381            return 0.0;
382        }
383        let p = ei as f64 / nl_nr as f64;
384        (ei as f64) * p.ln() + ((nl_nr - ei) as f64) * (1.0 - p).ln()
385    }
386
387    /// Perform a single MCMC move (tree rotation).
388    fn mcmc_move(&mut self, rng: &mut SplitMix64) -> f64 {
389        let num_internal = self.n - 1;
390        if num_internal < 2 {
391            return 0.0;
392        }
393
394        // Pick a random non-root internal node y, get its parent x
395        let (x, y) = loop {
396            let idx = rng.gen_index(num_internal);
397            if let Some(p) = self.nodes[idx].parent {
398                break (p, idx);
399            }
400        };
401
402        // Determine which side y is on
403        let side = match self.nodes[x].left {
404            Child::Internal(c) if c == y => LEFT,
405            _ => RIGHT,
406        };
407
408        if side == LEFT {
409            if rng.gen_unit() < 0.5 {
410                self.try_left_alpha(x, y, rng)
411            } else {
412                self.try_left_beta(x, y, rng)
413            }
414        } else if rng.gen_unit() < 0.5 {
415            self.try_right_alpha(x, y, rng)
416        } else {
417            self.try_right_beta(x, y, rng)
418        }
419    }
420
421    /// LEFT ALPHA: ((i,j),k) -> ((i,k),j)
422    #[allow(clippy::many_single_char_names)]
423    fn try_left_alpha(&mut self, x: usize, y: usize, rng: &mut SplitMix64) -> f64 {
424        let i = self.nodes[y].left;
425        let j = self.nodes[y].right;
426        let k = self.nodes[x].right;
427
428        let n_i = self.subtree_size(i);
429        let n_j = self.subtree_size(j);
430        let n_k = self.subtree_size(k);
431
432        // New y: (i, k) -> e_ik, n_y = n_i * n_k
433        let e_y = self.compute_edge_count(i, k);
434        let nl_nr_y = (n_i as u64) * (n_k as u64);
435        let log_l_y = Self::node_log_likelihood(e_y, nl_nr_y);
436
437        // New x: (y_new, j) -> e_x = old_e_x + old_e_y - e_ik
438        let e_x = self.nodes[x].e + self.nodes[y].e - e_y;
439        let nl_nr_x = ((n_i + n_k) as u64) * (n_j as u64);
440        let log_l_x = Self::node_log_likelihood(e_x, nl_nr_x);
441
442        let d_log_l = (log_l_x - self.nodes[x].log_l) + (log_l_y - self.nodes[y].log_l);
443
444        if d_log_l > 0.0 || rng.gen_unit() < d_log_l.exp() {
445            // Accept: swap j and k
446            self.nodes[y].right = k;
447            self.nodes[x].right = j;
448            self.update_child_parent(k, y, RIGHT);
449            self.update_child_parent(j, x, RIGHT);
450            self.nodes[y].n = n_i + n_k;
451            self.nodes[y].e = e_y;
452            #[allow(clippy::cast_precision_loss)]
453            {
454                self.nodes[y].p = e_y as f64 / nl_nr_y.max(1) as f64;
455            }
456            self.nodes[y].log_l = log_l_y;
457            self.nodes[x].e = e_x;
458            #[allow(clippy::cast_precision_loss)]
459            {
460                self.nodes[x].p = e_x as f64 / nl_nr_x.max(1) as f64;
461            }
462            self.nodes[x].log_l = log_l_x;
463            self.nodes[x].n = n_i + n_k + n_j;
464            self.total_log_l += d_log_l;
465            d_log_l
466        } else {
467            0.0
468        }
469    }
470
471    /// LEFT BETA: ((i,j),k) -> (i,(j,k))
472    #[allow(clippy::many_single_char_names)]
473    fn try_left_beta(&mut self, x: usize, y: usize, rng: &mut SplitMix64) -> f64 {
474        let i = self.nodes[y].left;
475        let j = self.nodes[y].right;
476        let k = self.nodes[x].right;
477
478        let n_i = self.subtree_size(i);
479        let n_j = self.subtree_size(j);
480        let n_k = self.subtree_size(k);
481
482        // New y: (j, k) -> e_jk
483        let e_y = self.compute_edge_count(j, k);
484        let nl_nr_y = (n_j as u64) * (n_k as u64);
485        let log_l_y = Self::node_log_likelihood(e_y, nl_nr_y);
486
487        // New x: (i, y_new) -> e_x = old_e_x + old_e_y - e_jk
488        let e_x = self.nodes[x].e + self.nodes[y].e - e_y;
489        let nl_nr_x = (n_i as u64) * ((n_j + n_k) as u64);
490        let log_l_x = Self::node_log_likelihood(e_x, nl_nr_x);
491
492        let d_log_l = (log_l_x - self.nodes[x].log_l) + (log_l_y - self.nodes[y].log_l);
493
494        if d_log_l > 0.0 || rng.gen_unit() < d_log_l.exp() {
495            // Accept: restructure
496            // y becomes right child of x; y holds (j, k); x holds (i, y)
497            self.nodes[y].left = j;
498            self.nodes[y].right = k;
499            self.nodes[x].left = i;
500            self.nodes[x].right = Child::Internal(y);
501            self.update_child_parent(j, y, LEFT);
502            self.update_child_parent(k, y, RIGHT);
503            self.update_child_parent(i, x, LEFT);
504            // y is already child of x, update edge list side
505
506            self.nodes[y].n = n_j + n_k;
507            self.nodes[y].e = e_y;
508            #[allow(clippy::cast_precision_loss)]
509            {
510                self.nodes[y].p = e_y as f64 / nl_nr_y.max(1) as f64;
511            }
512            self.nodes[y].log_l = log_l_y;
513            self.nodes[x].e = e_x;
514            #[allow(clippy::cast_precision_loss)]
515            {
516                self.nodes[x].p = e_x as f64 / nl_nr_x.max(1) as f64;
517            }
518            self.nodes[x].log_l = log_l_x;
519            self.nodes[x].n = n_i + n_j + n_k;
520            self.total_log_l += d_log_l;
521            d_log_l
522        } else {
523            0.0
524        }
525    }
526
527    /// RIGHT ALPHA: (i,(j,k)) -> ((i,k),j)
528    #[allow(clippy::many_single_char_names)]
529    fn try_right_alpha(&mut self, x: usize, y: usize, rng: &mut SplitMix64) -> f64 {
530        let i = self.nodes[x].left;
531        let j = self.nodes[y].left;
532        let k = self.nodes[y].right;
533
534        let n_i = self.subtree_size(i);
535        let n_j = self.subtree_size(j);
536        let n_k = self.subtree_size(k);
537
538        // New y: (i, k)
539        let e_y = self.compute_edge_count(i, k);
540        let nl_nr_y = (n_i as u64) * (n_k as u64);
541        let log_l_y = Self::node_log_likelihood(e_y, nl_nr_y);
542
543        // New x: (y_new, j)
544        let e_x = self.nodes[x].e + self.nodes[y].e - e_y;
545        let nl_nr_x = ((n_i + n_k) as u64) * (n_j as u64);
546        let log_l_x = Self::node_log_likelihood(e_x, nl_nr_x);
547
548        let d_log_l = (log_l_x - self.nodes[x].log_l) + (log_l_y - self.nodes[y].log_l);
549
550        if d_log_l > 0.0 || rng.gen_unit() < d_log_l.exp() {
551            // Accept: y becomes left child of x; y holds (i,k); x holds (y,j)
552            self.nodes[y].left = i;
553            self.nodes[y].right = k;
554            self.nodes[x].left = Child::Internal(y);
555            self.nodes[x].right = j;
556            self.update_child_parent(i, y, LEFT);
557            self.update_child_parent(k, y, RIGHT);
558            self.update_child_parent(j, x, RIGHT);
559
560            self.nodes[y].n = n_i + n_k;
561            self.nodes[y].e = e_y;
562            #[allow(clippy::cast_precision_loss)]
563            {
564                self.nodes[y].p = e_y as f64 / nl_nr_y.max(1) as f64;
565            }
566            self.nodes[y].log_l = log_l_y;
567            self.nodes[x].e = e_x;
568            #[allow(clippy::cast_precision_loss)]
569            {
570                self.nodes[x].p = e_x as f64 / nl_nr_x.max(1) as f64;
571            }
572            self.nodes[x].log_l = log_l_x;
573            self.nodes[x].n = n_i + n_k + n_j;
574            self.total_log_l += d_log_l;
575            d_log_l
576        } else {
577            0.0
578        }
579    }
580
581    /// RIGHT BETA: (i,(j,k)) -> ((i,j),k)
582    #[allow(clippy::many_single_char_names)]
583    fn try_right_beta(&mut self, x: usize, y: usize, rng: &mut SplitMix64) -> f64 {
584        let i = self.nodes[x].left;
585        let j = self.nodes[y].left;
586        let k = self.nodes[y].right;
587
588        let n_i = self.subtree_size(i);
589        let n_j = self.subtree_size(j);
590        let n_k = self.subtree_size(k);
591
592        // New y: (i, j)
593        let e_y = self.compute_edge_count(i, j);
594        let nl_nr_y = (n_i as u64) * (n_j as u64);
595        let log_l_y = Self::node_log_likelihood(e_y, nl_nr_y);
596
597        // New x: (y_new, k)
598        let e_x = self.nodes[x].e + self.nodes[y].e - e_y;
599        let nl_nr_x = ((n_i + n_j) as u64) * (n_k as u64);
600        let log_l_x = Self::node_log_likelihood(e_x, nl_nr_x);
601
602        let d_log_l = (log_l_x - self.nodes[x].log_l) + (log_l_y - self.nodes[y].log_l);
603
604        if d_log_l > 0.0 || rng.gen_unit() < d_log_l.exp() {
605            // Accept: y becomes left child of x; y holds (i,j); x holds (y,k)
606            self.nodes[y].left = i;
607            self.nodes[y].right = j;
608            self.nodes[x].left = Child::Internal(y);
609            self.nodes[x].right = k;
610            self.update_child_parent(i, y, LEFT);
611            self.update_child_parent(j, y, RIGHT);
612            self.update_child_parent(k, x, RIGHT);
613
614            self.nodes[y].n = n_i + n_j;
615            self.nodes[y].e = e_y;
616            #[allow(clippy::cast_precision_loss)]
617            {
618                self.nodes[y].p = e_y as f64 / nl_nr_y.max(1) as f64;
619            }
620            self.nodes[y].log_l = log_l_y;
621            self.nodes[x].e = e_x;
622            #[allow(clippy::cast_precision_loss)]
623            {
624                self.nodes[x].p = e_x as f64 / nl_nr_x.max(1) as f64;
625            }
626            self.nodes[x].log_l = log_l_x;
627            self.nodes[x].n = n_i + n_j + n_k;
628            self.total_log_l += d_log_l;
629            d_log_l
630        } else {
631            0.0
632        }
633    }
634
635    /// Update the parent pointer of a child (leaf or internal).
636    fn update_child_parent(&mut self, child: Child, new_parent: usize, side: u8) {
637        match child {
638            Child::Leaf(v) => {
639                self.leaf_parent[v as usize] = new_parent;
640                self.leaf_side[v as usize] = side;
641            }
642            Child::Internal(idx) => {
643                self.nodes[idx].parent = Some(new_parent);
644            }
645        }
646    }
647
648    /// Export current dendrogram to an [`HrgTree`].
649    #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
650    fn export_hrg(&self) -> HrgTree {
651        let n = self.n as u32;
652        let mut hrg = HrgTree::new(n);
653        let num_internal = self.n - 1;
654
655        for i in 0..num_internal {
656            hrg.left[i] = match self.nodes[i].left {
657                Child::Leaf(v) => v as i32,
658                Child::Internal(idx) => -(idx as i32 + 1),
659            };
660            hrg.right[i] = match self.nodes[i].right {
661                Child::Leaf(v) => v as i32,
662                Child::Internal(idx) => -(idx as i32 + 1),
663            };
664            hrg.prob[i] = self.nodes[i].p;
665            hrg.vertices[i] = self.nodes[i].n as i32;
666            hrg.edges[i] = self.nodes[i].e as i32;
667        }
668        hrg
669    }
670
671    /// Get the accumulated split probability for each non-existing edge.
672    #[allow(clippy::needless_range_loop, clippy::cast_possible_truncation)]
673    fn predict_edges(&self, adj_counts: &[Vec<f64>], num_samples: f64) -> Vec<(u32, u32, f64)> {
674        let n = self.n;
675        let mut predictions = Vec::new();
676        for i in 0..n {
677            for j in (i + 1)..n {
678                let iv = i as u32;
679                let jv = j as u32;
680                if !self.adj[i].contains(&jv) {
681                    let prob = adj_counts[i][j] / num_samples;
682                    predictions.push((iv, jv, prob));
683                }
684            }
685        }
686        predictions.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
687        predictions
688    }
689
690    /// Accumulate connection probabilities for current tree into counts matrix.
691    #[allow(clippy::needless_range_loop, clippy::cast_possible_truncation)]
692    fn accumulate_probabilities(&self, counts: &mut [Vec<f64>]) {
693        for i in 0..self.n {
694            for j in (i + 1)..self.n {
695                let lca = self.find_lca(i as u32, j as u32);
696                counts[i][j] += self.nodes[lca].p;
697                counts[j][i] += self.nodes[lca].p;
698            }
699        }
700    }
701
702    /// Find LCA of two leaves.
703    fn find_lca(&self, leaf_a: u32, leaf_b: u32) -> usize {
704        let num_internal = self.n - 1;
705        let mut visited = vec![false; num_internal];
706
707        // Walk from leaf_a to root, marking
708        let mut cur = self.leaf_parent[leaf_a as usize];
709        loop {
710            visited[cur] = true;
711            match self.nodes[cur].parent {
712                Some(p) => cur = p,
713                None => break,
714            }
715        }
716
717        // Walk from leaf_b to root, find first marked
718        let mut cur = self.leaf_parent[leaf_b as usize];
719        loop {
720            if visited[cur] {
721                return cur;
722            }
723            match self.nodes[cur].parent {
724                Some(p) => cur = p,
725                None => return cur, // root
726            }
727        }
728    }
729}
730
731/// Build a symmetric adjacency list from a graph (ignoring direction and self-loops).
732#[allow(clippy::cast_possible_truncation)]
733fn build_adjacency(graph: &Graph) -> IgraphResult<Vec<Vec<u32>>> {
734    let n = graph.vcount() as usize;
735    let mut adj: Vec<Vec<u32>> = vec![Vec::new(); n];
736
737    for eid in 0..graph.ecount() {
738        let (from, to) = graph.edge(eid as u32)?;
739        if from == to {
740            continue;
741        }
742        if !adj[from as usize].contains(&to) {
743            adj[from as usize].push(to);
744        }
745        if !adj[to as usize].contains(&from) {
746            adj[to as usize].push(from);
747        }
748    }
749
750    Ok(adj)
751}
752
753// ── Public API ──────────────────────────────────────────────────────────
754
755/// Fit a hierarchical random graph model to a network using MCMC.
756///
757/// If `start_hrg` is provided, the MCMC begins from that dendrogram
758/// structure; otherwise a random initial tree is used.
759///
760/// `steps` controls the MCMC budget:
761/// - If `steps > 0`, exactly that many MCMC moves are performed.
762/// - If `steps == 0`, the chain runs until convergence (mean
763///   log-likelihood stabilizes over 65536-step windows).
764///
765/// # Errors
766///
767/// Returns an error if the graph has fewer than 3 vertices.
768///
769/// # Example
770///
771/// ```
772/// use rust_igraph::{Graph, hrg_fit};
773///
774/// let g = Graph::from_edges(
775///     &[(0,1),(1,2),(2,3),(3,4),(4,0),(0,2),(1,3)],
776///     false, Some(5)
777/// ).unwrap();
778/// let hrg = hrg_fit(&g, None, 1000, 42).unwrap();
779/// assert_eq!(hrg.size(), 5);
780/// ```
781pub fn hrg_fit(
782    graph: &Graph,
783    start_hrg: Option<&HrgTree>,
784    steps: u64,
785    seed: u64,
786) -> IgraphResult<HrgTree> {
787    let mut rng = SplitMix64::new(seed);
788
789    let mut dendro = match start_hrg {
790        Some(hrg) => Dendro::from_hrg(graph, hrg)?,
791        None => Dendro::from_graph(graph, &mut rng)?,
792    };
793
794    if steps > 0 {
795        // Fixed number of steps
796        let mut best_l = dendro.total_log_l;
797        let mut best_hrg = dendro.export_hrg();
798
799        for _ in 0..steps {
800            dendro.mcmc_move(&mut rng);
801            if dendro.total_log_l > best_l {
802                best_l = dendro.total_log_l;
803                best_hrg = dendro.export_hrg();
804            }
805        }
806        // Periodically refresh to correct FP drift
807        dendro.refresh_likelihood();
808        if dendro.total_log_l > best_l {
809            best_hrg = dendro.export_hrg();
810        }
811        Ok(best_hrg)
812    } else {
813        // Run until convergence
814        mcmc_equilibrium(&mut dendro, &mut rng);
815        Ok(dendro.export_hrg())
816    }
817}
818
819/// Run MCMC until convergence (mean log-likelihood stabilizes).
820fn mcmc_equilibrium(dendro: &mut Dendro, rng: &mut SplitMix64) {
821    let window = 65536u64;
822    let mut old_mean = f64::NEG_INFINITY;
823
824    loop {
825        let mut sum = 0.0;
826        for _ in 0..window {
827            dendro.mcmc_move(rng);
828            sum += dendro.total_log_l;
829        }
830        dendro.refresh_likelihood();
831
832        #[allow(clippy::cast_precision_loss)]
833        let new_mean = sum / window as f64;
834
835        if (new_mean - old_mean).abs() < 1.0 {
836            break;
837        }
838        old_mean = new_mean;
839    }
840}
841
842/// Predict missing edges based on HRG ensemble sampling.
843///
844/// Returns a list of `(from, to, probability)` tuples sorted by
845/// probability (highest first). Only non-existing edges are included.
846///
847/// The MCMC chain first runs a burn-in of `200*n` steps, then samples
848/// with probability `1/(50*n)` per step until `num_samples` trees have
849/// been collected.
850///
851/// # Errors
852///
853/// Returns an error if the graph has fewer than 3 vertices.
854///
855/// # Example
856///
857/// ```
858/// use rust_igraph::{Graph, hrg_predict};
859///
860/// let g = Graph::from_edges(
861///     &[(0,1),(1,2),(2,3),(3,4),(4,0)],
862///     false, Some(5)
863/// ).unwrap();
864/// let predictions = hrg_predict(&g, None, 10, 42).unwrap();
865/// assert!(!predictions.is_empty());
866/// for &(from, to, prob) in &predictions {
867///     assert!(prob >= 0.0 && prob <= 1.0);
868///     assert!(from < to);
869/// }
870/// ```
871#[allow(clippy::cast_precision_loss)]
872pub fn hrg_predict(
873    graph: &Graph,
874    start_hrg: Option<&HrgTree>,
875    num_samples: u64,
876    seed: u64,
877) -> IgraphResult<Vec<(u32, u32, f64)>> {
878    let n = graph.vcount() as usize;
879    let mut rng = SplitMix64::new(seed);
880
881    let mut dendro = match start_hrg {
882        Some(hrg) => Dendro::from_hrg(graph, hrg)?,
883        None => Dendro::from_graph(graph, &mut rng)?,
884    };
885
886    // Burn-in phase
887    let burn_in = 200 * n;
888    for _ in 0..burn_in {
889        dendro.mcmc_move(&mut rng);
890    }
891    dendro.refresh_likelihood();
892
893    // Sample and accumulate probabilities
894    let mut counts: Vec<Vec<f64>> = vec![vec![0.0; n]; n];
895    let sample_interval = 50 * n;
896    let mut samples_taken = 0u64;
897
898    while samples_taken < num_samples {
899        for _ in 0..sample_interval {
900            dendro.mcmc_move(&mut rng);
901        }
902        dendro.accumulate_probabilities(&mut counts);
903        samples_taken += 1;
904    }
905
906    let result = dendro.predict_edges(&counts, num_samples as f64);
907    Ok(result)
908}
909
910/// Compute a consensus tree from MCMC samples of HRG models.
911///
912/// Returns `(parents, weights)`:
913/// - `parents[i]` is the parent of vertex i in the consensus tree
914///   (-1 for root). Vertex IDs 0..n are leaves, n..2n-1 are internal.
915/// - `weights[i]` is the frequency (0..1) of each internal split.
916///
917/// # Errors
918///
919/// Returns an error if the graph has fewer than 3 vertices.
920///
921/// # Example
922///
923/// ```
924/// use rust_igraph::{Graph, hrg_consensus};
925///
926/// let g = Graph::from_edges(
927///     &[(0,1),(1,2),(2,3),(3,4),(4,0),(0,2)],
928///     false, Some(5)
929/// ).unwrap();
930/// let (parents, weights) = hrg_consensus(&g, None, 10, 42).unwrap();
931/// assert_eq!(parents.len(), 2 * 5 - 1);
932/// assert_eq!(weights.len(), 4); // n-1 internal nodes
933/// ```
934#[allow(
935    clippy::cast_possible_truncation,
936    clippy::cast_possible_wrap,
937    clippy::cast_precision_loss
938)]
939pub fn hrg_consensus(
940    graph: &Graph,
941    start_hrg: Option<&HrgTree>,
942    num_samples: u64,
943    seed: u64,
944) -> IgraphResult<(Vec<i32>, Vec<f64>)> {
945    let n = graph.vcount() as usize;
946    let mut rng = SplitMix64::new(seed);
947
948    let mut dendro = match start_hrg {
949        Some(hrg) => Dendro::from_hrg(graph, hrg)?,
950        None => Dendro::from_graph(graph, &mut rng)?,
951    };
952
953    // Burn-in
954    let burn_in = 200 * n;
955    for _ in 0..burn_in {
956        dendro.mcmc_move(&mut rng);
957    }
958    dendro.refresh_likelihood();
959
960    // Sample split frequencies
961    let mut split_counts: std::collections::HashMap<Vec<u32>, u64> =
962        std::collections::HashMap::new();
963
964    let sample_interval = 50 * n;
965    let num_internal = n - 1;
966    let mut samples_taken = 0u64;
967
968    while samples_taken < num_samples {
969        for _ in 0..sample_interval {
970            dendro.mcmc_move(&mut rng);
971        }
972
973        // Record splits from current tree
974        for i in 0..num_internal {
975            let mut left_leaves = dendro.collect_leaves(dendro.nodes[i].left);
976            left_leaves.sort_unstable();
977            *split_counts.entry(left_leaves).or_insert(0) += 1;
978        }
979        samples_taken += 1;
980    }
981
982    // Build consensus tree from most frequent splits
983    let mut splits: Vec<(Vec<u32>, u64)> = split_counts.into_iter().collect();
984    splits.sort_by_key(|b| std::cmp::Reverse(b.1));
985
986    // Build the consensus tree as parent array
987    let total_nodes = 2 * n - 1;
988    let mut parents = vec![-1i32; total_nodes];
989    let mut weights = vec![0.0f64; n - 1];
990
991    // Start with all leaves under root (internal node index 0 = position n)
992    let root_pos = n;
993    for i in 0..n {
994        parents[i] = root_pos as i32;
995    }
996
997    // Greedily assign the top compatible splits
998    let mut used_internal = 0usize;
999
1000    for (split, count) in splits.iter().take(n - 1) {
1001        if split.len() < 2 || split.len() >= n {
1002            continue;
1003        }
1004        if used_internal >= n - 1 {
1005            break;
1006        }
1007
1008        let internal_idx = n + used_internal;
1009        weights[used_internal] = *count as f64 / (num_samples * num_internal as u64) as f64;
1010
1011        // Assign leaves in split to this internal node
1012        for &leaf in split {
1013            parents[leaf as usize] = internal_idx as i32;
1014        }
1015
1016        // This internal node is child of root
1017        if used_internal > 0 {
1018            parents[internal_idx] = root_pos as i32;
1019        }
1020
1021        used_internal += 1;
1022    }
1023
1024    Ok((parents, weights))
1025}
1026
1027#[cfg(test)]
1028mod tests {
1029    use super::*;
1030
1031    fn make_test_graph() -> Graph {
1032        // Pentagon with one diagonal
1033        Graph::from_edges(
1034            &[(0, 1), (1, 2), (2, 3), (3, 4), (4, 0), (0, 2)],
1035            false,
1036            Some(5),
1037        )
1038        .expect("graph creation")
1039    }
1040
1041    #[test]
1042    fn hrg_fit_returns_correct_size() {
1043        let g = make_test_graph();
1044        let hrg = hrg_fit(&g, None, 500, 42).expect("hrg_fit");
1045        assert_eq!(hrg.size(), 5);
1046        assert_eq!(hrg.num_internal(), 4);
1047    }
1048
1049    #[test]
1050    fn hrg_fit_deterministic() {
1051        let g = make_test_graph();
1052        let h1 = hrg_fit(&g, None, 200, 99).expect("hrg_fit");
1053        let h2 = hrg_fit(&g, None, 200, 99).expect("hrg_fit");
1054        for i in 0..h1.num_internal() {
1055            assert_eq!(h1.left[i], h2.left[i]);
1056            assert_eq!(h1.right[i], h2.right[i]);
1057            assert!((h1.prob[i] - h2.prob[i]).abs() < 1e-10);
1058        }
1059    }
1060
1061    #[test]
1062    fn hrg_fit_rejects_small_graph() {
1063        let g = Graph::from_edges(&[(0, 1)], false, Some(2)).expect("graph");
1064        assert!(hrg_fit(&g, None, 100, 0).is_err());
1065    }
1066
1067    #[test]
1068    fn hrg_fit_from_start_hrg() {
1069        let g = make_test_graph();
1070        let h1 = hrg_fit(&g, None, 200, 42).expect("hrg_fit");
1071        let h2 = hrg_fit(&g, Some(&h1), 200, 77).expect("hrg_fit from start");
1072        assert_eq!(h2.size(), 5);
1073    }
1074
1075    #[test]
1076    fn hrg_fit_probs_valid() {
1077        let g = make_test_graph();
1078        let hrg = hrg_fit(&g, None, 500, 42).expect("hrg_fit");
1079        for i in 0..hrg.num_internal() {
1080            assert!(hrg.prob[i] >= 0.0 && hrg.prob[i] <= 1.0);
1081        }
1082    }
1083
1084    #[test]
1085    fn hrg_predict_returns_results() {
1086        let g = make_test_graph();
1087        let preds = hrg_predict(&g, None, 20, 42).expect("hrg_predict");
1088        // Pentagon+diagonal has 6 edges out of 10 possible, so 4 missing
1089        assert_eq!(preds.len(), 4);
1090        for &(from, to, prob) in &preds {
1091            assert!(from < to);
1092            assert!((0.0..=1.0).contains(&prob));
1093        }
1094    }
1095
1096    #[test]
1097    fn hrg_predict_sorted_by_prob() {
1098        let g = make_test_graph();
1099        let preds = hrg_predict(&g, None, 20, 42).expect("hrg_predict");
1100        for w in preds.windows(2) {
1101            assert!(w[0].2 >= w[1].2);
1102        }
1103    }
1104
1105    #[test]
1106    fn hrg_consensus_returns_valid_tree() {
1107        let g = make_test_graph();
1108        let (parents, weights) = hrg_consensus(&g, None, 20, 42).expect("hrg_consensus");
1109        assert_eq!(parents.len(), 9); // 2*5-1
1110        assert_eq!(weights.len(), 4); // n-1
1111        // Root should have parent -1
1112        assert!(parents.contains(&-1));
1113        for &w in &weights {
1114            assert!((0.0..=1.0).contains(&w));
1115        }
1116    }
1117
1118    #[test]
1119    fn dendro_from_graph_correct_structure() {
1120        let g = make_test_graph();
1121        let mut rng = SplitMix64::new(42);
1122        let d = Dendro::from_graph(&g, &mut rng).expect("dendro");
1123        assert_eq!(d.n, 5);
1124        assert_eq!(d.nodes.len(), 4); // n-1 internal nodes
1125        // Root's subtree should have all n vertices
1126        assert_eq!(d.nodes[0].n, 5);
1127    }
1128
1129    #[test]
1130    fn dendro_likelihood_finite() {
1131        let g = make_test_graph();
1132        let mut rng = SplitMix64::new(42);
1133        let d = Dendro::from_graph(&g, &mut rng).expect("dendro");
1134        assert!(d.total_log_l.is_finite());
1135        assert!(d.total_log_l <= 0.0);
1136    }
1137
1138    #[test]
1139    fn dendro_mcmc_move_changes_likelihood() {
1140        let g = make_test_graph();
1141        let mut rng = SplitMix64::new(42);
1142        let mut d = Dendro::from_graph(&g, &mut rng).expect("dendro");
1143        let initial_l = d.total_log_l;
1144        let mut changed = false;
1145        for _ in 0..100 {
1146            d.mcmc_move(&mut rng);
1147            if (d.total_log_l - initial_l).abs() > 1e-15 {
1148                changed = true;
1149                break;
1150            }
1151        }
1152        assert!(changed, "MCMC should change likelihood within 100 moves");
1153    }
1154
1155    #[test]
1156    fn dendro_export_roundtrip() {
1157        let g = make_test_graph();
1158        let mut rng = SplitMix64::new(42);
1159        let d = Dendro::from_graph(&g, &mut rng).expect("dendro");
1160        let hrg = d.export_hrg();
1161        assert_eq!(hrg.size(), 5);
1162        assert_eq!(hrg.num_internal(), 4);
1163    }
1164}