Skip to main content

rust_igraph/algorithms/properties/
label_spread.rs

1//! Label spreading for semi-supervised node classification (ALGO-TR-012).
2//!
3//! Given a graph where some vertices have known labels and others are
4//! unlabeled, propagates labels through the graph structure to predict
5//! labels for unlabeled vertices. Implements the iterative label spreading
6//! algorithm (Zhou et al., 2004) which balances between smoothness over
7//! the graph and fitting the initial labels.
8//!
9//! Used as a baseline for graph semi-supervised learning and as a
10//! post-processing step in GNN pipelines.
11
12#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
13
14use crate::core::{Graph, IgraphError, IgraphResult, VertexId};
15
16/// Result of label spreading prediction.
17#[derive(Debug, Clone, PartialEq)]
18pub struct LabelSpreadResult {
19    /// Predicted label for each vertex (class with max probability).
20    pub labels: Vec<u32>,
21    /// Confidence matrix: `confidence[v][c]` = probability of vertex v
22    /// belonging to class c.
23    pub confidence: Vec<Vec<f64>>,
24}
25
26/// Predict labels for unlabeled vertices using label spreading.
27///
28/// Iteratively propagates label information from labeled vertices to
29/// their neighbors via the graph structure. At each step:
30/// `Y_{t+1} = α · S · Y_t + (1-α) · Y_0`
31///
32/// where `S = D^{-1/2} A D^{-1/2}` is the symmetric normalized adjacency,
33/// `Y_0` is the initial label matrix, and `α` controls the balance between
34/// propagation and clamping to initial labels.
35///
36/// # Parameters
37///
38/// - `graph` — Undirected graph.
39/// - `labels` — Label for each vertex: `Some(class_id)` for labeled vertices,
40///   `None` for unlabeled vertices to predict.
41/// - `alpha` — Propagation strength (0 < alpha < 1). Higher = more propagation.
42///   Typical: 0.2–0.8.
43/// - `max_iter` — Maximum iterations.
44/// - `tol` — Convergence tolerance on max label probability change.
45///
46/// # Returns
47///
48/// A [`LabelSpreadResult`] with predicted labels and confidence scores.
49///
50/// # Examples
51///
52/// ```
53/// use rust_igraph::{Graph, label_spread};
54///
55/// // Path 0-1-2-3: label vertex 0 as class 0, vertex 3 as class 1
56/// let g = Graph::from_edges(&[(0,1),(1,2),(2,3)], false, Some(4)).unwrap();
57/// let labels = vec![Some(0), None, None, Some(1)];
58/// let result = label_spread(&g, &labels, 0.5, 50, 1e-6).unwrap();
59/// // Vertex 1 should be closer to class 0, vertex 2 closer to class 1
60/// assert_eq!(result.labels[0], 0);
61/// assert_eq!(result.labels[3], 1);
62/// assert_eq!(result.labels[1], 0);
63/// assert_eq!(result.labels[2], 1);
64/// ```
65pub fn label_spread(
66    graph: &Graph,
67    labels: &[Option<u32>],
68    alpha: f64,
69    max_iter: usize,
70    tol: f64,
71) -> IgraphResult<LabelSpreadResult> {
72    let nv = graph.vcount() as usize;
73
74    if labels.len() != nv {
75        return Err(IgraphError::InvalidArgument(format!(
76            "labels length {} does not match vcount {}",
77            labels.len(),
78            nv
79        )));
80    }
81
82    if alpha <= 0.0 || alpha >= 1.0 {
83        return Err(IgraphError::InvalidArgument(format!(
84            "alpha must be in (0.0, 1.0), got {alpha}"
85        )));
86    }
87
88    if graph.is_directed() {
89        return Err(IgraphError::InvalidArgument(
90            "label_spread is defined for undirected graphs only".to_string(),
91        ));
92    }
93
94    // Determine number of classes
95    let num_classes = labels.iter().filter_map(|l| *l).max().map_or(0, |m| m + 1) as usize;
96
97    if num_classes == 0 {
98        return Err(IgraphError::InvalidArgument(
99            "at least one labeled vertex is required".to_string(),
100        ));
101    }
102
103    // Compute degrees and D^{-1/2}
104    let mut degrees = Vec::with_capacity(nv);
105    for v in 0..nv {
106        degrees.push(graph.degree(v as VertexId)?);
107    }
108    let inv_sqrt_deg: Vec<f64> = degrees
109        .iter()
110        .map(|&d| if d == 0 { 0.0 } else { 1.0 / (d as f64).sqrt() })
111        .collect();
112
113    // Initialize Y_0: one-hot for labeled, uniform for unlabeled
114    let mut y_init: Vec<Vec<f64>> = Vec::with_capacity(nv);
115    for label in labels {
116        let mut row = vec![0.0; num_classes];
117        if let Some(c) = label {
118            let c_idx = *c as usize;
119            if c_idx < num_classes {
120                row[c_idx] = 1.0;
121            }
122        } else {
123            let uniform = 1.0 / num_classes as f64;
124            row.fill(uniform);
125        }
126        y_init.push(row);
127    }
128
129    let one_minus_alpha = 1.0 - alpha;
130    let mut y_current = y_init.clone();
131
132    // Iterate: Y_{t+1} = α · S · Y_t + (1-α) · Y_0
133    for _ in 0..max_iter {
134        let mut y_next: Vec<Vec<f64>> = vec![vec![0.0; num_classes]; nv];
135        let mut max_diff = 0.0_f64;
136
137        // Apply S = D^{-1/2} A D^{-1/2} to y_current
138        for v in 0..nv {
139            if degrees[v] == 0 {
140                // Isolated: keep initial label
141                for c in 0..num_classes {
142                    y_next[v][c] = y_init[v][c];
143                }
144                continue;
145            }
146
147            let neighbors = graph.neighbors(v as VertexId)?;
148            for c in 0..num_classes {
149                let mut propagated = 0.0;
150                for &u in &neighbors {
151                    let u_idx = u as usize;
152                    propagated += inv_sqrt_deg[u_idx] * y_current[u_idx][c];
153                }
154                propagated *= inv_sqrt_deg[v];
155
156                let new_val = alpha * propagated + one_minus_alpha * y_init[v][c];
157                let diff = (new_val - y_current[v][c]).abs();
158                if diff > max_diff {
159                    max_diff = diff;
160                }
161                y_next[v][c] = new_val;
162            }
163        }
164
165        y_current = y_next;
166
167        if max_diff < tol {
168            break;
169        }
170    }
171
172    // Extract predictions
173    let mut predicted_labels = Vec::with_capacity(nv);
174    for row in &y_current {
175        let mut best_class = 0u32;
176        let mut best_prob = f64::NEG_INFINITY;
177        for (c, &prob) in row.iter().enumerate() {
178            if prob > best_prob {
179                best_prob = prob;
180                best_class = c as u32;
181            }
182        }
183        predicted_labels.push(best_class);
184    }
185
186    Ok(LabelSpreadResult {
187        labels: predicted_labels,
188        confidence: y_current,
189    })
190}
191
192/// Predict labels using simple majority voting from labeled neighbors.
193///
194/// A non-iterative baseline: each unlabeled vertex adopts the most
195/// common label among its labeled neighbors. If no labeled neighbor
196/// exists, assigns label 0 (or keeps the existing label if provided).
197///
198/// # Examples
199///
200/// ```
201/// use rust_igraph::{Graph, label_propagate_predict};
202///
203/// let g = Graph::from_edges(&[(0,1),(1,2),(0,2),(2,3)], false, Some(4)).unwrap();
204/// let labels = vec![Some(0), Some(0), None, Some(1)];
205/// let predicted = label_propagate_predict(&g, &labels).unwrap();
206/// // Vertex 2 has neighbors: 0(class 0), 1(class 0), 3(class 1) → majority = 0
207/// assert_eq!(predicted[2], 0);
208/// ```
209pub fn label_propagate_predict(graph: &Graph, labels: &[Option<u32>]) -> IgraphResult<Vec<u32>> {
210    let nv = graph.vcount() as usize;
211
212    if labels.len() != nv {
213        return Err(IgraphError::InvalidArgument(format!(
214            "labels length {} does not match vcount {}",
215            labels.len(),
216            nv
217        )));
218    }
219
220    if graph.is_directed() {
221        return Err(IgraphError::InvalidArgument(
222            "label_propagate_predict is defined for undirected graphs only".to_string(),
223        ));
224    }
225
226    let num_classes = labels.iter().filter_map(|l| *l).max().map_or(0, |m| m + 1) as usize;
227
228    let mut result: Vec<u32> = Vec::with_capacity(nv);
229
230    for (v, label) in labels.iter().enumerate() {
231        if let Some(c) = label {
232            result.push(*c);
233        } else {
234            // Count labeled neighbors by class
235            let neighbors = graph.neighbors(v as VertexId)?;
236            let mut counts = vec![0u32; num_classes.max(1)];
237            for &u in &neighbors {
238                if let Some(c) = labels[u as usize] {
239                    if (c as usize) < counts.len() {
240                        counts[c as usize] += 1;
241                    }
242                }
243            }
244
245            let best_class = counts
246                .iter()
247                .enumerate()
248                .max_by_key(|(_, cnt)| *cnt)
249                .map_or(0, |(c, _)| c as u32);
250            result.push(best_class);
251        }
252    }
253
254    Ok(result)
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    fn path4() -> Graph {
262        Graph::from_edges(&[(0, 1), (1, 2), (2, 3)], false, Some(4)).unwrap()
263    }
264
265    fn triangle_with_tail() -> Graph {
266        // 0-1-2-0, 2-3
267        Graph::from_edges(&[(0, 1), (1, 2), (0, 2), (2, 3)], false, Some(4)).unwrap()
268    }
269
270    // --- label_spread tests ---
271
272    #[test]
273    fn spread_basic_path() {
274        let g = path4();
275        let labels = vec![Some(0), None, None, Some(1)];
276        let result = label_spread(&g, &labels, 0.5, 100, 1e-8).unwrap();
277        assert_eq!(result.labels[0], 0);
278        assert_eq!(result.labels[3], 1);
279        // Middle vertices: vertex 1 closer to 0, vertex 2 closer to 1
280        assert_eq!(result.labels[1], 0);
281        assert_eq!(result.labels[2], 1);
282    }
283
284    #[test]
285    fn spread_all_labeled() {
286        let g = path4();
287        let labels = vec![Some(0), Some(1), Some(0), Some(1)];
288        let result = label_spread(&g, &labels, 0.3, 50, 1e-6).unwrap();
289        // With strong clamping, labels should stay close to initial
290        assert_eq!(result.labels[0], 0);
291        assert_eq!(result.labels[1], 1);
292        assert_eq!(result.labels[2], 0);
293        assert_eq!(result.labels[3], 1);
294    }
295
296    #[test]
297    fn spread_single_class() {
298        let g = path4();
299        let labels = vec![Some(0), None, None, Some(0)];
300        let result = label_spread(&g, &labels, 0.5, 50, 1e-6).unwrap();
301        for &l in &result.labels {
302            assert_eq!(l, 0);
303        }
304    }
305
306    #[test]
307    fn spread_confidence_sums_reasonable() {
308        let g = path4();
309        let labels = vec![Some(0), None, None, Some(1)];
310        let result = label_spread(&g, &labels, 0.5, 50, 1e-6).unwrap();
311        for row in &result.confidence {
312            let sum: f64 = row.iter().sum();
313            // Should be approximately 1 for labeled vertices
314            assert!(sum > 0.0);
315            for &p in row {
316                assert!(p >= 0.0);
317            }
318        }
319    }
320
321    #[test]
322    fn spread_invalid_alpha() {
323        let g = path4();
324        let labels = vec![Some(0), None, None, Some(1)];
325        assert!(label_spread(&g, &labels, 0.0, 50, 1e-6).is_err());
326        assert!(label_spread(&g, &labels, 1.0, 50, 1e-6).is_err());
327        assert!(label_spread(&g, &labels, -0.5, 50, 1e-6).is_err());
328    }
329
330    #[test]
331    fn spread_invalid_labels_length() {
332        let g = path4();
333        assert!(label_spread(&g, &[Some(0)], 0.5, 50, 1e-6).is_err());
334    }
335
336    #[test]
337    fn spread_no_labeled_vertices() {
338        let g = path4();
339        let labels = vec![None, None, None, None];
340        assert!(label_spread(&g, &labels, 0.5, 50, 1e-6).is_err());
341    }
342
343    #[test]
344    fn spread_directed_error() {
345        let g = Graph::from_edges(&[(0, 1), (1, 2)], true, Some(3)).unwrap();
346        let labels = vec![Some(0), None, Some(1)];
347        assert!(label_spread(&g, &labels, 0.5, 50, 1e-6).is_err());
348    }
349
350    #[test]
351    fn spread_isolated_vertex() {
352        let mut labels = vec![Some(0), None, None, Some(1), None];
353        // Vertex 4 is isolated
354        let g = Graph::from_edges(&[(0, 1), (1, 2), (2, 3)], false, Some(5)).unwrap();
355        labels.push(None);
356        labels.truncate(5);
357        let result = label_spread(&g, &labels, 0.5, 50, 1e-6).unwrap();
358        assert_eq!(result.labels.len(), 5);
359    }
360
361    #[test]
362    fn spread_multiclass() {
363        let g =
364            Graph::from_edges(&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)], false, Some(6)).unwrap();
365        let labels = vec![Some(0), None, Some(1), None, Some(2), None];
366        let result = label_spread(&g, &labels, 0.5, 100, 1e-8).unwrap();
367        assert_eq!(result.labels[0], 0);
368        assert_eq!(result.labels[2], 1);
369        assert_eq!(result.labels[4], 2);
370        assert_eq!(result.confidence[0].len(), 3);
371    }
372
373    // --- label_propagate_predict tests ---
374
375    #[test]
376    fn predict_majority_vote() {
377        let g = triangle_with_tail();
378        let labels = vec![Some(0), Some(0), None, Some(1)];
379        let predicted = label_propagate_predict(&g, &labels).unwrap();
380        // Vertex 2: neighbors are 0(class 0), 1(class 0), 3(class 1) → majority = 0
381        assert_eq!(predicted[2], 0);
382    }
383
384    #[test]
385    fn predict_all_labeled() {
386        let g = path4();
387        let labels = vec![Some(0), Some(1), Some(0), Some(1)];
388        let predicted = label_propagate_predict(&g, &labels).unwrap();
389        assert_eq!(predicted, vec![0, 1, 0, 1]);
390    }
391
392    #[test]
393    fn predict_no_labeled_neighbors() {
394        let g = Graph::from_edges(&[(0, 1), (2, 3)], false, Some(4)).unwrap();
395        let labels = vec![Some(0), None, None, Some(1)];
396        let predicted = label_propagate_predict(&g, &labels).unwrap();
397        // Vertex 1: only neighbor is vertex 0 (class 0) → class 0
398        assert_eq!(predicted[1], 0);
399        // Vertex 2: only neighbor is vertex 3 (class 1) → class 1
400        assert_eq!(predicted[2], 1);
401    }
402
403    #[test]
404    fn predict_invalid_length() {
405        let g = path4();
406        assert!(label_propagate_predict(&g, &[Some(0)]).is_err());
407    }
408
409    #[test]
410    fn predict_directed_error() {
411        let g = Graph::from_edges(&[(0, 1)], true, Some(2)).unwrap();
412        assert!(label_propagate_predict(&g, &[Some(0), None]).is_err());
413    }
414}