Skip to main content

rust_igraph/algorithms/motifs/
triad_census.rs

1//! Triad census (ALGO-MO-002).
2//!
3//! Classifies all vertex triples in a directed graph into the 16 isomorphism
4//! classes defined by Davis and Leinhardt (1972).
5//! Counterpart of `igraph_triad_census`.
6
7use crate::core::{Graph, IgraphError, IgraphResult};
8
9/// The 16 triad types in Davis-Leinhardt MAN notation.
10///
11/// Each variant represents one of the 16 isomorphism classes of directed triads.
12/// The name encodes: number of Mutual, Asymmetric, and Null dyads, plus a
13/// letter suffix for orientation (D=down, U=up, C=chain, T=transitive).
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum TriadType {
16    /// 003: A, B, C (empty graph on 3 vertices)
17    T003 = 0,
18    /// 012: A->B, C
19    T012 = 1,
20    /// 102: A<->B, C
21    T102 = 2,
22    /// 021D: A<-B->C
23    T021D = 3,
24    /// 021U: A->B<-C
25    T021U = 4,
26    /// 021C: A->B->C
27    T021C = 5,
28    /// 111D: A<->B<-C
29    T111D = 6,
30    /// 111U: A<->B->C
31    T111U = 7,
32    /// 030T: A->B<-C, A->C
33    T030T = 8,
34    /// 030C: A<-B<-C, A->C
35    T030C = 9,
36    /// 201: A<->B<->C
37    T201 = 10,
38    /// 120D: A<-B->C, A<->C
39    T120D = 11,
40    /// 120U: A->B<-C, A<->C
41    T120U = 12,
42    /// 120C: A->B->C, A<->C
43    T120C = 13,
44    /// 210: A->B<->C, A<->C
45    T210 = 14,
46    /// 300: A<->B<->C, A<->C (complete graph)
47    T300 = 15,
48}
49
50/// Result of a triad census.
51#[derive(Debug, Clone, PartialEq)]
52pub struct TriadCensus {
53    /// Counts for each of the 16 triad types, indexed by `TriadType` ordinal.
54    pub counts: [f64; 16],
55}
56
57impl TriadCensus {
58    /// Get the count for a specific triad type.
59    ///
60    /// # Examples
61    ///
62    /// ```
63    /// use rust_igraph::{Graph, triad_census, TriadType};
64    ///
65    /// let mut g = Graph::new(3, true).unwrap();
66    /// g.add_edge(0, 1).unwrap();
67    /// g.add_edge(1, 2).unwrap();
68    /// g.add_edge(2, 0).unwrap();
69    /// let tc = triad_census(&g).unwrap();
70    /// assert_eq!(tc.get(TriadType::T030C), 1.0);
71    /// ```
72    pub fn get(&self, triad_type: TriadType) -> f64 {
73        self.counts[triad_type as usize]
74    }
75}
76
77/// Performs a triad census on a directed graph.
78///
79/// Classifies all `n*(n-1)*(n-2)/6` vertex triples into the 16
80/// Davis-Leinhardt triad types. Returns counts as `f64` to avoid overflow
81/// for large graphs.
82///
83/// For undirected graphs, edges are treated as mutual connections.
84///
85/// # Examples
86///
87/// ```
88/// use rust_igraph::{Graph, triad_census, TriadType};
89///
90/// // Complete directed graph on 3 vertices: all triads are type 300
91/// let mut g = Graph::new(3, true).unwrap();
92/// for i in 0..3u32 {
93///     for j in 0..3u32 {
94///         if i != j { g.add_edge(i, j).unwrap(); }
95///     }
96/// }
97/// let tc = triad_census(&g).unwrap();
98/// assert!((tc.get(TriadType::T300) - 1.0).abs() < 1e-10);
99///
100/// // Directed 3-cycle: 0->1->2->0 — all asymmetric, one 030C triad
101/// let mut g = Graph::new(3, true).unwrap();
102/// g.add_edge(0, 1).unwrap();
103/// g.add_edge(1, 2).unwrap();
104/// g.add_edge(2, 0).unwrap();
105/// let tc = triad_census(&g).unwrap();
106/// assert!((tc.get(TriadType::T030C) - 1.0).abs() < 1e-10);
107/// ```
108pub fn triad_census(graph: &Graph) -> IgraphResult<TriadCensus> {
109    let n = graph.vcount();
110
111    if n < 3 {
112        return Ok(TriadCensus { counts: [0.0; 16] });
113    }
114
115    let adj = build_dyad_matrix(graph)?;
116    let mut counts = [0.0_f64; 16];
117
118    for i in 0..n {
119        for j in (i + 1)..n {
120            for k in (j + 1)..n {
121                let ab = adj[(i as usize) * (n as usize) + (j as usize)];
122                let ac = adj[(i as usize) * (n as usize) + (k as usize)];
123                let bc = adj[(j as usize) * (n as usize) + (k as usize)];
124                let idx = lookup_triad_type(ab, ac, bc);
125                counts[idx] += 1.0;
126            }
127        }
128    }
129
130    Ok(TriadCensus { counts })
131}
132
133/// Dyad code encoding: for each ordered pair (u, v):
134/// - 0 = no edge
135/// - 1 = u->v only
136/// - 2 = v->u only
137/// - 3 = mutual (both directions)
138fn build_dyad_matrix(graph: &Graph) -> IgraphResult<Vec<u8>> {
139    let n = graph.vcount();
140    let size = (n as usize)
141        .checked_mul(n as usize)
142        .ok_or_else(|| IgraphError::InvalidArgument("graph too large for triad census".into()))?;
143    let mut matrix = vec![0u8; size];
144    let nn = n as usize;
145    let ecount = graph.ecount();
146
147    for eid in 0..ecount {
148        #[allow(clippy::cast_possible_truncation)]
149        let (src, tgt) = graph.edge(eid as u32)?;
150        if src == tgt {
151            continue;
152        }
153        let idx_st = (src as usize) * nn + (tgt as usize);
154        let idx_ts = (tgt as usize) * nn + (src as usize);
155        matrix[idx_st] |= 1;
156        matrix[idx_ts] |= 2;
157    }
158
159    if !graph.is_directed() {
160        // Undirected: every edge is mutual
161        for cell in &mut matrix {
162            if *cell != 0 {
163                *cell = 3;
164            }
165        }
166    }
167
168    Ok(matrix)
169}
170
171/// Lookup table mapping three dyad codes to the triad type index (0..15).
172///
173/// Each dyad code is 0-3. We encode the triple as a single index
174/// into a precomputed table. The table is symmetric under vertex permutation:
175/// we canonicalize the triple by sorting the MAN counts.
176fn lookup_triad_type(ab: u8, ac: u8, bc: u8) -> usize {
177    // Count mutual (M), asymmetric (A), null (N) dyads
178    let mut m = 0u8;
179    let mut a = 0u8;
180    let mut n_count = 0u8;
181
182    for &d in &[ab, ac, bc] {
183        match d {
184            0 => n_count += 1,
185            3 => m += 1,
186            _ => a += 1,
187        }
188    }
189
190    match (m, a, n_count) {
191        (0, 1, 2) => 1, // 012
192        (1, 0, 2) => 2, // 102
193        (0, 2, 1) => classify_021(ab, ac, bc),
194        (1, 1, 1) => classify_111(ab, ac, bc),
195        (0, 3, 0) => classify_030(ab, ac, bc),
196        (2, 0, 1) => 10, // 201
197        (1, 2, 0) => classify_120(ab, ac, bc),
198        (2, 1, 0) => 14, // 210
199        (3, 0, 0) => 15, // 300
200        _ => 0,          // 003 and any impossible combos
201    }
202}
203
204/// 021 subtypes: two asymmetric dyads, one null.
205/// 021D (3): out-star from center
206/// 021U (4): in-star to center
207/// 021C (5): directed chain
208fn classify_021(ab: u8, ac: u8, bc: u8) -> usize {
209    // Find the null dyad to identify the two non-adjacent vertices.
210    // The shared vertex (center) connects to both others asymmetrically.
211    let (from_center_1, from_center_2) = if bc == 0 {
212        // Center = A (vertex i), connects to B and C
213        (ab, ac)
214    } else if ac == 0 {
215        // Center = B (vertex j), connects to A and C
216        (flip_dyad(ab), bc)
217    } else {
218        // ab == 0: Center = C (vertex k), connects to A and B
219        (flip_dyad(ac), flip_dyad(bc))
220    };
221
222    // from_center: 1 = center->other, 2 = other->center
223    match (from_center_1, from_center_2) {
224        (1, 1) => 3, // 021D: center->both (out-star)
225        (2, 2) => 4, // 021U: both->center (in-star)
226        _ => 5,      // 021C: chain
227    }
228}
229
230/// 111 subtypes: one mutual, one asymmetric, one null.
231/// 111D (6): third->mutual_vertex
232/// 111U (7): mutual_vertex->third
233fn classify_111(ab: u8, ac: u8, bc: u8) -> usize {
234    // The asymmetric dyad connects the mutual pair to the third vertex.
235    // Direction of the asymmetric edge from the mutual-pair-vertex determines type.
236    let asym_from_mutual_vertex = if ab == 3 {
237        // Mutual: A-B. Third = C. Asymmetric connects mutual pair to C.
238        // ac: from A to C (1=A->C, 2=C->A). bc: from B to C (1=B->C, 2=C->B).
239        if ac != 0 { ac } else { bc }
240    } else if ac == 3 {
241        // Mutual: A-C. Third = B. Asymmetric connects mutual pair to B.
242        // ab: from A to B (1=A->B, 2=B->A). bc: from B to C (need C's view to B = flip).
243        if ab != 0 { ab } else { flip_dyad(bc) }
244    } else {
245        // Mutual: B-C. Third = A. Asymmetric connects mutual pair to A.
246        // ab: from A to B (need B's view to A = flip). ac: from A to C (need C's view to A = flip).
247        if ab != 0 {
248            flip_dyad(ab)
249        } else {
250            flip_dyad(ac)
251        }
252    };
253
254    // 1 = mutual_vertex->third (111U), 2 = third->mutual_vertex (111D)
255    if asym_from_mutual_vertex == 1 {
256        7 // 111U
257    } else {
258        6 // 111D
259    }
260}
261
262/// 030 subtypes: three asymmetric dyads.
263/// 030T (8): transitive (one vertex has out-degree 2)
264/// 030C (9): cyclic (each vertex has out-degree 1)
265fn classify_030(ab: u8, ac: u8, bc: u8) -> usize {
266    // Count out-degree for each vertex
267    let mut out_a = 0u8;
268    let mut out_b = 0u8;
269    let mut out_c = 0u8;
270
271    if ab == 1 {
272        out_a += 1;
273    } else {
274        out_b += 1;
275    }
276    if ac == 1 {
277        out_a += 1;
278    } else {
279        out_c += 1;
280    }
281    if bc == 1 {
282        out_b += 1;
283    } else {
284        out_c += 1;
285    }
286
287    if out_a == 2 || out_b == 2 || out_c == 2 {
288        8 // 030T: transitive
289    } else {
290        9 // 030C: cyclic
291    }
292}
293
294/// 120 subtypes: one mutual, two asymmetric, zero null.
295/// 120D (11): both mutual-pair vertices point OUT to third
296/// 120U (12): third points IN to both mutual-pair vertices
297/// 120C (13): chain through third
298fn classify_120(ab: u8, ac: u8, bc: u8) -> usize {
299    // Find the mutual dyad. The other two dyads are asymmetric, connecting
300    // the mutual-pair vertices to the third vertex.
301    let (to_third_1, to_third_2) = if ab == 3 {
302        // Mutual: A-B. Third = C. Asymmetric: AC and BC.
303        (ac, bc)
304    } else if ac == 3 {
305        // Mutual: A-C. Third = B. Asymmetric: AB and CB=flip(BC).
306        (ab, flip_dyad(bc))
307    } else {
308        // Mutual: B-C. Third = A. Asymmetric: BA=flip(AB) and CA=flip(AC).
309        (flip_dyad(ab), flip_dyad(ac))
310    };
311
312    // to_third: 1 = mutual_vertex->third, 2 = third->mutual_vertex
313    match (to_third_1, to_third_2) {
314        (2, 2) => 11, // 120D: third sends to both mutual vertices
315        (1, 1) => 12, // 120U: both mutual vertices send to third
316        _ => 13,      // 120C: chain
317    }
318}
319
320/// Flip a dyad code (swap perspective between the two vertices).
321fn flip_dyad(d: u8) -> u8 {
322    match d {
323        1 => 2,
324        2 => 1,
325        other => other,
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    #[test]
334    fn test_empty_graph() {
335        let g = Graph::new(0, true).unwrap();
336        let tc = triad_census(&g).unwrap();
337        assert!(tc.counts.iter().all(|&c| c.abs() < 1e-10));
338    }
339
340    #[test]
341    fn test_two_vertices() {
342        let g = Graph::new(2, true).unwrap();
343        let tc = triad_census(&g).unwrap();
344        assert!(tc.counts.iter().all(|&c| c.abs() < 1e-10));
345    }
346
347    #[test]
348    fn test_three_vertices_no_edges() {
349        let g = Graph::new(3, true).unwrap();
350        let tc = triad_census(&g).unwrap();
351        assert!((tc.get(TriadType::T003) - 1.0).abs() < 1e-10);
352        assert!(tc.counts[1..].iter().all(|&c| c.abs() < 1e-10));
353    }
354
355    #[test]
356    fn test_single_edge() {
357        let mut g = Graph::new(3, true).unwrap();
358        g.add_edge(0, 1).unwrap();
359        let tc = triad_census(&g).unwrap();
360        assert!((tc.get(TriadType::T012) - 1.0).abs() < 1e-10);
361        assert!((tc.get(TriadType::T003)).abs() < 1e-10);
362    }
363
364    #[test]
365    fn test_mutual_edge() {
366        let mut g = Graph::new(3, true).unwrap();
367        g.add_edge(0, 1).unwrap();
368        g.add_edge(1, 0).unwrap();
369        let tc = triad_census(&g).unwrap();
370        assert!((tc.get(TriadType::T102) - 1.0).abs() < 1e-10);
371    }
372
373    #[test]
374    fn test_directed_3_cycle() {
375        let mut g = Graph::new(3, true).unwrap();
376        g.add_edge(0, 1).unwrap();
377        g.add_edge(1, 2).unwrap();
378        g.add_edge(2, 0).unwrap();
379        let tc = triad_census(&g).unwrap();
380        assert!((tc.get(TriadType::T030C) - 1.0).abs() < 1e-10);
381    }
382
383    #[test]
384    fn test_transitive_triple() {
385        let mut g = Graph::new(3, true).unwrap();
386        g.add_edge(0, 1).unwrap();
387        g.add_edge(0, 2).unwrap();
388        g.add_edge(1, 2).unwrap();
389        let tc = triad_census(&g).unwrap();
390        assert!((tc.get(TriadType::T030T) - 1.0).abs() < 1e-10);
391    }
392
393    #[test]
394    fn test_complete_directed() {
395        let mut g = Graph::new(3, true).unwrap();
396        for i in 0..3u32 {
397            for j in 0..3u32 {
398                if i != j {
399                    g.add_edge(i, j).unwrap();
400                }
401            }
402        }
403        let tc = triad_census(&g).unwrap();
404        assert!((tc.get(TriadType::T300) - 1.0).abs() < 1e-10);
405    }
406
407    #[test]
408    fn test_021d_out_star() {
409        // B->A, B->C (out-star from B=1): 021D
410        let mut g = Graph::new(3, true).unwrap();
411        g.add_edge(1, 0).unwrap();
412        g.add_edge(1, 2).unwrap();
413        let tc = triad_census(&g).unwrap();
414        assert!((tc.get(TriadType::T021D) - 1.0).abs() < 1e-10);
415    }
416
417    #[test]
418    fn test_021u_in_star() {
419        // A->B, C->B (in-star to B=1): 021U
420        let mut g = Graph::new(3, true).unwrap();
421        g.add_edge(0, 1).unwrap();
422        g.add_edge(2, 1).unwrap();
423        let tc = triad_census(&g).unwrap();
424        assert!((tc.get(TriadType::T021U) - 1.0).abs() < 1e-10);
425    }
426
427    #[test]
428    fn test_021c_chain() {
429        // A->B->C (chain): 021C
430        let mut g = Graph::new(3, true).unwrap();
431        g.add_edge(0, 1).unwrap();
432        g.add_edge(1, 2).unwrap();
433        let tc = triad_census(&g).unwrap();
434        assert!((tc.get(TriadType::T021C) - 1.0).abs() < 1e-10);
435    }
436
437    #[test]
438    fn test_four_vertices_sum() {
439        // 4 vertices: C(4,3) = 4 total triples
440        let mut g = Graph::new(4, true).unwrap();
441        g.add_edge(0, 1).unwrap();
442        g.add_edge(1, 2).unwrap();
443        g.add_edge(2, 0).unwrap();
444        let tc = triad_census(&g).unwrap();
445        let total: f64 = tc.counts.iter().sum();
446        assert!((total - 4.0).abs() < 1e-10);
447    }
448
449    #[test]
450    fn test_201_two_mutual() {
451        // 0<->1, 0<->2, no edge between 1 and 2: 201
452        let mut g = Graph::new(3, true).unwrap();
453        g.add_edge(0, 1).unwrap();
454        g.add_edge(1, 0).unwrap();
455        g.add_edge(0, 2).unwrap();
456        g.add_edge(2, 0).unwrap();
457        let tc = triad_census(&g).unwrap();
458        assert!((tc.get(TriadType::T201) - 1.0).abs() < 1e-10);
459    }
460
461    #[test]
462    fn test_undirected_triangle() {
463        let mut g = Graph::with_vertices(3);
464        g.add_edge(0, 1).unwrap();
465        g.add_edge(1, 2).unwrap();
466        g.add_edge(0, 2).unwrap();
467        let tc = triad_census(&g).unwrap();
468        assert!((tc.get(TriadType::T300) - 1.0).abs() < 1e-10);
469    }
470
471    #[test]
472    fn test_undirected_path() {
473        // Undirected path 0-1-2: mutual(0,1), mutual(1,2), null(0,2) -> 201
474        let mut g = Graph::with_vertices(3);
475        g.add_edge(0, 1).unwrap();
476        g.add_edge(1, 2).unwrap();
477        let tc = triad_census(&g).unwrap();
478        assert!((tc.get(TriadType::T201) - 1.0).abs() < 1e-10);
479    }
480
481    #[test]
482    fn test_counts_sum_to_total() {
483        // n=5, total triples = C(5,3) = 10
484        let mut g = Graph::new(5, true).unwrap();
485        g.add_edge(0, 1).unwrap();
486        g.add_edge(1, 2).unwrap();
487        g.add_edge(2, 3).unwrap();
488        g.add_edge(3, 4).unwrap();
489        g.add_edge(4, 0).unwrap();
490        let tc = triad_census(&g).unwrap();
491        let total: f64 = tc.counts.iter().sum();
492        assert!((total - 10.0).abs() < 1e-10);
493    }
494
495    #[test]
496    fn test_111d() {
497        // 111D: A<->B<-C. Mutual (A,B), asymmetric C->B, null (A,C)
498        let mut g = Graph::new(3, true).unwrap();
499        g.add_edge(0, 1).unwrap();
500        g.add_edge(1, 0).unwrap();
501        g.add_edge(2, 1).unwrap();
502        let tc = triad_census(&g).unwrap();
503        assert!((tc.get(TriadType::T111D) - 1.0).abs() < 1e-10);
504    }
505
506    #[test]
507    fn test_111u() {
508        // 111U: A<->B->C. Mutual (A,B), asymmetric B->C, null (A,C)
509        let mut g = Graph::new(3, true).unwrap();
510        g.add_edge(0, 1).unwrap();
511        g.add_edge(1, 0).unwrap();
512        g.add_edge(1, 2).unwrap();
513        let tc = triad_census(&g).unwrap();
514        assert!((tc.get(TriadType::T111U) - 1.0).abs() < 1e-10);
515    }
516
517    #[test]
518    fn test_210() {
519        // 210: 2 mutual + 1 asymmetric + 0 null
520        // A<->B, A<->C, A->... wait: mutual(B,C), mutual(A,C), asymmetric(A,B) = A->B
521        let mut g = Graph::new(3, true).unwrap();
522        g.add_edge(1, 2).unwrap();
523        g.add_edge(2, 1).unwrap();
524        g.add_edge(0, 2).unwrap();
525        g.add_edge(2, 0).unwrap();
526        g.add_edge(0, 1).unwrap();
527        let tc = triad_census(&g).unwrap();
528        assert!((tc.get(TriadType::T210) - 1.0).abs() < 1e-10);
529    }
530
531    #[test]
532    fn test_self_loops_ignored() {
533        let mut g = Graph::new(3, true).unwrap();
534        g.add_edge(0, 0).unwrap();
535        g.add_edge(0, 1).unwrap();
536        let tc = triad_census(&g).unwrap();
537        assert!((tc.get(TriadType::T012) - 1.0).abs() < 1e-10);
538    }
539
540    #[test]
541    fn test_120d() {
542        // 120D: A<-B->C, A<->C
543        // Mutual: A-C. Asymmetric: B->A, B->C. So B points out to both.
544        // Vertices: 0=A, 1=B, 2=C
545        let mut g = Graph::new(3, true).unwrap();
546        g.add_edge(0, 2).unwrap();
547        g.add_edge(2, 0).unwrap();
548        g.add_edge(1, 0).unwrap();
549        g.add_edge(1, 2).unwrap();
550        let tc = triad_census(&g).unwrap();
551        assert!((tc.get(TriadType::T120D) - 1.0).abs() < 1e-10);
552    }
553
554    #[test]
555    fn test_120u() {
556        // 120U: A->B<-C, A<->C
557        // Mutual: A-C. Asymmetric: A->B, C->B. So both point to B.
558        let mut g = Graph::new(3, true).unwrap();
559        g.add_edge(0, 2).unwrap();
560        g.add_edge(2, 0).unwrap();
561        g.add_edge(0, 1).unwrap();
562        g.add_edge(2, 1).unwrap();
563        let tc = triad_census(&g).unwrap();
564        assert!((tc.get(TriadType::T120U) - 1.0).abs() < 1e-10);
565    }
566
567    #[test]
568    fn test_120c() {
569        // 120C: A->B->C, A<->C
570        // Mutual: A-C. Asymmetric: A->B, B->C. Chain through B.
571        let mut g = Graph::new(3, true).unwrap();
572        g.add_edge(0, 2).unwrap();
573        g.add_edge(2, 0).unwrap();
574        g.add_edge(0, 1).unwrap();
575        g.add_edge(1, 2).unwrap();
576        let tc = triad_census(&g).unwrap();
577        assert!((tc.get(TriadType::T120C) - 1.0).abs() < 1e-10);
578    }
579}