1#![allow(
45 unknown_lints,
46 clippy::cast_possible_truncation,
47 clippy::cast_precision_loss,
48 clippy::cast_sign_loss,
49 clippy::float_cmp,
50 clippy::too_many_arguments,
51 clippy::similar_names,
52 clippy::many_single_char_names,
53 clippy::needless_range_loop
54)]
55
56use crate::core::rng::SplitMix64;
57use crate::core::{Graph, IgraphError, IgraphResult, VertexId};
58
59fn validate(
60 nodes: u32,
61 types: u32,
62 type_dist: Option<&[f64]>,
63 pref_matrix: &[Vec<f64>],
64 directed: bool,
65) -> IgraphResult<()> {
66 if types < 1 {
67 return Err(IgraphError::InvalidArgument(
68 "The number of vertex types must be at least 1.".into(),
69 ));
70 }
71 let k = types as usize;
72 if let Some(td) = type_dist {
73 if td.len() != k {
74 return Err(IgraphError::InvalidArgument(format!(
75 "type_dist length ({}) does not match number of types ({k})",
76 td.len()
77 )));
78 }
79 for (i, &v) in td.iter().enumerate() {
80 if v.is_nan() {
81 return Err(IgraphError::InvalidArgument(format!(
82 "type_dist[{i}] must not be NaN"
83 )));
84 }
85 if v < 0.0 {
86 return Err(IgraphError::InvalidArgument(format!(
87 "type_dist[{i}] = {v} must be non-negative"
88 )));
89 }
90 }
91 }
92 if pref_matrix.len() != k {
93 return Err(IgraphError::InvalidArgument(format!(
94 "preference matrix has {} rows (expected {k})",
95 pref_matrix.len()
96 )));
97 }
98 for (i, row) in pref_matrix.iter().enumerate() {
99 if row.len() != k {
100 return Err(IgraphError::InvalidArgument(format!(
101 "preference matrix row {i} has length {} (expected {k})",
102 row.len()
103 )));
104 }
105 for (j, &p) in row.iter().enumerate() {
106 if p.is_nan() {
107 return Err(IgraphError::InvalidArgument(format!(
108 "preference matrix entry [{i}][{j}] must not be NaN"
109 )));
110 }
111 if !(0.0..=1.0).contains(&p) {
112 return Err(IgraphError::InvalidArgument(format!(
113 "preference matrix entry [{i}][{j}] = {p} must lie in [0, 1]"
114 )));
115 }
116 }
117 }
118 if !directed {
119 for (i, row_i) in pref_matrix.iter().enumerate() {
120 for (j, row_j) in pref_matrix.iter().enumerate().skip(i + 1) {
121 if row_i[j] != row_j[i] {
122 return Err(IgraphError::InvalidArgument(format!(
123 "preference matrix must be symmetric for undirected graphs: \
124 pref[{i}][{j}] = {} but pref[{j}][{i}] = {}",
125 row_i[j], row_j[i],
126 )));
127 }
128 }
129 }
130 }
131 let _ = nodes;
132 Ok(())
133}
134
135fn cumdist_lookup(cumdist: &[f64], u: f64) -> usize {
138 let mut lo = 1usize;
139 let mut hi = cumdist.len();
140 while lo < hi {
141 let mid = lo + (hi - lo) / 2;
142 if cumdist[mid] > u {
143 hi = mid;
144 } else {
145 lo = mid + 1;
146 }
147 }
148 lo.min(cumdist.len() - 1).max(1)
149}
150
151pub fn callaway_traits_game(
199 nodes: u32,
200 types: u32,
201 edges_per_step: u32,
202 type_dist: Option<&[f64]>,
203 pref_matrix: &[Vec<f64>],
204 directed: bool,
205 seed: u64,
206) -> IgraphResult<(Graph, Vec<u32>)> {
207 validate(nodes, types, type_dist, pref_matrix, directed)?;
208
209 let n_types = types as usize;
210 let mut rng = SplitMix64::new(seed);
211
212 let mut cumdist = vec![0.0f64; n_types + 1];
213 if let Some(td) = type_dist {
214 for i in 0..n_types {
215 cumdist[i + 1] = cumdist[i] + td[i];
216 }
217 } else {
218 for i in 0..n_types {
219 cumdist[i + 1] = (i + 1) as f64;
220 }
221 }
222 let maxcum = cumdist[n_types];
223 if maxcum <= 0.0 {
224 return Err(IgraphError::InvalidArgument(
225 "type_dist must contain at least one positive value".into(),
226 ));
227 }
228
229 let mut node_types = vec![0u32; nodes as usize];
230 for v in 0..(nodes as usize) {
231 let u = rng.gen_unit() * maxcum;
232 let pos = cumdist_lookup(&cumdist, u);
233 let t = (pos - 1).min(n_types - 1);
234 node_types[v] = t as u32;
235 }
236
237 let mut edges: Vec<(VertexId, VertexId)> = Vec::new();
238 if edges_per_step > 0 && nodes >= 2 {
239 for i in 1..nodes {
240 let span = (i as usize) + 1;
242 for _ in 0..edges_per_step {
243 let n1 = rng.gen_index(span) as u32;
244 let n2 = rng.gen_index(span) as u32;
245 let t1 = node_types[n1 as usize] as usize;
246 let t2 = node_types[n2 as usize] as usize;
247 let p = pref_matrix[t1][t2];
248 if p > 0.0 && rng.gen_unit() < p {
249 edges.push((n1, n2));
250 }
251 }
252 }
253 }
254
255 let mut g = Graph::new(nodes, directed)?;
256 g.add_edges(edges)?;
257 Ok((g, node_types))
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use std::collections::HashSet as Set;
264
265 fn diag_pref(types: usize, p: f64) -> Vec<Vec<f64>> {
266 (0..types)
267 .map(|i| (0..types).map(|j| if i == j { p } else { 0.0 }).collect())
268 .collect()
269 }
270
271 fn full_pref(types: usize, p: f64) -> Vec<Vec<f64>> {
272 vec![vec![p; types]; types]
273 }
274
275 #[test]
276 fn nodes_zero_returns_empty_graph() {
277 let pref = full_pref(2, 0.5);
278 let (g, types) = callaway_traits_game(0, 2, 5, None, &pref, false, 42).unwrap();
279 assert_eq!(g.vcount(), 0);
280 assert_eq!(g.ecount(), 0);
281 assert_eq!(types.len(), 0);
282 assert!(!g.is_directed());
283 }
284
285 #[test]
286 fn nodes_one_returns_no_edges() {
287 let pref = full_pref(2, 1.0);
288 let (g, types) = callaway_traits_game(1, 2, 10, None, &pref, false, 7).unwrap();
289 assert_eq!(g.vcount(), 1);
290 assert_eq!(g.ecount(), 0);
291 assert_eq!(types.len(), 1);
292 }
293
294 #[test]
295 fn edges_per_step_zero_yields_no_edges() {
296 let pref = full_pref(2, 1.0);
297 let (g, _) = callaway_traits_game(50, 2, 0, None, &pref, false, 11).unwrap();
298 assert_eq!(g.vcount(), 50);
299 assert_eq!(g.ecount(), 0);
300 }
301
302 #[test]
303 fn zero_pref_matrix_yields_no_edges() {
304 let pref = full_pref(3, 0.0);
305 let (g, types) = callaway_traits_game(40, 3, 5, None, &pref, false, 99).unwrap();
306 assert_eq!(g.vcount(), 40);
307 assert_eq!(g.ecount(), 0);
308 assert!(types.iter().all(|&t| t < 3));
309 }
310
311 #[test]
312 fn full_pref_p1_undirected_max_ecount() {
313 let pref = full_pref(2, 1.0);
316 let nodes = 30u32;
317 let eps = 4u32;
318 let (g, _) = callaway_traits_game(nodes, 2, eps, None, &pref, false, 123).unwrap();
319 assert_eq!(g.ecount() as u32, (nodes - 1) * eps);
320 }
321
322 #[test]
323 fn full_pref_p1_directed_max_ecount() {
324 let pref = vec![vec![1.0, 1.0], vec![1.0, 1.0]];
325 let nodes = 25u32;
326 let eps = 3u32;
327 let (g, _) = callaway_traits_game(nodes, 2, eps, None, &pref, true, 456).unwrap();
328 assert_eq!(g.ecount() as u32, (nodes - 1) * eps);
329 assert!(g.is_directed());
330 }
331
332 #[test]
333 fn determinism_same_seed_same_graph() {
334 let pref = full_pref(3, 0.4);
335 let (g1, t1) = callaway_traits_game(50, 3, 5, None, &pref, false, 0xDEAD).unwrap();
336 let (g2, t2) = callaway_traits_game(50, 3, 5, None, &pref, false, 0xDEAD).unwrap();
337 assert_eq!(g1.ecount(), g2.ecount());
338 assert_eq!(t1, t2);
339 for eid in 0..g1.ecount() as u32 {
340 assert_eq!(g1.edge(eid).unwrap(), g2.edge(eid).unwrap());
341 }
342 }
343
344 #[test]
345 fn different_seeds_diverge() {
346 let pref = full_pref(3, 0.4);
347 let (g1, _) = callaway_traits_game(80, 3, 5, None, &pref, false, 1).unwrap();
348 let (g2, _) = callaway_traits_game(80, 3, 5, None, &pref, false, 2).unwrap();
349 let differ = g1.ecount() != g2.ecount() || {
351 let mut e1: Vec<_> = (0..g1.ecount() as u32)
352 .map(|e| g1.edge(e).unwrap())
353 .collect();
354 let mut e2: Vec<_> = (0..g2.ecount() as u32)
355 .map(|e| g2.edge(e).unwrap())
356 .collect();
357 e1.sort_unstable();
358 e2.sort_unstable();
359 e1 != e2
360 };
361 assert!(differ);
362 }
363
364 #[test]
365 fn type_dist_skewed_one_hot() {
366 let pref = full_pref(3, 0.0);
368 let dist = vec![1.0, 0.0, 0.0];
369 let (_, types) = callaway_traits_game(40, 3, 4, Some(&dist), &pref, false, 12).unwrap();
370 assert!(types.iter().all(|&t| t == 0));
371 }
372
373 #[test]
374 fn type_dist_zero_everywhere_errors() {
375 let pref = full_pref(2, 0.5);
376 let dist = vec![0.0, 0.0];
377 let err = callaway_traits_game(10, 2, 3, Some(&dist), &pref, false, 1).unwrap_err();
378 assert!(matches!(err, IgraphError::InvalidArgument(_)));
379 }
380
381 #[test]
382 fn types_zero_errors() {
383 let pref: Vec<Vec<f64>> = Vec::new();
384 let err = callaway_traits_game(10, 0, 3, None, &pref, false, 1).unwrap_err();
385 assert!(matches!(err, IgraphError::InvalidArgument(_)));
386 }
387
388 #[test]
389 fn pref_wrong_shape_errors() {
390 let pref = vec![vec![0.5, 0.5]]; let err = callaway_traits_game(10, 2, 3, None, &pref, false, 1).unwrap_err();
392 assert!(matches!(err, IgraphError::InvalidArgument(_)));
393 }
394
395 #[test]
396 fn pref_row_wrong_length_errors() {
397 let pref = vec![vec![0.5, 0.5], vec![0.5]];
398 let err = callaway_traits_game(10, 2, 3, None, &pref, false, 1).unwrap_err();
399 assert!(matches!(err, IgraphError::InvalidArgument(_)));
400 }
401
402 #[test]
403 fn pref_nan_errors() {
404 let pref = vec![vec![0.5, f64::NAN], vec![f64::NAN, 0.5]];
405 let err = callaway_traits_game(10, 2, 3, None, &pref, false, 1).unwrap_err();
406 assert!(matches!(err, IgraphError::InvalidArgument(_)));
407 }
408
409 #[test]
410 fn pref_out_of_range_errors() {
411 let pref = vec![vec![0.5, 1.5], vec![1.5, 0.5]];
412 let err = callaway_traits_game(10, 2, 3, None, &pref, false, 1).unwrap_err();
413 assert!(matches!(err, IgraphError::InvalidArgument(_)));
414 }
415
416 #[test]
417 fn pref_negative_errors() {
418 let pref = vec![vec![0.5, -0.1], vec![-0.1, 0.5]];
419 let err = callaway_traits_game(10, 2, 3, None, &pref, false, 1).unwrap_err();
420 assert!(matches!(err, IgraphError::InvalidArgument(_)));
421 }
422
423 #[test]
424 fn pref_asymmetric_undirected_errors() {
425 let pref = vec![vec![0.5, 0.2], vec![0.7, 0.5]];
426 let err = callaway_traits_game(10, 2, 3, None, &pref, false, 1).unwrap_err();
427 assert!(matches!(err, IgraphError::InvalidArgument(_)));
428 }
429
430 #[test]
431 fn pref_asymmetric_directed_ok() {
432 let pref = vec![vec![0.5, 0.2], vec![0.7, 0.5]];
433 let (g, _) = callaway_traits_game(20, 2, 3, None, &pref, true, 1).unwrap();
434 assert_eq!(g.vcount(), 20);
435 assert!(g.is_directed());
436 }
437
438 #[test]
439 fn type_dist_nan_errors() {
440 let pref = full_pref(2, 0.5);
441 let dist = vec![1.0, f64::NAN];
442 let err = callaway_traits_game(10, 2, 3, Some(&dist), &pref, false, 1).unwrap_err();
443 assert!(matches!(err, IgraphError::InvalidArgument(_)));
444 }
445
446 #[test]
447 fn type_dist_negative_errors() {
448 let pref = full_pref(2, 0.5);
449 let dist = vec![1.0, -0.1];
450 let err = callaway_traits_game(10, 2, 3, Some(&dist), &pref, false, 1).unwrap_err();
451 assert!(matches!(err, IgraphError::InvalidArgument(_)));
452 }
453
454 #[test]
455 fn type_dist_wrong_length_errors() {
456 let pref = full_pref(3, 0.5);
457 let dist = vec![1.0, 1.0]; let err = callaway_traits_game(10, 3, 3, Some(&dist), &pref, false, 1).unwrap_err();
459 assert!(matches!(err, IgraphError::InvalidArgument(_)));
460 }
461
462 #[test]
463 fn diag_only_endpoints_share_type_when_edge_present() {
464 let pref = diag_pref(3, 1.0);
467 let (g, types) = callaway_traits_game(60, 3, 4, None, &pref, false, 7777).unwrap();
468 for eid in 0..g.ecount() as u32 {
469 let (s, d) = g.edge(eid).unwrap();
470 assert_eq!(types[s as usize], types[d as usize]);
471 }
472 }
473
474 #[test]
475 fn cross_only_endpoints_differ_when_edge_present() {
476 let pref = vec![vec![0.0, 1.0], vec![1.0, 0.0]];
480 let (g, types) = callaway_traits_game(60, 2, 4, None, &pref, false, 8888).unwrap();
481 for eid in 0..g.ecount() as u32 {
482 let (s, d) = g.edge(eid).unwrap();
483 assert_ne!(types[s as usize], types[d as usize]);
484 }
485 }
486
487 #[test]
488 fn type_range_is_subset_of_types() {
489 let pref = full_pref(4, 0.3);
490 let (_, types) = callaway_traits_game(100, 4, 3, None, &pref, false, 314).unwrap();
491 let observed: Set<u32> = types.iter().copied().collect();
492 assert!(observed.iter().all(|&t| t < 4));
493 }
494
495 #[test]
496 fn directed_p1_ecount_matches_formula() {
497 let pref = vec![vec![1.0; 3]; 3];
500 let nodes = 40u32;
501 let eps = 2u32;
502 let (g, _) =
503 callaway_traits_game(nodes, 3, eps, Some(&[1.0, 1.0, 1.0]), &pref, true, 555).unwrap();
504 assert_eq!(g.ecount() as u32, (nodes - 1) * eps);
505 }
506}
507
508#[cfg(all(test, feature = "proptest-harness"))]
509mod proptest_invariants {
510 use super::*;
511 use proptest::prelude::*;
512
513 proptest! {
514 #![proptest_config(ProptestConfig::with_cases(32))]
515
516 #[test]
517 fn ecount_bounded_by_full_accept(
518 nodes in 0u32..=80,
519 types in 1u32..=5,
520 eps in 0u32..=6,
521 seed in any::<u64>(),
522 directed in any::<bool>(),
523 ) {
524 let t = types as usize;
525 let pref = vec![vec![0.5; t]; t];
528 let (g, types_v) = callaway_traits_game(
529 nodes, types, eps, None, &pref, directed, seed,
530 ).unwrap();
531 prop_assert_eq!(g.vcount(), nodes);
532 prop_assert_eq!(types_v.len(), nodes as usize);
533 let max_e = if nodes >= 1 { (nodes - 1) * eps } else { 0 };
535 prop_assert!((g.ecount() as u32) <= max_e);
536 }
537
538 #[test]
539 fn types_in_range(
540 nodes in 0u32..=60,
541 types in 1u32..=5,
542 seed in any::<u64>(),
543 ) {
544 let t = types as usize;
545 let pref = vec![vec![0.0; t]; t];
546 let (_, types_v) = callaway_traits_game(
547 nodes, types, 3, None, &pref, false, seed,
548 ).unwrap();
549 for &x in &types_v {
550 prop_assert!(x < types);
551 }
552 }
553
554 #[test]
555 fn determinism(
556 nodes in 0u32..=60,
557 types in 1u32..=4,
558 eps in 0u32..=4,
559 seed in any::<u64>(),
560 ) {
561 let t = types as usize;
562 let pref = vec![vec![0.3; t]; t];
563 let (g1, t1) = callaway_traits_game(
564 nodes, types, eps, None, &pref, false, seed,
565 ).unwrap();
566 let (g2, t2) = callaway_traits_game(
567 nodes, types, eps, None, &pref, false, seed,
568 ).unwrap();
569 prop_assert_eq!(g1.ecount(), g2.ecount());
570 prop_assert_eq!(t1, t2);
571 }
572
573 #[test]
574 fn p1_full_pref_yields_exact_max_ecount(
575 nodes in 1u32..=50,
576 types in 1u32..=4,
577 eps in 0u32..=5,
578 seed in any::<u64>(),
579 directed in any::<bool>(),
580 ) {
581 let t = types as usize;
582 let pref = vec![vec![1.0; t]; t];
583 let (g, _) = callaway_traits_game(
584 nodes, types, eps, None, &pref, directed, seed,
585 ).unwrap();
586 prop_assert_eq!(g.ecount() as u32, (nodes - 1) * eps);
587 }
588
589 #[test]
590 fn p0_yields_no_edges(
591 nodes in 0u32..=50,
592 types in 1u32..=4,
593 eps in 0u32..=5,
594 seed in any::<u64>(),
595 directed in any::<bool>(),
596 ) {
597 let t = types as usize;
598 let pref = vec![vec![0.0; t]; t];
599 let (g, _) = callaway_traits_game(
600 nodes, types, eps, None, &pref, directed, seed,
601 ).unwrap();
602 prop_assert_eq!(g.ecount(), 0);
603 }
604 }
605}