rust_igraph/algorithms/nongraph/
sampling.rs1use crate::core::IgraphResult;
8use crate::core::error::IgraphError;
9use crate::core::rng::SplitMix64;
10
11pub fn sample_sphere_surface(
44 dim: usize,
45 n: usize,
46 radius: f64,
47 positive: bool,
48 seed: u64,
49) -> IgraphResult<Vec<Vec<f64>>> {
50 if dim < 2 {
51 return Err(IgraphError::InvalidArgument(
52 "sample_sphere_surface: dimension must be at least 2".to_string(),
53 ));
54 }
55 if radius <= 0.0 {
56 return Err(IgraphError::InvalidArgument(
57 "sample_sphere_surface: radius must be positive".to_string(),
58 ));
59 }
60
61 let mut rng = SplitMix64::new(seed);
62 let mut result = Vec::with_capacity(n);
63
64 for _ in 0..n {
65 let mut point = Vec::with_capacity(dim);
66 let mut sum_sq = 0.0_f64;
67 for _ in 0..dim {
68 let z = rng.gen_normal();
69 sum_sq += z * z;
70 point.push(z);
71 }
72 let norm = sum_sq.sqrt();
73 for c in &mut point {
74 *c = radius * *c / norm;
75 if positive {
76 *c = c.abs();
77 }
78 }
79 result.push(point);
80 }
81
82 Ok(result)
83}
84
85pub fn sample_sphere_volume(
117 dim: usize,
118 n: usize,
119 radius: f64,
120 positive: bool,
121 seed: u64,
122) -> IgraphResult<Vec<Vec<f64>>> {
123 if dim < 2 {
124 return Err(IgraphError::InvalidArgument(
125 "sample_sphere_volume: dimension must be at least 2".to_string(),
126 ));
127 }
128 if radius <= 0.0 {
129 return Err(IgraphError::InvalidArgument(
130 "sample_sphere_volume: radius must be positive".to_string(),
131 ));
132 }
133
134 let mut rng = SplitMix64::new(seed);
135 let inv_dim = 1.0 / dim as f64;
136 let mut result = Vec::with_capacity(n);
137
138 for _ in 0..n {
139 let mut point = Vec::with_capacity(dim);
140 let mut sum_sq = 0.0_f64;
141 for _ in 0..dim {
142 let z = rng.gen_normal();
143 sum_sq += z * z;
144 point.push(z);
145 }
146 let norm = sum_sq.sqrt();
147 let u = rng.gen_unit().powf(inv_dim);
148 for c in &mut point {
149 *c = radius * u * *c / norm;
150 if positive {
151 *c = c.abs();
152 }
153 }
154 result.push(point);
155 }
156
157 Ok(result)
158}
159
160pub fn sample_dirichlet(n: usize, alpha: &[f64], seed: u64) -> IgraphResult<Vec<Vec<f64>>> {
194 let dim = alpha.len();
195 if dim < 2 {
196 return Err(IgraphError::InvalidArgument(
197 "sample_dirichlet: alpha must have at least 2 entries".to_string(),
198 ));
199 }
200
201 for (i, &a) in alpha.iter().enumerate() {
202 if a <= 0.0 {
203 return Err(IgraphError::InvalidArgument(format!(
204 "sample_dirichlet: alpha[{i}] = {a}, must be positive"
205 )));
206 }
207 }
208
209 let mut rng = SplitMix64::new(seed);
210 let mut result = Vec::with_capacity(n);
211
212 for _ in 0..n {
213 let mut sample = Vec::with_capacity(dim);
214 let mut sum = 0.0_f64;
215 for &a in alpha {
216 let g = rng.gen_gamma(a);
217 sum += g;
218 sample.push(g);
219 }
220 for v in &mut sample {
221 *v /= sum;
222 }
223 result.push(sample);
224 }
225
226 Ok(result)
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn sphere_surface_on_unit_sphere() {
235 let points = sample_sphere_surface(3, 100, 1.0, false, 42).unwrap();
236 assert_eq!(points.len(), 100);
237 for p in &points {
238 assert_eq!(p.len(), 3);
239 let r2: f64 = p.iter().map(|&x| x * x).sum();
240 assert!(
241 (r2 - 1.0).abs() < 1e-10,
242 "point not on unit sphere: r²={r2}"
243 );
244 }
245 }
246
247 #[test]
248 fn sphere_surface_scaled_radius() {
249 let radius = 5.0;
250 let points = sample_sphere_surface(2, 50, radius, false, 99).unwrap();
251 for p in &points {
252 let r2: f64 = p.iter().map(|&x| x * x).sum();
253 assert!(
254 (r2 - radius * radius).abs() < 1e-8,
255 "point not on sphere of radius {radius}: r²={r2}"
256 );
257 }
258 }
259
260 #[test]
261 fn sphere_surface_positive_orthant() {
262 let points = sample_sphere_surface(3, 100, 1.0, true, 42).unwrap();
263 for p in &points {
264 for &c in p {
265 assert!(c >= 0.0, "expected non-negative, got {c}");
266 }
267 }
268 }
269
270 #[test]
271 fn sphere_surface_dim_1_error() {
272 assert!(sample_sphere_surface(1, 10, 1.0, false, 42).is_err());
273 }
274
275 #[test]
276 fn sphere_surface_negative_radius_error() {
277 assert!(sample_sphere_surface(2, 10, -1.0, false, 42).is_err());
278 }
279
280 #[test]
281 fn sphere_surface_zero_samples() {
282 let points = sample_sphere_surface(3, 0, 1.0, false, 42).unwrap();
283 assert!(points.is_empty());
284 }
285
286 #[test]
287 fn sphere_volume_inside_ball() {
288 let points = sample_sphere_volume(3, 200, 1.0, false, 42).unwrap();
289 assert_eq!(points.len(), 200);
290 for p in &points {
291 let r2: f64 = p.iter().map(|&x| x * x).sum();
292 assert!(r2 <= 1.0 + 1e-10, "point outside unit ball: r²={r2}");
293 }
294 }
295
296 #[test]
297 fn sphere_volume_not_all_on_surface() {
298 let points = sample_sphere_volume(3, 100, 1.0, false, 42).unwrap();
299 let on_surface = points
300 .iter()
301 .filter(|p| {
302 let r2: f64 = p.iter().map(|&x| x * x).sum();
303 (r2 - 1.0).abs() < 0.01
304 })
305 .count();
306 assert!(
307 on_surface < 100,
308 "all points on surface — volume sampling likely broken"
309 );
310 }
311
312 #[test]
313 fn sphere_volume_positive() {
314 let points = sample_sphere_volume(2, 100, 2.0, true, 42).unwrap();
315 for p in &points {
316 for &c in p {
317 assert!(c >= 0.0);
318 }
319 }
320 }
321
322 #[test]
323 fn sphere_volume_scaled() {
324 let radius = 3.0;
325 let points = sample_sphere_volume(2, 200, radius, false, 42).unwrap();
326 for p in &points {
327 let r2: f64 = p.iter().map(|&x| x * x).sum();
328 assert!(r2 <= radius * radius + 1e-8);
329 }
330 }
331
332 #[test]
333 fn dirichlet_sums_to_one() {
334 let alpha = [1.0, 2.0, 3.0];
335 let samples = sample_dirichlet(100, &alpha, 42).unwrap();
336 assert_eq!(samples.len(), 100);
337 for s in &samples {
338 assert_eq!(s.len(), 3);
339 let sum: f64 = s.iter().sum();
340 assert!(
341 (sum - 1.0).abs() < 1e-10,
342 "Dirichlet sample doesn't sum to 1: {sum}"
343 );
344 }
345 }
346
347 #[test]
348 fn dirichlet_all_positive() {
349 let alpha = [0.5, 0.5, 0.5];
350 let samples = sample_dirichlet(200, &alpha, 42).unwrap();
351 for s in &samples {
352 for &v in s {
353 assert!(v >= 0.0, "negative value in Dirichlet sample: {v}");
354 }
355 }
356 }
357
358 #[test]
359 fn dirichlet_mean_matches() {
360 let alpha = [1.0, 2.0, 3.0];
362 let alpha_sum: f64 = alpha.iter().sum();
363 let n = 50_000;
364 let samples = sample_dirichlet(n, &alpha, 42).unwrap();
365 for (j, &a) in alpha.iter().enumerate() {
366 let mean: f64 = samples.iter().map(|s| s[j]).sum::<f64>() / n as f64;
367 let expected = a / alpha_sum;
368 assert!(
369 (mean - expected).abs() < 0.02,
370 "dim {j}: mean={mean}, expected={expected}"
371 );
372 }
373 }
374
375 #[test]
376 fn dirichlet_short_alpha_error() {
377 assert!(sample_dirichlet(10, &[1.0], 42).is_err());
378 }
379
380 #[test]
381 fn dirichlet_non_positive_alpha_error() {
382 assert!(sample_dirichlet(10, &[1.0, -0.5], 42).is_err());
383 assert!(sample_dirichlet(10, &[1.0, 0.0], 42).is_err());
384 }
385
386 #[test]
387 fn dirichlet_zero_samples() {
388 let samples = sample_dirichlet(0, &[1.0, 2.0], 42).unwrap();
389 assert!(samples.is_empty());
390 }
391
392 #[test]
393 fn deterministic_same_seed() {
394 let a = sample_sphere_surface(3, 10, 1.0, false, 42).unwrap();
395 let b = sample_sphere_surface(3, 10, 1.0, false, 42).unwrap();
396 for (pa, pb) in a.iter().zip(b.iter()) {
397 for (ca, cb) in pa.iter().zip(pb.iter()) {
398 assert!((ca - cb).abs() < 1e-15);
399 }
400 }
401 }
402}