1#![allow(
11 clippy::doc_markdown,
12 clippy::cast_possible_truncation,
13 clippy::cast_possible_wrap,
14 clippy::cast_sign_loss,
15 clippy::cast_precision_loss,
16 clippy::cast_lossless,
17 clippy::needless_range_loop,
18 clippy::many_single_char_names
19)]
20
21use crate::core::rng::SplitMix64;
22use crate::core::{Graph, IgraphError, IgraphResult};
23
24use super::HrgTree;
25
26const LEFT: u8 = 0;
29const RIGHT: u8 = 1;
30
31#[derive(Debug, Clone, Copy)]
33enum Child {
34 Leaf(u32),
35 Internal(usize),
36}
37
38#[derive(Debug, Clone)]
40struct DendroNode {
41 left: Child,
42 right: Child,
43 parent: Option<usize>,
44 n: u32,
46 e: u32,
48 p: f64,
50 log_l: f64,
52 label: u32,
54}
55
56struct Dendro {
58 nodes: Vec<DendroNode>,
59 leaf_parent: Vec<usize>,
61 leaf_side: Vec<u8>,
63 adj: Vec<Vec<u32>>,
65 total_log_l: f64,
67 n: usize,
69}
70
71impl Dendro {
72 #[allow(clippy::cast_possible_truncation)]
74 fn from_graph(graph: &Graph, rng: &mut SplitMix64) -> IgraphResult<Self> {
75 let n = graph.vcount() as usize;
76 if n < 3 {
77 return Err(IgraphError::InvalidArgument(
78 "HRG fit requires at least 3 vertices".into(),
79 ));
80 }
81
82 let adj = build_adjacency(graph)?;
84
85 let num_internal = n - 1;
88 let mut nodes: Vec<DendroNode> = Vec::with_capacity(num_internal);
89 for _ in 0..num_internal {
90 nodes.push(DendroNode {
91 left: Child::Leaf(0),
92 right: Child::Leaf(0),
93 parent: None,
94 n: 0,
95 e: 0,
96 p: 0.0,
97 log_l: 0.0,
98 label: 0,
99 });
100 }
101
102 let mut leaf_parent = vec![0usize; n];
103 let mut leaf_side = vec![LEFT; n];
104
105 let mut perm: Vec<u32> = (0..n as u32).collect();
107 for i in (1..n).rev() {
108 let j = rng.gen_index(i + 1);
109 perm.swap(i, j);
110 }
111
112 nodes[0].left = Child::Leaf(perm[0]);
115 nodes[0].right = Child::Leaf(perm[1]);
116 leaf_parent[perm[0] as usize] = 0;
117 leaf_side[perm[0] as usize] = LEFT;
118 leaf_parent[perm[1] as usize] = 0;
119 leaf_side[perm[1] as usize] = RIGHT;
120
121 for (active_internal, leaf_idx) in (1usize..).zip(2..n) {
125 let new_leaf = perm[leaf_idx];
126 let target_leaf = perm[rng.gen_index(leaf_idx)];
128 let target_parent = leaf_parent[target_leaf as usize];
129 let target_side = leaf_side[target_leaf as usize];
130
131 let new_internal = active_internal;
132
133 nodes[new_internal].left = Child::Leaf(target_leaf);
135 nodes[new_internal].right = Child::Leaf(new_leaf);
136 nodes[new_internal].parent = Some(target_parent);
137
138 if target_side == LEFT {
140 nodes[target_parent].left = Child::Internal(new_internal);
141 } else {
142 nodes[target_parent].right = Child::Internal(new_internal);
143 }
144
145 leaf_parent[target_leaf as usize] = new_internal;
147 leaf_side[target_leaf as usize] = LEFT;
148 leaf_parent[new_leaf as usize] = new_internal;
149 leaf_side[new_leaf as usize] = RIGHT;
150 }
151
152 let mut dendro = Dendro {
153 nodes,
154 leaf_parent,
155 leaf_side,
156 adj,
157 total_log_l: 0.0,
158 n,
159 };
160
161 dendro.recompute_all();
162 Ok(dendro)
163 }
164
165 #[allow(clippy::needless_range_loop, clippy::cast_possible_truncation)]
167 fn from_hrg(graph: &Graph, hrg: &HrgTree) -> IgraphResult<Self> {
168 let n = graph.vcount() as usize;
169 if n < 3 {
170 return Err(IgraphError::InvalidArgument(
171 "HRG fit requires at least 3 vertices".into(),
172 ));
173 }
174 if hrg.size() as usize != n {
175 return Err(IgraphError::InvalidArgument(
176 "HRG size does not match graph vertex count".into(),
177 ));
178 }
179
180 let adj = build_adjacency(graph)?;
181 let num_internal = n - 1;
182
183 let mut nodes: Vec<DendroNode> = Vec::with_capacity(num_internal);
184 let mut leaf_parent = vec![0usize; n];
185 let mut leaf_side = vec![LEFT; n];
186
187 for i in 0..num_internal {
188 let lc = hrg.left[i];
189 let rc = hrg.right[i];
190
191 let left = if lc < 0 {
192 #[allow(clippy::cast_sign_loss)]
193 let idx = (-lc - 1) as usize;
194 Child::Internal(idx)
195 } else {
196 #[allow(clippy::cast_sign_loss)]
197 let idx = lc as u32;
198 leaf_parent[idx as usize] = i;
199 leaf_side[idx as usize] = LEFT;
200 Child::Leaf(idx)
201 };
202
203 let right = if rc < 0 {
204 #[allow(clippy::cast_sign_loss)]
205 let idx = (-rc - 1) as usize;
206 Child::Internal(idx)
207 } else {
208 #[allow(clippy::cast_sign_loss)]
209 let idx = rc as u32;
210 leaf_parent[idx as usize] = i;
211 leaf_side[idx as usize] = RIGHT;
212 Child::Leaf(idx)
213 };
214
215 nodes.push(DendroNode {
216 left,
217 right,
218 parent: None,
219 n: hrg.vertices[i] as u32,
220 e: hrg.edges[i] as u32,
221 p: hrg.prob[i],
222 log_l: 0.0,
223 label: 0,
224 });
225 }
226
227 for i in 0..num_internal {
229 match nodes[i].left {
230 Child::Internal(c) => nodes[c].parent = Some(i),
231 Child::Leaf(_) => {}
232 }
233 match nodes[i].right {
234 Child::Internal(c) => nodes[c].parent = Some(i),
235 Child::Leaf(_) => {}
236 }
237 }
238
239 let mut dendro = Dendro {
240 nodes,
241 leaf_parent,
242 leaf_side,
243 adj,
244 total_log_l: 0.0,
245 n,
246 };
247
248 dendro.recompute_all();
249 Ok(dendro)
250 }
251
252 #[allow(clippy::needless_range_loop)]
254 fn recompute_all(&mut self) {
255 let num_internal = self.n - 1;
256
257 for i in (0..num_internal).rev() {
259 let nl = self.subtree_size(self.nodes[i].left);
260 let nr = self.subtree_size(self.nodes[i].right);
261 self.nodes[i].n = nl + nr;
262 }
263
264 for i in (0..num_internal).rev() {
266 let ll = self.min_label(self.nodes[i].left);
267 let rl = self.min_label(self.nodes[i].right);
268 self.nodes[i].label = ll.min(rl);
269 }
270
271 for i in 0..num_internal {
273 let left_leaves = self.collect_leaves(self.nodes[i].left);
274 let right_leaves = self.collect_leaves(self.nodes[i].right);
275 let mut e = 0u32;
276 for &lv in &left_leaves {
277 for &rv in &right_leaves {
278 if self.adj[lv as usize].contains(&rv) {
279 e += 1;
280 }
281 }
282 }
283 self.nodes[i].e = e;
284 }
285
286 self.refresh_likelihood();
288 }
289
290 fn refresh_likelihood(&mut self) {
292 self.total_log_l = 0.0;
293 let num_internal = self.n - 1;
294 for i in 0..num_internal {
295 let nl = self.subtree_size(self.nodes[i].left);
296 let nr = self.subtree_size(self.nodes[i].right);
297 #[allow(clippy::cast_possible_truncation)]
298 let nl_nr = (nl as u64) * (nr as u64);
299 let ei = self.nodes[i].e as u64;
300
301 if nl_nr == 0 {
302 self.nodes[i].p = 0.0;
303 self.nodes[i].log_l = 0.0;
304 } else {
305 #[allow(clippy::cast_precision_loss)]
306 let p = ei as f64 / nl_nr as f64;
307 self.nodes[i].p = p;
308
309 if ei == 0 || ei == nl_nr {
310 self.nodes[i].log_l = 0.0;
311 } else {
312 #[allow(clippy::cast_precision_loss)]
313 let dl = (ei as f64) * p.ln() + ((nl_nr - ei) as f64) * (1.0 - p).ln();
314 self.nodes[i].log_l = dl;
315 }
316 }
317 self.total_log_l += self.nodes[i].log_l;
318 }
319 }
320
321 fn subtree_size(&self, child: Child) -> u32 {
322 match child {
323 Child::Leaf(_) => 1,
324 Child::Internal(idx) => self.nodes[idx].n,
325 }
326 }
327
328 fn min_label(&self, child: Child) -> u32 {
329 match child {
330 Child::Leaf(v) => v,
331 Child::Internal(idx) => self.nodes[idx].label,
332 }
333 }
334
335 fn collect_leaves(&self, child: Child) -> Vec<u32> {
336 let mut result = Vec::new();
337 let mut stack = vec![child];
338 while let Some(c) = stack.pop() {
339 match c {
340 Child::Leaf(v) => result.push(v),
341 Child::Internal(idx) => {
342 stack.push(self.nodes[idx].left);
343 stack.push(self.nodes[idx].right);
344 }
345 }
346 }
347 result
348 }
349
350 fn compute_edge_count(&self, a: Child, b: Child) -> u32 {
352 let leaves_a = self.collect_leaves(a);
353 let leaves_b = self.collect_leaves(b);
354 let mut count = 0u32;
355 if leaves_a.len() <= leaves_b.len() {
357 for &va in &leaves_a {
358 for &vb in &leaves_b {
359 if self.adj[va as usize].contains(&vb) {
360 count += 1;
361 }
362 }
363 }
364 } else {
365 for &vb in &leaves_b {
366 for &va in &leaves_a {
367 if self.adj[va as usize].contains(&vb) {
368 count += 1;
369 }
370 }
371 }
372 }
373 count
374 }
375
376 #[allow(clippy::cast_precision_loss)]
378 fn node_log_likelihood(e: u32, nl_nr: u64) -> f64 {
379 let ei = e as u64;
380 if ei == 0 || ei == nl_nr || nl_nr == 0 {
381 return 0.0;
382 }
383 let p = ei as f64 / nl_nr as f64;
384 (ei as f64) * p.ln() + ((nl_nr - ei) as f64) * (1.0 - p).ln()
385 }
386
387 fn mcmc_move(&mut self, rng: &mut SplitMix64) -> f64 {
389 let num_internal = self.n - 1;
390 if num_internal < 2 {
391 return 0.0;
392 }
393
394 let (x, y) = loop {
396 let idx = rng.gen_index(num_internal);
397 if let Some(p) = self.nodes[idx].parent {
398 break (p, idx);
399 }
400 };
401
402 let side = match self.nodes[x].left {
404 Child::Internal(c) if c == y => LEFT,
405 _ => RIGHT,
406 };
407
408 if side == LEFT {
409 if rng.gen_unit() < 0.5 {
410 self.try_left_alpha(x, y, rng)
411 } else {
412 self.try_left_beta(x, y, rng)
413 }
414 } else if rng.gen_unit() < 0.5 {
415 self.try_right_alpha(x, y, rng)
416 } else {
417 self.try_right_beta(x, y, rng)
418 }
419 }
420
421 #[allow(clippy::many_single_char_names)]
423 fn try_left_alpha(&mut self, x: usize, y: usize, rng: &mut SplitMix64) -> f64 {
424 let i = self.nodes[y].left;
425 let j = self.nodes[y].right;
426 let k = self.nodes[x].right;
427
428 let n_i = self.subtree_size(i);
429 let n_j = self.subtree_size(j);
430 let n_k = self.subtree_size(k);
431
432 let e_y = self.compute_edge_count(i, k);
434 let nl_nr_y = (n_i as u64) * (n_k as u64);
435 let log_l_y = Self::node_log_likelihood(e_y, nl_nr_y);
436
437 let e_x = self.nodes[x].e + self.nodes[y].e - e_y;
439 let nl_nr_x = ((n_i + n_k) as u64) * (n_j as u64);
440 let log_l_x = Self::node_log_likelihood(e_x, nl_nr_x);
441
442 let d_log_l = (log_l_x - self.nodes[x].log_l) + (log_l_y - self.nodes[y].log_l);
443
444 if d_log_l > 0.0 || rng.gen_unit() < d_log_l.exp() {
445 self.nodes[y].right = k;
447 self.nodes[x].right = j;
448 self.update_child_parent(k, y, RIGHT);
449 self.update_child_parent(j, x, RIGHT);
450 self.nodes[y].n = n_i + n_k;
451 self.nodes[y].e = e_y;
452 #[allow(clippy::cast_precision_loss)]
453 {
454 self.nodes[y].p = e_y as f64 / nl_nr_y.max(1) as f64;
455 }
456 self.nodes[y].log_l = log_l_y;
457 self.nodes[x].e = e_x;
458 #[allow(clippy::cast_precision_loss)]
459 {
460 self.nodes[x].p = e_x as f64 / nl_nr_x.max(1) as f64;
461 }
462 self.nodes[x].log_l = log_l_x;
463 self.nodes[x].n = n_i + n_k + n_j;
464 self.total_log_l += d_log_l;
465 d_log_l
466 } else {
467 0.0
468 }
469 }
470
471 #[allow(clippy::many_single_char_names)]
473 fn try_left_beta(&mut self, x: usize, y: usize, rng: &mut SplitMix64) -> f64 {
474 let i = self.nodes[y].left;
475 let j = self.nodes[y].right;
476 let k = self.nodes[x].right;
477
478 let n_i = self.subtree_size(i);
479 let n_j = self.subtree_size(j);
480 let n_k = self.subtree_size(k);
481
482 let e_y = self.compute_edge_count(j, k);
484 let nl_nr_y = (n_j as u64) * (n_k as u64);
485 let log_l_y = Self::node_log_likelihood(e_y, nl_nr_y);
486
487 let e_x = self.nodes[x].e + self.nodes[y].e - e_y;
489 let nl_nr_x = (n_i as u64) * ((n_j + n_k) as u64);
490 let log_l_x = Self::node_log_likelihood(e_x, nl_nr_x);
491
492 let d_log_l = (log_l_x - self.nodes[x].log_l) + (log_l_y - self.nodes[y].log_l);
493
494 if d_log_l > 0.0 || rng.gen_unit() < d_log_l.exp() {
495 self.nodes[y].left = j;
498 self.nodes[y].right = k;
499 self.nodes[x].left = i;
500 self.nodes[x].right = Child::Internal(y);
501 self.update_child_parent(j, y, LEFT);
502 self.update_child_parent(k, y, RIGHT);
503 self.update_child_parent(i, x, LEFT);
504 self.nodes[y].n = n_j + n_k;
507 self.nodes[y].e = e_y;
508 #[allow(clippy::cast_precision_loss)]
509 {
510 self.nodes[y].p = e_y as f64 / nl_nr_y.max(1) as f64;
511 }
512 self.nodes[y].log_l = log_l_y;
513 self.nodes[x].e = e_x;
514 #[allow(clippy::cast_precision_loss)]
515 {
516 self.nodes[x].p = e_x as f64 / nl_nr_x.max(1) as f64;
517 }
518 self.nodes[x].log_l = log_l_x;
519 self.nodes[x].n = n_i + n_j + n_k;
520 self.total_log_l += d_log_l;
521 d_log_l
522 } else {
523 0.0
524 }
525 }
526
527 #[allow(clippy::many_single_char_names)]
529 fn try_right_alpha(&mut self, x: usize, y: usize, rng: &mut SplitMix64) -> f64 {
530 let i = self.nodes[x].left;
531 let j = self.nodes[y].left;
532 let k = self.nodes[y].right;
533
534 let n_i = self.subtree_size(i);
535 let n_j = self.subtree_size(j);
536 let n_k = self.subtree_size(k);
537
538 let e_y = self.compute_edge_count(i, k);
540 let nl_nr_y = (n_i as u64) * (n_k as u64);
541 let log_l_y = Self::node_log_likelihood(e_y, nl_nr_y);
542
543 let e_x = self.nodes[x].e + self.nodes[y].e - e_y;
545 let nl_nr_x = ((n_i + n_k) as u64) * (n_j as u64);
546 let log_l_x = Self::node_log_likelihood(e_x, nl_nr_x);
547
548 let d_log_l = (log_l_x - self.nodes[x].log_l) + (log_l_y - self.nodes[y].log_l);
549
550 if d_log_l > 0.0 || rng.gen_unit() < d_log_l.exp() {
551 self.nodes[y].left = i;
553 self.nodes[y].right = k;
554 self.nodes[x].left = Child::Internal(y);
555 self.nodes[x].right = j;
556 self.update_child_parent(i, y, LEFT);
557 self.update_child_parent(k, y, RIGHT);
558 self.update_child_parent(j, x, RIGHT);
559
560 self.nodes[y].n = n_i + n_k;
561 self.nodes[y].e = e_y;
562 #[allow(clippy::cast_precision_loss)]
563 {
564 self.nodes[y].p = e_y as f64 / nl_nr_y.max(1) as f64;
565 }
566 self.nodes[y].log_l = log_l_y;
567 self.nodes[x].e = e_x;
568 #[allow(clippy::cast_precision_loss)]
569 {
570 self.nodes[x].p = e_x as f64 / nl_nr_x.max(1) as f64;
571 }
572 self.nodes[x].log_l = log_l_x;
573 self.nodes[x].n = n_i + n_k + n_j;
574 self.total_log_l += d_log_l;
575 d_log_l
576 } else {
577 0.0
578 }
579 }
580
581 #[allow(clippy::many_single_char_names)]
583 fn try_right_beta(&mut self, x: usize, y: usize, rng: &mut SplitMix64) -> f64 {
584 let i = self.nodes[x].left;
585 let j = self.nodes[y].left;
586 let k = self.nodes[y].right;
587
588 let n_i = self.subtree_size(i);
589 let n_j = self.subtree_size(j);
590 let n_k = self.subtree_size(k);
591
592 let e_y = self.compute_edge_count(i, j);
594 let nl_nr_y = (n_i as u64) * (n_j as u64);
595 let log_l_y = Self::node_log_likelihood(e_y, nl_nr_y);
596
597 let e_x = self.nodes[x].e + self.nodes[y].e - e_y;
599 let nl_nr_x = ((n_i + n_j) as u64) * (n_k as u64);
600 let log_l_x = Self::node_log_likelihood(e_x, nl_nr_x);
601
602 let d_log_l = (log_l_x - self.nodes[x].log_l) + (log_l_y - self.nodes[y].log_l);
603
604 if d_log_l > 0.0 || rng.gen_unit() < d_log_l.exp() {
605 self.nodes[y].left = i;
607 self.nodes[y].right = j;
608 self.nodes[x].left = Child::Internal(y);
609 self.nodes[x].right = k;
610 self.update_child_parent(i, y, LEFT);
611 self.update_child_parent(j, y, RIGHT);
612 self.update_child_parent(k, x, RIGHT);
613
614 self.nodes[y].n = n_i + n_j;
615 self.nodes[y].e = e_y;
616 #[allow(clippy::cast_precision_loss)]
617 {
618 self.nodes[y].p = e_y as f64 / nl_nr_y.max(1) as f64;
619 }
620 self.nodes[y].log_l = log_l_y;
621 self.nodes[x].e = e_x;
622 #[allow(clippy::cast_precision_loss)]
623 {
624 self.nodes[x].p = e_x as f64 / nl_nr_x.max(1) as f64;
625 }
626 self.nodes[x].log_l = log_l_x;
627 self.nodes[x].n = n_i + n_j + n_k;
628 self.total_log_l += d_log_l;
629 d_log_l
630 } else {
631 0.0
632 }
633 }
634
635 fn update_child_parent(&mut self, child: Child, new_parent: usize, side: u8) {
637 match child {
638 Child::Leaf(v) => {
639 self.leaf_parent[v as usize] = new_parent;
640 self.leaf_side[v as usize] = side;
641 }
642 Child::Internal(idx) => {
643 self.nodes[idx].parent = Some(new_parent);
644 }
645 }
646 }
647
648 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
650 fn export_hrg(&self) -> HrgTree {
651 let n = self.n as u32;
652 let mut hrg = HrgTree::new(n);
653 let num_internal = self.n - 1;
654
655 for i in 0..num_internal {
656 hrg.left[i] = match self.nodes[i].left {
657 Child::Leaf(v) => v as i32,
658 Child::Internal(idx) => -(idx as i32 + 1),
659 };
660 hrg.right[i] = match self.nodes[i].right {
661 Child::Leaf(v) => v as i32,
662 Child::Internal(idx) => -(idx as i32 + 1),
663 };
664 hrg.prob[i] = self.nodes[i].p;
665 hrg.vertices[i] = self.nodes[i].n as i32;
666 hrg.edges[i] = self.nodes[i].e as i32;
667 }
668 hrg
669 }
670
671 #[allow(clippy::needless_range_loop, clippy::cast_possible_truncation)]
673 fn predict_edges(&self, adj_counts: &[Vec<f64>], num_samples: f64) -> Vec<(u32, u32, f64)> {
674 let n = self.n;
675 let mut predictions = Vec::new();
676 for i in 0..n {
677 for j in (i + 1)..n {
678 let iv = i as u32;
679 let jv = j as u32;
680 if !self.adj[i].contains(&jv) {
681 let prob = adj_counts[i][j] / num_samples;
682 predictions.push((iv, jv, prob));
683 }
684 }
685 }
686 predictions.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
687 predictions
688 }
689
690 #[allow(clippy::needless_range_loop, clippy::cast_possible_truncation)]
692 fn accumulate_probabilities(&self, counts: &mut [Vec<f64>]) {
693 for i in 0..self.n {
694 for j in (i + 1)..self.n {
695 let lca = self.find_lca(i as u32, j as u32);
696 counts[i][j] += self.nodes[lca].p;
697 counts[j][i] += self.nodes[lca].p;
698 }
699 }
700 }
701
702 fn find_lca(&self, leaf_a: u32, leaf_b: u32) -> usize {
704 let num_internal = self.n - 1;
705 let mut visited = vec![false; num_internal];
706
707 let mut cur = self.leaf_parent[leaf_a as usize];
709 loop {
710 visited[cur] = true;
711 match self.nodes[cur].parent {
712 Some(p) => cur = p,
713 None => break,
714 }
715 }
716
717 let mut cur = self.leaf_parent[leaf_b as usize];
719 loop {
720 if visited[cur] {
721 return cur;
722 }
723 match self.nodes[cur].parent {
724 Some(p) => cur = p,
725 None => return cur, }
727 }
728 }
729}
730
731#[allow(clippy::cast_possible_truncation)]
733fn build_adjacency(graph: &Graph) -> IgraphResult<Vec<Vec<u32>>> {
734 let n = graph.vcount() as usize;
735 let mut adj: Vec<Vec<u32>> = vec![Vec::new(); n];
736
737 for eid in 0..graph.ecount() {
738 let (from, to) = graph.edge(eid as u32)?;
739 if from == to {
740 continue;
741 }
742 if !adj[from as usize].contains(&to) {
743 adj[from as usize].push(to);
744 }
745 if !adj[to as usize].contains(&from) {
746 adj[to as usize].push(from);
747 }
748 }
749
750 Ok(adj)
751}
752
753pub fn hrg_fit(
782 graph: &Graph,
783 start_hrg: Option<&HrgTree>,
784 steps: u64,
785 seed: u64,
786) -> IgraphResult<HrgTree> {
787 let mut rng = SplitMix64::new(seed);
788
789 let mut dendro = match start_hrg {
790 Some(hrg) => Dendro::from_hrg(graph, hrg)?,
791 None => Dendro::from_graph(graph, &mut rng)?,
792 };
793
794 if steps > 0 {
795 let mut best_l = dendro.total_log_l;
797 let mut best_hrg = dendro.export_hrg();
798
799 for _ in 0..steps {
800 dendro.mcmc_move(&mut rng);
801 if dendro.total_log_l > best_l {
802 best_l = dendro.total_log_l;
803 best_hrg = dendro.export_hrg();
804 }
805 }
806 dendro.refresh_likelihood();
808 if dendro.total_log_l > best_l {
809 best_hrg = dendro.export_hrg();
810 }
811 Ok(best_hrg)
812 } else {
813 mcmc_equilibrium(&mut dendro, &mut rng);
815 Ok(dendro.export_hrg())
816 }
817}
818
819fn mcmc_equilibrium(dendro: &mut Dendro, rng: &mut SplitMix64) {
821 let window = 65536u64;
822 let mut old_mean = f64::NEG_INFINITY;
823
824 loop {
825 let mut sum = 0.0;
826 for _ in 0..window {
827 dendro.mcmc_move(rng);
828 sum += dendro.total_log_l;
829 }
830 dendro.refresh_likelihood();
831
832 #[allow(clippy::cast_precision_loss)]
833 let new_mean = sum / window as f64;
834
835 if (new_mean - old_mean).abs() < 1.0 {
836 break;
837 }
838 old_mean = new_mean;
839 }
840}
841
842#[allow(clippy::cast_precision_loss)]
872pub fn hrg_predict(
873 graph: &Graph,
874 start_hrg: Option<&HrgTree>,
875 num_samples: u64,
876 seed: u64,
877) -> IgraphResult<Vec<(u32, u32, f64)>> {
878 let n = graph.vcount() as usize;
879 let mut rng = SplitMix64::new(seed);
880
881 let mut dendro = match start_hrg {
882 Some(hrg) => Dendro::from_hrg(graph, hrg)?,
883 None => Dendro::from_graph(graph, &mut rng)?,
884 };
885
886 let burn_in = 200 * n;
888 for _ in 0..burn_in {
889 dendro.mcmc_move(&mut rng);
890 }
891 dendro.refresh_likelihood();
892
893 let mut counts: Vec<Vec<f64>> = vec![vec![0.0; n]; n];
895 let sample_interval = 50 * n;
896 let mut samples_taken = 0u64;
897
898 while samples_taken < num_samples {
899 for _ in 0..sample_interval {
900 dendro.mcmc_move(&mut rng);
901 }
902 dendro.accumulate_probabilities(&mut counts);
903 samples_taken += 1;
904 }
905
906 let result = dendro.predict_edges(&counts, num_samples as f64);
907 Ok(result)
908}
909
910#[allow(
935 clippy::cast_possible_truncation,
936 clippy::cast_possible_wrap,
937 clippy::cast_precision_loss
938)]
939pub fn hrg_consensus(
940 graph: &Graph,
941 start_hrg: Option<&HrgTree>,
942 num_samples: u64,
943 seed: u64,
944) -> IgraphResult<(Vec<i32>, Vec<f64>)> {
945 let n = graph.vcount() as usize;
946 let mut rng = SplitMix64::new(seed);
947
948 let mut dendro = match start_hrg {
949 Some(hrg) => Dendro::from_hrg(graph, hrg)?,
950 None => Dendro::from_graph(graph, &mut rng)?,
951 };
952
953 let burn_in = 200 * n;
955 for _ in 0..burn_in {
956 dendro.mcmc_move(&mut rng);
957 }
958 dendro.refresh_likelihood();
959
960 let mut split_counts: std::collections::HashMap<Vec<u32>, u64> =
962 std::collections::HashMap::new();
963
964 let sample_interval = 50 * n;
965 let num_internal = n - 1;
966 let mut samples_taken = 0u64;
967
968 while samples_taken < num_samples {
969 for _ in 0..sample_interval {
970 dendro.mcmc_move(&mut rng);
971 }
972
973 for i in 0..num_internal {
975 let mut left_leaves = dendro.collect_leaves(dendro.nodes[i].left);
976 left_leaves.sort_unstable();
977 *split_counts.entry(left_leaves).or_insert(0) += 1;
978 }
979 samples_taken += 1;
980 }
981
982 let mut splits: Vec<(Vec<u32>, u64)> = split_counts.into_iter().collect();
984 splits.sort_by_key(|b| std::cmp::Reverse(b.1));
985
986 let total_nodes = 2 * n - 1;
988 let mut parents = vec![-1i32; total_nodes];
989 let mut weights = vec![0.0f64; n - 1];
990
991 let root_pos = n;
993 for i in 0..n {
994 parents[i] = root_pos as i32;
995 }
996
997 let mut used_internal = 0usize;
999
1000 for (split, count) in splits.iter().take(n - 1) {
1001 if split.len() < 2 || split.len() >= n {
1002 continue;
1003 }
1004 if used_internal >= n - 1 {
1005 break;
1006 }
1007
1008 let internal_idx = n + used_internal;
1009 weights[used_internal] = *count as f64 / (num_samples * num_internal as u64) as f64;
1010
1011 for &leaf in split {
1013 parents[leaf as usize] = internal_idx as i32;
1014 }
1015
1016 if used_internal > 0 {
1018 parents[internal_idx] = root_pos as i32;
1019 }
1020
1021 used_internal += 1;
1022 }
1023
1024 Ok((parents, weights))
1025}
1026
1027#[cfg(test)]
1028mod tests {
1029 use super::*;
1030
1031 fn make_test_graph() -> Graph {
1032 Graph::from_edges(
1034 &[(0, 1), (1, 2), (2, 3), (3, 4), (4, 0), (0, 2)],
1035 false,
1036 Some(5),
1037 )
1038 .expect("graph creation")
1039 }
1040
1041 #[test]
1042 fn hrg_fit_returns_correct_size() {
1043 let g = make_test_graph();
1044 let hrg = hrg_fit(&g, None, 500, 42).expect("hrg_fit");
1045 assert_eq!(hrg.size(), 5);
1046 assert_eq!(hrg.num_internal(), 4);
1047 }
1048
1049 #[test]
1050 fn hrg_fit_deterministic() {
1051 let g = make_test_graph();
1052 let h1 = hrg_fit(&g, None, 200, 99).expect("hrg_fit");
1053 let h2 = hrg_fit(&g, None, 200, 99).expect("hrg_fit");
1054 for i in 0..h1.num_internal() {
1055 assert_eq!(h1.left[i], h2.left[i]);
1056 assert_eq!(h1.right[i], h2.right[i]);
1057 assert!((h1.prob[i] - h2.prob[i]).abs() < 1e-10);
1058 }
1059 }
1060
1061 #[test]
1062 fn hrg_fit_rejects_small_graph() {
1063 let g = Graph::from_edges(&[(0, 1)], false, Some(2)).expect("graph");
1064 assert!(hrg_fit(&g, None, 100, 0).is_err());
1065 }
1066
1067 #[test]
1068 fn hrg_fit_from_start_hrg() {
1069 let g = make_test_graph();
1070 let h1 = hrg_fit(&g, None, 200, 42).expect("hrg_fit");
1071 let h2 = hrg_fit(&g, Some(&h1), 200, 77).expect("hrg_fit from start");
1072 assert_eq!(h2.size(), 5);
1073 }
1074
1075 #[test]
1076 fn hrg_fit_probs_valid() {
1077 let g = make_test_graph();
1078 let hrg = hrg_fit(&g, None, 500, 42).expect("hrg_fit");
1079 for i in 0..hrg.num_internal() {
1080 assert!(hrg.prob[i] >= 0.0 && hrg.prob[i] <= 1.0);
1081 }
1082 }
1083
1084 #[test]
1085 fn hrg_predict_returns_results() {
1086 let g = make_test_graph();
1087 let preds = hrg_predict(&g, None, 20, 42).expect("hrg_predict");
1088 assert_eq!(preds.len(), 4);
1090 for &(from, to, prob) in &preds {
1091 assert!(from < to);
1092 assert!((0.0..=1.0).contains(&prob));
1093 }
1094 }
1095
1096 #[test]
1097 fn hrg_predict_sorted_by_prob() {
1098 let g = make_test_graph();
1099 let preds = hrg_predict(&g, None, 20, 42).expect("hrg_predict");
1100 for w in preds.windows(2) {
1101 assert!(w[0].2 >= w[1].2);
1102 }
1103 }
1104
1105 #[test]
1106 fn hrg_consensus_returns_valid_tree() {
1107 let g = make_test_graph();
1108 let (parents, weights) = hrg_consensus(&g, None, 20, 42).expect("hrg_consensus");
1109 assert_eq!(parents.len(), 9); assert_eq!(weights.len(), 4); assert!(parents.contains(&-1));
1113 for &w in &weights {
1114 assert!((0.0..=1.0).contains(&w));
1115 }
1116 }
1117
1118 #[test]
1119 fn dendro_from_graph_correct_structure() {
1120 let g = make_test_graph();
1121 let mut rng = SplitMix64::new(42);
1122 let d = Dendro::from_graph(&g, &mut rng).expect("dendro");
1123 assert_eq!(d.n, 5);
1124 assert_eq!(d.nodes.len(), 4); assert_eq!(d.nodes[0].n, 5);
1127 }
1128
1129 #[test]
1130 fn dendro_likelihood_finite() {
1131 let g = make_test_graph();
1132 let mut rng = SplitMix64::new(42);
1133 let d = Dendro::from_graph(&g, &mut rng).expect("dendro");
1134 assert!(d.total_log_l.is_finite());
1135 assert!(d.total_log_l <= 0.0);
1136 }
1137
1138 #[test]
1139 fn dendro_mcmc_move_changes_likelihood() {
1140 let g = make_test_graph();
1141 let mut rng = SplitMix64::new(42);
1142 let mut d = Dendro::from_graph(&g, &mut rng).expect("dendro");
1143 let initial_l = d.total_log_l;
1144 let mut changed = false;
1145 for _ in 0..100 {
1146 d.mcmc_move(&mut rng);
1147 if (d.total_log_l - initial_l).abs() > 1e-15 {
1148 changed = true;
1149 break;
1150 }
1151 }
1152 assert!(changed, "MCMC should change likelihood within 100 moves");
1153 }
1154
1155 #[test]
1156 fn dendro_export_roundtrip() {
1157 let g = make_test_graph();
1158 let mut rng = SplitMix64::new(42);
1159 let d = Dendro::from_graph(&g, &mut rng).expect("dendro");
1160 let hrg = d.export_hrg();
1161 assert_eq!(hrg.size(), 5);
1162 assert_eq!(hrg.num_internal(), 4);
1163 }
1164}