1#![allow(
22 clippy::cast_precision_loss,
23 clippy::cast_possible_truncation,
24 clippy::cast_sign_loss
25)]
26
27use std::collections::HashMap;
28
29use crate::core::error::{IgraphError, IgraphResult};
30
31use super::reindex_membership::reindex_membership;
32
33#[derive(Debug, Copy, Clone, Eq, PartialEq)]
37pub enum CommunityComparison {
38 VariationOfInformation,
42 NormalizedMutualInformation,
47 SplitJoin,
51 Rand,
55 AdjustedRand,
63}
64
65pub fn compare_communities(
100 comm1: &[u32],
101 comm2: &[u32],
102 method: CommunityComparison,
103) -> IgraphResult<f64> {
104 if comm1.len() != comm2.len() {
105 return Err(IgraphError::InvalidArgument(format!(
106 "community membership vectors have different lengths: {} and {}",
107 comm1.len(),
108 comm2.len(),
109 )));
110 }
111
112 let n = comm1.len();
113
114 if n == 0 {
115 return match method {
116 CommunityComparison::NormalizedMutualInformation => Ok(1.0),
117 CommunityComparison::VariationOfInformation | CommunityComparison::SplitJoin => Ok(0.0),
118 CommunityComparison::Rand | CommunityComparison::AdjustedRand => {
119 Err(IgraphError::InvalidArgument(format!(
120 "Rand indices not defined for zero or one vertices. \
121 Found membership vector of size {n}.",
122 )))
123 }
124 };
125 }
126
127 let c1 = reindex_membership(comm1)?;
129 let c2 = reindex_membership(comm2)?;
130
131 match method {
132 CommunityComparison::VariationOfInformation => {
133 let (h1, h2, mi) = entropy_and_mutual_information(&c1.membership, &c2.membership, n);
134 Ok(h1 + h2 - 2.0 * mi)
135 }
136 CommunityComparison::NormalizedMutualInformation => {
137 let (h1, h2, mi) = entropy_and_mutual_information(&c1.membership, &c2.membership, n);
138 if h1 == 0.0 && h2 == 0.0 {
139 Ok(1.0)
140 } else {
141 Ok(2.0 * mi / (h1 + h2))
142 }
143 }
144 CommunityComparison::SplitJoin => {
145 let (d12, d21) = split_join_distances(&c1.membership, &c2.membership, n);
146 Ok((d12 + d21) as f64)
150 }
151 CommunityComparison::Rand | CommunityComparison::AdjustedRand => {
152 if n < 2 {
153 return Err(IgraphError::InvalidArgument(format!(
154 "Rand indices not defined for zero or one vertices. \
155 Found membership vector of size {n}.",
156 )));
157 }
158 Ok(rand_index(
159 &c1.membership,
160 &c2.membership,
161 n,
162 matches!(method, CommunityComparison::AdjustedRand),
163 ))
164 }
165 }
166}
167
168fn entropy_and_mutual_information(v1: &[u32], v2: &[u32], n: usize) -> (f64, f64, f64) {
172 let k1 = max_plus_one(v1);
173 let k2 = max_plus_one(v2);
174 let n_f = n as f64;
175
176 let mut p1: Vec<f64> = vec![0.0; k1];
178 let mut p2: Vec<f64> = vec![0.0; k2];
179 for &c in v1 {
180 p1[c as usize] += 1.0;
181 }
182 for &c in v2 {
183 p2[c as usize] += 1.0;
184 }
185
186 let mut h1 = 0.0;
190 for x in &mut p1 {
191 *x /= n_f;
192 h1 -= *x * x.ln();
193 }
194 let mut h2 = 0.0;
195 for x in &mut p2 {
196 *x /= n_f;
197 h2 -= *x * x.ln();
198 }
199
200 let log_p1: Vec<f64> = p1.iter().map(|&p| p.ln()).collect();
202 let log_p2: Vec<f64> = p2.iter().map(|&p| p.ln()).collect();
203
204 let mut counts: HashMap<(u32, u32), u32> = HashMap::new();
206 for i in 0..n {
207 *counts.entry((v1[i], v2[i])).or_insert(0) += 1;
208 }
209
210 let mut mut_inf = 0.0;
211 for (&(r, c), &cnt) in &counts {
212 let p = f64::from(cnt) / n_f;
213 mut_inf += p * (p.ln() - log_p1[r as usize] - log_p2[c as usize]);
214 }
215
216 (h1, h2, mut_inf)
217}
218
219pub(crate) fn split_join_distances(v1: &[u32], v2: &[u32], n: usize) -> (u64, u64) {
224 let k1 = max_plus_one(v1);
225 let k2 = max_plus_one(v2);
226
227 let mut counts: HashMap<(u32, u32), u32> = HashMap::new();
228 for i in 0..n {
229 *counts.entry((v1[i], v2[i])).or_insert(0) += 1;
230 }
231
232 let mut row_max: Vec<u32> = vec![0; k1];
233 let mut col_max: Vec<u32> = vec![0; k2];
234 for (&(r, c), &cnt) in &counts {
235 let r_slot = &mut row_max[r as usize];
236 if cnt > *r_slot {
237 *r_slot = cnt;
238 }
239 let c_slot = &mut col_max[c as usize];
240 if cnt > *c_slot {
241 *c_slot = cnt;
242 }
243 }
244
245 let sum_row: u64 = row_max.iter().map(|&x| u64::from(x)).sum();
246 let sum_col: u64 = col_max.iter().map(|&x| u64::from(x)).sum();
247
248 let n_u64 = n as u64;
249 (n_u64 - sum_row, n_u64 - sum_col)
250}
251
252fn rand_index(v1: &[u32], v2: &[u32], n: usize, adjust: bool) -> f64 {
255 let k1 = max_plus_one(v1);
256 let k2 = max_plus_one(v2);
257 let n_f = n as f64;
258
259 let mut counts: HashMap<(u32, u32), u32> = HashMap::new();
260 for i in 0..n {
261 *counts.entry((v1[i], v2[i])).or_insert(0) += 1;
262 }
263
264 let mut row_sums: Vec<f64> = vec![0.0; k1];
265 let mut col_sums: Vec<f64> = vec![0.0; k2];
266 for (&(r, c), &cnt) in &counts {
267 row_sums[r as usize] += f64::from(cnt);
268 col_sums[c as usize] += f64::from(cnt);
269 }
270
271 let mut joint = 0.0;
273 for &cnt in counts.values() {
274 let v = f64::from(cnt);
275 joint += (v / n_f) * (v - 1.0) / (n_f - 1.0);
276 }
277
278 let mut frac_in_1 = 0.0;
279 for &v in &row_sums {
280 frac_in_1 += (v / n_f) * (v - 1.0) / (n_f - 1.0);
281 }
282 let mut frac_in_2 = 0.0;
283 for &v in &col_sums {
284 frac_in_2 += (v / n_f) * (v - 1.0) / (n_f - 1.0);
285 }
286
287 let rand = 1.0 + 2.0 * joint - frac_in_1 - frac_in_2;
289
290 if adjust {
291 let expected = frac_in_1 * frac_in_2 + (1.0 - frac_in_1) * (1.0 - frac_in_2);
292 let denom = 1.0 - expected;
293 if denom == 0.0 {
298 1.0
299 } else {
300 (rand - expected) / denom
301 }
302 } else {
303 rand
304 }
305}
306
307fn max_plus_one(v: &[u32]) -> usize {
310 let m = v.iter().copied().max().unwrap_or(0);
311 (m as usize) + 1
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 fn close(a: f64, b: f64, tol: f64) -> bool {
319 (a - b).abs() < tol
320 }
321
322 #[test]
323 fn err_on_length_mismatch() {
324 let err = compare_communities(&[0, 1], &[0], CommunityComparison::VariationOfInformation)
325 .unwrap_err();
326 match err {
327 IgraphError::InvalidArgument(_) => (),
328 other => panic!("expected InvalidArgument, got {other:?}"),
329 }
330 }
331
332 #[test]
333 fn empty_input_returns_method_defaults() {
334 for (m, expected) in [
335 (CommunityComparison::VariationOfInformation, 0.0),
336 (CommunityComparison::NormalizedMutualInformation, 1.0),
337 (CommunityComparison::SplitJoin, 0.0),
338 ] {
339 let q = compare_communities(&[], &[], m).unwrap();
340 assert!(close(q, expected, 1e-12), "method {m:?} got {q}");
341 }
342 for m in [CommunityComparison::Rand, CommunityComparison::AdjustedRand] {
343 assert!(compare_communities(&[], &[], m).is_err());
344 }
345 }
346
347 #[test]
348 fn identical_partitions_have_nmi_1_and_vi_0() {
349 let v = [0, 0, 1, 1, 2, 2];
350 assert!(close(
351 compare_communities(&v, &v, CommunityComparison::NormalizedMutualInformation).unwrap(),
352 1.0,
353 1e-12,
354 ));
355 assert!(close(
356 compare_communities(&v, &v, CommunityComparison::VariationOfInformation).unwrap(),
357 0.0,
358 1e-12,
359 ));
360 assert!(close(
361 compare_communities(&v, &v, CommunityComparison::Rand).unwrap(),
362 1.0,
363 1e-12,
364 ));
365 assert!(close(
366 compare_communities(&v, &v, CommunityComparison::AdjustedRand).unwrap(),
367 1.0,
368 1e-12,
369 ));
370 assert!(close(
371 compare_communities(&v, &v, CommunityComparison::SplitJoin).unwrap(),
372 0.0,
373 1e-12,
374 ));
375 }
376
377 #[test]
378 fn relabel_invariance() {
379 let a = [0, 0, 1, 1, 2, 2];
382 let b = [7, 7, 3, 3, 9, 9];
383 for m in [
384 CommunityComparison::VariationOfInformation,
385 CommunityComparison::NormalizedMutualInformation,
386 CommunityComparison::SplitJoin,
387 CommunityComparison::Rand,
388 CommunityComparison::AdjustedRand,
389 ] {
390 let q1 = compare_communities(&a, &a, m).unwrap();
391 let q2 = compare_communities(&a, &b, m).unwrap();
392 assert!(close(q1, q2, 1e-12), "method {m:?}: {q1} vs {q2}");
393 }
394 }
395
396 #[test]
397 fn singletons_vs_singletons() {
398 let v: Vec<u32> = (0..6).collect();
399 assert!(close(
400 compare_communities(&v, &v, CommunityComparison::NormalizedMutualInformation).unwrap(),
401 1.0,
402 1e-12,
403 ));
404 let w: Vec<u32> = (0..6).rev().collect();
407 assert!(close(
408 compare_communities(&v, &w, CommunityComparison::Rand).unwrap(),
409 1.0,
410 1e-12,
411 ));
412 }
413
414 #[test]
415 fn one_cluster_each_side_is_nmi_one_per_spec() {
416 let v = [0u32; 5];
418 let w = [9u32; 5];
419 assert!(close(
420 compare_communities(&v, &w, CommunityComparison::NormalizedMutualInformation).unwrap(),
421 1.0,
422 1e-12,
423 ));
424 assert!(close(
425 compare_communities(&v, &w, CommunityComparison::VariationOfInformation).unwrap(),
426 0.0,
427 1e-12,
428 ));
429 assert!(close(
430 compare_communities(&v, &w, CommunityComparison::SplitJoin).unwrap(),
431 0.0,
432 1e-12,
433 ));
434 assert!(close(
436 compare_communities(&v, &w, CommunityComparison::Rand).unwrap(),
437 1.0,
438 1e-12,
439 ));
440 }
441
442 #[test]
443 fn full_disagreement_two_clusters() {
444 let a = [0u32, 0, 1, 1];
447 let b = [0u32, 1, 0, 1];
448 let nmi =
449 compare_communities(&a, &b, CommunityComparison::NormalizedMutualInformation).unwrap();
450 assert!(close(nmi, 0.0, 1e-12), "NMI = {nmi}");
451 let vi = compare_communities(&a, &b, CommunityComparison::VariationOfInformation).unwrap();
453 assert!(close(vi, 2.0 * 2f64.ln(), 1e-12), "VI = {vi}");
454 let sj = compare_communities(&a, &b, CommunityComparison::SplitJoin).unwrap();
456 assert!(close(sj, 4.0, 1e-12), "SJ = {sj}");
457 let rand = compare_communities(&a, &b, CommunityComparison::Rand).unwrap();
463 assert!(close(rand, 1.0 / 3.0, 1e-12), "Rand = {rand}");
464 let ar = compare_communities(&a, &b, CommunityComparison::AdjustedRand).unwrap();
466 assert!(close(ar, -0.5, 1e-12), "AR = {ar}");
467 }
468
469 #[test]
470 fn split_join_is_zero_for_subpartition() {
471 let a = [0u32, 0, 0, 1, 1, 1];
474 let b = [5u32, 5, 6, 7, 7, 8];
475 let r1 = reindex_membership(&a).unwrap();
477 let r2 = reindex_membership(&b).unwrap();
478 let (d12, d21) = split_join_distances(&r1.membership, &r2.membership, a.len());
479 assert_eq!(d12, 2);
482 assert_eq!(d21, 0);
486 let sj = compare_communities(&a, &b, CommunityComparison::SplitJoin).unwrap();
487 assert!(close(sj, 2.0, 1e-12));
488 }
489
490 #[test]
491 fn nmi_is_symmetric() {
492 let a = [0u32, 0, 1, 1, 2, 2, 0, 1];
493 let b = [3u32, 4, 4, 3, 3, 4, 4, 3];
494 let n_ab =
495 compare_communities(&a, &b, CommunityComparison::NormalizedMutualInformation).unwrap();
496 let n_ba =
497 compare_communities(&b, &a, CommunityComparison::NormalizedMutualInformation).unwrap();
498 assert!(close(n_ab, n_ba, 1e-12));
499 }
500
501 #[test]
502 fn rand_requires_at_least_two_vertices() {
503 let v = [0u32];
504 assert!(compare_communities(&v, &v, CommunityComparison::Rand).is_err());
505 assert!(compare_communities(&v, &v, CommunityComparison::AdjustedRand).is_err());
506 }
507
508 #[test]
509 fn variation_of_information_zero_iff_same_partition() {
510 let a = [0u32, 0, 1, 1];
511 let b = [1u32, 1, 0, 0]; let vi = compare_communities(&a, &b, CommunityComparison::VariationOfInformation).unwrap();
513 assert!(close(vi, 0.0, 1e-12));
514 }
515
516 #[cfg(all(test, feature = "proptest-harness"))]
517 mod prop {
518 use super::*;
519 use proptest::prelude::*;
520
521 prop_compose! {
522 fn arb_pair()(
523 n in 2usize..=24,
524 k1 in 1u32..=5,
525 k2 in 1u32..=5,
526 seed in any::<u64>(),
527 ) -> (Vec<u32>, Vec<u32>) {
528 let mut rng: u64 = seed.wrapping_add(0xDEAD_BEEF_C0FF_EE00);
529 let mut step = || -> u32 {
530 rng = rng.wrapping_mul(0x9E37_79B9_7F4A_7C15).wrapping_add(1);
531 (rng >> 32) as u32
532 };
533 let v1: Vec<u32> = (0..n).map(|_| step() % k1).collect();
534 let v2: Vec<u32> = (0..n).map(|_| step() % k2).collect();
535 (v1, v2)
536 }
537 }
538
539 proptest! {
540 #![proptest_config(ProptestConfig { cases: 60, ..ProptestConfig::default() })]
541
542 #[test]
543 fn nmi_in_unit_interval((a, b) in arb_pair()) {
544 let q = compare_communities(
545 &a, &b, CommunityComparison::NormalizedMutualInformation,
546 ).unwrap();
547 prop_assert!((-1e-9..=1.0 + 1e-9).contains(&q), "NMI out of [0,1]: {}", q);
548 }
549
550 #[test]
551 fn vi_non_negative((a, b) in arb_pair()) {
552 let q = compare_communities(
553 &a, &b, CommunityComparison::VariationOfInformation,
554 ).unwrap();
555 prop_assert!(q >= -1e-9, "VI < 0: {}", q);
556 }
557
558 #[test]
559 fn rand_in_unit_interval((a, b) in arb_pair()) {
560 let q = compare_communities(
561 &a, &b, CommunityComparison::Rand,
562 ).unwrap();
563 prop_assert!((-1e-9..=1.0 + 1e-9).contains(&q), "Rand out of [0,1]: {}", q);
564 }
565
566 #[test]
567 fn adjusted_rand_capped_at_one((a, b) in arb_pair()) {
568 let q = compare_communities(
569 &a, &b, CommunityComparison::AdjustedRand,
570 ).unwrap();
571 prop_assert!(q <= 1.0 + 1e-9, "AR > 1: {}", q);
572 }
573
574 #[test]
575 fn measures_are_relabel_invariant((a, b) in arb_pair()) {
576 let bump = |v: &[u32], offset: u32| -> Vec<u32> {
579 v.iter().map(|&x| x.wrapping_add(offset).wrapping_mul(7)).collect()
580 };
581 let a2 = bump(&a, 100);
582 let b2 = bump(&b, 50);
583 for m in [
584 CommunityComparison::VariationOfInformation,
585 CommunityComparison::NormalizedMutualInformation,
586 CommunityComparison::SplitJoin,
587 CommunityComparison::Rand,
588 CommunityComparison::AdjustedRand,
589 ] {
590 let q1 = compare_communities(&a, &b, m).unwrap();
591 let q2 = compare_communities(&a2, &b2, m).unwrap();
592 prop_assert!((q1 - q2).abs() < 1e-9, "method {:?}: {} vs {}", m, q1, q2);
593 }
594 }
595
596 #[test]
597 fn nmi_symmetric((a, b) in arb_pair()) {
598 let ab = compare_communities(
599 &a, &b, CommunityComparison::NormalizedMutualInformation,
600 ).unwrap();
601 let ba = compare_communities(
602 &b, &a, CommunityComparison::NormalizedMutualInformation,
603 ).unwrap();
604 prop_assert!((ab - ba).abs() < 1e-9);
605 }
606
607 #[test]
608 fn identical_partition_is_extremal((a, _b) in arb_pair()) {
609 for (m, expected) in [
610 (CommunityComparison::VariationOfInformation, 0.0_f64),
611 (CommunityComparison::NormalizedMutualInformation, 1.0),
612 (CommunityComparison::SplitJoin, 0.0),
613 (CommunityComparison::Rand, 1.0),
614 (CommunityComparison::AdjustedRand, 1.0),
615 ] {
616 let q = compare_communities(&a, &a, m).unwrap();
617 prop_assert!((q - expected).abs() < 1e-9, "method {:?}: {} vs {}", m, q, expected);
618 }
619 }
620 }
621 }
622}