Skip to main content

rust_igraph/algorithms/traversal/
neighbor_sample.rs

1//! K-hop neighborhood sampling for mini-batch GNN training (ALGO-TR-006).
2//!
3//! Implements layer-wise neighbor sampling as used by `GraphSAGE`, `PinSAGE`,
4//! and similar inductive GNN architectures. Given a batch of seed vertices,
5//! samples a fixed number of neighbors per vertex at each hop, producing a
6//! computation graph (subgraph) suitable for message-passing.
7
8use crate::core::rng::SplitMix64;
9use crate::core::{Graph, IgraphResult, VertexId};
10
11/// Result of k-hop neighborhood sampling.
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct NeighborSampleResult {
14    /// Sampled vertices at each layer, from outermost (k-hop) to innermost
15    /// (seed). `layers[0]` are the seeds, `layers[1]` are their sampled
16    /// neighbors, etc.
17    pub layers: Vec<Vec<VertexId>>,
18    /// Edges connecting each layer pair. `edges[i]` contains `(src, dst)`
19    /// pairs where `src` is in `layers[i+1]` and `dst` is in `layers[i]`.
20    /// Vertex ids are original graph ids.
21    pub edges: Vec<Vec<(VertexId, VertexId)>>,
22}
23
24/// Sample k-hop neighborhoods around seed vertices.
25///
26/// For each hop from 1 to `fan_out.len()`, samples at most `fan_out[hop-1]`
27/// neighbors per frontier vertex. Sampling is uniform without replacement
28/// when the number of neighbors exceeds the fan-out; otherwise all
29/// neighbors are included.
30///
31/// # Parameters
32///
33/// - `graph` — The input graph.
34/// - `seeds` — Starting vertices (the "batch").
35/// - `fan_out` — Number of neighbors to sample at each hop. Length
36///   determines the number of hops.
37/// - `seed` — PRNG seed for deterministic sampling.
38///
39/// # Returns
40///
41/// A [`NeighborSampleResult`] with layers and inter-layer edges.
42///
43/// # Examples
44///
45/// ```
46/// use rust_igraph::{Graph, neighbor_sample};
47///
48/// // Star graph: center 0 connected to 1,2,3,4
49/// let g = Graph::from_edges(
50///     &[(0,1),(0,2),(0,3),(0,4)], false, Some(5)
51/// ).unwrap();
52///
53/// // 1-hop sampling from vertex 0, fan_out=2
54/// let result = neighbor_sample(&g, &[0], &[2], 42).unwrap();
55/// assert_eq!(result.layers[0], vec![0]);
56/// assert_eq!(result.layers[1].len(), 2); // sampled 2 of 4 neighbors
57/// assert_eq!(result.edges[0].len(), 2);
58/// ```
59pub fn neighbor_sample(
60    graph: &Graph,
61    seeds: &[VertexId],
62    fan_out: &[usize],
63    seed: u64,
64) -> IgraphResult<NeighborSampleResult> {
65    let n = graph.vcount();
66
67    for &s in seeds {
68        if s >= n {
69            return Err(crate::core::IgraphError::VertexOutOfRange { id: s, n });
70        }
71    }
72
73    if seeds.is_empty() || fan_out.is_empty() {
74        return Ok(NeighborSampleResult {
75            layers: vec![seeds.to_vec()],
76            edges: Vec::new(),
77        });
78    }
79
80    let mut rng = SplitMix64::new(seed);
81    let mut layers: Vec<Vec<VertexId>> = Vec::with_capacity(fan_out.len() + 1);
82    let mut edges: Vec<Vec<(VertexId, VertexId)>> = Vec::with_capacity(fan_out.len());
83
84    layers.push(seeds.to_vec());
85
86    for &num_samples in fan_out {
87        let frontier = layers.last().unwrap();
88        let mut next_layer: Vec<VertexId> = Vec::new();
89        let mut layer_edges: Vec<(VertexId, VertexId)> = Vec::new();
90
91        for &v in frontier {
92            let neighbors = graph.neighbors(v)?;
93            if neighbors.is_empty() {
94                continue;
95            }
96
97            let sampled = if neighbors.len() <= num_samples {
98                neighbors
99            } else {
100                sample_without_replacement(&neighbors, num_samples, &mut rng)
101            };
102
103            for &u in &sampled {
104                layer_edges.push((u, v));
105                next_layer.push(u);
106            }
107        }
108
109        next_layer.sort_unstable();
110        next_layer.dedup();
111        edges.push(layer_edges);
112        layers.push(next_layer);
113    }
114
115    Ok(NeighborSampleResult { layers, edges })
116}
117
118/// Sample k-hop neighborhoods with importance-weighted sampling.
119///
120/// Like [`neighbor_sample`] but samples neighbors proportional to edge
121/// weights. Higher-weight edges are more likely to be sampled.
122///
123/// # Examples
124///
125/// ```
126/// use rust_igraph::{Graph, neighbor_sample_weighted};
127///
128/// let g = Graph::from_edges(
129///     &[(0,1),(0,2),(0,3)], false, Some(4)
130/// ).unwrap();
131/// let weights = vec![10.0, 1.0, 1.0]; // edge 0→1 has much higher weight
132///
133/// let result = neighbor_sample_weighted(&g, &[0], &[2], &weights, 42).unwrap();
134/// assert_eq!(result.layers[0], vec![0]);
135/// assert_eq!(result.layers[1].len(), 2);
136/// ```
137pub fn neighbor_sample_weighted(
138    graph: &Graph,
139    seeds: &[VertexId],
140    fan_out: &[usize],
141    weights: &[f64],
142    seed: u64,
143) -> IgraphResult<NeighborSampleResult> {
144    let n = graph.vcount();
145
146    for &s in seeds {
147        if s >= n {
148            return Err(crate::core::IgraphError::VertexOutOfRange { id: s, n });
149        }
150    }
151
152    if weights.len() != graph.ecount() {
153        return Err(crate::core::IgraphError::InvalidArgument(format!(
154            "weights length {} != ecount {}",
155            weights.len(),
156            graph.ecount()
157        )));
158    }
159
160    for (i, &w) in weights.iter().enumerate() {
161        if w < 0.0 || w.is_nan() {
162            return Err(crate::core::IgraphError::InvalidArgument(format!(
163                "weight[{i}] = {w} is invalid (must be non-negative and finite)"
164            )));
165        }
166    }
167
168    if seeds.is_empty() || fan_out.is_empty() {
169        return Ok(NeighborSampleResult {
170            layers: vec![seeds.to_vec()],
171            edges: Vec::new(),
172        });
173    }
174
175    let mut rng = SplitMix64::new(seed);
176    let mut layers: Vec<Vec<VertexId>> = Vec::with_capacity(fan_out.len() + 1);
177    let mut edges: Vec<Vec<(VertexId, VertexId)>> = Vec::with_capacity(fan_out.len());
178
179    layers.push(seeds.to_vec());
180
181    for &num_samples in fan_out {
182        let frontier = layers.last().unwrap();
183        let mut next_layer: Vec<VertexId> = Vec::new();
184        let mut layer_edges: Vec<(VertexId, VertexId)> = Vec::new();
185
186        for &v in frontier {
187            let incident = graph.incident(v)?;
188            if incident.is_empty() {
189                continue;
190            }
191
192            let mut neighbor_weights: Vec<(VertexId, f64)> = Vec::with_capacity(incident.len());
193            for &eid in &incident {
194                let neighbor = graph.edge_other(eid, v)?;
195                neighbor_weights.push((neighbor, weights[eid as usize]));
196            }
197
198            let sampled = if neighbor_weights.len() <= num_samples {
199                neighbor_weights.iter().map(|&(u, _)| u).collect()
200            } else {
201                weighted_sample_without_replacement(&neighbor_weights, num_samples, &mut rng)
202            };
203
204            for &u in &sampled {
205                layer_edges.push((u, v));
206                next_layer.push(u);
207            }
208        }
209
210        next_layer.sort_unstable();
211        next_layer.dedup();
212        edges.push(layer_edges);
213        layers.push(next_layer);
214    }
215
216    Ok(NeighborSampleResult { layers, edges })
217}
218
219// --- Internal helpers ---
220
221fn sample_without_replacement(items: &[VertexId], k: usize, rng: &mut SplitMix64) -> Vec<VertexId> {
222    let n = items.len();
223    if k >= n {
224        return items.to_vec();
225    }
226
227    let mut pool: Vec<VertexId> = items.to_vec();
228    for i in 0..k {
229        let j = i + rng.gen_index(n - i);
230        pool.swap(i, j);
231    }
232    pool.truncate(k);
233    pool
234}
235
236fn weighted_sample_without_replacement(
237    items: &[(VertexId, f64)],
238    k: usize,
239    rng: &mut SplitMix64,
240) -> Vec<VertexId> {
241    let n = items.len();
242    if k >= n {
243        return items.iter().map(|&(v, _)| v).collect();
244    }
245
246    let mut keys: Vec<(f64, VertexId)> = items
247        .iter()
248        .map(|&(v, w)| {
249            let u = rng.gen_unit();
250            let key = if w > 0.0 { u.powf(1.0 / w) } else { 0.0 };
251            (key, v)
252        })
253        .collect();
254
255    keys.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
256    keys.iter().take(k).map(|&(_, v)| v).collect()
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    fn star5() -> Graph {
264        Graph::from_edges(&[(0, 1), (0, 2), (0, 3), (0, 4)], false, Some(5)).unwrap()
265    }
266
267    fn path5() -> Graph {
268        Graph::from_edges(&[(0, 1), (1, 2), (2, 3), (3, 4)], false, Some(5)).unwrap()
269    }
270
271    #[test]
272    fn basic_one_hop() {
273        let g = star5();
274        let result = neighbor_sample(&g, &[0], &[2], 42).unwrap();
275        assert_eq!(result.layers[0], vec![0]);
276        assert_eq!(result.layers[1].len(), 2);
277        assert_eq!(result.edges[0].len(), 2);
278        for &(src, dst) in &result.edges[0] {
279            assert_eq!(dst, 0);
280            assert!((1..=4).contains(&src));
281        }
282    }
283
284    #[test]
285    fn all_neighbors_when_fan_out_large() {
286        let g = star5();
287        let result = neighbor_sample(&g, &[0], &[10], 42).unwrap();
288        assert_eq!(result.layers[1].len(), 4);
289    }
290
291    #[test]
292    fn two_hop_sampling() {
293        let g = path5();
294        let result = neighbor_sample(&g, &[0], &[2, 2], 42).unwrap();
295        assert_eq!(result.layers.len(), 3);
296        assert_eq!(result.edges.len(), 2);
297        assert_eq!(result.layers[0], vec![0]);
298    }
299
300    #[test]
301    fn multiple_seeds() {
302        let g = path5();
303        let result = neighbor_sample(&g, &[0, 4], &[2], 42).unwrap();
304        assert_eq!(result.layers[0], vec![0, 4]);
305        assert!(result.layers[1].len() >= 2);
306    }
307
308    #[test]
309    fn isolated_vertex() {
310        let g = Graph::with_vertices(3);
311        let result = neighbor_sample(&g, &[0, 1], &[5], 42).unwrap();
312        assert_eq!(result.layers[0], vec![0, 1]);
313        assert!(result.layers[1].is_empty());
314    }
315
316    #[test]
317    fn deterministic() {
318        let g = star5();
319        let r1 = neighbor_sample(&g, &[0], &[2], 99).unwrap();
320        let r2 = neighbor_sample(&g, &[0], &[2], 99).unwrap();
321        assert_eq!(r1, r2);
322    }
323
324    #[test]
325    fn different_seeds_different_results() {
326        let g = star5();
327        let r1 = neighbor_sample(&g, &[0], &[2], 1).unwrap();
328        let r2 = neighbor_sample(&g, &[0], &[2], 2).unwrap();
329        // With 4 neighbors and fan_out=2, different seeds should usually give different samples
330        // (probability of collision = C(4,2)^{-1} = 1/6)
331        // Not a hard guarantee, but overwhelming probability with these seeds
332        let mut s1 = r1.layers[1].clone();
333        let mut s2 = r2.layers[1].clone();
334        s1.sort_unstable();
335        s2.sort_unstable();
336        assert_ne!(s1, s2);
337    }
338
339    #[test]
340    fn empty_seeds() {
341        let g = star5();
342        let result = neighbor_sample(&g, &[], &[2], 42).unwrap();
343        assert_eq!(result.layers.len(), 1);
344        assert!(result.layers[0].is_empty());
345        assert!(result.edges.is_empty());
346    }
347
348    #[test]
349    fn empty_fan_out() {
350        let g = star5();
351        let result = neighbor_sample(&g, &[0], &[], 42).unwrap();
352        assert_eq!(result.layers.len(), 1);
353        assert_eq!(result.layers[0], vec![0]);
354        assert!(result.edges.is_empty());
355    }
356
357    #[test]
358    fn invalid_seed_vertex() {
359        let g = star5();
360        let result = neighbor_sample(&g, &[10], &[2], 42);
361        assert!(result.is_err());
362    }
363
364    #[test]
365    fn weighted_basic() {
366        let g = star5();
367        let weights = vec![10.0, 1.0, 1.0, 1.0]; // 4 edges
368        let result = neighbor_sample_weighted(&g, &[0], &[2], &weights, 42).unwrap();
369        assert_eq!(result.layers[0], vec![0]);
370        assert_eq!(result.layers[1].len(), 2);
371    }
372
373    #[test]
374    fn weighted_high_weight_preferred() {
375        let g = star5();
376        // Edge 0→1 has very high weight, others nearly zero
377        let weights = vec![1000.0, 0.001, 0.001, 0.001];
378        let mut vertex1_count = 0;
379        for trial in 0..20u64 {
380            let result = neighbor_sample_weighted(&g, &[0], &[1], &weights, trial * 137).unwrap();
381            if result.layers[1].contains(&1) {
382                vertex1_count += 1;
383            }
384        }
385        // Vertex 1 should be selected in most trials
386        assert!(vertex1_count >= 15);
387    }
388
389    #[test]
390    fn weighted_invalid_weights_length() {
391        let g = star5();
392        let weights = vec![1.0, 2.0]; // wrong length
393        let result = neighbor_sample_weighted(&g, &[0], &[2], &weights, 42);
394        assert!(result.is_err());
395    }
396
397    #[test]
398    fn weighted_negative_weight() {
399        let g = star5();
400        let weights = vec![1.0, -1.0, 1.0, 1.0];
401        let result = neighbor_sample_weighted(&g, &[0], &[2], &weights, 42);
402        assert!(result.is_err());
403    }
404
405    #[test]
406    fn deduplication_across_frontier() {
407        // Triangle: 0-1, 1-2, 0-2. Seeds=[0,1], fan_out=[3]
408        // Both 0 and 1 have vertex 2 as neighbor
409        let g = Graph::from_edges(&[(0, 1), (1, 2), (0, 2)], false, Some(3)).unwrap();
410        let result = neighbor_sample(&g, &[0, 1], &[3], 42).unwrap();
411        // layers[1] should be deduplicated
412        let mut sorted = result.layers[1].clone();
413        sorted.sort_unstable();
414        let mut deduped = sorted.clone();
415        deduped.dedup();
416        assert_eq!(sorted, deduped);
417    }
418}