1use crate::algorithms::paths::dijkstra::DijkstraMode;
15use crate::core::graph::EdgeId;
16use crate::core::rng::SplitMix64;
17use crate::core::{Graph, IgraphError, IgraphResult, VertexId};
18
19use super::random_walk::validate_weights;
20
21pub type Node2VecWalkResult = (Vec<VertexId>, Vec<EdgeId>);
23
24#[allow(clippy::too_many_arguments)]
70pub fn random_walk_node2vec(
71 graph: &Graph,
72 weights: Option<&[f64]>,
73 start: VertexId,
74 mode: DijkstraMode,
75 steps: u32,
76 p: f64,
77 q: f64,
78 seed: u64,
79) -> IgraphResult<Node2VecWalkResult> {
80 let n = graph.vcount();
81 if start >= n {
82 return Err(IgraphError::VertexOutOfRange { id: start, n });
83 }
84 if !p.is_finite() || p <= 0.0 {
85 return Err(IgraphError::InvalidArgument(format!(
86 "p must be positive and finite, got {p}"
87 )));
88 }
89 if !q.is_finite() || q <= 0.0 {
90 return Err(IgraphError::InvalidArgument(format!(
91 "q must be positive and finite, got {q}"
92 )));
93 }
94 validate_weights(graph, weights)?;
95
96 let mut rng = SplitMix64::new(seed);
97 let mut vs: Vec<VertexId> = Vec::with_capacity(steps as usize + 1);
98 let mut es: Vec<EdgeId> = Vec::with_capacity(steps as usize);
99 vs.push(start);
100
101 if steps == 0 {
102 return Ok((vs, es));
103 }
104
105 let first_next = pick_neighbor(graph, start, weights, mode, &mut rng)?;
108 let Some((first_eid, first_v)) = first_next else {
109 return Ok((vs, es));
110 };
111 es.push(first_eid);
112 vs.push(first_v);
113
114 let inv_p = 1.0 / p;
116 let inv_q = 1.0 / q;
117
118 for _ in 1..steps {
119 let prev = vs[vs.len() - 2]; let current = *vs.last().unwrap(); let next =
123 pick_biased_neighbor(graph, prev, current, weights, mode, inv_p, inv_q, &mut rng)?;
124 let Some((eid, next_v)) = next else {
125 break;
126 };
127 es.push(eid);
128 vs.push(next_v);
129 }
130
131 Ok((vs, es))
132}
133
134fn pick_neighbor(
136 graph: &Graph,
137 v: VertexId,
138 weights: Option<&[f64]>,
139 mode: DijkstraMode,
140 rng: &mut SplitMix64,
141) -> IgraphResult<Option<(EdgeId, VertexId)>> {
142 let incidents = incident_for_mode(graph, v, mode)?;
143 if incidents.is_empty() {
144 return Ok(None);
145 }
146
147 let eid = match weights {
148 None => {
149 let idx = rng.gen_index(incidents.len());
150 incidents[idx]
151 }
152 Some(ws) => {
153 let chosen = weighted_pick(&incidents, ws, rng);
154 let Some(e) = chosen else {
155 return Ok(None);
156 };
157 e
158 }
159 };
160 let next = graph.edge_other(eid, v)?;
161 Ok(Some((eid, next)))
162}
163
164#[allow(clippy::too_many_arguments)]
166fn pick_biased_neighbor(
167 graph: &Graph,
168 prev: VertexId,
169 current: VertexId,
170 weights: Option<&[f64]>,
171 mode: DijkstraMode,
172 inv_p: f64,
173 inv_q: f64,
174 rng: &mut SplitMix64,
175) -> IgraphResult<Option<(EdgeId, VertexId)>> {
176 let incidents = incident_for_mode(graph, current, mode)?;
177 if incidents.is_empty() {
178 return Ok(None);
179 }
180
181 let prev_neighbors = neighbor_set(graph, prev, mode)?;
183
184 let mut biased_weights: Vec<f64> = Vec::with_capacity(incidents.len());
186 let mut total = 0.0_f64;
187
188 for &eid in &incidents {
189 let base_weight = match weights {
190 None => 1.0,
191 Some(ws) => {
192 let w = ws[eid as usize];
193 if !(w.is_finite() && w > 0.0) {
194 biased_weights.push(0.0);
195 continue;
196 }
197 w
198 }
199 };
200
201 let neighbor = graph.edge_other(eid, current)?;
202
203 let alpha = if neighbor == prev {
205 inv_p } else if prev_neighbors.contains(&neighbor) {
207 1.0 } else {
209 inv_q };
211
212 let w = alpha * base_weight;
213 biased_weights.push(w);
214 total += w;
215 }
216
217 if total <= 0.0 {
218 return Ok(None);
219 }
220
221 let target = rng.gen_unit() * total;
223 let mut acc = 0.0_f64;
224 for (i, &w) in biased_weights.iter().enumerate() {
225 if w <= 0.0 {
226 continue;
227 }
228 acc += w;
229 if acc >= target {
230 let eid = incidents[i];
231 let next = graph.edge_other(eid, current)?;
232 return Ok(Some((eid, next)));
233 }
234 }
235
236 for (i, &w) in biased_weights.iter().enumerate().rev() {
238 if w > 0.0 {
239 let eid = incidents[i];
240 let next = graph.edge_other(eid, current)?;
241 return Ok(Some((eid, next)));
242 }
243 }
244
245 Ok(None)
246}
247
248fn neighbor_set(graph: &Graph, v: VertexId, mode: DijkstraMode) -> IgraphResult<Vec<VertexId>> {
250 let incidents = incident_for_mode(graph, v, mode)?;
251 let mut neighbors: Vec<VertexId> = Vec::with_capacity(incidents.len());
252 for &eid in &incidents {
253 let other = graph.edge_other(eid, v)?;
254 neighbors.push(other);
255 }
256 neighbors.sort_unstable();
257 neighbors.dedup();
258 Ok(neighbors)
259}
260
261fn incident_for_mode(graph: &Graph, v: VertexId, mode: DijkstraMode) -> IgraphResult<Vec<EdgeId>> {
262 if !graph.is_directed() {
263 return graph.incident(v);
264 }
265 match mode {
266 DijkstraMode::Out => graph.incident(v),
267 DijkstraMode::In => graph.incident_in(v),
268 DijkstraMode::All => {
269 let mut out = graph.incident(v)?;
270 out.extend(graph.incident_in(v)?);
271 Ok(out)
272 }
273 }
274}
275
276fn weighted_pick(incidents: &[EdgeId], ws: &[f64], rng: &mut SplitMix64) -> Option<EdgeId> {
277 let mut total = 0.0_f64;
278 for &eid in incidents {
279 let w = ws[eid as usize];
280 if w.is_finite() && w > 0.0 {
281 total += w;
282 }
283 }
284 if total <= 0.0 {
285 return None;
286 }
287 let target = rng.gen_unit() * total;
288 let mut acc = 0.0_f64;
289 for &eid in incidents {
290 let w = ws[eid as usize];
291 if !(w.is_finite() && w > 0.0) {
292 continue;
293 }
294 acc += w;
295 if acc >= target {
296 return Some(eid);
297 }
298 }
299 for &eid in incidents.iter().rev() {
301 let w = ws[eid as usize];
302 if w.is_finite() && w > 0.0 {
303 return Some(eid);
304 }
305 }
306 None
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 fn path_graph(n: u32) -> Graph {
314 let mut g = Graph::with_vertices(n);
315 for i in 0..n - 1 {
316 g.add_edge(i, i + 1).unwrap();
317 }
318 g
319 }
320
321 fn grid_graph() -> Graph {
322 let mut g = Graph::with_vertices(9);
324 let edges = [
325 (0, 1),
326 (1, 2),
327 (3, 4),
328 (4, 5),
329 (6, 7),
330 (7, 8),
331 (0, 3),
332 (1, 4),
333 (2, 5),
334 (3, 6),
335 (4, 7),
336 (5, 8),
337 ];
338 for (u, v) in edges {
339 g.add_edge(u, v).unwrap();
340 }
341 g
342 }
343
344 #[test]
345 fn unit_basic_walk_length() {
346 let g = path_graph(10);
347 let (vs, es) =
348 random_walk_node2vec(&g, None, 0, DijkstraMode::Out, 5, 1.0, 1.0, 42).unwrap();
349 assert_eq!(vs[0], 0);
350 assert!(vs.len() <= 6);
351 assert_eq!(es.len(), vs.len() - 1);
352 }
353
354 #[test]
355 fn unit_p_q_one_reduces_to_standard() {
356 let g = path_graph(5);
359 let (vs, _) =
360 random_walk_node2vec(&g, None, 2, DijkstraMode::Out, 20, 1.0, 1.0, 123).unwrap();
361 assert_eq!(vs[0], 2);
362 for v in &vs {
364 assert!(*v < 5);
365 }
366 }
367
368 #[test]
369 fn unit_high_p_discourages_return() {
370 let g = grid_graph();
373 let mut immediate_returns = 0;
374 for seed in 0..100 {
375 let (vs, _) =
376 random_walk_node2vec(&g, None, 4, DijkstraMode::Out, 3, 100.0, 1.0, seed).unwrap();
377 if vs.len() >= 3 && vs[2] == vs[0] {
378 immediate_returns += 1;
379 }
380 }
381 assert!(
383 immediate_returns < 15,
384 "expected few immediate returns with high p, got {immediate_returns}/100"
385 );
386 }
387
388 #[test]
389 fn unit_low_p_encourages_return() {
390 let g = grid_graph();
392 let mut immediate_returns = 0;
393 for seed in 0..100 {
394 let (vs, _) =
395 random_walk_node2vec(&g, None, 4, DijkstraMode::Out, 3, 0.01, 1.0, seed).unwrap();
396 if vs.len() >= 3 && vs[2] == vs[0] {
397 immediate_returns += 1;
398 }
399 }
400 assert!(
402 immediate_returns > 40,
403 "expected many immediate returns with low p, got {immediate_returns}/100"
404 );
405 }
406
407 #[test]
408 fn unit_invalid_p() {
409 let g = path_graph(5);
410 let result = random_walk_node2vec(&g, None, 0, DijkstraMode::Out, 5, 0.0, 1.0, 42);
411 assert!(result.is_err());
412
413 let result = random_walk_node2vec(&g, None, 0, DijkstraMode::Out, 5, -1.0, 1.0, 42);
414 assert!(result.is_err());
415
416 let result = random_walk_node2vec(&g, None, 0, DijkstraMode::Out, 5, f64::NAN, 1.0, 42);
417 assert!(result.is_err());
418 }
419
420 #[test]
421 fn unit_invalid_q() {
422 let g = path_graph(5);
423 let result = random_walk_node2vec(&g, None, 0, DijkstraMode::Out, 5, 1.0, 0.0, 42);
424 assert!(result.is_err());
425
426 let result =
427 random_walk_node2vec(&g, None, 0, DijkstraMode::Out, 5, 1.0, f64::INFINITY, 42);
428 assert!(result.is_err());
429 }
430
431 #[test]
432 fn unit_start_out_of_range() {
433 let g = path_graph(5);
434 let result = random_walk_node2vec(&g, None, 10, DijkstraMode::Out, 5, 1.0, 1.0, 42);
435 assert!(result.is_err());
436 }
437
438 #[test]
439 fn unit_zero_steps() {
440 let g = path_graph(5);
441 let (vs, es) =
442 random_walk_node2vec(&g, None, 2, DijkstraMode::Out, 0, 1.0, 1.0, 42).unwrap();
443 assert_eq!(vs, vec![2]);
444 assert!(es.is_empty());
445 }
446
447 #[test]
448 fn unit_stuck_at_leaf() {
449 let mut g = Graph::new(4, true).unwrap();
451 g.add_edge(0, 1).unwrap();
452 g.add_edge(1, 2).unwrap();
453 g.add_edge(2, 3).unwrap();
454 let (vs, es) =
455 random_walk_node2vec(&g, None, 3, DijkstraMode::Out, 10, 1.0, 1.0, 42).unwrap();
456 assert_eq!(vs, vec![3]);
457 assert!(es.is_empty());
458 }
459
460 #[test]
461 fn unit_weighted_walk() {
462 let mut g = Graph::with_vertices(3);
464 g.add_edge(0, 1).unwrap(); g.add_edge(1, 2).unwrap(); g.add_edge(0, 2).unwrap(); let weights = vec![10.0, 1.0, 1.0];
469 let (vs, _) =
470 random_walk_node2vec(&g, Some(&weights), 0, DijkstraMode::Out, 1, 1.0, 1.0, 42)
471 .unwrap();
472 assert_eq!(vs[0], 0);
473 assert!(vs.len() == 2);
474 }
475
476 #[test]
477 fn unit_deterministic() {
478 let g = grid_graph();
479 let r1 = random_walk_node2vec(&g, None, 4, DijkstraMode::Out, 20, 2.0, 0.5, 99).unwrap();
480 let r2 = random_walk_node2vec(&g, None, 4, DijkstraMode::Out, 20, 2.0, 0.5, 99).unwrap();
481 assert_eq!(r1, r2);
482 }
483
484 #[test]
485 fn unit_single_vertex_graph() {
486 let g = Graph::with_vertices(1);
487 let (vs, es) =
488 random_walk_node2vec(&g, None, 0, DijkstraMode::Out, 5, 1.0, 1.0, 42).unwrap();
489 assert_eq!(vs, vec![0]);
490 assert!(es.is_empty());
491 }
492}