1use crate::algorithms::properties::is_simple::{SimpleMode, is_simple_with_mode};
27use crate::core::rng::SplitMix64;
28use crate::core::{Graph, IgraphError, IgraphResult};
29
30#[derive(Debug, Clone, PartialEq)]
35pub struct Sir {
36 pub times: Vec<f64>,
38 pub no_s: Vec<usize>,
40 pub no_i: Vec<usize>,
42 pub no_r: Vec<usize>,
44}
45
46struct PsumTree {
52 n: usize,
53 bit: Vec<f64>,
54 values: Vec<f64>,
55 total: f64,
56}
57
58impl PsumTree {
59 fn new(n: usize) -> Self {
60 Self {
61 n,
62 bit: vec![0.0; n + 1],
63 values: vec![0.0; n],
64 total: 0.0,
65 }
66 }
67
68 fn get(&self, i: usize) -> f64 {
69 self.values[i]
70 }
71
72 fn total(&self) -> f64 {
73 self.total
74 }
75
76 fn set(&mut self, i: usize, v: f64) {
77 let delta = v - self.values[i];
78 self.values[i] = v;
79 self.total += delta;
80 let mut k = i + 1;
81 while k <= self.n {
82 self.bit[k] += delta;
83 k += k & k.wrapping_neg();
84 }
85 }
86
87 fn reset(&mut self) {
88 for b in &mut self.bit {
89 *b = 0.0;
90 }
91 for v in &mut self.values {
92 *v = 0.0;
93 }
94 self.total = 0.0;
95 }
96
97 fn search(&self, target: f64) -> usize {
102 if self.n == 0 {
103 return 0;
104 }
105 let mut idx: usize = 0;
106 let mut remaining = target;
107 let mut step = 1usize;
108 while step.saturating_mul(2) <= self.n {
109 step *= 2;
110 }
111 while step > 0 {
112 let next = idx + step;
113 if next <= self.n && self.bit[next] <= remaining {
114 idx = next;
115 remaining -= self.bit[next];
116 }
117 step >>= 1;
118 }
119 idx.min(self.n - 1)
120 }
121}
122
123const S_S: u8 = 0;
124const S_I: u8 = 1;
125const S_R: u8 = 2;
126
127pub fn sir(
178 graph: &Graph,
179 beta: f64,
180 gamma: f64,
181 no_sim: usize,
182 seed: u64,
183) -> IgraphResult<Vec<Sir>> {
184 let n = graph.vcount() as usize;
185
186 if n == 0 {
187 return Err(IgraphError::InvalidArgument(
188 "Cannot run SIR model on empty graph.".to_string(),
189 ));
190 }
191 if beta < 0.0 {
192 return Err(IgraphError::InvalidArgument(format!(
193 "The infection rate beta must be non-negative (got {beta})."
194 )));
195 }
196 if gamma <= 0.0 {
197 return Err(IgraphError::InvalidArgument(format!(
198 "The recovery rate gamma must be positive (got {gamma})."
199 )));
200 }
201 if no_sim == 0 {
202 return Err(IgraphError::InvalidArgument(
203 "Number of SIR simulations must be positive.".to_string(),
204 ));
205 }
206 if !is_simple_with_mode(graph, SimpleMode::DirectedAsUndirected)? {
207 return Err(IgraphError::InvalidArgument(
208 "SIR model only works with simple graphs.".to_string(),
209 ));
210 }
211
212 let adj = build_undirected_adj(graph)?;
213 let mut rng = SplitMix64::new(seed);
214 let mut tree = PsumTree::new(n);
215 let mut status = vec![S_S; n];
216
217 let mut result = Vec::with_capacity(no_sim);
218 for _ in 0..no_sim {
219 result.push(run_one(
220 &adj,
221 beta,
222 gamma,
223 n,
224 &mut rng,
225 &mut tree,
226 &mut status,
227 ));
228 }
229 Ok(result)
230}
231
232fn build_undirected_adj(graph: &Graph) -> IgraphResult<Vec<Vec<usize>>> {
233 let n = graph.vcount() as usize;
234 let m = graph.ecount();
235 let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
236 for eid in 0..m {
237 let eid_u32 =
238 u32::try_from(eid).map_err(|_| IgraphError::Internal("ecount exceeds u32::MAX"))?;
239 let (src, tgt) = graph.edge(eid_u32)?;
240 adj[src as usize].push(tgt as usize);
242 adj[tgt as usize].push(src as usize);
243 }
244 Ok(adj)
245}
246
247fn run_one(
249 adj: &[Vec<usize>],
250 beta: f64,
251 gamma: f64,
252 n: usize,
253 rng: &mut SplitMix64,
254 tree: &mut PsumTree,
255 status: &mut [u8],
256) -> Sir {
257 let infected = rng.gen_index(n);
258
259 for s in status.iter_mut() {
260 *s = S_S;
261 }
262 status[infected] = S_I;
263 let mut ns = n - 1;
264 let mut ni = 1usize;
265 let mut nr = 0usize;
266
267 let mut times = vec![0.0_f64];
268 let mut no_s = vec![ns];
269 let mut no_i = vec![ni];
270 let mut no_r = vec![nr];
271
272 tree.reset();
273 tree.set(infected, gamma);
274 for &nei in &adj[infected] {
275 tree.set(nei, beta);
276 }
277
278 while ni > 0 {
279 let psum = tree.total();
280 let tt = -(1.0 - rng.gen_unit()).ln() / psum;
284 let r = rng.gen_unit() * psum;
285 let vchange = tree.search(r);
286
287 if status[vchange] == S_I {
288 status[vchange] = S_R;
289 ni -= 1;
290 nr += 1;
291 tree.set(vchange, 0.0);
292 for &nei in &adj[vchange] {
293 if status[nei] == S_S {
294 let mut rate = tree.get(nei) - beta;
295 if rate < 0.0 {
296 rate = 0.0;
297 }
298 tree.set(nei, rate);
299 }
300 }
301 } else {
302 status[vchange] = S_I;
303 ns -= 1;
304 ni += 1;
305 tree.set(vchange, gamma);
306 for &nei in &adj[vchange] {
307 if status[nei] == S_S {
308 let rate = tree.get(nei) + beta;
309 tree.set(nei, rate);
310 }
311 }
312 }
313
314 let last = *times.last().unwrap_or(&0.0);
315 times.push(tt + last);
316 no_s.push(ns);
317 no_i.push(ni);
318 no_r.push(nr);
319 }
320
321 Sir {
322 times,
323 no_s,
324 no_i,
325 no_r,
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 fn ring(n: u32) -> Graph {
334 let mut g = Graph::with_vertices(n);
335 for i in 0..n {
336 g.add_edge(i, (i + 1) % n).unwrap();
337 }
338 g
339 }
340
341 fn complete(n: u32) -> Graph {
342 let mut g = Graph::with_vertices(n);
343 for i in 0..n {
344 for j in (i + 1)..n {
345 g.add_edge(i, j).unwrap();
346 }
347 }
348 g
349 }
350
351 #[test]
352 fn empty_graph_errors() {
353 let g = Graph::with_vertices(0);
354 assert!(sir(&g, 1.0, 1.0, 1, 0).is_err());
355 }
356
357 #[test]
358 fn parameter_errors() {
359 let g = ring(5);
360 assert!(sir(&g, -0.1, 1.0, 1, 0).is_err()); assert!(sir(&g, 1.0, 0.0, 1, 0).is_err()); assert!(sir(&g, 1.0, -1.0, 1, 0).is_err()); assert!(sir(&g, 1.0, 1.0, 0, 0).is_err()); }
365
366 #[test]
367 fn non_simple_graph_errors() {
368 let mut g = Graph::with_vertices(3);
369 g.add_edge(0, 1).unwrap();
370 g.add_edge(0, 1).unwrap(); assert!(sir(&g, 1.0, 1.0, 1, 0).is_err());
372
373 let mut g2 = Graph::with_vertices(3);
374 g2.add_edge(0, 0).unwrap(); g2.add_edge(1, 2).unwrap();
376 assert!(sir(&g2, 1.0, 1.0, 1, 0).is_err());
377 }
378
379 #[test]
380 fn produces_requested_number_of_runs() {
381 let g = ring(10);
382 let runs = sir(&g, 2.0, 1.0, 7, 0xABCD).unwrap();
383 assert_eq!(runs.len(), 7);
384 }
385
386 #[test]
387 fn initial_state_is_consistent() {
388 let g = complete(6);
389 let runs = sir(&g, 1.0, 1.0, 5, 42).unwrap();
390 for run in &runs {
391 #[allow(clippy::float_cmp)]
392 {
393 assert_eq!(run.times[0], 0.0);
394 }
395 assert_eq!(run.no_i[0], 1);
396 assert_eq!(run.no_s[0], 5);
397 assert_eq!(run.no_r[0], 0);
398 }
399 }
400
401 #[test]
402 fn population_conserved_and_terminates() {
403 let g = complete(8);
404 let runs = sir(&g, 3.0, 1.0, 10, 0x1234_5678).unwrap();
405 for run in &runs {
406 let len = run.times.len();
407 assert_eq!(run.no_s.len(), len);
408 assert_eq!(run.no_i.len(), len);
409 assert_eq!(run.no_r.len(), len);
410 for k in 0..len {
411 assert_eq!(run.no_s[k] + run.no_i[k] + run.no_r[k], 8);
412 }
413 assert_eq!(*run.no_i.last().unwrap(), 0);
415 for k in 1..len {
417 assert!(run.no_s[k] <= run.no_s[k - 1]);
418 assert!(run.no_r[k] >= run.no_r[k - 1]);
419 }
420 }
421 }
422
423 #[test]
424 fn times_strictly_increasing() {
425 let g = complete(7);
426 let runs = sir(&g, 2.0, 1.0, 4, 0x9999).unwrap();
427 for run in &runs {
428 for k in 1..run.times.len() {
429 assert!(run.times[k] > run.times[k - 1]);
430 }
431 }
432 }
433
434 #[test]
435 fn deterministic_with_seed() {
436 let g = complete(6);
437 let a = sir(&g, 1.5, 0.7, 5, 0xDEAD_BEEF).unwrap();
438 let b = sir(&g, 1.5, 0.7, 5, 0xDEAD_BEEF).unwrap();
439 assert_eq!(a, b);
440 }
441
442 #[test]
443 fn different_seeds_differ() {
444 let g = complete(20);
445 let a = sir(&g, 2.0, 0.5, 1, 1).unwrap();
446 let b = sir(&g, 2.0, 0.5, 1, 2).unwrap();
447 assert!(a != b);
449 }
450
451 #[test]
452 fn zero_beta_recovers_immediately() {
453 let g = complete(5);
456 let runs = sir(&g, 0.0, 1.0, 6, 0x2468).unwrap();
457 for run in &runs {
458 assert_eq!(run.times.len(), 2);
459 assert_eq!(run.no_r.last().copied(), Some(1));
460 assert_eq!(run.no_s.last().copied(), Some(4));
461 }
462 }
463
464 #[test]
465 fn directed_graph_ignores_direction() {
466 let mut g = Graph::new(5, true).unwrap();
468 for i in 0..5u32 {
469 g.add_edge(i, (i + 1) % 5).unwrap();
470 }
471 let runs = sir(&g, 2.0, 1.0, 3, 0x55).unwrap();
472 assert_eq!(runs.len(), 3);
473 for run in &runs {
474 assert_eq!(*run.no_i.last().unwrap(), 0);
475 }
476 }
477}