1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
13
14use crate::core::{Graph, IgraphError, IgraphResult, VertexId};
15
16#[derive(Debug, Clone, PartialEq)]
18pub struct LabelSpreadResult {
19 pub labels: Vec<u32>,
21 pub confidence: Vec<Vec<f64>>,
24}
25
26pub fn label_spread(
66 graph: &Graph,
67 labels: &[Option<u32>],
68 alpha: f64,
69 max_iter: usize,
70 tol: f64,
71) -> IgraphResult<LabelSpreadResult> {
72 let nv = graph.vcount() as usize;
73
74 if labels.len() != nv {
75 return Err(IgraphError::InvalidArgument(format!(
76 "labels length {} does not match vcount {}",
77 labels.len(),
78 nv
79 )));
80 }
81
82 if alpha <= 0.0 || alpha >= 1.0 {
83 return Err(IgraphError::InvalidArgument(format!(
84 "alpha must be in (0.0, 1.0), got {alpha}"
85 )));
86 }
87
88 if graph.is_directed() {
89 return Err(IgraphError::InvalidArgument(
90 "label_spread is defined for undirected graphs only".to_string(),
91 ));
92 }
93
94 let num_classes = labels.iter().filter_map(|l| *l).max().map_or(0, |m| m + 1) as usize;
96
97 if num_classes == 0 {
98 return Err(IgraphError::InvalidArgument(
99 "at least one labeled vertex is required".to_string(),
100 ));
101 }
102
103 let mut degrees = Vec::with_capacity(nv);
105 for v in 0..nv {
106 degrees.push(graph.degree(v as VertexId)?);
107 }
108 let inv_sqrt_deg: Vec<f64> = degrees
109 .iter()
110 .map(|&d| if d == 0 { 0.0 } else { 1.0 / (d as f64).sqrt() })
111 .collect();
112
113 let mut y_init: Vec<Vec<f64>> = Vec::with_capacity(nv);
115 for label in labels {
116 let mut row = vec![0.0; num_classes];
117 if let Some(c) = label {
118 let c_idx = *c as usize;
119 if c_idx < num_classes {
120 row[c_idx] = 1.0;
121 }
122 } else {
123 let uniform = 1.0 / num_classes as f64;
124 row.fill(uniform);
125 }
126 y_init.push(row);
127 }
128
129 let one_minus_alpha = 1.0 - alpha;
130 let mut y_current = y_init.clone();
131
132 for _ in 0..max_iter {
134 let mut y_next: Vec<Vec<f64>> = vec![vec![0.0; num_classes]; nv];
135 let mut max_diff = 0.0_f64;
136
137 for v in 0..nv {
139 if degrees[v] == 0 {
140 for c in 0..num_classes {
142 y_next[v][c] = y_init[v][c];
143 }
144 continue;
145 }
146
147 let neighbors = graph.neighbors(v as VertexId)?;
148 for c in 0..num_classes {
149 let mut propagated = 0.0;
150 for &u in &neighbors {
151 let u_idx = u as usize;
152 propagated += inv_sqrt_deg[u_idx] * y_current[u_idx][c];
153 }
154 propagated *= inv_sqrt_deg[v];
155
156 let new_val = alpha * propagated + one_minus_alpha * y_init[v][c];
157 let diff = (new_val - y_current[v][c]).abs();
158 if diff > max_diff {
159 max_diff = diff;
160 }
161 y_next[v][c] = new_val;
162 }
163 }
164
165 y_current = y_next;
166
167 if max_diff < tol {
168 break;
169 }
170 }
171
172 let mut predicted_labels = Vec::with_capacity(nv);
174 for row in &y_current {
175 let mut best_class = 0u32;
176 let mut best_prob = f64::NEG_INFINITY;
177 for (c, &prob) in row.iter().enumerate() {
178 if prob > best_prob {
179 best_prob = prob;
180 best_class = c as u32;
181 }
182 }
183 predicted_labels.push(best_class);
184 }
185
186 Ok(LabelSpreadResult {
187 labels: predicted_labels,
188 confidence: y_current,
189 })
190}
191
192pub fn label_propagate_predict(graph: &Graph, labels: &[Option<u32>]) -> IgraphResult<Vec<u32>> {
210 let nv = graph.vcount() as usize;
211
212 if labels.len() != nv {
213 return Err(IgraphError::InvalidArgument(format!(
214 "labels length {} does not match vcount {}",
215 labels.len(),
216 nv
217 )));
218 }
219
220 if graph.is_directed() {
221 return Err(IgraphError::InvalidArgument(
222 "label_propagate_predict is defined for undirected graphs only".to_string(),
223 ));
224 }
225
226 let num_classes = labels.iter().filter_map(|l| *l).max().map_or(0, |m| m + 1) as usize;
227
228 let mut result: Vec<u32> = Vec::with_capacity(nv);
229
230 for (v, label) in labels.iter().enumerate() {
231 if let Some(c) = label {
232 result.push(*c);
233 } else {
234 let neighbors = graph.neighbors(v as VertexId)?;
236 let mut counts = vec![0u32; num_classes.max(1)];
237 for &u in &neighbors {
238 if let Some(c) = labels[u as usize] {
239 if (c as usize) < counts.len() {
240 counts[c as usize] += 1;
241 }
242 }
243 }
244
245 let best_class = counts
246 .iter()
247 .enumerate()
248 .max_by_key(|(_, cnt)| *cnt)
249 .map_or(0, |(c, _)| c as u32);
250 result.push(best_class);
251 }
252 }
253
254 Ok(result)
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 fn path4() -> Graph {
262 Graph::from_edges(&[(0, 1), (1, 2), (2, 3)], false, Some(4)).unwrap()
263 }
264
265 fn triangle_with_tail() -> Graph {
266 Graph::from_edges(&[(0, 1), (1, 2), (0, 2), (2, 3)], false, Some(4)).unwrap()
268 }
269
270 #[test]
273 fn spread_basic_path() {
274 let g = path4();
275 let labels = vec![Some(0), None, None, Some(1)];
276 let result = label_spread(&g, &labels, 0.5, 100, 1e-8).unwrap();
277 assert_eq!(result.labels[0], 0);
278 assert_eq!(result.labels[3], 1);
279 assert_eq!(result.labels[1], 0);
281 assert_eq!(result.labels[2], 1);
282 }
283
284 #[test]
285 fn spread_all_labeled() {
286 let g = path4();
287 let labels = vec![Some(0), Some(1), Some(0), Some(1)];
288 let result = label_spread(&g, &labels, 0.3, 50, 1e-6).unwrap();
289 assert_eq!(result.labels[0], 0);
291 assert_eq!(result.labels[1], 1);
292 assert_eq!(result.labels[2], 0);
293 assert_eq!(result.labels[3], 1);
294 }
295
296 #[test]
297 fn spread_single_class() {
298 let g = path4();
299 let labels = vec![Some(0), None, None, Some(0)];
300 let result = label_spread(&g, &labels, 0.5, 50, 1e-6).unwrap();
301 for &l in &result.labels {
302 assert_eq!(l, 0);
303 }
304 }
305
306 #[test]
307 fn spread_confidence_sums_reasonable() {
308 let g = path4();
309 let labels = vec![Some(0), None, None, Some(1)];
310 let result = label_spread(&g, &labels, 0.5, 50, 1e-6).unwrap();
311 for row in &result.confidence {
312 let sum: f64 = row.iter().sum();
313 assert!(sum > 0.0);
315 for &p in row {
316 assert!(p >= 0.0);
317 }
318 }
319 }
320
321 #[test]
322 fn spread_invalid_alpha() {
323 let g = path4();
324 let labels = vec![Some(0), None, None, Some(1)];
325 assert!(label_spread(&g, &labels, 0.0, 50, 1e-6).is_err());
326 assert!(label_spread(&g, &labels, 1.0, 50, 1e-6).is_err());
327 assert!(label_spread(&g, &labels, -0.5, 50, 1e-6).is_err());
328 }
329
330 #[test]
331 fn spread_invalid_labels_length() {
332 let g = path4();
333 assert!(label_spread(&g, &[Some(0)], 0.5, 50, 1e-6).is_err());
334 }
335
336 #[test]
337 fn spread_no_labeled_vertices() {
338 let g = path4();
339 let labels = vec![None, None, None, None];
340 assert!(label_spread(&g, &labels, 0.5, 50, 1e-6).is_err());
341 }
342
343 #[test]
344 fn spread_directed_error() {
345 let g = Graph::from_edges(&[(0, 1), (1, 2)], true, Some(3)).unwrap();
346 let labels = vec![Some(0), None, Some(1)];
347 assert!(label_spread(&g, &labels, 0.5, 50, 1e-6).is_err());
348 }
349
350 #[test]
351 fn spread_isolated_vertex() {
352 let mut labels = vec![Some(0), None, None, Some(1), None];
353 let g = Graph::from_edges(&[(0, 1), (1, 2), (2, 3)], false, Some(5)).unwrap();
355 labels.push(None);
356 labels.truncate(5);
357 let result = label_spread(&g, &labels, 0.5, 50, 1e-6).unwrap();
358 assert_eq!(result.labels.len(), 5);
359 }
360
361 #[test]
362 fn spread_multiclass() {
363 let g =
364 Graph::from_edges(&[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)], false, Some(6)).unwrap();
365 let labels = vec![Some(0), None, Some(1), None, Some(2), None];
366 let result = label_spread(&g, &labels, 0.5, 100, 1e-8).unwrap();
367 assert_eq!(result.labels[0], 0);
368 assert_eq!(result.labels[2], 1);
369 assert_eq!(result.labels[4], 2);
370 assert_eq!(result.confidence[0].len(), 3);
371 }
372
373 #[test]
376 fn predict_majority_vote() {
377 let g = triangle_with_tail();
378 let labels = vec![Some(0), Some(0), None, Some(1)];
379 let predicted = label_propagate_predict(&g, &labels).unwrap();
380 assert_eq!(predicted[2], 0);
382 }
383
384 #[test]
385 fn predict_all_labeled() {
386 let g = path4();
387 let labels = vec![Some(0), Some(1), Some(0), Some(1)];
388 let predicted = label_propagate_predict(&g, &labels).unwrap();
389 assert_eq!(predicted, vec![0, 1, 0, 1]);
390 }
391
392 #[test]
393 fn predict_no_labeled_neighbors() {
394 let g = Graph::from_edges(&[(0, 1), (2, 3)], false, Some(4)).unwrap();
395 let labels = vec![Some(0), None, None, Some(1)];
396 let predicted = label_propagate_predict(&g, &labels).unwrap();
397 assert_eq!(predicted[1], 0);
399 assert_eq!(predicted[2], 1);
401 }
402
403 #[test]
404 fn predict_invalid_length() {
405 let g = path4();
406 assert!(label_propagate_predict(&g, &[Some(0)]).is_err());
407 }
408
409 #[test]
410 fn predict_directed_error() {
411 let g = Graph::from_edges(&[(0, 1)], true, Some(2)).unwrap();
412 assert!(label_propagate_predict(&g, &[Some(0), None]).is_err());
413 }
414}