Skip to main content

rust_igraph/algorithms/
epidemics.rs

1//! Epidemics models on graphs.
2//!
3//! Currently provides the **SIR** (susceptible–infected–recovered)
4//! stochastic model, a direct translation of `igraph_sir`
5//! (`references/igraph/src/misc/sir.c`).
6//!
7//! The simulation is a continuous-time Gillespie process. Every
8//! individual is in one of three states:
9//!
10//! * **S** — susceptible (can catch the disease),
11//! * **I** — infected (spreads the disease and may recover),
12//! * **R** — recovered (immune, inert).
13//!
14//! A susceptible vertex with `k` infected neighbours becomes infected at
15//! rate `k · beta`; an infected vertex recovers at rate `gamma`. Each run
16//! starts from a single, uniformly random infected vertex and stops when
17//! no infected individuals remain. Event times and the S/I/R population
18//! sizes are recorded after every state transition.
19//!
20//! Determinism: randomness comes from the project's `SplitMix64` PRNG
21//! seeded by the caller, so a given `seed` reproduces the same trajectory
22//! bit-for-bit. (This means trajectories will *not* coincide with
23//! upstream igraph, which uses a different RNG — only the statistical
24//! behaviour matches.)
25
26use crate::algorithms::properties::is_simple::{SimpleMode, is_simple_with_mode};
27use crate::core::rng::SplitMix64;
28use crate::core::{Graph, IgraphError, IgraphResult};
29
30/// Result of a single SIR simulation run.
31///
32/// All four vectors have the same length: one entry for the initial
33/// state plus one entry per recorded state transition.
34#[derive(Debug, Clone, PartialEq)]
35pub struct Sir {
36    /// Cumulative event times. `times[0] == 0.0`; strictly increasing.
37    pub times: Vec<f64>,
38    /// Number of susceptible individuals at each recorded time.
39    pub no_s: Vec<usize>,
40    /// Number of infected individuals at each recorded time.
41    pub no_i: Vec<usize>,
42    /// Number of recovered individuals at each recorded time.
43    pub no_r: Vec<usize>,
44}
45
46/// Fenwick (binary-indexed) tree holding per-vertex event rates, with
47/// O(log n) point update and O(log n) cumulative-rate search.
48///
49/// Mirrors `igraph_psumtree`: `search(r)` returns the vertex whose rate
50/// interval contains the cumulative target `r ∈ [0, total)`.
51struct PsumTree {
52    n: usize,
53    bit: Vec<f64>,
54    values: Vec<f64>,
55    total: f64,
56}
57
58impl PsumTree {
59    fn new(n: usize) -> Self {
60        Self {
61            n,
62            bit: vec![0.0; n + 1],
63            values: vec![0.0; n],
64            total: 0.0,
65        }
66    }
67
68    fn get(&self, i: usize) -> f64 {
69        self.values[i]
70    }
71
72    fn total(&self) -> f64 {
73        self.total
74    }
75
76    fn set(&mut self, i: usize, v: f64) {
77        let delta = v - self.values[i];
78        self.values[i] = v;
79        self.total += delta;
80        let mut k = i + 1;
81        while k <= self.n {
82            self.bit[k] += delta;
83            k += k & k.wrapping_neg();
84        }
85    }
86
87    fn reset(&mut self) {
88        for b in &mut self.bit {
89            *b = 0.0;
90        }
91        for v in &mut self.values {
92            *v = 0.0;
93        }
94        self.total = 0.0;
95    }
96
97    /// Smallest index whose inclusive prefix sum first exceeds `target`.
98    ///
99    /// `target` is expected in `[0, total)`. The result is clamped to
100    /// `[0, n)` so FP drift in the BIT can never index out of range.
101    fn search(&self, target: f64) -> usize {
102        if self.n == 0 {
103            return 0;
104        }
105        let mut idx: usize = 0;
106        let mut remaining = target;
107        let mut step = 1usize;
108        while step.saturating_mul(2) <= self.n {
109            step *= 2;
110        }
111        while step > 0 {
112            let next = idx + step;
113            if next <= self.n && self.bit[next] <= remaining {
114                idx = next;
115                remaining -= self.bit[next];
116            }
117            step >>= 1;
118        }
119        idx.min(self.n - 1)
120    }
121}
122
123const S_S: u8 = 0;
124const S_I: u8 = 1;
125const S_R: u8 = 2;
126
127/// Runs `no_sim` independent SIR epidemic simulations on `graph`.
128///
129/// Edge directions are ignored: an edge contributes to both endpoints'
130/// neighbourhoods. The graph must be *simple* in its undirected view
131/// (no self-loops, no parallel or mutual edges).
132///
133/// * `beta` — per-edge infection rate (rate for a susceptible with one
134///   infected neighbour); the rate scales linearly with the number of
135///   infected neighbours. Must be non-negative.
136/// * `gamma` — recovery rate of an infected individual. Must be strictly
137///   positive (otherwise the process would never terminate).
138/// * `no_sim` — number of independent runs. Must be positive.
139/// * `seed` — seed for the deterministic `SplitMix64` PRNG.
140///
141/// Returns one [`Sir`] trajectory per simulation.
142///
143/// # Errors
144///
145/// * The graph is empty (`vcount == 0`).
146/// * `beta < 0`, `gamma <= 0`, or `no_sim == 0`.
147/// * The graph is not simple in its undirected view.
148///
149/// # Examples
150///
151/// ```
152/// use rust_igraph::{Graph, sir};
153///
154/// // A small ring; every run starts with exactly one infected vertex.
155/// let mut g = Graph::with_vertices(5);
156/// g.add_edge(0, 1).unwrap();
157/// g.add_edge(1, 2).unwrap();
158/// g.add_edge(2, 3).unwrap();
159/// g.add_edge(3, 4).unwrap();
160/// g.add_edge(4, 0).unwrap();
161///
162/// let runs = sir(&g, 2.0, 1.0, 3, 0x5152).unwrap();
163/// assert_eq!(runs.len(), 3);
164/// for run in &runs {
165///     // Every trajectory starts at t = 0 with one infected, four susceptible.
166///     assert_eq!(run.times[0], 0.0);
167///     assert_eq!(run.no_i[0], 1);
168///     assert_eq!(run.no_s[0], 4);
169///     assert_eq!(run.no_r[0], 0);
170///     // Population is conserved at every step and ends with no infected.
171///     for k in 0..run.times.len() {
172///         assert_eq!(run.no_s[k] + run.no_i[k] + run.no_r[k], 5);
173///     }
174///     assert_eq!(*run.no_i.last().unwrap(), 0);
175/// }
176/// ```
177pub fn sir(
178    graph: &Graph,
179    beta: f64,
180    gamma: f64,
181    no_sim: usize,
182    seed: u64,
183) -> IgraphResult<Vec<Sir>> {
184    let n = graph.vcount() as usize;
185
186    if n == 0 {
187        return Err(IgraphError::InvalidArgument(
188            "Cannot run SIR model on empty graph.".to_string(),
189        ));
190    }
191    if beta < 0.0 {
192        return Err(IgraphError::InvalidArgument(format!(
193            "The infection rate beta must be non-negative (got {beta})."
194        )));
195    }
196    if gamma <= 0.0 {
197        return Err(IgraphError::InvalidArgument(format!(
198            "The recovery rate gamma must be positive (got {gamma})."
199        )));
200    }
201    if no_sim == 0 {
202        return Err(IgraphError::InvalidArgument(
203            "Number of SIR simulations must be positive.".to_string(),
204        ));
205    }
206    if !is_simple_with_mode(graph, SimpleMode::DirectedAsUndirected)? {
207        return Err(IgraphError::InvalidArgument(
208            "SIR model only works with simple graphs.".to_string(),
209        ));
210    }
211
212    let adj = build_undirected_adj(graph)?;
213    let mut rng = SplitMix64::new(seed);
214    let mut tree = PsumTree::new(n);
215    let mut status = vec![S_S; n];
216
217    let mut result = Vec::with_capacity(no_sim);
218    for _ in 0..no_sim {
219        result.push(run_one(
220            &adj,
221            beta,
222            gamma,
223            n,
224            &mut rng,
225            &mut tree,
226            &mut status,
227        ));
228    }
229    Ok(result)
230}
231
232fn build_undirected_adj(graph: &Graph) -> IgraphResult<Vec<Vec<usize>>> {
233    let n = graph.vcount() as usize;
234    let m = graph.ecount();
235    let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
236    for eid in 0..m {
237        let eid_u32 =
238            u32::try_from(eid).map_err(|_| IgraphError::Internal("ecount exceeds u32::MAX"))?;
239        let (src, tgt) = graph.edge(eid_u32)?;
240        // Simple-graph invariant guarantees src != tgt.
241        adj[src as usize].push(tgt as usize);
242        adj[tgt as usize].push(src as usize);
243    }
244    Ok(adj)
245}
246
247/// One SIR run. `tree` and `status` are reused across runs and reset here.
248fn run_one(
249    adj: &[Vec<usize>],
250    beta: f64,
251    gamma: f64,
252    n: usize,
253    rng: &mut SplitMix64,
254    tree: &mut PsumTree,
255    status: &mut [u8],
256) -> Sir {
257    let infected = rng.gen_index(n);
258
259    for s in status.iter_mut() {
260        *s = S_S;
261    }
262    status[infected] = S_I;
263    let mut ns = n - 1;
264    let mut ni = 1usize;
265    let mut nr = 0usize;
266
267    let mut times = vec![0.0_f64];
268    let mut no_s = vec![ns];
269    let mut no_i = vec![ni];
270    let mut no_r = vec![nr];
271
272    tree.reset();
273    tree.set(infected, gamma);
274    for &nei in &adj[infected] {
275        tree.set(nei, beta);
276    }
277
278    while ni > 0 {
279        let psum = tree.total();
280        // Exponential waiting time with rate `psum`: -ln(1-U)/psum.
281        // `psum > 0` is guaranteed because at least one infected vertex
282        // contributes rate gamma > 0.
283        let tt = -(1.0 - rng.gen_unit()).ln() / psum;
284        let r = rng.gen_unit() * psum;
285        let vchange = tree.search(r);
286
287        if status[vchange] == S_I {
288            status[vchange] = S_R;
289            ni -= 1;
290            nr += 1;
291            tree.set(vchange, 0.0);
292            for &nei in &adj[vchange] {
293                if status[nei] == S_S {
294                    let mut rate = tree.get(nei) - beta;
295                    if rate < 0.0 {
296                        rate = 0.0;
297                    }
298                    tree.set(nei, rate);
299                }
300            }
301        } else {
302            status[vchange] = S_I;
303            ns -= 1;
304            ni += 1;
305            tree.set(vchange, gamma);
306            for &nei in &adj[vchange] {
307                if status[nei] == S_S {
308                    let rate = tree.get(nei) + beta;
309                    tree.set(nei, rate);
310                }
311            }
312        }
313
314        let last = *times.last().unwrap_or(&0.0);
315        times.push(tt + last);
316        no_s.push(ns);
317        no_i.push(ni);
318        no_r.push(nr);
319    }
320
321    Sir {
322        times,
323        no_s,
324        no_i,
325        no_r,
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    fn ring(n: u32) -> Graph {
334        let mut g = Graph::with_vertices(n);
335        for i in 0..n {
336            g.add_edge(i, (i + 1) % n).unwrap();
337        }
338        g
339    }
340
341    fn complete(n: u32) -> Graph {
342        let mut g = Graph::with_vertices(n);
343        for i in 0..n {
344            for j in (i + 1)..n {
345                g.add_edge(i, j).unwrap();
346            }
347        }
348        g
349    }
350
351    #[test]
352    fn empty_graph_errors() {
353        let g = Graph::with_vertices(0);
354        assert!(sir(&g, 1.0, 1.0, 1, 0).is_err());
355    }
356
357    #[test]
358    fn parameter_errors() {
359        let g = ring(5);
360        assert!(sir(&g, -0.1, 1.0, 1, 0).is_err()); // beta < 0
361        assert!(sir(&g, 1.0, 0.0, 1, 0).is_err()); // gamma == 0
362        assert!(sir(&g, 1.0, -1.0, 1, 0).is_err()); // gamma < 0
363        assert!(sir(&g, 1.0, 1.0, 0, 0).is_err()); // no_sim == 0
364    }
365
366    #[test]
367    fn non_simple_graph_errors() {
368        let mut g = Graph::with_vertices(3);
369        g.add_edge(0, 1).unwrap();
370        g.add_edge(0, 1).unwrap(); // parallel edge
371        assert!(sir(&g, 1.0, 1.0, 1, 0).is_err());
372
373        let mut g2 = Graph::with_vertices(3);
374        g2.add_edge(0, 0).unwrap(); // self-loop
375        g2.add_edge(1, 2).unwrap();
376        assert!(sir(&g2, 1.0, 1.0, 1, 0).is_err());
377    }
378
379    #[test]
380    fn produces_requested_number_of_runs() {
381        let g = ring(10);
382        let runs = sir(&g, 2.0, 1.0, 7, 0xABCD).unwrap();
383        assert_eq!(runs.len(), 7);
384    }
385
386    #[test]
387    fn initial_state_is_consistent() {
388        let g = complete(6);
389        let runs = sir(&g, 1.0, 1.0, 5, 42).unwrap();
390        for run in &runs {
391            #[allow(clippy::float_cmp)]
392            {
393                assert_eq!(run.times[0], 0.0);
394            }
395            assert_eq!(run.no_i[0], 1);
396            assert_eq!(run.no_s[0], 5);
397            assert_eq!(run.no_r[0], 0);
398        }
399    }
400
401    #[test]
402    fn population_conserved_and_terminates() {
403        let g = complete(8);
404        let runs = sir(&g, 3.0, 1.0, 10, 0x1234_5678).unwrap();
405        for run in &runs {
406            let len = run.times.len();
407            assert_eq!(run.no_s.len(), len);
408            assert_eq!(run.no_i.len(), len);
409            assert_eq!(run.no_r.len(), len);
410            for k in 0..len {
411                assert_eq!(run.no_s[k] + run.no_i[k] + run.no_r[k], 8);
412            }
413            // Ends with nobody infected.
414            assert_eq!(*run.no_i.last().unwrap(), 0);
415            // S is non-increasing, R is non-decreasing.
416            for k in 1..len {
417                assert!(run.no_s[k] <= run.no_s[k - 1]);
418                assert!(run.no_r[k] >= run.no_r[k - 1]);
419            }
420        }
421    }
422
423    #[test]
424    fn times_strictly_increasing() {
425        let g = complete(7);
426        let runs = sir(&g, 2.0, 1.0, 4, 0x9999).unwrap();
427        for run in &runs {
428            for k in 1..run.times.len() {
429                assert!(run.times[k] > run.times[k - 1]);
430            }
431        }
432    }
433
434    #[test]
435    fn deterministic_with_seed() {
436        let g = complete(6);
437        let a = sir(&g, 1.5, 0.7, 5, 0xDEAD_BEEF).unwrap();
438        let b = sir(&g, 1.5, 0.7, 5, 0xDEAD_BEEF).unwrap();
439        assert_eq!(a, b);
440    }
441
442    #[test]
443    fn different_seeds_differ() {
444        let g = complete(20);
445        let a = sir(&g, 2.0, 0.5, 1, 1).unwrap();
446        let b = sir(&g, 2.0, 0.5, 1, 2).unwrap();
447        // Overwhelmingly likely to differ in length or trajectory.
448        assert!(a != b);
449    }
450
451    #[test]
452    fn zero_beta_recovers_immediately() {
453        // With beta == 0 nobody else is infected: the single initial
454        // case just recovers, giving exactly one transition.
455        let g = complete(5);
456        let runs = sir(&g, 0.0, 1.0, 6, 0x2468).unwrap();
457        for run in &runs {
458            assert_eq!(run.times.len(), 2);
459            assert_eq!(run.no_r.last().copied(), Some(1));
460            assert_eq!(run.no_s.last().copied(), Some(4));
461        }
462    }
463
464    #[test]
465    fn directed_graph_ignores_direction() {
466        // A directed ring is simple in undirected view; SIR should run.
467        let mut g = Graph::new(5, true).unwrap();
468        for i in 0..5u32 {
469            g.add_edge(i, (i + 1) % 5).unwrap();
470        }
471        let runs = sir(&g, 2.0, 1.0, 3, 0x55).unwrap();
472        assert_eq!(runs.len(), 3);
473        for run in &runs {
474            assert_eq!(*run.no_i.last().unwrap(), 0);
475        }
476    }
477}