Skip to main content

rust_igraph/algorithms/nongraph/
random_sample.rs

1//! Random sampling without replacement (ALGO-NG-001).
2//!
3//! Counterpart of `igraph_random_sample()` from
4//! `references/igraph/src/random/random.c`.
5//!
6//! Generates an increasing (sorted) random sequence of integers from a
7//! given interval using Vitter's Algorithm D (1987), with Algorithm A
8//! as a fallback for the tail. Expected time complexity is O(length).
9
10use crate::core::IgraphResult;
11use crate::core::error::IgraphError;
12use crate::core::rng::SplitMix64;
13
14/// Generate an increasing random sequence of integers from `[l, h]`.
15///
16/// Returns a sorted `Vec<i64>` of `length` distinct integers sampled
17/// uniformly without replacement from the inclusive interval `[l, h]`.
18/// The algorithm is Vitter's Algorithm D (1987), which runs in expected
19/// O(`length`) time regardless of the interval size.
20///
21/// # Arguments
22///
23/// * `l` — lower bound of the sampling interval (inclusive).
24/// * `h` — upper bound of the sampling interval (inclusive).
25/// * `length` — number of integers to sample.
26/// * `seed` — seed for the internal PRNG (deterministic for a given seed).
27///
28/// # Errors
29///
30/// Returns `InvalidArgument` if:
31/// - `l > h` (empty interval).
32/// - `length` exceeds the number of integers in `[l, h]`.
33///
34/// # Examples
35///
36/// ```
37/// use rust_igraph::random_sample;
38///
39/// // Sample 5 integers from [1, 100]
40/// let sample = random_sample(1, 100, 5, 42).unwrap();
41/// assert_eq!(sample.len(), 5);
42/// // Result is sorted ascending
43/// for w in sample.windows(2) {
44///     assert!(w[0] < w[1]);
45/// }
46/// // All values in range
47/// for &v in &sample {
48///     assert!(v >= 1 && v <= 100);
49/// }
50/// ```
51pub fn random_sample(l: i64, h: i64, length: usize, seed: u64) -> IgraphResult<Vec<i64>> {
52    if l > h {
53        return Err(IgraphError::InvalidArgument(
54            "random_sample: lower limit is greater than upper limit".to_string(),
55        ));
56    }
57
58    let pool_size = (h - l).checked_add(1).ok_or_else(|| {
59        IgraphError::InvalidArgument("random_sample: interval overflows".to_string())
60    })?;
61    let pool_size_u = u64::try_from(pool_size).map_err(|_| {
62        IgraphError::InvalidArgument("random_sample: interval overflows u64".to_string())
63    })?;
64
65    if (length as u64) > pool_size_u {
66        return Err(IgraphError::InvalidArgument(
67            "random_sample: sample size exceeds size of candidate pool".to_string(),
68        ));
69    }
70
71    if l == h {
72        return Ok(vec![l]);
73    }
74    if length == 0 {
75        return Ok(Vec::new());
76    }
77    let length_u64 = length as u64;
78    if length_u64 == pool_size_u {
79        return Ok((l..=h).collect());
80    }
81
82    let mut rng = SplitMix64::new(seed);
83    let mut result = Vec::with_capacity(length);
84
85    let mut n_real = length as f64;
86    #[allow(clippy::float_cmp)]
87    let n_inv = if n_real == 0.0 { 0.0 } else { 1.0 / n_real };
88    let mut big_n = pool_size_u;
89    let mut big_n_real = big_n as f64;
90    let mut vprime = rng.gen_unit().powf(n_inv);
91    let mut cur = l - 1;
92    let mut n_remaining = length as u64;
93    let mut qu1 = big_n.wrapping_sub(n_remaining).wrapping_add(1);
94    let mut qu1_real = big_n_real - n_real + 1.0;
95    let neg_alpha_inv: f64 = -13.0;
96    let mut threshold = (neg_alpha_inv * n_real).abs();
97
98    while n_remaining > 1 && threshold < big_n_real {
99        let nmin1inv = 1.0 / (-1.0 + n_real);
100
101        let skip;
102        loop {
103            let mut x;
104            loop {
105                x = big_n_real * (1.0 - vprime);
106                let s_candidate = x as u64;
107                if s_candidate < qu1 {
108                    break;
109                }
110                vprime = rng.gen_unit().powf(n_inv);
111            }
112
113            let s = x as u64;
114            let u = rng.gen_unit();
115            let neg_s_real = -(s as f64);
116
117            let y1 = (u * big_n_real / qu1_real).powf(nmin1inv);
118            vprime = y1 * (1.0 - x / big_n_real) * (qu1_real / (neg_s_real + qu1_real));
119            if vprime <= 1.0 {
120                skip = s;
121                break;
122            }
123
124            let mut y2 = 1.0_f64;
125            let mut top = big_n_real - 1.0;
126            let (mut bottom, limit);
127            if n_remaining - 1 > s {
128                bottom = big_n_real - n_real;
129                limit = big_n - s;
130            } else {
131                bottom = big_n_real - 1.0 + neg_s_real;
132                limit = qu1;
133            }
134
135            let mut t = big_n - 1;
136            while t >= limit {
137                y2 = (y2 * top) / bottom;
138                top -= 1.0;
139                bottom -= 1.0;
140                t -= 1;
141            }
142
143            if big_n_real / (big_n_real - x) >= y1 * y2.powf(nmin1inv) {
144                vprime = rng.gen_unit().powf(nmin1inv);
145                skip = s;
146                break;
147            }
148            vprime = rng.gen_unit().powf(n_inv);
149        }
150
151        cur = cur.checked_add(skip as i64 + 1).ok_or_else(|| {
152            IgraphError::InvalidArgument("random_sample: overflow in position".to_string())
153        })?;
154        result.push(cur);
155
156        big_n -= skip + 1;
157        big_n_real -= skip as f64 + 1.0;
158        n_remaining -= 1;
159        n_real -= 1.0;
160        qu1 -= skip;
161        qu1_real -= skip as f64;
162        threshold -= neg_alpha_inv;
163    }
164
165    if n_remaining > 1 {
166        algorithm_a(&mut rng, &mut result, cur + 1, h, n_remaining);
167    } else {
168        let s = (big_n_real * vprime) as i64;
169        cur = cur.checked_add(s + 1).ok_or_else(|| {
170            IgraphError::InvalidArgument("random_sample: overflow in final position".to_string())
171        })?;
172        result.push(cur);
173    }
174
175    Ok(result)
176}
177
178/// Vitter's Algorithm A — simple sequential sampling fallback.
179fn algorithm_a(rng: &mut SplitMix64, result: &mut Vec<i64>, l: i64, h: i64, length: u64) {
180    let mut big_n = (h - l + 1) as f64;
181    let mut n = length;
182    let mut cur = l - 1;
183
184    while n >= 2 {
185        let v = rng.gen_unit();
186        let mut s: i64 = 1;
187        let mut quot = (big_n - n as f64) / big_n;
188        while quot > v {
189            s += 1;
190            big_n -= 1.0;
191            quot = (quot * (big_n - n as f64)) / big_n;
192        }
193        cur += s;
194        result.push(cur);
195        big_n -= 1.0;
196        n -= 1;
197    }
198
199    let s = (big_n * rng.gen_unit()).trunc() as i64;
200    cur += s + 1;
201    result.push(cur);
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn empty_sample() {
210        let result = random_sample(1, 100, 0, 42).unwrap();
211        assert!(result.is_empty());
212    }
213
214    #[test]
215    fn single_element_interval() {
216        let result = random_sample(5, 5, 1, 42).unwrap();
217        assert_eq!(result, vec![5]);
218    }
219
220    #[test]
221    fn full_interval() {
222        let result = random_sample(10, 15, 6, 42).unwrap();
223        assert_eq!(result, vec![10, 11, 12, 13, 14, 15]);
224    }
225
226    #[test]
227    fn error_l_greater_than_h() {
228        assert!(random_sample(10, 5, 1, 42).is_err());
229    }
230
231    #[test]
232    fn error_length_exceeds_pool() {
233        assert!(random_sample(1, 5, 10, 42).is_err());
234    }
235
236    #[test]
237    fn result_is_sorted_ascending() {
238        let result = random_sample(0, 1_000_000, 100, 12345).unwrap();
239        assert_eq!(result.len(), 100);
240        for w in result.windows(2) {
241            assert!(w[0] < w[1], "not sorted: {} >= {}", w[0], w[1]);
242        }
243    }
244
245    #[test]
246    fn all_values_in_range() {
247        let result = random_sample(-50, 50, 30, 99).unwrap();
248        assert_eq!(result.len(), 30);
249        for &v in &result {
250            assert!(v >= -50 && v <= 50, "value {v} out of range");
251        }
252    }
253
254    #[test]
255    fn no_duplicates() {
256        let result = random_sample(1, 1000, 200, 777).unwrap();
257        let mut deduped = result.clone();
258        deduped.dedup();
259        assert_eq!(result.len(), deduped.len());
260    }
261
262    #[test]
263    fn deterministic_same_seed() {
264        let a = random_sample(0, 999, 50, 42).unwrap();
265        let b = random_sample(0, 999, 50, 42).unwrap();
266        assert_eq!(a, b);
267    }
268
269    #[test]
270    fn different_seeds_differ() {
271        let a = random_sample(0, 999, 50, 1).unwrap();
272        let b = random_sample(0, 999, 50, 2).unwrap();
273        assert_ne!(a, b);
274    }
275
276    #[test]
277    fn negative_range() {
278        let result = random_sample(-100, -1, 10, 42).unwrap();
279        assert_eq!(result.len(), 10);
280        for &v in &result {
281            assert!(v >= -100 && v <= -1);
282        }
283        for w in result.windows(2) {
284            assert!(w[0] < w[1]);
285        }
286    }
287
288    #[test]
289    fn large_interval_small_sample() {
290        let result = random_sample(0, 1_000_000_000, 10, 42).unwrap();
291        assert_eq!(result.len(), 10);
292        for w in result.windows(2) {
293            assert!(w[0] < w[1]);
294        }
295    }
296
297    #[test]
298    fn sample_size_one() {
299        let result = random_sample(1, 100, 1, 42).unwrap();
300        assert_eq!(result.len(), 1);
301        assert!(result[0] >= 1 && result[0] <= 100);
302    }
303
304    #[test]
305    fn sample_size_two() {
306        let result = random_sample(1, 100, 2, 42).unwrap();
307        assert_eq!(result.len(), 2);
308        assert!(result[0] < result[1]);
309    }
310
311    #[test]
312    fn sample_nearly_full() {
313        let result = random_sample(1, 20, 19, 42).unwrap();
314        assert_eq!(result.len(), 19);
315        for w in result.windows(2) {
316            assert!(w[0] < w[1]);
317        }
318        for &v in &result {
319            assert!(v >= 1 && v <= 20);
320        }
321    }
322
323    #[test]
324    fn statistical_uniformity() {
325        // Over many runs with different seeds, each value in [0, 9] should
326        // appear roughly equally often when sampling 1 from [0, 9].
327        let mut counts = [0u32; 10];
328        for seed in 0..10_000u64 {
329            let result = random_sample(0, 9, 1, seed).unwrap();
330            counts[result[0] as usize] += 1;
331        }
332        for (i, &c) in counts.iter().enumerate() {
333            let expected = 1000.0;
334            let deviation = (c as f64 - expected).abs() / expected;
335            assert!(
336                deviation < 0.1,
337                "value {i} appeared {c} times (expected ~{expected})"
338            );
339        }
340    }
341}