1#![allow(
50 unknown_lints,
51 clippy::cast_possible_truncation,
52 clippy::cast_precision_loss,
53 clippy::cast_sign_loss,
54 clippy::float_cmp,
55 clippy::too_many_arguments,
56 clippy::similar_names,
57 clippy::manual_midpoint
58)]
59
60use crate::core::rng::SplitMix64;
61use crate::core::{Graph, IgraphError, IgraphResult, VertexId};
62
63#[derive(Clone, Copy)]
66pub(crate) enum PairShape {
67 Rect,
70 RectNoDiag,
73 TriInclDiag,
76 TriExclDiag,
79}
80
81impl PairShape {
82 pub(crate) fn decode(self, idx: u64, fromsize: u32) -> (u32, u32) {
85 let fs = u64::from(fromsize);
86 match self {
87 Self::Rect => {
88 let vto = idx / fs;
89 let vfrom = idx - vto * fs;
90 debug_assert!(vfrom < fs);
91 (vfrom as u32, vto as u32)
92 }
93 Self::RectNoDiag => {
94 let vto = idx / fs;
95 let vfrom = idx - vto * fs;
96 let vto = if vfrom == vto { fs - 1 } else { vto };
97 debug_assert!(vfrom < fs && vto < fs && vfrom != vto);
98 (vfrom as u32, vto as u32)
99 }
100 Self::TriInclDiag => {
101 let idx_f = idx as f64;
105 let vto_f = ((8.0 * idx_f + 1.0).sqrt() - 1.0) / 2.0;
106 let mut vto = vto_f.trunc() as u64;
107 let mut vfrom = idx - vto * (vto + 1) / 2;
108 while vfrom > vto {
109 vto += 1;
110 vfrom = idx - vto * (vto + 1) / 2;
111 }
112 debug_assert!(vfrom <= vto && vto < fs);
113 (vfrom as u32, vto as u32)
114 }
115 Self::TriExclDiag => {
116 let idx_f = idx as f64;
120 let vto_f = ((8.0 * idx_f + 1.0).sqrt() + 1.0) / 2.0;
121 let mut vto = vto_f.trunc() as u64;
122 if vto < 1 {
123 vto = 1;
124 }
125 let mut vfrom = idx - vto * (vto - 1) / 2;
126 while vfrom >= vto {
127 vto += 1;
128 vfrom = idx - vto * (vto - 1) / 2;
129 }
130 debug_assert!(vfrom < vto && vto < fs);
131 (vfrom as u32, vto as u32)
132 }
133 }
134 }
135}
136
137fn pair_shape(directed: bool, loops: bool, on_diagonal: bool) -> PairShape {
140 match (directed, loops, on_diagonal) {
141 (true, false, true) => PairShape::RectNoDiag,
142 (false, true, true) => PairShape::TriInclDiag,
143 (false, false, true) => PairShape::TriExclDiag,
144 _ => PairShape::Rect,
145 }
146}
147
148fn validate(
149 pref_matrix: &[Vec<f64>],
150 block_sizes: &[u32],
151 directed: bool,
152 multiple: bool,
153) -> IgraphResult<()> {
154 let k = pref_matrix.len();
155 if block_sizes.len() != k {
156 return Err(IgraphError::InvalidArgument(format!(
157 "block_sizes length ({}) does not match preference matrix size ({k})",
158 block_sizes.len()
159 )));
160 }
161 for (i, row) in pref_matrix.iter().enumerate() {
162 if row.len() != k {
163 return Err(IgraphError::InvalidArgument(format!(
164 "preference matrix row {i} has length {} (expected {k} for square matrix)",
165 row.len()
166 )));
167 }
168 for (j, &val) in row.iter().enumerate() {
169 if !val.is_finite() {
170 return Err(IgraphError::InvalidArgument(format!(
171 "preference matrix entry [{i}][{j}] must be finite (got {val})"
172 )));
173 }
174 if multiple {
175 if val < 0.0 {
176 return Err(IgraphError::InvalidArgument(format!(
177 "preference matrix entry [{i}][{j}] = {val} must be non-negative \
178 (multigraph SBM uses expected multiplicities)"
179 )));
180 }
181 } else if !(0.0..=1.0).contains(&val) {
182 return Err(IgraphError::InvalidArgument(format!(
183 "preference matrix entry [{i}][{j}] = {val} must lie in [0, 1] \
184 (simple SBM uses connection probabilities)"
185 )));
186 }
187 }
188 }
189 if !directed {
190 for (i, row_i) in pref_matrix.iter().enumerate() {
191 for (j, row_j) in pref_matrix.iter().enumerate().skip(i + 1) {
192 let pij = row_i[j];
193 let pji = row_j[i];
194 if pij != pji {
195 return Err(IgraphError::InvalidArgument(format!(
196 "preference matrix must be symmetric for undirected SBM: \
197 pref[{i}][{j}] = {pij} but pref[{j}][{i}] = {pji}"
198 )));
199 }
200 }
201 }
202 }
203 Ok(())
204}
205
206pub(crate) fn block_offsets(block_sizes: &[u32]) -> IgraphResult<(Vec<u32>, u32)> {
208 let mut offsets: Vec<u32> = Vec::with_capacity(block_sizes.len() + 1);
209 offsets.push(0);
210 let mut acc: u32 = 0;
211 for &s in block_sizes {
212 acc = acc.checked_add(s).ok_or_else(|| {
213 IgraphError::InvalidArgument("sum of block_sizes overflows u32".into())
214 })?;
215 offsets.push(acc);
216 }
217 Ok((offsets, acc))
218}
219
220#[allow(clippy::too_many_arguments)]
224pub(crate) fn sample_pair_with_max(
225 rng: &mut SplitMix64,
226 edges: &mut Vec<(VertexId, VertexId)>,
227 fromsize: u32,
228 fromoff: u32,
229 tooff: u32,
230 shape: PairShape,
231 multiple: bool,
232 prob: f64,
233 maxedges: u64,
234) {
235 if maxedges == 0 || prob <= 0.0 {
236 return;
237 }
238 let prob_step = if multiple { prob / (1.0 + prob) } else { prob };
239 if prob_step <= 0.0 {
240 return;
241 }
242 let max_f = maxedges as f64;
243 let step_extra: f64 = if multiple { 0.0 } else { 1.0 };
246
247 let mut last = rng.gen_geom(prob_step);
248 while last < max_f {
249 let idx = last.trunc() as u64;
250 if idx >= maxedges {
251 break;
252 }
253 let (vfrom, vto) = shape.decode(idx, fromsize);
254 edges.push((fromoff + vfrom, tooff + vto));
255 last += rng.gen_geom(prob_step);
256 last += step_extra;
257 }
258}
259
260pub fn sbm_game(
308 pref_matrix: &[Vec<f64>],
309 block_sizes: &[u32],
310 directed: bool,
311 loops: bool,
312 multiple: bool,
313 seed: u64,
314) -> IgraphResult<Graph> {
315 validate(pref_matrix, block_sizes, directed, multiple)?;
316
317 let no_blocks = block_sizes.len();
318 let (offsets, n) = block_offsets(block_sizes)?;
319
320 if no_blocks == 0 || n == 0 {
321 return Graph::new(n, directed);
322 }
323
324 let mut rng = SplitMix64::new(seed);
325 let mut edges: Vec<(VertexId, VertexId)> = Vec::new();
326
327 for from in 0..no_blocks {
328 let fromsize = block_sizes[from];
329 if fromsize == 0 {
330 continue;
331 }
332 let fromoff = offsets[from];
333 let start = if directed { 0 } else { from };
334 for to in start..no_blocks {
335 let tosize = block_sizes[to];
336 if tosize == 0 {
337 continue;
338 }
339 let tooff = offsets[to];
340 let prob = pref_matrix[from][to];
341 if prob <= 0.0 {
342 continue;
343 }
344
345 let on_diagonal = from == to;
346 let shape = pair_shape(directed, loops, on_diagonal);
347 let maxedges = match shape {
348 PairShape::Rect => u64::from(fromsize) * u64::from(tosize),
349 PairShape::RectNoDiag => {
350 let fs = u64::from(fromsize);
351 fs * fs.saturating_sub(1)
352 }
353 PairShape::TriInclDiag => {
354 let fs = u64::from(fromsize);
355 fs * (fs + 1) / 2
356 }
357 PairShape::TriExclDiag => {
358 let fs = u64::from(fromsize);
359 fs * fs.saturating_sub(1) / 2
360 }
361 };
362
363 sample_pair_with_max(
364 &mut rng, &mut edges, fromsize, fromoff, tooff, shape, multiple, prob, maxedges,
365 );
366 }
367 }
368
369 let mut g = Graph::new(n, directed)?;
370 g.add_edges(edges)?;
371 Ok(g)
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
381 fn decode_rect_basic() {
382 let shape = PairShape::Rect;
383 assert_eq!(shape.decode(0, 3), (0, 0));
384 assert_eq!(shape.decode(1, 3), (1, 0));
385 assert_eq!(shape.decode(2, 3), (2, 0));
386 assert_eq!(shape.decode(3, 3), (0, 1));
387 assert_eq!(shape.decode(8, 3), (2, 2));
388 }
389
390 #[test]
391 fn decode_rect_no_diag_remaps() {
392 let shape = PairShape::RectNoDiag;
395 assert_eq!(shape.decode(0, 3), (0, 2));
397 assert_eq!(shape.decode(1, 3), (1, 0));
399 assert_eq!(shape.decode(4, 3), (1, 2));
401 assert_eq!(shape.decode(5, 3), (2, 1));
408 }
409
410 #[test]
411 fn decode_tri_incl_diag_covers_all_pairs() {
412 let shape = PairShape::TriInclDiag;
415 let mut seen: Vec<(u32, u32)> = (0..10).map(|i| shape.decode(i, 4)).collect();
416 seen.sort_unstable();
417 let mut expected: Vec<(u32, u32)> = (0..4)
418 .flat_map(|to| (0..=to).map(move |from| (from, to)))
419 .collect();
420 expected.sort_unstable();
421 assert_eq!(seen, expected);
422 }
423
424 #[test]
425 fn decode_tri_excl_diag_covers_all_pairs() {
426 let shape = PairShape::TriExclDiag;
427 let mut seen: Vec<(u32, u32)> = (0..6).map(|i| shape.decode(i, 4)).collect();
428 seen.sort_unstable();
429 let mut expected: Vec<(u32, u32)> = (0..4)
430 .flat_map(|to| (0..to).map(move |from| (from, to)))
431 .collect();
432 expected.sort_unstable();
433 assert_eq!(seen, expected);
434 }
435
436 #[test]
439 fn empty_pref_and_blocks_gives_empty_graph() {
440 let g = sbm_game(&[], &[], false, false, false, 0).unwrap();
441 assert_eq!(g.vcount(), 0);
442 assert_eq!(g.ecount(), 0);
443 assert!(!g.is_directed());
444 }
445
446 #[test]
447 fn rejects_non_square_pref() {
448 let pref: Vec<Vec<f64>> = vec![vec![0.1, 0.2], vec![0.2]];
449 let res = sbm_game(&pref, &[3, 3], false, false, false, 0);
450 assert!(matches!(res, Err(IgraphError::InvalidArgument(_))));
451 }
452
453 #[test]
454 fn rejects_mismatched_block_sizes() {
455 let pref = vec![vec![0.1, 0.2], vec![0.2, 0.3]];
456 let res = sbm_game(&pref, &[3, 3, 3], false, false, false, 0);
457 assert!(matches!(res, Err(IgraphError::InvalidArgument(_))));
458 }
459
460 #[test]
461 fn rejects_prob_above_one() {
462 let pref = vec![vec![0.5, 1.2], vec![1.2, 0.5]];
463 let res = sbm_game(&pref, &[3, 3], false, false, false, 0);
464 assert!(matches!(res, Err(IgraphError::InvalidArgument(_))));
465 }
466
467 #[test]
468 fn rejects_negative_prob() {
469 let pref = vec![vec![0.5, -0.01], vec![-0.01, 0.5]];
470 let res = sbm_game(&pref, &[3, 3], false, false, false, 0);
471 assert!(matches!(res, Err(IgraphError::InvalidArgument(_))));
472 }
473
474 #[test]
475 fn rejects_non_finite() {
476 let pref = vec![vec![0.5, f64::NAN], vec![f64::NAN, 0.5]];
477 let res = sbm_game(&pref, &[3, 3], false, false, false, 0);
478 assert!(matches!(res, Err(IgraphError::InvalidArgument(_))));
479 }
480
481 #[test]
482 fn rejects_asymmetric_for_undirected() {
483 let pref = vec![vec![0.1, 0.2], vec![0.3, 0.1]];
484 let res = sbm_game(&pref, &[3, 3], false, false, false, 0);
485 assert!(matches!(res, Err(IgraphError::InvalidArgument(_))));
486 }
487
488 #[test]
489 fn accepts_asymmetric_for_directed() {
490 let pref = vec![vec![0.1, 0.2], vec![0.3, 0.1]];
491 let res = sbm_game(&pref, &[3, 3], true, false, false, 0);
492 assert!(res.is_ok());
493 }
494
495 #[test]
496 fn rejects_multiple_negative() {
497 let pref = vec![vec![0.5, -0.1], vec![-0.1, 0.5]];
498 let res = sbm_game(&pref, &[3, 3], false, false, true, 0);
499 assert!(matches!(res, Err(IgraphError::InvalidArgument(_))));
500 }
501
502 #[test]
503 fn accepts_multiple_above_one() {
504 let pref = vec![vec![0.5, 2.5], vec![2.5, 0.5]];
507 let g = sbm_game(&pref, &[3, 3], false, false, true, 0).unwrap();
508 assert_eq!(g.vcount(), 6);
509 }
510
511 fn block_of(v: u32, block_sizes: &[u32]) -> usize {
514 let mut acc = 0u32;
515 for (i, &s) in block_sizes.iter().enumerate() {
516 acc += s;
517 if v < acc {
518 return i;
519 }
520 }
521 block_sizes.len()
522 }
523
524 #[test]
525 fn single_block_reduces_to_er_undirected() {
526 let pref = vec![vec![0.5]];
529 let g = sbm_game(&pref, &[30], false, false, false, 0xA5A5).unwrap();
530 assert_eq!(g.vcount(), 30);
531 let m = g.ecount();
532 assert!(
533 (100..350).contains(&m),
534 "single-block undirected ecount = {m}"
535 );
536 for e in 0..g.ecount() as u32 {
538 let (u, v) = g.edge(e).unwrap();
539 assert!(u < 30 && v < 30);
540 assert_ne!(u, v, "no self-loops when loops=false");
541 }
542 }
543
544 #[test]
545 fn diagonal_only_pref_gives_per_block_er() {
546 let pref = vec![
548 vec![0.5, 0.0, 0.0],
549 vec![0.0, 0.5, 0.0],
550 vec![0.0, 0.0, 0.5],
551 ];
552 let sizes = [20u32, 20, 20];
553 let g = sbm_game(&pref, &sizes, false, false, false, 0x00C0_FFEE).unwrap();
554 assert_eq!(g.vcount(), 60);
555 let m = g.ecount() as u32;
556 for e in 0..m {
557 let (u, v) = g.edge(e).unwrap();
558 let bu = block_of(u, &sizes);
559 let bv = block_of(v, &sizes);
560 assert_eq!(
561 bu, bv,
562 "edge {u}-{v} crosses blocks ({bu} vs {bv}) under diagonal pref"
563 );
564 }
565 }
566
567 #[test]
568 fn off_diagonal_only_pref_gives_bipartite_blocks() {
569 let pref = vec![vec![0.0, 0.4], vec![0.4, 0.0]];
572 let sizes = [15u32, 15];
573 let g = sbm_game(&pref, &sizes, false, false, false, 0xDEAD_BEEF).unwrap();
574 assert_eq!(g.vcount(), 30);
575 let m = g.ecount() as u32;
576 for e in 0..m {
577 let (u, v) = g.edge(e).unwrap();
578 let bu = block_of(u, &sizes);
579 let bv = block_of(v, &sizes);
580 assert_ne!(bu, bv, "edge {u}-{v} stays inside block {bu}");
581 }
582 }
583
584 #[test]
585 fn no_self_loops_when_loops_false() {
586 let pref = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
588 let sizes = [5u32, 5];
589 let g = sbm_game(&pref, &sizes, false, false, false, 0x42).unwrap();
590 for e in 0..g.ecount() as u32 {
591 let (u, v) = g.edge(e).unwrap();
592 assert_ne!(u, v, "self-loop on edge {e} = ({u}, {v})");
593 }
594 assert_eq!(g.ecount(), 20);
596 }
597
598 #[test]
599 fn self_loops_when_loops_true() {
600 let pref = vec![vec![1.0]];
603 let g = sbm_game(&pref, &[4], false, true, false, 0).unwrap();
604 assert_eq!(g.ecount(), 10);
606 let mut loop_count = 0;
607 for e in 0..g.ecount() as u32 {
608 let (u, v) = g.edge(e).unwrap();
609 if u == v {
610 loop_count += 1;
611 }
612 }
613 assert_eq!(loop_count, 4, "expected one self-loop per vertex");
614 }
615
616 #[test]
617 fn directed_complete_no_loops() {
618 let pref = vec![vec![1.0]];
620 let g = sbm_game(&pref, &[5], true, false, false, 0).unwrap();
621 assert_eq!(g.ecount(), 5 * 4);
622 let mut seen: std::collections::HashSet<(u32, u32)> = std::collections::HashSet::new();
624 for e in 0..g.ecount() as u32 {
625 let (u, v) = g.edge(e).unwrap();
626 assert_ne!(u, v);
627 assert!(seen.insert((u, v)), "duplicate ordered pair ({u}, {v})");
628 }
629 assert_eq!(seen.len(), 5 * 4);
630 }
631
632 #[test]
633 fn directed_asymmetric_pref_only_one_direction() {
634 let pref = vec![vec![0.0, 0.8], vec![0.0, 0.0]];
637 let sizes = [8u32, 8];
638 let g = sbm_game(&pref, &sizes, true, false, false, 0xBEEF).unwrap();
639 for e in 0..g.ecount() as u32 {
640 let (u, v) = g.edge(e).unwrap();
641 assert!(u < 8 && v >= 8, "edge ({u}, {v}) should run 0→1 only");
642 }
643 }
644
645 #[test]
646 fn determinism_same_seed_same_graph() {
647 let pref = vec![vec![0.3, 0.05], vec![0.05, 0.3]];
648 let sizes = [30u32, 30];
649 let g1 = sbm_game(&pref, &sizes, false, false, false, 0xAB_CDEF).unwrap();
650 let g2 = sbm_game(&pref, &sizes, false, false, false, 0xAB_CDEF).unwrap();
651 assert_eq!(g1.ecount(), g2.ecount());
652 let edges1: Vec<_> = (0..g1.ecount() as u32)
653 .map(|e| g1.edge(e).unwrap())
654 .collect();
655 let edges2: Vec<_> = (0..g2.ecount() as u32)
656 .map(|e| g2.edge(e).unwrap())
657 .collect();
658 assert_eq!(edges1, edges2);
659 }
660
661 #[test]
662 fn zero_size_block_does_not_break() {
663 let pref = vec![vec![0.5, 0.5], vec![0.5, 0.5]];
666 let g = sbm_game(&pref, &[5, 0], false, false, false, 0).unwrap();
667 assert_eq!(g.vcount(), 5);
668 }
669
670 #[test]
671 fn zero_probability_block_yields_no_edges() {
672 let pref = vec![vec![0.0, 0.0], vec![0.0, 0.0]];
673 let g = sbm_game(&pref, &[10, 10], false, false, false, 42).unwrap();
674 assert_eq!(g.vcount(), 20);
675 assert_eq!(g.ecount(), 0);
676 }
677}
678
679#[cfg(all(test, feature = "proptest-harness"))]
680mod proptest_invariants {
681 use super::*;
682 use proptest::prelude::*;
683
684 fn uniform_pref(k: usize, p: f64) -> Vec<Vec<f64>> {
686 vec![vec![p; k]; k]
687 }
688
689 fn diag_pref(k: usize, p: f64) -> Vec<Vec<f64>> {
692 let mut m = vec![vec![0.0; k]; k];
693 for (i, row) in m.iter_mut().enumerate() {
694 row[i] = p;
695 }
696 m
697 }
698
699 proptest! {
700 #![proptest_config(ProptestConfig::with_cases(48))]
701
702 #[test]
703 fn vcount_equals_block_sum(
704 sizes in prop::collection::vec(1u32..8, 1usize..5),
705 seed: u64,
706 ) {
707 let pref = uniform_pref(sizes.len(), 0.2);
708 let g = sbm_game(&pref, &sizes, false, false, false, seed).unwrap();
709 let expected: u32 = sizes.iter().sum();
710 prop_assert_eq!(g.vcount(), expected);
711 }
712
713 #[test]
714 fn no_self_loops_when_loops_false_undirected(
715 sizes in prop::collection::vec(1u32..6, 1usize..4),
716 seed: u64,
717 ) {
718 let pref = uniform_pref(sizes.len(), 0.25);
719 let g = sbm_game(&pref, &sizes, false, false, false, seed).unwrap();
720 for e in 0..g.ecount() as u32 {
721 let (u, v) = g.edge(e).unwrap();
722 prop_assert_ne!(u, v);
723 }
724 }
725
726 #[test]
727 fn edges_respect_block_pref_support(
728 sizes in prop::collection::vec(1u32..6, 2usize..4),
729 seed: u64,
730 ) {
731 let pref = diag_pref(sizes.len(), 0.3);
732 let g = sbm_game(&pref, &sizes, false, false, false, seed).unwrap();
733 let mut offsets = vec![0u32];
734 let mut acc = 0u32;
735 for &s in &sizes {
736 acc += s;
737 offsets.push(acc);
738 }
739 let block_of = |v: u32| -> usize {
740 offsets.iter().position(|&o| v < o).map(|i| i - 1).unwrap_or(0)
741 };
742 for e in 0..g.ecount() as u32 {
743 let (u, v) = g.edge(e).unwrap();
744 prop_assert_eq!(block_of(u), block_of(v));
745 }
746 }
747 }
748}