1use crate::core::rng::SplitMix64;
9use crate::core::{Graph, IgraphResult, VertexId};
10
11#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct NeighborSampleResult {
14 pub layers: Vec<Vec<VertexId>>,
18 pub edges: Vec<Vec<(VertexId, VertexId)>>,
22}
23
24pub fn neighbor_sample(
60 graph: &Graph,
61 seeds: &[VertexId],
62 fan_out: &[usize],
63 seed: u64,
64) -> IgraphResult<NeighborSampleResult> {
65 let n = graph.vcount();
66
67 for &s in seeds {
68 if s >= n {
69 return Err(crate::core::IgraphError::VertexOutOfRange { id: s, n });
70 }
71 }
72
73 if seeds.is_empty() || fan_out.is_empty() {
74 return Ok(NeighborSampleResult {
75 layers: vec![seeds.to_vec()],
76 edges: Vec::new(),
77 });
78 }
79
80 let mut rng = SplitMix64::new(seed);
81 let mut layers: Vec<Vec<VertexId>> = Vec::with_capacity(fan_out.len() + 1);
82 let mut edges: Vec<Vec<(VertexId, VertexId)>> = Vec::with_capacity(fan_out.len());
83
84 layers.push(seeds.to_vec());
85
86 for &num_samples in fan_out {
87 let frontier = layers.last().unwrap();
88 let mut next_layer: Vec<VertexId> = Vec::new();
89 let mut layer_edges: Vec<(VertexId, VertexId)> = Vec::new();
90
91 for &v in frontier {
92 let neighbors = graph.neighbors(v)?;
93 if neighbors.is_empty() {
94 continue;
95 }
96
97 let sampled = if neighbors.len() <= num_samples {
98 neighbors
99 } else {
100 sample_without_replacement(&neighbors, num_samples, &mut rng)
101 };
102
103 for &u in &sampled {
104 layer_edges.push((u, v));
105 next_layer.push(u);
106 }
107 }
108
109 next_layer.sort_unstable();
110 next_layer.dedup();
111 edges.push(layer_edges);
112 layers.push(next_layer);
113 }
114
115 Ok(NeighborSampleResult { layers, edges })
116}
117
118pub fn neighbor_sample_weighted(
138 graph: &Graph,
139 seeds: &[VertexId],
140 fan_out: &[usize],
141 weights: &[f64],
142 seed: u64,
143) -> IgraphResult<NeighborSampleResult> {
144 let n = graph.vcount();
145
146 for &s in seeds {
147 if s >= n {
148 return Err(crate::core::IgraphError::VertexOutOfRange { id: s, n });
149 }
150 }
151
152 if weights.len() != graph.ecount() {
153 return Err(crate::core::IgraphError::InvalidArgument(format!(
154 "weights length {} != ecount {}",
155 weights.len(),
156 graph.ecount()
157 )));
158 }
159
160 for (i, &w) in weights.iter().enumerate() {
161 if w < 0.0 || w.is_nan() {
162 return Err(crate::core::IgraphError::InvalidArgument(format!(
163 "weight[{i}] = {w} is invalid (must be non-negative and finite)"
164 )));
165 }
166 }
167
168 if seeds.is_empty() || fan_out.is_empty() {
169 return Ok(NeighborSampleResult {
170 layers: vec![seeds.to_vec()],
171 edges: Vec::new(),
172 });
173 }
174
175 let mut rng = SplitMix64::new(seed);
176 let mut layers: Vec<Vec<VertexId>> = Vec::with_capacity(fan_out.len() + 1);
177 let mut edges: Vec<Vec<(VertexId, VertexId)>> = Vec::with_capacity(fan_out.len());
178
179 layers.push(seeds.to_vec());
180
181 for &num_samples in fan_out {
182 let frontier = layers.last().unwrap();
183 let mut next_layer: Vec<VertexId> = Vec::new();
184 let mut layer_edges: Vec<(VertexId, VertexId)> = Vec::new();
185
186 for &v in frontier {
187 let incident = graph.incident(v)?;
188 if incident.is_empty() {
189 continue;
190 }
191
192 let mut neighbor_weights: Vec<(VertexId, f64)> = Vec::with_capacity(incident.len());
193 for &eid in &incident {
194 let neighbor = graph.edge_other(eid, v)?;
195 neighbor_weights.push((neighbor, weights[eid as usize]));
196 }
197
198 let sampled = if neighbor_weights.len() <= num_samples {
199 neighbor_weights.iter().map(|&(u, _)| u).collect()
200 } else {
201 weighted_sample_without_replacement(&neighbor_weights, num_samples, &mut rng)
202 };
203
204 for &u in &sampled {
205 layer_edges.push((u, v));
206 next_layer.push(u);
207 }
208 }
209
210 next_layer.sort_unstable();
211 next_layer.dedup();
212 edges.push(layer_edges);
213 layers.push(next_layer);
214 }
215
216 Ok(NeighborSampleResult { layers, edges })
217}
218
219fn sample_without_replacement(items: &[VertexId], k: usize, rng: &mut SplitMix64) -> Vec<VertexId> {
222 let n = items.len();
223 if k >= n {
224 return items.to_vec();
225 }
226
227 let mut pool: Vec<VertexId> = items.to_vec();
228 for i in 0..k {
229 let j = i + rng.gen_index(n - i);
230 pool.swap(i, j);
231 }
232 pool.truncate(k);
233 pool
234}
235
236fn weighted_sample_without_replacement(
237 items: &[(VertexId, f64)],
238 k: usize,
239 rng: &mut SplitMix64,
240) -> Vec<VertexId> {
241 let n = items.len();
242 if k >= n {
243 return items.iter().map(|&(v, _)| v).collect();
244 }
245
246 let mut keys: Vec<(f64, VertexId)> = items
247 .iter()
248 .map(|&(v, w)| {
249 let u = rng.gen_unit();
250 let key = if w > 0.0 { u.powf(1.0 / w) } else { 0.0 };
251 (key, v)
252 })
253 .collect();
254
255 keys.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
256 keys.iter().take(k).map(|&(_, v)| v).collect()
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 fn star5() -> Graph {
264 Graph::from_edges(&[(0, 1), (0, 2), (0, 3), (0, 4)], false, Some(5)).unwrap()
265 }
266
267 fn path5() -> Graph {
268 Graph::from_edges(&[(0, 1), (1, 2), (2, 3), (3, 4)], false, Some(5)).unwrap()
269 }
270
271 #[test]
272 fn basic_one_hop() {
273 let g = star5();
274 let result = neighbor_sample(&g, &[0], &[2], 42).unwrap();
275 assert_eq!(result.layers[0], vec![0]);
276 assert_eq!(result.layers[1].len(), 2);
277 assert_eq!(result.edges[0].len(), 2);
278 for &(src, dst) in &result.edges[0] {
279 assert_eq!(dst, 0);
280 assert!((1..=4).contains(&src));
281 }
282 }
283
284 #[test]
285 fn all_neighbors_when_fan_out_large() {
286 let g = star5();
287 let result = neighbor_sample(&g, &[0], &[10], 42).unwrap();
288 assert_eq!(result.layers[1].len(), 4);
289 }
290
291 #[test]
292 fn two_hop_sampling() {
293 let g = path5();
294 let result = neighbor_sample(&g, &[0], &[2, 2], 42).unwrap();
295 assert_eq!(result.layers.len(), 3);
296 assert_eq!(result.edges.len(), 2);
297 assert_eq!(result.layers[0], vec![0]);
298 }
299
300 #[test]
301 fn multiple_seeds() {
302 let g = path5();
303 let result = neighbor_sample(&g, &[0, 4], &[2], 42).unwrap();
304 assert_eq!(result.layers[0], vec![0, 4]);
305 assert!(result.layers[1].len() >= 2);
306 }
307
308 #[test]
309 fn isolated_vertex() {
310 let g = Graph::with_vertices(3);
311 let result = neighbor_sample(&g, &[0, 1], &[5], 42).unwrap();
312 assert_eq!(result.layers[0], vec![0, 1]);
313 assert!(result.layers[1].is_empty());
314 }
315
316 #[test]
317 fn deterministic() {
318 let g = star5();
319 let r1 = neighbor_sample(&g, &[0], &[2], 99).unwrap();
320 let r2 = neighbor_sample(&g, &[0], &[2], 99).unwrap();
321 assert_eq!(r1, r2);
322 }
323
324 #[test]
325 fn different_seeds_different_results() {
326 let g = star5();
327 let r1 = neighbor_sample(&g, &[0], &[2], 1).unwrap();
328 let r2 = neighbor_sample(&g, &[0], &[2], 2).unwrap();
329 let mut s1 = r1.layers[1].clone();
333 let mut s2 = r2.layers[1].clone();
334 s1.sort_unstable();
335 s2.sort_unstable();
336 assert_ne!(s1, s2);
337 }
338
339 #[test]
340 fn empty_seeds() {
341 let g = star5();
342 let result = neighbor_sample(&g, &[], &[2], 42).unwrap();
343 assert_eq!(result.layers.len(), 1);
344 assert!(result.layers[0].is_empty());
345 assert!(result.edges.is_empty());
346 }
347
348 #[test]
349 fn empty_fan_out() {
350 let g = star5();
351 let result = neighbor_sample(&g, &[0], &[], 42).unwrap();
352 assert_eq!(result.layers.len(), 1);
353 assert_eq!(result.layers[0], vec![0]);
354 assert!(result.edges.is_empty());
355 }
356
357 #[test]
358 fn invalid_seed_vertex() {
359 let g = star5();
360 let result = neighbor_sample(&g, &[10], &[2], 42);
361 assert!(result.is_err());
362 }
363
364 #[test]
365 fn weighted_basic() {
366 let g = star5();
367 let weights = vec![10.0, 1.0, 1.0, 1.0]; let result = neighbor_sample_weighted(&g, &[0], &[2], &weights, 42).unwrap();
369 assert_eq!(result.layers[0], vec![0]);
370 assert_eq!(result.layers[1].len(), 2);
371 }
372
373 #[test]
374 fn weighted_high_weight_preferred() {
375 let g = star5();
376 let weights = vec![1000.0, 0.001, 0.001, 0.001];
378 let mut vertex1_count = 0;
379 for trial in 0..20u64 {
380 let result = neighbor_sample_weighted(&g, &[0], &[1], &weights, trial * 137).unwrap();
381 if result.layers[1].contains(&1) {
382 vertex1_count += 1;
383 }
384 }
385 assert!(vertex1_count >= 15);
387 }
388
389 #[test]
390 fn weighted_invalid_weights_length() {
391 let g = star5();
392 let weights = vec![1.0, 2.0]; let result = neighbor_sample_weighted(&g, &[0], &[2], &weights, 42);
394 assert!(result.is_err());
395 }
396
397 #[test]
398 fn weighted_negative_weight() {
399 let g = star5();
400 let weights = vec![1.0, -1.0, 1.0, 1.0];
401 let result = neighbor_sample_weighted(&g, &[0], &[2], &weights, 42);
402 assert!(result.is_err());
403 }
404
405 #[test]
406 fn deduplication_across_frontier() {
407 let g = Graph::from_edges(&[(0, 1), (1, 2), (0, 2)], false, Some(3)).unwrap();
410 let result = neighbor_sample(&g, &[0, 1], &[3], 42).unwrap();
411 let mut sorted = result.layers[1].clone();
413 sorted.sort_unstable();
414 let mut deduped = sorted.clone();
415 deduped.dedup();
416 assert_eq!(sorted, deduped);
417 }
418}