rust_igraph/algorithms/nongraph/
random_sample.rs1use crate::core::IgraphResult;
11use crate::core::error::IgraphError;
12use crate::core::rng::SplitMix64;
13
14pub fn random_sample(l: i64, h: i64, length: usize, seed: u64) -> IgraphResult<Vec<i64>> {
52 if l > h {
53 return Err(IgraphError::InvalidArgument(
54 "random_sample: lower limit is greater than upper limit".to_string(),
55 ));
56 }
57
58 let pool_size = (h - l).checked_add(1).ok_or_else(|| {
59 IgraphError::InvalidArgument("random_sample: interval overflows".to_string())
60 })?;
61 let pool_size_u = u64::try_from(pool_size).map_err(|_| {
62 IgraphError::InvalidArgument("random_sample: interval overflows u64".to_string())
63 })?;
64
65 if (length as u64) > pool_size_u {
66 return Err(IgraphError::InvalidArgument(
67 "random_sample: sample size exceeds size of candidate pool".to_string(),
68 ));
69 }
70
71 if l == h {
72 return Ok(vec![l]);
73 }
74 if length == 0 {
75 return Ok(Vec::new());
76 }
77 let length_u64 = length as u64;
78 if length_u64 == pool_size_u {
79 return Ok((l..=h).collect());
80 }
81
82 let mut rng = SplitMix64::new(seed);
83 let mut result = Vec::with_capacity(length);
84
85 let mut n_real = length as f64;
86 #[allow(clippy::float_cmp)]
87 let n_inv = if n_real == 0.0 { 0.0 } else { 1.0 / n_real };
88 let mut big_n = pool_size_u;
89 let mut big_n_real = big_n as f64;
90 let mut vprime = rng.gen_unit().powf(n_inv);
91 let mut cur = l - 1;
92 let mut n_remaining = length as u64;
93 let mut qu1 = big_n.wrapping_sub(n_remaining).wrapping_add(1);
94 let mut qu1_real = big_n_real - n_real + 1.0;
95 let neg_alpha_inv: f64 = -13.0;
96 let mut threshold = (neg_alpha_inv * n_real).abs();
97
98 while n_remaining > 1 && threshold < big_n_real {
99 let nmin1inv = 1.0 / (-1.0 + n_real);
100
101 let skip;
102 loop {
103 let mut x;
104 loop {
105 x = big_n_real * (1.0 - vprime);
106 let s_candidate = x as u64;
107 if s_candidate < qu1 {
108 break;
109 }
110 vprime = rng.gen_unit().powf(n_inv);
111 }
112
113 let s = x as u64;
114 let u = rng.gen_unit();
115 let neg_s_real = -(s as f64);
116
117 let y1 = (u * big_n_real / qu1_real).powf(nmin1inv);
118 vprime = y1 * (1.0 - x / big_n_real) * (qu1_real / (neg_s_real + qu1_real));
119 if vprime <= 1.0 {
120 skip = s;
121 break;
122 }
123
124 let mut y2 = 1.0_f64;
125 let mut top = big_n_real - 1.0;
126 let (mut bottom, limit);
127 if n_remaining - 1 > s {
128 bottom = big_n_real - n_real;
129 limit = big_n - s;
130 } else {
131 bottom = big_n_real - 1.0 + neg_s_real;
132 limit = qu1;
133 }
134
135 let mut t = big_n - 1;
136 while t >= limit {
137 y2 = (y2 * top) / bottom;
138 top -= 1.0;
139 bottom -= 1.0;
140 t -= 1;
141 }
142
143 if big_n_real / (big_n_real - x) >= y1 * y2.powf(nmin1inv) {
144 vprime = rng.gen_unit().powf(nmin1inv);
145 skip = s;
146 break;
147 }
148 vprime = rng.gen_unit().powf(n_inv);
149 }
150
151 cur = cur.checked_add(skip as i64 + 1).ok_or_else(|| {
152 IgraphError::InvalidArgument("random_sample: overflow in position".to_string())
153 })?;
154 result.push(cur);
155
156 big_n -= skip + 1;
157 big_n_real -= skip as f64 + 1.0;
158 n_remaining -= 1;
159 n_real -= 1.0;
160 qu1 -= skip;
161 qu1_real -= skip as f64;
162 threshold -= neg_alpha_inv;
163 }
164
165 if n_remaining > 1 {
166 algorithm_a(&mut rng, &mut result, cur + 1, h, n_remaining);
167 } else {
168 let s = (big_n_real * vprime) as i64;
169 cur = cur.checked_add(s + 1).ok_or_else(|| {
170 IgraphError::InvalidArgument("random_sample: overflow in final position".to_string())
171 })?;
172 result.push(cur);
173 }
174
175 Ok(result)
176}
177
178fn algorithm_a(rng: &mut SplitMix64, result: &mut Vec<i64>, l: i64, h: i64, length: u64) {
180 let mut big_n = (h - l + 1) as f64;
181 let mut n = length;
182 let mut cur = l - 1;
183
184 while n >= 2 {
185 let v = rng.gen_unit();
186 let mut s: i64 = 1;
187 let mut quot = (big_n - n as f64) / big_n;
188 while quot > v {
189 s += 1;
190 big_n -= 1.0;
191 quot = (quot * (big_n - n as f64)) / big_n;
192 }
193 cur += s;
194 result.push(cur);
195 big_n -= 1.0;
196 n -= 1;
197 }
198
199 let s = (big_n * rng.gen_unit()).trunc() as i64;
200 cur += s + 1;
201 result.push(cur);
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn empty_sample() {
210 let result = random_sample(1, 100, 0, 42).unwrap();
211 assert!(result.is_empty());
212 }
213
214 #[test]
215 fn single_element_interval() {
216 let result = random_sample(5, 5, 1, 42).unwrap();
217 assert_eq!(result, vec![5]);
218 }
219
220 #[test]
221 fn full_interval() {
222 let result = random_sample(10, 15, 6, 42).unwrap();
223 assert_eq!(result, vec![10, 11, 12, 13, 14, 15]);
224 }
225
226 #[test]
227 fn error_l_greater_than_h() {
228 assert!(random_sample(10, 5, 1, 42).is_err());
229 }
230
231 #[test]
232 fn error_length_exceeds_pool() {
233 assert!(random_sample(1, 5, 10, 42).is_err());
234 }
235
236 #[test]
237 fn result_is_sorted_ascending() {
238 let result = random_sample(0, 1_000_000, 100, 12345).unwrap();
239 assert_eq!(result.len(), 100);
240 for w in result.windows(2) {
241 assert!(w[0] < w[1], "not sorted: {} >= {}", w[0], w[1]);
242 }
243 }
244
245 #[test]
246 fn all_values_in_range() {
247 let result = random_sample(-50, 50, 30, 99).unwrap();
248 assert_eq!(result.len(), 30);
249 for &v in &result {
250 assert!(v >= -50 && v <= 50, "value {v} out of range");
251 }
252 }
253
254 #[test]
255 fn no_duplicates() {
256 let result = random_sample(1, 1000, 200, 777).unwrap();
257 let mut deduped = result.clone();
258 deduped.dedup();
259 assert_eq!(result.len(), deduped.len());
260 }
261
262 #[test]
263 fn deterministic_same_seed() {
264 let a = random_sample(0, 999, 50, 42).unwrap();
265 let b = random_sample(0, 999, 50, 42).unwrap();
266 assert_eq!(a, b);
267 }
268
269 #[test]
270 fn different_seeds_differ() {
271 let a = random_sample(0, 999, 50, 1).unwrap();
272 let b = random_sample(0, 999, 50, 2).unwrap();
273 assert_ne!(a, b);
274 }
275
276 #[test]
277 fn negative_range() {
278 let result = random_sample(-100, -1, 10, 42).unwrap();
279 assert_eq!(result.len(), 10);
280 for &v in &result {
281 assert!(v >= -100 && v <= -1);
282 }
283 for w in result.windows(2) {
284 assert!(w[0] < w[1]);
285 }
286 }
287
288 #[test]
289 fn large_interval_small_sample() {
290 let result = random_sample(0, 1_000_000_000, 10, 42).unwrap();
291 assert_eq!(result.len(), 10);
292 for w in result.windows(2) {
293 assert!(w[0] < w[1]);
294 }
295 }
296
297 #[test]
298 fn sample_size_one() {
299 let result = random_sample(1, 100, 1, 42).unwrap();
300 assert_eq!(result.len(), 1);
301 assert!(result[0] >= 1 && result[0] <= 100);
302 }
303
304 #[test]
305 fn sample_size_two() {
306 let result = random_sample(1, 100, 2, 42).unwrap();
307 assert_eq!(result.len(), 2);
308 assert!(result[0] < result[1]);
309 }
310
311 #[test]
312 fn sample_nearly_full() {
313 let result = random_sample(1, 20, 19, 42).unwrap();
314 assert_eq!(result.len(), 19);
315 for w in result.windows(2) {
316 assert!(w[0] < w[1]);
317 }
318 for &v in &result {
319 assert!(v >= 1 && v <= 20);
320 }
321 }
322
323 #[test]
324 fn statistical_uniformity() {
325 let mut counts = [0u32; 10];
328 for seed in 0..10_000u64 {
329 let result = random_sample(0, 9, 1, seed).unwrap();
330 counts[result[0] as usize] += 1;
331 }
332 for (i, &c) in counts.iter().enumerate() {
333 let expected = 1000.0;
334 let deviation = (c as f64 - expected).abs() / expected;
335 assert!(
336 deviation < 0.1,
337 "value {i} appeared {c} times (expected ~{expected})"
338 );
339 }
340 }
341}