Skip to main content

rust_igraph/algorithms/paths/
random_walk_node2vec.rs

1//! Second-order biased random walk (`Node2Vec`) on a graph (ALGO-TR-004).
2//!
3//! Implements the biased walk described in Grover & Leskovec (2016) "node2vec:
4//! Scalable Feature Learning for Networks". At each step the transition
5//! probability from vertex `t → v → x` is proportional to `α(t, x) · w(v, x)`,
6//! where `w(v, x)` is the edge weight and `α` is:
7//!
8//! - `1/p` if `x == t`  (return to previous vertex — controlled by `p`)
9//! - `1`   if `x` is a neighbor of `t` (stay close — BFS-like when `q > 1`)
10//! - `1/q` if `x` is NOT a neighbor of `t` (move away — DFS-like when `q < 1`)
11//!
12//! When `p = q = 1` this reduces to a standard (first-order) random walk.
13
14use crate::algorithms::paths::dijkstra::DijkstraMode;
15use crate::core::graph::EdgeId;
16use crate::core::rng::SplitMix64;
17use crate::core::{Graph, IgraphError, IgraphResult, VertexId};
18
19use super::random_walk::validate_weights;
20
21/// Result of a `Node2Vec` random walk: the vertex chain and the edge chain.
22pub type Node2VecWalkResult = (Vec<VertexId>, Vec<EdgeId>);
23
24/// Performs a second-order biased random walk (`Node2Vec`) starting at `start`
25/// for up to `steps` transitions.
26///
27/// # Parameters
28///
29/// - `p` — Return parameter. Higher `p` makes it less likely to return to the
30///   previously visited vertex. `p > 1` discourages backtracking; `p < 1`
31///   encourages it.
32/// - `q` — In-out parameter. Higher `q` biases towards BFS-like exploration
33///   (staying close to `t`); `q < 1` biases towards DFS-like exploration
34///   (moving away from `t`).
35/// - `weights` — Optional edge weights (positive). `None` for unweighted.
36/// - `mode` — Direction mode for directed graphs (`Out`, `In`, `All`).
37/// - `seed` — Deterministic RNG seed.
38///
39/// # Returns
40///
41/// `(vertex_chain, edge_chain)` where `vertex_chain[0] == start`. If the walk
42/// gets stuck (no admissible outgoing edges), the result is truncated.
43///
44/// # Errors
45///
46/// Returns [`IgraphError::InvalidArgument`] if `p` or `q` are not positive
47/// finite, if `start` is out of range, or if weights are invalid.
48///
49/// # Examples
50///
51/// ```
52/// use rust_igraph::{Graph, random_walk_node2vec, DijkstraMode};
53///
54/// let mut g = Graph::with_vertices(6);
55/// g.add_edge(0, 1).unwrap();
56/// g.add_edge(1, 2).unwrap();
57/// g.add_edge(2, 3).unwrap();
58/// g.add_edge(3, 4).unwrap();
59/// g.add_edge(4, 5).unwrap();
60/// g.add_edge(0, 2).unwrap(); // shortcut
61///
62/// // With q < 1: biased towards DFS (exploring further away)
63/// let (vs, es) = random_walk_node2vec(
64///     &g, None, 0, DijkstraMode::Out, 10, 1.0, 0.5, 42
65/// ).unwrap();
66/// assert_eq!(vs[0], 0);
67/// assert!(vs.len() <= 11);
68/// ```
69#[allow(clippy::too_many_arguments)]
70pub fn random_walk_node2vec(
71    graph: &Graph,
72    weights: Option<&[f64]>,
73    start: VertexId,
74    mode: DijkstraMode,
75    steps: u32,
76    p: f64,
77    q: f64,
78    seed: u64,
79) -> IgraphResult<Node2VecWalkResult> {
80    let n = graph.vcount();
81    if start >= n {
82        return Err(IgraphError::VertexOutOfRange { id: start, n });
83    }
84    if !p.is_finite() || p <= 0.0 {
85        return Err(IgraphError::InvalidArgument(format!(
86            "p must be positive and finite, got {p}"
87        )));
88    }
89    if !q.is_finite() || q <= 0.0 {
90        return Err(IgraphError::InvalidArgument(format!(
91            "q must be positive and finite, got {q}"
92        )));
93    }
94    validate_weights(graph, weights)?;
95
96    let mut rng = SplitMix64::new(seed);
97    let mut vs: Vec<VertexId> = Vec::with_capacity(steps as usize + 1);
98    let mut es: Vec<EdgeId> = Vec::with_capacity(steps as usize);
99    vs.push(start);
100
101    if steps == 0 {
102        return Ok((vs, es));
103    }
104
105    // First step: standard (unbiased or weight-proportional) since there's
106    // no "previous" vertex yet.
107    let first_next = pick_neighbor(graph, start, weights, mode, &mut rng)?;
108    let Some((first_eid, first_v)) = first_next else {
109        return Ok((vs, es));
110    };
111    es.push(first_eid);
112    vs.push(first_v);
113
114    // Subsequent steps: second-order biased
115    let inv_p = 1.0 / p;
116    let inv_q = 1.0 / q;
117
118    for _ in 1..steps {
119        let prev = vs[vs.len() - 2]; // t
120        let current = *vs.last().unwrap(); // v
121
122        let next =
123            pick_biased_neighbor(graph, prev, current, weights, mode, inv_p, inv_q, &mut rng)?;
124        let Some((eid, next_v)) = next else {
125            break;
126        };
127        es.push(eid);
128        vs.push(next_v);
129    }
130
131    Ok((vs, es))
132}
133
134/// Pick a neighbor uniformly (unweighted) or proportional to weight.
135fn pick_neighbor(
136    graph: &Graph,
137    v: VertexId,
138    weights: Option<&[f64]>,
139    mode: DijkstraMode,
140    rng: &mut SplitMix64,
141) -> IgraphResult<Option<(EdgeId, VertexId)>> {
142    let incidents = incident_for_mode(graph, v, mode)?;
143    if incidents.is_empty() {
144        return Ok(None);
145    }
146
147    let eid = match weights {
148        None => {
149            let idx = rng.gen_index(incidents.len());
150            incidents[idx]
151        }
152        Some(ws) => {
153            let chosen = weighted_pick(&incidents, ws, rng);
154            let Some(e) = chosen else {
155                return Ok(None);
156            };
157            e
158        }
159    };
160    let next = graph.edge_other(eid, v)?;
161    Ok(Some((eid, next)))
162}
163
164/// Pick a neighbor with `Node2Vec` second-order bias.
165#[allow(clippy::too_many_arguments)]
166fn pick_biased_neighbor(
167    graph: &Graph,
168    prev: VertexId,
169    current: VertexId,
170    weights: Option<&[f64]>,
171    mode: DijkstraMode,
172    inv_p: f64,
173    inv_q: f64,
174    rng: &mut SplitMix64,
175) -> IgraphResult<Option<(EdgeId, VertexId)>> {
176    let incidents = incident_for_mode(graph, current, mode)?;
177    if incidents.is_empty() {
178        return Ok(None);
179    }
180
181    // Collect neighbors of `prev` for the distance check.
182    let prev_neighbors = neighbor_set(graph, prev, mode)?;
183
184    // Compute biased weights for each candidate edge.
185    let mut biased_weights: Vec<f64> = Vec::with_capacity(incidents.len());
186    let mut total = 0.0_f64;
187
188    for &eid in &incidents {
189        let base_weight = match weights {
190            None => 1.0,
191            Some(ws) => {
192                let w = ws[eid as usize];
193                if !(w.is_finite() && w > 0.0) {
194                    biased_weights.push(0.0);
195                    continue;
196                }
197                w
198            }
199        };
200
201        let neighbor = graph.edge_other(eid, current)?;
202
203        // Apply `Node2Vec` bias based on distance from `prev` to `neighbor`
204        let alpha = if neighbor == prev {
205            inv_p // d_tx = 0: returning to previous
206        } else if prev_neighbors.contains(&neighbor) {
207            1.0 // d_tx = 1: neighbor of prev
208        } else {
209            inv_q // d_tx = 2: not neighbor of prev
210        };
211
212        let w = alpha * base_weight;
213        biased_weights.push(w);
214        total += w;
215    }
216
217    if total <= 0.0 {
218        return Ok(None);
219    }
220
221    // Weighted random selection
222    let target = rng.gen_unit() * total;
223    let mut acc = 0.0_f64;
224    for (i, &w) in biased_weights.iter().enumerate() {
225        if w <= 0.0 {
226            continue;
227        }
228        acc += w;
229        if acc >= target {
230            let eid = incidents[i];
231            let next = graph.edge_other(eid, current)?;
232            return Ok(Some((eid, next)));
233        }
234    }
235
236    // Floating-point fallback: pick last positive-weight edge
237    for (i, &w) in biased_weights.iter().enumerate().rev() {
238        if w > 0.0 {
239            let eid = incidents[i];
240            let next = graph.edge_other(eid, current)?;
241            return Ok(Some((eid, next)));
242        }
243    }
244
245    Ok(None)
246}
247
248/// Get a set of neighbors for efficient lookup.
249fn neighbor_set(graph: &Graph, v: VertexId, mode: DijkstraMode) -> IgraphResult<Vec<VertexId>> {
250    let incidents = incident_for_mode(graph, v, mode)?;
251    let mut neighbors: Vec<VertexId> = Vec::with_capacity(incidents.len());
252    for &eid in &incidents {
253        let other = graph.edge_other(eid, v)?;
254        neighbors.push(other);
255    }
256    neighbors.sort_unstable();
257    neighbors.dedup();
258    Ok(neighbors)
259}
260
261fn incident_for_mode(graph: &Graph, v: VertexId, mode: DijkstraMode) -> IgraphResult<Vec<EdgeId>> {
262    if !graph.is_directed() {
263        return graph.incident(v);
264    }
265    match mode {
266        DijkstraMode::Out => graph.incident(v),
267        DijkstraMode::In => graph.incident_in(v),
268        DijkstraMode::All => {
269            let mut out = graph.incident(v)?;
270            out.extend(graph.incident_in(v)?);
271            Ok(out)
272        }
273    }
274}
275
276fn weighted_pick(incidents: &[EdgeId], ws: &[f64], rng: &mut SplitMix64) -> Option<EdgeId> {
277    let mut total = 0.0_f64;
278    for &eid in incidents {
279        let w = ws[eid as usize];
280        if w.is_finite() && w > 0.0 {
281            total += w;
282        }
283    }
284    if total <= 0.0 {
285        return None;
286    }
287    let target = rng.gen_unit() * total;
288    let mut acc = 0.0_f64;
289    for &eid in incidents {
290        let w = ws[eid as usize];
291        if !(w.is_finite() && w > 0.0) {
292            continue;
293        }
294        acc += w;
295        if acc >= target {
296            return Some(eid);
297        }
298    }
299    // Fallback
300    for &eid in incidents.iter().rev() {
301        let w = ws[eid as usize];
302        if w.is_finite() && w > 0.0 {
303            return Some(eid);
304        }
305    }
306    None
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    fn path_graph(n: u32) -> Graph {
314        let mut g = Graph::with_vertices(n);
315        for i in 0..n - 1 {
316            g.add_edge(i, i + 1).unwrap();
317        }
318        g
319    }
320
321    fn grid_graph() -> Graph {
322        // 3x3 grid: 0-1-2 / 3-4-5 / 6-7-8
323        let mut g = Graph::with_vertices(9);
324        let edges = [
325            (0, 1),
326            (1, 2),
327            (3, 4),
328            (4, 5),
329            (6, 7),
330            (7, 8),
331            (0, 3),
332            (1, 4),
333            (2, 5),
334            (3, 6),
335            (4, 7),
336            (5, 8),
337        ];
338        for (u, v) in edges {
339            g.add_edge(u, v).unwrap();
340        }
341        g
342    }
343
344    #[test]
345    fn unit_basic_walk_length() {
346        let g = path_graph(10);
347        let (vs, es) =
348            random_walk_node2vec(&g, None, 0, DijkstraMode::Out, 5, 1.0, 1.0, 42).unwrap();
349        assert_eq!(vs[0], 0);
350        assert!(vs.len() <= 6);
351        assert_eq!(es.len(), vs.len() - 1);
352    }
353
354    #[test]
355    fn unit_p_q_one_reduces_to_standard() {
356        // With p=q=1 on a path graph, the walk should behave like a standard
357        // random walk (all alphas are 1.0 regardless of distance).
358        let g = path_graph(5);
359        let (vs, _) =
360            random_walk_node2vec(&g, None, 2, DijkstraMode::Out, 20, 1.0, 1.0, 123).unwrap();
361        assert_eq!(vs[0], 2);
362        // On an undirected path, every vertex has at most 2 neighbors
363        for v in &vs {
364            assert!(*v < 5);
365        }
366    }
367
368    #[test]
369    fn unit_high_p_discourages_return() {
370        // With very high p, the walk should rarely return immediately.
371        // On a grid starting from center (4), test over many walks.
372        let g = grid_graph();
373        let mut immediate_returns = 0;
374        for seed in 0..100 {
375            let (vs, _) =
376                random_walk_node2vec(&g, None, 4, DijkstraMode::Out, 3, 100.0, 1.0, seed).unwrap();
377            if vs.len() >= 3 && vs[2] == vs[0] {
378                immediate_returns += 1;
379            }
380        }
381        // With p=100, returning should be very rare (< 10% of walks)
382        assert!(
383            immediate_returns < 15,
384            "expected few immediate returns with high p, got {immediate_returns}/100"
385        );
386    }
387
388    #[test]
389    fn unit_low_p_encourages_return() {
390        // With very low p, the walk should often return immediately.
391        let g = grid_graph();
392        let mut immediate_returns = 0;
393        for seed in 0..100 {
394            let (vs, _) =
395                random_walk_node2vec(&g, None, 4, DijkstraMode::Out, 3, 0.01, 1.0, seed).unwrap();
396            if vs.len() >= 3 && vs[2] == vs[0] {
397                immediate_returns += 1;
398            }
399        }
400        // With p=0.01, returning should be very common (> 50% of walks)
401        assert!(
402            immediate_returns > 40,
403            "expected many immediate returns with low p, got {immediate_returns}/100"
404        );
405    }
406
407    #[test]
408    fn unit_invalid_p() {
409        let g = path_graph(5);
410        let result = random_walk_node2vec(&g, None, 0, DijkstraMode::Out, 5, 0.0, 1.0, 42);
411        assert!(result.is_err());
412
413        let result = random_walk_node2vec(&g, None, 0, DijkstraMode::Out, 5, -1.0, 1.0, 42);
414        assert!(result.is_err());
415
416        let result = random_walk_node2vec(&g, None, 0, DijkstraMode::Out, 5, f64::NAN, 1.0, 42);
417        assert!(result.is_err());
418    }
419
420    #[test]
421    fn unit_invalid_q() {
422        let g = path_graph(5);
423        let result = random_walk_node2vec(&g, None, 0, DijkstraMode::Out, 5, 1.0, 0.0, 42);
424        assert!(result.is_err());
425
426        let result =
427            random_walk_node2vec(&g, None, 0, DijkstraMode::Out, 5, 1.0, f64::INFINITY, 42);
428        assert!(result.is_err());
429    }
430
431    #[test]
432    fn unit_start_out_of_range() {
433        let g = path_graph(5);
434        let result = random_walk_node2vec(&g, None, 10, DijkstraMode::Out, 5, 1.0, 1.0, 42);
435        assert!(result.is_err());
436    }
437
438    #[test]
439    fn unit_zero_steps() {
440        let g = path_graph(5);
441        let (vs, es) =
442            random_walk_node2vec(&g, None, 2, DijkstraMode::Out, 0, 1.0, 1.0, 42).unwrap();
443        assert_eq!(vs, vec![2]);
444        assert!(es.is_empty());
445    }
446
447    #[test]
448    fn unit_stuck_at_leaf() {
449        // Directed path: 0→1→2→3. Starting from 3, walk gets stuck immediately.
450        let mut g = Graph::new(4, true).unwrap();
451        g.add_edge(0, 1).unwrap();
452        g.add_edge(1, 2).unwrap();
453        g.add_edge(2, 3).unwrap();
454        let (vs, es) =
455            random_walk_node2vec(&g, None, 3, DijkstraMode::Out, 10, 1.0, 1.0, 42).unwrap();
456        assert_eq!(vs, vec![3]);
457        assert!(es.is_empty());
458    }
459
460    #[test]
461    fn unit_weighted_walk() {
462        // Triangle 0-1-2 with weights favoring 0→1
463        let mut g = Graph::with_vertices(3);
464        g.add_edge(0, 1).unwrap(); // edge 0, weight 10
465        g.add_edge(1, 2).unwrap(); // edge 1, weight 1
466        g.add_edge(0, 2).unwrap(); // edge 2, weight 1
467
468        let weights = vec![10.0, 1.0, 1.0];
469        let (vs, _) =
470            random_walk_node2vec(&g, Some(&weights), 0, DijkstraMode::Out, 1, 1.0, 1.0, 42)
471                .unwrap();
472        assert_eq!(vs[0], 0);
473        assert!(vs.len() == 2);
474    }
475
476    #[test]
477    fn unit_deterministic() {
478        let g = grid_graph();
479        let r1 = random_walk_node2vec(&g, None, 4, DijkstraMode::Out, 20, 2.0, 0.5, 99).unwrap();
480        let r2 = random_walk_node2vec(&g, None, 4, DijkstraMode::Out, 20, 2.0, 0.5, 99).unwrap();
481        assert_eq!(r1, r2);
482    }
483
484    #[test]
485    fn unit_single_vertex_graph() {
486        let g = Graph::with_vertices(1);
487        let (vs, es) =
488            random_walk_node2vec(&g, None, 0, DijkstraMode::Out, 5, 1.0, 1.0, 42).unwrap();
489        assert_eq!(vs, vec![0]);
490        assert!(es.is_empty());
491    }
492}