Skip to main content

rust_igraph/algorithms/properties/
neighbor_agg.rs

1//! Neighborhood aggregation operators (ALGO-TR-020).
2//!
3//! Core primitives for message-passing / GNN-style computation on graphs.
4//! Each operator computes, for every vertex, an aggregate of a signal over
5//! its neighbors:
6//!
7//! - **Mean aggregation**: `agg(v) = (1/deg(v)) · Σ_{u∈N(v)} f(u)`
8//! - **Sum aggregation**: `agg(v) = Σ_{u∈N(v)} f(u)`
9//! - **Max aggregation**: `agg(v) = max_{u∈N(v)} f(u)`
10//! - **Min aggregation**: `agg(v) = min_{u∈N(v)} f(u)`
11//! - **Attention-weighted aggregation**: `agg(v) = Σ_{u∈N(v)} α(v,u) · f(u)`
12//!   where `α(v,u)` is normalized via softmax over `N(v)`.
13
14#![allow(
15    clippy::cast_possible_truncation,
16    clippy::cast_precision_loss,
17    clippy::needless_range_loop
18)]
19
20use crate::core::{Graph, IgraphError, IgraphResult, VertexId};
21
22/// Aggregation mode for neighborhood operations.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum AggMode {
25    /// Mean over neighbors: `Σ f(u) / deg(v)`.
26    Mean,
27    /// Sum over neighbors: `Σ f(u)`.
28    Sum,
29    /// Maximum over neighbors.
30    Max,
31    /// Minimum over neighbors.
32    Min,
33}
34
35/// Aggregate a signal over each vertex's neighborhood.
36///
37/// For each vertex `v`, computes an aggregate of `signal[u]` over all
38/// neighbors `u ∈ N(v)` using the specified mode. Isolated vertices
39/// receive `0.0` for `Mean`/`Sum`, `f64::NEG_INFINITY` for `Max`,
40/// and `f64::INFINITY` for `Min`.
41///
42/// # Parameters
43///
44/// - `graph` — The input graph (undirected).
45/// - `signal` — Input signal of length `vcount`.
46/// - `mode` — Aggregation mode (`Mean`, `Sum`, `Max`, `Min`).
47///
48/// # Examples
49///
50/// ```
51/// use rust_igraph::{Graph, AggMode, neighbor_aggregate};
52///
53/// let g = Graph::from_edges(&[(0,1),(1,2),(0,2)], false, Some(3)).unwrap();
54/// let signal = vec![1.0, 2.0, 3.0];
55/// let mean_agg = neighbor_aggregate(&g, &signal, AggMode::Mean).unwrap();
56/// // Vertex 0: neighbors {1,2}, mean = (2+3)/2 = 2.5
57/// assert!((mean_agg[0] - 2.5).abs() < 1e-10);
58/// let sum_agg = neighbor_aggregate(&g, &signal, AggMode::Sum).unwrap();
59/// // Vertex 0: sum = 2+3 = 5
60/// assert!((sum_agg[0] - 5.0).abs() < 1e-10);
61/// ```
62pub fn neighbor_aggregate(graph: &Graph, signal: &[f64], mode: AggMode) -> IgraphResult<Vec<f64>> {
63    let nv = graph.vcount() as usize;
64
65    if signal.len() != nv {
66        return Err(IgraphError::InvalidArgument(format!(
67            "signal length {} does not match vcount {nv}",
68            signal.len()
69        )));
70    }
71
72    if graph.is_directed() {
73        return Err(IgraphError::InvalidArgument(
74            "neighbor_aggregate is defined for undirected graphs only".to_string(),
75        ));
76    }
77
78    let mut result = match mode {
79        AggMode::Mean | AggMode::Sum => vec![0.0_f64; nv],
80        AggMode::Max => vec![f64::NEG_INFINITY; nv],
81        AggMode::Min => vec![f64::INFINITY; nv],
82    };
83
84    for (u, v) in graph.edges() {
85        let ui = u as usize;
86        let vi = v as usize;
87
88        match mode {
89            AggMode::Mean | AggMode::Sum => {
90                result[ui] += signal[vi];
91                result[vi] += signal[ui];
92            }
93            AggMode::Max => {
94                if signal[vi] > result[ui] {
95                    result[ui] = signal[vi];
96                }
97                if signal[ui] > result[vi] {
98                    result[vi] = signal[ui];
99                }
100            }
101            AggMode::Min => {
102                if signal[vi] < result[ui] {
103                    result[ui] = signal[vi];
104                }
105                if signal[ui] < result[vi] {
106                    result[vi] = signal[ui];
107                }
108            }
109        }
110    }
111
112    if mode == AggMode::Mean {
113        for v in 0..nv {
114            let deg = graph.degree(v as VertexId)?;
115            if deg > 0 {
116                result[v] /= deg as f64;
117            }
118        }
119    }
120
121    // Fix isolated vertices for Max/Min
122    if matches!(mode, AggMode::Max | AggMode::Min) {
123        for v in 0..nv {
124            let deg = graph.degree(v as VertexId)?;
125            if deg == 0 {
126                result[v] = 0.0;
127            }
128        }
129    }
130
131    Ok(result)
132}
133
134/// Aggregate a signal with per-edge attention weights.
135///
136/// For each vertex `v`, computes
137/// `agg(v) = Σ_{u∈N(v)} softmax(attn(v,u)) · signal(u)`
138/// where `softmax` normalizes `attn` scores across `N(v)`.
139///
140/// # Parameters
141///
142/// - `graph` — Undirected graph.
143/// - `signal` — Input signal of length `vcount`.
144/// - `attention` — Raw attention scores, one per edge. Length must equal
145///   `ecount`. For edge `(u,v)`, the score is used in both directions.
146///
147/// # Examples
148///
149/// ```
150/// use rust_igraph::{Graph, attention_aggregate};
151///
152/// let g = Graph::from_edges(&[(0,1),(0,2)], false, Some(3)).unwrap();
153/// let signal = vec![0.0, 1.0, 2.0];
154/// // Equal attention → equivalent to mean
155/// let agg = attention_aggregate(&g, &signal, &[0.0, 0.0]).unwrap();
156/// assert!((agg[0] - 1.5).abs() < 1e-10); // mean(1, 2)
157/// ```
158pub fn attention_aggregate(
159    graph: &Graph,
160    signal: &[f64],
161    attention: &[f64],
162) -> IgraphResult<Vec<f64>> {
163    let nv = graph.vcount() as usize;
164    let ne = graph.ecount();
165
166    if signal.len() != nv {
167        return Err(IgraphError::InvalidArgument(format!(
168            "signal length {} does not match vcount {nv}",
169            signal.len()
170        )));
171    }
172
173    if attention.len() != ne {
174        return Err(IgraphError::InvalidArgument(format!(
175            "attention length {} does not match ecount {ne}",
176            attention.len()
177        )));
178    }
179
180    if graph.is_directed() {
181        return Err(IgraphError::InvalidArgument(
182            "attention_aggregate is defined for undirected graphs only".to_string(),
183        ));
184    }
185
186    // Build per-vertex neighbor lists with attention scores
187    let mut neighbor_scores: Vec<Vec<(usize, f64, f64)>> = vec![Vec::new(); nv];
188    for (eid, (u, v)) in graph.edges().enumerate() {
189        let ui = u as usize;
190        let vi = v as usize;
191        let attn = attention[eid];
192        neighbor_scores[ui].push((vi, attn, signal[vi]));
193        neighbor_scores[vi].push((ui, attn, signal[ui]));
194    }
195
196    // Softmax + weighted sum per vertex
197    let mut result = vec![0.0_f64; nv];
198    for v in 0..nv {
199        let neighbors = &neighbor_scores[v];
200        if neighbors.is_empty() {
201            continue;
202        }
203
204        // Stable softmax: subtract max for numerical stability
205        let max_attn = neighbors
206            .iter()
207            .map(|&(_, a, _)| a)
208            .fold(f64::NEG_INFINITY, f64::max);
209
210        let mut sum_exp = 0.0_f64;
211        let exps: Vec<f64> = neighbors
212            .iter()
213            .map(|&(_, a, _)| {
214                let e = (a - max_attn).exp();
215                sum_exp += e;
216                e
217            })
218            .collect();
219
220        if sum_exp > 0.0 {
221            for (i, &(_, _, sig)) in neighbors.iter().enumerate() {
222                result[v] += (exps[i] / sum_exp) * sig;
223            }
224        }
225    }
226
227    Ok(result)
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    fn triangle() -> Graph {
235        Graph::from_edges(&[(0, 1), (1, 2), (0, 2)], false, Some(3)).unwrap()
236    }
237
238    fn path4() -> Graph {
239        Graph::from_edges(&[(0, 1), (1, 2), (2, 3)], false, Some(4)).unwrap()
240    }
241
242    fn star4() -> Graph {
243        Graph::from_edges(&[(0, 1), (0, 2), (0, 3)], false, Some(4)).unwrap()
244    }
245
246    // --- neighbor_aggregate Mean tests ---
247
248    #[test]
249    fn mean_triangle() {
250        let g = triangle();
251        let s = vec![1.0, 2.0, 3.0];
252        let r = neighbor_aggregate(&g, &s, AggMode::Mean).unwrap();
253        assert!((r[0] - 2.5).abs() < 1e-10); // (2+3)/2
254        assert!((r[1] - 2.0).abs() < 1e-10); // (1+3)/2
255        assert!((r[2] - 1.5).abs() < 1e-10); // (1+2)/2
256    }
257
258    #[test]
259    fn mean_isolated() {
260        let g = Graph::with_vertices(3);
261        let s = vec![1.0, 2.0, 3.0];
262        let r = neighbor_aggregate(&g, &s, AggMode::Mean).unwrap();
263        for &v in &r {
264            assert!(v.abs() < 1e-10);
265        }
266    }
267
268    #[test]
269    fn mean_star() {
270        let g = star4();
271        let s = vec![0.0, 1.0, 2.0, 3.0];
272        let r = neighbor_aggregate(&g, &s, AggMode::Mean).unwrap();
273        assert!((r[0] - 2.0).abs() < 1e-10); // (1+2+3)/3
274        assert!((r[1] - 0.0).abs() < 1e-10); // only neighbor is 0
275        assert!((r[2] - 0.0).abs() < 1e-10);
276        assert!((r[3] - 0.0).abs() < 1e-10);
277    }
278
279    // --- neighbor_aggregate Sum tests ---
280
281    #[test]
282    fn sum_triangle() {
283        let g = triangle();
284        let s = vec![1.0, 2.0, 3.0];
285        let r = neighbor_aggregate(&g, &s, AggMode::Sum).unwrap();
286        assert!((r[0] - 5.0).abs() < 1e-10); // 2+3
287        assert!((r[1] - 4.0).abs() < 1e-10); // 1+3
288        assert!((r[2] - 3.0).abs() < 1e-10); // 1+2
289    }
290
291    #[test]
292    fn sum_path() {
293        let g = path4();
294        let s = vec![1.0, 2.0, 3.0, 4.0];
295        let r = neighbor_aggregate(&g, &s, AggMode::Sum).unwrap();
296        assert!((r[0] - 2.0).abs() < 1e-10); // only neighbor 1
297        assert!((r[1] - 4.0).abs() < 1e-10); // 1+3
298        assert!((r[2] - 6.0).abs() < 1e-10); // 2+4
299        assert!((r[3] - 3.0).abs() < 1e-10); // only neighbor 2
300    }
301
302    // --- neighbor_aggregate Max tests ---
303
304    #[test]
305    fn max_triangle() {
306        let g = triangle();
307        let s = vec![1.0, 5.0, 3.0];
308        let r = neighbor_aggregate(&g, &s, AggMode::Max).unwrap();
309        assert!((r[0] - 5.0).abs() < 1e-10); // max(5, 3)
310        assert!((r[1] - 3.0).abs() < 1e-10); // max(1, 3)
311        assert!((r[2] - 5.0).abs() < 1e-10); // max(1, 5)
312    }
313
314    #[test]
315    fn max_isolated() {
316        let g = Graph::with_vertices(2);
317        let s = vec![10.0, 20.0];
318        let r = neighbor_aggregate(&g, &s, AggMode::Max).unwrap();
319        assert!(r[0].abs() < 1e-10);
320        assert!(r[1].abs() < 1e-10);
321    }
322
323    // --- neighbor_aggregate Min tests ---
324
325    #[test]
326    fn min_triangle() {
327        let g = triangle();
328        let s = vec![1.0, 5.0, 3.0];
329        let r = neighbor_aggregate(&g, &s, AggMode::Min).unwrap();
330        assert!((r[0] - 3.0).abs() < 1e-10); // min(5, 3)
331        assert!((r[1] - 1.0).abs() < 1e-10); // min(1, 3)
332        assert!((r[2] - 1.0).abs() < 1e-10); // min(1, 5)
333    }
334
335    #[test]
336    fn min_isolated() {
337        let g = Graph::with_vertices(2);
338        let s = vec![10.0, 20.0];
339        let r = neighbor_aggregate(&g, &s, AggMode::Min).unwrap();
340        assert!(r[0].abs() < 1e-10);
341        assert!(r[1].abs() < 1e-10);
342    }
343
344    // --- error tests ---
345
346    #[test]
347    fn agg_invalid_signal() {
348        let g = triangle();
349        assert!(neighbor_aggregate(&g, &[1.0], AggMode::Mean).is_err());
350    }
351
352    #[test]
353    fn agg_directed_error() {
354        let g = Graph::from_edges(&[(0, 1)], true, Some(2)).unwrap();
355        assert!(neighbor_aggregate(&g, &[1.0, 2.0], AggMode::Sum).is_err());
356    }
357
358    // --- attention_aggregate tests ---
359
360    #[test]
361    fn attn_equal_weights() {
362        let g = Graph::from_edges(&[(0, 1), (0, 2)], false, Some(3)).unwrap();
363        let s = vec![0.0, 1.0, 2.0];
364        let r = attention_aggregate(&g, &s, &[0.0, 0.0]).unwrap();
365        // Equal attention → mean
366        assert!((r[0] - 1.5).abs() < 1e-10);
367    }
368
369    #[test]
370    fn attn_dominant_weight() {
371        let g = Graph::from_edges(&[(0, 1), (0, 2)], false, Some(3)).unwrap();
372        let s = vec![0.0, 1.0, 2.0];
373        // Very high attention to edge 0-1, low to 0-2
374        let r = attention_aggregate(&g, &s, &[100.0, 0.0]).unwrap();
375        // Vertex 0 should weight neighbor 1 much more
376        assert!((r[0] - 1.0).abs() < 0.01);
377    }
378
379    #[test]
380    fn attn_isolated() {
381        let g = Graph::with_vertices(2);
382        let s = vec![1.0, 2.0];
383        let r = attention_aggregate(&g, &s, &[]).unwrap();
384        assert!(r[0].abs() < 1e-10);
385        assert!(r[1].abs() < 1e-10);
386    }
387
388    #[test]
389    fn attn_invalid_signal() {
390        let g = triangle();
391        assert!(attention_aggregate(&g, &[1.0], &[0.0; 3]).is_err());
392    }
393
394    #[test]
395    fn attn_invalid_attention() {
396        let g = triangle();
397        assert!(attention_aggregate(&g, &[0.0; 3], &[1.0]).is_err());
398    }
399
400    #[test]
401    fn attn_directed_error() {
402        let g = Graph::from_edges(&[(0, 1)], true, Some(2)).unwrap();
403        assert!(attention_aggregate(&g, &[1.0, 2.0], &[0.0]).is_err());
404    }
405
406    // --- consistency tests ---
407
408    #[test]
409    fn sum_is_mean_times_degree() {
410        let g = triangle();
411        let s = vec![1.0, 2.0, 3.0];
412        let mean = neighbor_aggregate(&g, &s, AggMode::Mean).unwrap();
413        let sum = neighbor_aggregate(&g, &s, AggMode::Sum).unwrap();
414        for v in 0..3 {
415            let deg = g.degree(v as VertexId).unwrap() as f64;
416            assert!((sum[v] - mean[v] * deg).abs() < 1e-10);
417        }
418    }
419
420    #[test]
421    fn constant_signal_mean_equals_constant() {
422        let g = star4();
423        let c = 7.0;
424        let s = vec![c; 4];
425        let r = neighbor_aggregate(&g, &s, AggMode::Mean).unwrap();
426        for v in 0..4 {
427            if g.degree(v as VertexId).unwrap() > 0 {
428                assert!((r[v] - c).abs() < 1e-10);
429            }
430        }
431    }
432}