1#![allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
59
60use crate::core::rng::SplitMix64;
61use crate::core::{Graph, IgraphError, IgraphResult, VertexId};
62
63#[derive(Debug, Clone, Copy)]
65pub struct DotProductWarnings {
66 pub had_negative: bool,
69 pub had_over_one: bool,
72}
73
74fn validate_vecs(vecs: &[Vec<f64>]) -> IgraphResult<(usize, usize)> {
75 let n = vecs.len();
76 if n == 0 {
77 return Ok((0, 0));
78 }
79 let d = vecs[0].len();
80 for (i, v) in vecs.iter().enumerate() {
81 if v.len() != d {
82 return Err(IgraphError::InvalidArgument(format!(
83 "dot_product_game vecs[{i}] has length {} but vecs[0] has length {d}; \
84 every latent position vector must have the same dimension",
85 v.len()
86 )));
87 }
88 for (k, &x) in v.iter().enumerate() {
89 if !x.is_finite() {
90 return Err(IgraphError::InvalidArgument(format!(
91 "dot_product_game vecs[{i}][{k}] = {x} is not finite; \
92 NaN/±∞ entries are rejected so the inner-product clamp is well-defined"
93 )));
94 }
95 }
96 }
97 Ok((n, d))
98}
99
100#[inline]
101fn dot(a: &[f64], b: &[f64]) -> f64 {
102 let mut acc = 0.0_f64;
104 for k in 0..a.len() {
105 acc += a[k] * b[k];
106 }
107 acc
108}
109
110pub fn dot_product_game_with_warnings(
156 vecs: &[Vec<f64>],
157 directed: bool,
158 seed: u64,
159) -> IgraphResult<(Graph, DotProductWarnings)> {
160 let (n, _d) = validate_vecs(vecs)?;
161 let n_u32 = u32::try_from(n).map_err(|_| {
162 IgraphError::InvalidArgument(format!(
163 "dot_product_game vertex count {n} exceeds u32::MAX"
164 ))
165 })?;
166 if n == 0 {
167 return Ok((
168 Graph::new(0, directed)?,
169 DotProductWarnings {
170 had_negative: false,
171 had_over_one: false,
172 },
173 ));
174 }
175
176 let mut rng = SplitMix64::new(seed);
177 let mut edges: Vec<(VertexId, VertexId)> = Vec::new();
178 let mut had_negative = false;
179 let mut had_over_one = false;
180
181 for i in 0..n {
182 let i_id = i as VertexId;
183 let j_start = if directed { 0 } else { i + 1 };
184 for j in j_start..n {
185 if i == j {
186 continue;
187 }
188 let prob = dot(&vecs[i], &vecs[j]);
189 let j_id = j as VertexId;
190 if prob > 1.0 {
191 had_over_one = true;
192 edges.push((i_id, j_id));
193 } else if prob < 0.0 {
194 had_negative = true;
195 } else if rng.gen_unit() < prob {
197 edges.push((i_id, j_id));
198 }
199 }
200 }
201
202 let mut g = Graph::new(n_u32, directed)?;
203 g.add_edges(edges)?;
204 Ok((
205 g,
206 DotProductWarnings {
207 had_negative,
208 had_over_one,
209 },
210 ))
211}
212
213pub fn dot_product_game(vecs: &[Vec<f64>], directed: bool, seed: u64) -> IgraphResult<Graph> {
238 dot_product_game_with_warnings(vecs, directed, seed).map(|(g, _)| g)
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 fn has_self_loop(g: &Graph) -> bool {
246 for e in 0..g.ecount() {
247 let (u, v) = g.edge(e as u32).unwrap();
248 if u == v {
249 return true;
250 }
251 }
252 false
253 }
254
255 fn is_simple_undirected(g: &Graph) -> bool {
256 assert!(!g.is_directed());
257 let mut seen: std::collections::HashSet<(VertexId, VertexId)> =
258 std::collections::HashSet::new();
259 for e in 0..g.ecount() {
260 let (u, v) = g.edge(e as u32).unwrap();
261 let key = if u <= v { (u, v) } else { (v, u) };
262 if !seen.insert(key) {
263 return false;
264 }
265 }
266 true
267 }
268
269 fn is_simple_directed(g: &Graph) -> bool {
270 assert!(g.is_directed());
271 let mut seen: std::collections::HashSet<(VertexId, VertexId)> =
272 std::collections::HashSet::new();
273 for e in 0..g.ecount() {
274 let (u, v) = g.edge(e as u32).unwrap();
275 if !seen.insert((u, v)) {
276 return false;
277 }
278 }
279 true
280 }
281
282 #[test]
283 fn empty_vecs_produces_empty_graph() {
284 let vecs: Vec<Vec<f64>> = Vec::new();
285 let g = dot_product_game(&vecs, false, 0).unwrap();
286 assert_eq!(g.vcount(), 0);
287 assert_eq!(g.ecount(), 0);
288 assert!(!g.is_directed());
289
290 let g_dir = dot_product_game(&vecs, true, 0).unwrap();
291 assert!(g_dir.is_directed());
292 assert_eq!(g_dir.vcount(), 0);
293 }
294
295 #[test]
296 fn single_vertex_no_edges() {
297 let vecs = vec![vec![0.5, 0.5]];
298 let g = dot_product_game(&vecs, false, 7).unwrap();
299 assert_eq!(g.vcount(), 1);
300 assert_eq!(g.ecount(), 0);
301 }
302
303 #[test]
304 fn all_zero_probs_gives_empty_edges() {
305 let vecs = vec![vec![0.0; 3]; 5];
307 let g = dot_product_game(&vecs, false, 99).unwrap();
308 assert_eq!(g.vcount(), 5);
309 assert_eq!(g.ecount(), 0);
310 }
311
312 #[test]
313 fn unit_probs_gives_complete_graph_undirected() {
314 let n = 6u32;
318 let vecs = vec![vec![1.0]; n as usize];
319 let g = dot_product_game(&vecs, false, 31).unwrap();
320 assert_eq!(g.vcount(), n);
321 assert_eq!(g.ecount(), (n as usize) * ((n as usize) - 1) / 2);
322 assert!(!has_self_loop(&g));
323 assert!(is_simple_undirected(&g));
324 }
325
326 #[test]
327 fn unit_probs_gives_complete_graph_directed() {
328 let n = 5u32;
329 let vecs = vec![vec![1.0]; n as usize];
330 let g = dot_product_game(&vecs, true, 31).unwrap();
331 assert_eq!(g.ecount(), (n as usize) * ((n as usize) - 1));
332 assert!(!has_self_loop(&g));
333 assert!(is_simple_directed(&g));
334 }
335
336 #[test]
337 fn over_one_short_circuit_adds_edge_no_warn_negative() {
338 let vecs = vec![vec![1.5]; 4];
341 let (g, warn) = dot_product_game_with_warnings(&vecs, false, 0).unwrap();
342 assert_eq!(g.ecount(), 6);
343 assert!(warn.had_over_one);
344 assert!(!warn.had_negative);
345 }
346
347 #[test]
348 fn negative_dot_skips_and_warns() {
349 let vecs = vec![vec![1.0], vec![1.0], vec![-0.5], vec![-0.5]];
355 let (_, warn) = dot_product_game_with_warnings(&vecs, false, 11).unwrap();
356 assert!(warn.had_negative);
357 }
358
359 #[test]
360 fn directed_matrix_need_not_be_symmetric() {
361 let vecs = vec![vec![0.5]; 8];
365 let g = dot_product_game(&vecs, true, 12345).unwrap();
366 let n = vecs.len();
367 assert!(g.is_directed());
368 assert!(g.ecount() <= n * (n - 1));
369 assert!(!has_self_loop(&g));
370 }
371
372 #[test]
373 fn determinism_same_seed_same_graph() {
374 let vecs = vec![
375 vec![0.7, 0.2],
376 vec![0.3, 0.4],
377 vec![0.1, 0.5],
378 vec![0.6, 0.6],
379 ];
380 let g1 = dot_product_game(&vecs, false, 0xDEAD_BEEF).unwrap();
381 let g2 = dot_product_game(&vecs, false, 0xDEAD_BEEF).unwrap();
382 assert_eq!(g1.ecount(), g2.ecount());
383 for e in 0..g1.ecount() {
384 assert_eq!(g1.edge(e as u32).unwrap(), g2.edge(e as u32).unwrap());
385 }
386 }
387
388 #[test]
389 fn determinism_different_seed_likely_differs() {
390 let vecs = vec![
393 vec![0.5, 0.4],
394 vec![0.3, 0.5],
395 vec![0.4, 0.3],
396 vec![0.6, 0.2],
397 vec![0.2, 0.6],
398 vec![0.5, 0.5],
399 vec![0.3, 0.3],
400 vec![0.4, 0.4],
401 ];
402 let g1 = dot_product_game(&vecs, false, 1).unwrap();
403 let g2 = dot_product_game(&vecs, false, 2).unwrap();
404 let edges_of = |g: &Graph| {
405 let mut v: Vec<(VertexId, VertexId)> =
406 (0..g.ecount()).map(|e| g.edge(e as u32).unwrap()).collect();
407 v.sort_unstable();
408 v
409 };
410 assert_ne!(edges_of(&g1), edges_of(&g2));
411 }
412
413 #[test]
414 fn mismatched_dim_errors() {
415 let vecs = vec![vec![0.1, 0.2], vec![0.3]];
416 let err = dot_product_game(&vecs, false, 0).unwrap_err();
417 match err {
418 IgraphError::InvalidArgument(msg) => assert!(msg.contains("dimension")),
419 other => panic!("expected InvalidArgument, got {other:?}"),
420 }
421 }
422
423 #[test]
424 fn nan_in_vec_errors() {
425 let vecs = vec![vec![0.1, f64::NAN], vec![0.2, 0.3]];
426 let err = dot_product_game(&vecs, false, 0).unwrap_err();
427 match err {
428 IgraphError::InvalidArgument(msg) => assert!(msg.contains("finite")),
429 other => panic!("expected InvalidArgument, got {other:?}"),
430 }
431 }
432
433 #[test]
434 fn inf_in_vec_errors() {
435 let vecs = vec![vec![f64::INFINITY], vec![0.5]];
436 assert!(dot_product_game(&vecs, false, 0).is_err());
437 }
438
439 #[test]
440 fn zero_dimension_yields_zero_dot_products() {
441 let vecs = vec![Vec::<f64>::new(); 4];
443 let g = dot_product_game(&vecs, false, 0).unwrap();
444 assert_eq!(g.vcount(), 4);
445 assert_eq!(g.ecount(), 0);
446 }
447}
448
449#[cfg(all(test, feature = "proptest-harness"))]
450mod proptests {
451 use super::*;
452 use proptest::prelude::*;
453
454 fn vecs_strategy() -> impl Strategy<Value = Vec<Vec<f64>>> {
455 (1usize..=4).prop_flat_map(|d| {
458 prop::collection::vec(prop::collection::vec(-0.5f64..1.5, d..=d), 0usize..=8)
459 })
460 }
461
462 proptest! {
463 #[test]
464 fn never_self_loop(
465 vecs in vecs_strategy(),
466 directed in any::<bool>(),
467 seed in any::<u64>(),
468 ) {
469 let g = dot_product_game(&vecs, directed, seed).unwrap();
470 for e in 0..g.ecount() {
471 let (u, v) = g.edge(e as u32).unwrap();
472 prop_assert_ne!(u, v);
473 }
474 }
475
476 #[test]
477 fn always_simple(
478 vecs in vecs_strategy(),
479 directed in any::<bool>(),
480 seed in any::<u64>(),
481 ) {
482 let g = dot_product_game(&vecs, directed, seed).unwrap();
483 let mut seen: std::collections::HashSet<(VertexId, VertexId)> =
484 std::collections::HashSet::new();
485 for e in 0..g.ecount() {
486 let (u, v) = g.edge(e as u32).unwrap();
487 let key = if directed {
488 (u, v)
489 } else if u <= v {
490 (u, v)
491 } else {
492 (v, u)
493 };
494 prop_assert!(seen.insert(key));
495 }
496 }
497
498 #[test]
499 fn vcount_matches_input(
500 vecs in vecs_strategy(),
501 directed in any::<bool>(),
502 seed in any::<u64>(),
503 ) {
504 let g = dot_product_game(&vecs, directed, seed).unwrap();
505 prop_assert_eq!(g.vcount() as usize, vecs.len());
506 prop_assert_eq!(g.is_directed(), directed);
507 }
508
509 #[test]
510 fn edge_count_within_bounds(
511 vecs in vecs_strategy(),
512 directed in any::<bool>(),
513 seed in any::<u64>(),
514 ) {
515 let g = dot_product_game(&vecs, directed, seed).unwrap();
516 let n = vecs.len();
517 let bound = if directed {
518 n.saturating_mul(n.saturating_sub(1))
519 } else {
520 n.saturating_mul(n.saturating_sub(1)) / 2
521 };
522 prop_assert!(g.ecount() <= bound);
523 }
524
525 #[test]
526 fn determinism(
527 vecs in vecs_strategy(),
528 directed in any::<bool>(),
529 seed in any::<u64>(),
530 ) {
531 let g1 = dot_product_game(&vecs, directed, seed).unwrap();
532 let g2 = dot_product_game(&vecs, directed, seed).unwrap();
533 prop_assert_eq!(g1.ecount(), g2.ecount());
534 for e in 0..g1.ecount() {
535 prop_assert_eq!(
536 g1.edge(e as u32).unwrap(),
537 g2.edge(e as u32).unwrap()
538 );
539 }
540 }
541 }
542}