rust_igraph/algorithms/properties/
neighbor_agg.rs1#![allow(
15 clippy::cast_possible_truncation,
16 clippy::cast_precision_loss,
17 clippy::needless_range_loop
18)]
19
20use crate::core::{Graph, IgraphError, IgraphResult, VertexId};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum AggMode {
25 Mean,
27 Sum,
29 Max,
31 Min,
33}
34
35pub fn neighbor_aggregate(graph: &Graph, signal: &[f64], mode: AggMode) -> IgraphResult<Vec<f64>> {
63 let nv = graph.vcount() as usize;
64
65 if signal.len() != nv {
66 return Err(IgraphError::InvalidArgument(format!(
67 "signal length {} does not match vcount {nv}",
68 signal.len()
69 )));
70 }
71
72 if graph.is_directed() {
73 return Err(IgraphError::InvalidArgument(
74 "neighbor_aggregate is defined for undirected graphs only".to_string(),
75 ));
76 }
77
78 let mut result = match mode {
79 AggMode::Mean | AggMode::Sum => vec![0.0_f64; nv],
80 AggMode::Max => vec![f64::NEG_INFINITY; nv],
81 AggMode::Min => vec![f64::INFINITY; nv],
82 };
83
84 for (u, v) in graph.edges() {
85 let ui = u as usize;
86 let vi = v as usize;
87
88 match mode {
89 AggMode::Mean | AggMode::Sum => {
90 result[ui] += signal[vi];
91 result[vi] += signal[ui];
92 }
93 AggMode::Max => {
94 if signal[vi] > result[ui] {
95 result[ui] = signal[vi];
96 }
97 if signal[ui] > result[vi] {
98 result[vi] = signal[ui];
99 }
100 }
101 AggMode::Min => {
102 if signal[vi] < result[ui] {
103 result[ui] = signal[vi];
104 }
105 if signal[ui] < result[vi] {
106 result[vi] = signal[ui];
107 }
108 }
109 }
110 }
111
112 if mode == AggMode::Mean {
113 for v in 0..nv {
114 let deg = graph.degree(v as VertexId)?;
115 if deg > 0 {
116 result[v] /= deg as f64;
117 }
118 }
119 }
120
121 if matches!(mode, AggMode::Max | AggMode::Min) {
123 for v in 0..nv {
124 let deg = graph.degree(v as VertexId)?;
125 if deg == 0 {
126 result[v] = 0.0;
127 }
128 }
129 }
130
131 Ok(result)
132}
133
134pub fn attention_aggregate(
159 graph: &Graph,
160 signal: &[f64],
161 attention: &[f64],
162) -> IgraphResult<Vec<f64>> {
163 let nv = graph.vcount() as usize;
164 let ne = graph.ecount();
165
166 if signal.len() != nv {
167 return Err(IgraphError::InvalidArgument(format!(
168 "signal length {} does not match vcount {nv}",
169 signal.len()
170 )));
171 }
172
173 if attention.len() != ne {
174 return Err(IgraphError::InvalidArgument(format!(
175 "attention length {} does not match ecount {ne}",
176 attention.len()
177 )));
178 }
179
180 if graph.is_directed() {
181 return Err(IgraphError::InvalidArgument(
182 "attention_aggregate is defined for undirected graphs only".to_string(),
183 ));
184 }
185
186 let mut neighbor_scores: Vec<Vec<(usize, f64, f64)>> = vec![Vec::new(); nv];
188 for (eid, (u, v)) in graph.edges().enumerate() {
189 let ui = u as usize;
190 let vi = v as usize;
191 let attn = attention[eid];
192 neighbor_scores[ui].push((vi, attn, signal[vi]));
193 neighbor_scores[vi].push((ui, attn, signal[ui]));
194 }
195
196 let mut result = vec![0.0_f64; nv];
198 for v in 0..nv {
199 let neighbors = &neighbor_scores[v];
200 if neighbors.is_empty() {
201 continue;
202 }
203
204 let max_attn = neighbors
206 .iter()
207 .map(|&(_, a, _)| a)
208 .fold(f64::NEG_INFINITY, f64::max);
209
210 let mut sum_exp = 0.0_f64;
211 let exps: Vec<f64> = neighbors
212 .iter()
213 .map(|&(_, a, _)| {
214 let e = (a - max_attn).exp();
215 sum_exp += e;
216 e
217 })
218 .collect();
219
220 if sum_exp > 0.0 {
221 for (i, &(_, _, sig)) in neighbors.iter().enumerate() {
222 result[v] += (exps[i] / sum_exp) * sig;
223 }
224 }
225 }
226
227 Ok(result)
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 fn triangle() -> Graph {
235 Graph::from_edges(&[(0, 1), (1, 2), (0, 2)], false, Some(3)).unwrap()
236 }
237
238 fn path4() -> Graph {
239 Graph::from_edges(&[(0, 1), (1, 2), (2, 3)], false, Some(4)).unwrap()
240 }
241
242 fn star4() -> Graph {
243 Graph::from_edges(&[(0, 1), (0, 2), (0, 3)], false, Some(4)).unwrap()
244 }
245
246 #[test]
249 fn mean_triangle() {
250 let g = triangle();
251 let s = vec![1.0, 2.0, 3.0];
252 let r = neighbor_aggregate(&g, &s, AggMode::Mean).unwrap();
253 assert!((r[0] - 2.5).abs() < 1e-10); assert!((r[1] - 2.0).abs() < 1e-10); assert!((r[2] - 1.5).abs() < 1e-10); }
257
258 #[test]
259 fn mean_isolated() {
260 let g = Graph::with_vertices(3);
261 let s = vec![1.0, 2.0, 3.0];
262 let r = neighbor_aggregate(&g, &s, AggMode::Mean).unwrap();
263 for &v in &r {
264 assert!(v.abs() < 1e-10);
265 }
266 }
267
268 #[test]
269 fn mean_star() {
270 let g = star4();
271 let s = vec![0.0, 1.0, 2.0, 3.0];
272 let r = neighbor_aggregate(&g, &s, AggMode::Mean).unwrap();
273 assert!((r[0] - 2.0).abs() < 1e-10); assert!((r[1] - 0.0).abs() < 1e-10); assert!((r[2] - 0.0).abs() < 1e-10);
276 assert!((r[3] - 0.0).abs() < 1e-10);
277 }
278
279 #[test]
282 fn sum_triangle() {
283 let g = triangle();
284 let s = vec![1.0, 2.0, 3.0];
285 let r = neighbor_aggregate(&g, &s, AggMode::Sum).unwrap();
286 assert!((r[0] - 5.0).abs() < 1e-10); assert!((r[1] - 4.0).abs() < 1e-10); assert!((r[2] - 3.0).abs() < 1e-10); }
290
291 #[test]
292 fn sum_path() {
293 let g = path4();
294 let s = vec![1.0, 2.0, 3.0, 4.0];
295 let r = neighbor_aggregate(&g, &s, AggMode::Sum).unwrap();
296 assert!((r[0] - 2.0).abs() < 1e-10); assert!((r[1] - 4.0).abs() < 1e-10); assert!((r[2] - 6.0).abs() < 1e-10); assert!((r[3] - 3.0).abs() < 1e-10); }
301
302 #[test]
305 fn max_triangle() {
306 let g = triangle();
307 let s = vec![1.0, 5.0, 3.0];
308 let r = neighbor_aggregate(&g, &s, AggMode::Max).unwrap();
309 assert!((r[0] - 5.0).abs() < 1e-10); assert!((r[1] - 3.0).abs() < 1e-10); assert!((r[2] - 5.0).abs() < 1e-10); }
313
314 #[test]
315 fn max_isolated() {
316 let g = Graph::with_vertices(2);
317 let s = vec![10.0, 20.0];
318 let r = neighbor_aggregate(&g, &s, AggMode::Max).unwrap();
319 assert!(r[0].abs() < 1e-10);
320 assert!(r[1].abs() < 1e-10);
321 }
322
323 #[test]
326 fn min_triangle() {
327 let g = triangle();
328 let s = vec![1.0, 5.0, 3.0];
329 let r = neighbor_aggregate(&g, &s, AggMode::Min).unwrap();
330 assert!((r[0] - 3.0).abs() < 1e-10); assert!((r[1] - 1.0).abs() < 1e-10); assert!((r[2] - 1.0).abs() < 1e-10); }
334
335 #[test]
336 fn min_isolated() {
337 let g = Graph::with_vertices(2);
338 let s = vec![10.0, 20.0];
339 let r = neighbor_aggregate(&g, &s, AggMode::Min).unwrap();
340 assert!(r[0].abs() < 1e-10);
341 assert!(r[1].abs() < 1e-10);
342 }
343
344 #[test]
347 fn agg_invalid_signal() {
348 let g = triangle();
349 assert!(neighbor_aggregate(&g, &[1.0], AggMode::Mean).is_err());
350 }
351
352 #[test]
353 fn agg_directed_error() {
354 let g = Graph::from_edges(&[(0, 1)], true, Some(2)).unwrap();
355 assert!(neighbor_aggregate(&g, &[1.0, 2.0], AggMode::Sum).is_err());
356 }
357
358 #[test]
361 fn attn_equal_weights() {
362 let g = Graph::from_edges(&[(0, 1), (0, 2)], false, Some(3)).unwrap();
363 let s = vec![0.0, 1.0, 2.0];
364 let r = attention_aggregate(&g, &s, &[0.0, 0.0]).unwrap();
365 assert!((r[0] - 1.5).abs() < 1e-10);
367 }
368
369 #[test]
370 fn attn_dominant_weight() {
371 let g = Graph::from_edges(&[(0, 1), (0, 2)], false, Some(3)).unwrap();
372 let s = vec![0.0, 1.0, 2.0];
373 let r = attention_aggregate(&g, &s, &[100.0, 0.0]).unwrap();
375 assert!((r[0] - 1.0).abs() < 0.01);
377 }
378
379 #[test]
380 fn attn_isolated() {
381 let g = Graph::with_vertices(2);
382 let s = vec![1.0, 2.0];
383 let r = attention_aggregate(&g, &s, &[]).unwrap();
384 assert!(r[0].abs() < 1e-10);
385 assert!(r[1].abs() < 1e-10);
386 }
387
388 #[test]
389 fn attn_invalid_signal() {
390 let g = triangle();
391 assert!(attention_aggregate(&g, &[1.0], &[0.0; 3]).is_err());
392 }
393
394 #[test]
395 fn attn_invalid_attention() {
396 let g = triangle();
397 assert!(attention_aggregate(&g, &[0.0; 3], &[1.0]).is_err());
398 }
399
400 #[test]
401 fn attn_directed_error() {
402 let g = Graph::from_edges(&[(0, 1)], true, Some(2)).unwrap();
403 assert!(attention_aggregate(&g, &[1.0, 2.0], &[0.0]).is_err());
404 }
405
406 #[test]
409 fn sum_is_mean_times_degree() {
410 let g = triangle();
411 let s = vec![1.0, 2.0, 3.0];
412 let mean = neighbor_aggregate(&g, &s, AggMode::Mean).unwrap();
413 let sum = neighbor_aggregate(&g, &s, AggMode::Sum).unwrap();
414 for v in 0..3 {
415 let deg = g.degree(v as VertexId).unwrap() as f64;
416 assert!((sum[v] - mean[v] * deg).abs() < 1e-10);
417 }
418 }
419
420 #[test]
421 fn constant_signal_mean_equals_constant() {
422 let g = star4();
423 let c = 7.0;
424 let s = vec![c; 4];
425 let r = neighbor_aggregate(&g, &s, AggMode::Mean).unwrap();
426 for v in 0..4 {
427 if g.degree(v as VertexId).unwrap() > 0 {
428 assert!((r[v] - c).abs() < 1e-10);
429 }
430 }
431 }
432}