Skip to main content

rust_igraph/algorithms/nongraph/
sampling.rs

1//! Sphere-surface, sphere-volume, and Dirichlet sampling (ALGO-NG-002..004).
2//!
3//! Counterpart of `igraph_rng_sample_sphere_surface()`,
4//! `igraph_rng_sample_sphere_volume()`, and `igraph_rng_sample_dirichlet()`
5//! from `references/igraph/src/random/sampling.c`.
6
7use crate::core::IgraphResult;
8use crate::core::error::IgraphError;
9use crate::core::rng::SplitMix64;
10
11/// Sample points uniformly from the surface of a sphere.
12///
13/// Generates `n` points uniformly distributed on the surface of a
14/// `dim`-dimensional sphere of the given `radius`, centered at the
15/// origin. Uses the Muller (1959) method: generate `dim` independent
16/// standard normals, normalize to unit length, then scale by `radius`.
17///
18/// If `positive` is `true`, all coordinates are mapped to their
19/// absolute values (restricting to the positive orthant).
20///
21/// Returns a `Vec<Vec<f64>>` where each inner `Vec` has `dim` elements.
22///
23/// # Errors
24///
25/// Returns `InvalidArgument` if:
26/// - `dim < 2`
27/// - `radius <= 0`
28///
29/// # Examples
30///
31/// ```
32/// use rust_igraph::sample_sphere_surface;
33///
34/// let points = sample_sphere_surface(3, 2, 1.0, false, 42).unwrap();
35/// assert_eq!(points.len(), 2);
36/// assert_eq!(points[0].len(), 3);
37/// // Each point lies on the unit sphere
38/// for p in &points {
39///     let r2: f64 = p.iter().map(|&x| x * x).sum();
40///     assert!((r2 - 1.0).abs() < 1e-10);
41/// }
42/// ```
43pub fn sample_sphere_surface(
44    dim: usize,
45    n: usize,
46    radius: f64,
47    positive: bool,
48    seed: u64,
49) -> IgraphResult<Vec<Vec<f64>>> {
50    if dim < 2 {
51        return Err(IgraphError::InvalidArgument(
52            "sample_sphere_surface: dimension must be at least 2".to_string(),
53        ));
54    }
55    if radius <= 0.0 {
56        return Err(IgraphError::InvalidArgument(
57            "sample_sphere_surface: radius must be positive".to_string(),
58        ));
59    }
60
61    let mut rng = SplitMix64::new(seed);
62    let mut result = Vec::with_capacity(n);
63
64    for _ in 0..n {
65        let mut point = Vec::with_capacity(dim);
66        let mut sum_sq = 0.0_f64;
67        for _ in 0..dim {
68            let z = rng.gen_normal();
69            sum_sq += z * z;
70            point.push(z);
71        }
72        let norm = sum_sq.sqrt();
73        for c in &mut point {
74            *c = radius * *c / norm;
75            if positive {
76                *c = c.abs();
77            }
78        }
79        result.push(point);
80    }
81
82    Ok(result)
83}
84
85/// Sample points uniformly from the volume of a sphere.
86///
87/// Generates `n` points uniformly distributed inside a
88/// `dim`-dimensional ball of the given `radius`, centered at the
89/// origin. Uses sphere-surface sampling followed by radial scaling
90/// `U^(1/dim)` where `U ~ U(0,1)`.
91///
92/// If `positive` is `true`, all coordinates are mapped to their
93/// absolute values (restricting to the positive orthant).
94///
95/// Returns a `Vec<Vec<f64>>` where each inner `Vec` has `dim` elements.
96///
97/// # Errors
98///
99/// Returns `InvalidArgument` if:
100/// - `dim < 2`
101/// - `radius <= 0`
102///
103/// # Examples
104///
105/// ```
106/// use rust_igraph::sample_sphere_volume;
107///
108/// let points = sample_sphere_volume(2, 100, 1.0, false, 42).unwrap();
109/// assert_eq!(points.len(), 100);
110/// // All points inside the unit disk
111/// for p in &points {
112///     let r2: f64 = p.iter().map(|&x| x * x).sum();
113///     assert!(r2 <= 1.0 + 1e-10);
114/// }
115/// ```
116pub fn sample_sphere_volume(
117    dim: usize,
118    n: usize,
119    radius: f64,
120    positive: bool,
121    seed: u64,
122) -> IgraphResult<Vec<Vec<f64>>> {
123    if dim < 2 {
124        return Err(IgraphError::InvalidArgument(
125            "sample_sphere_volume: dimension must be at least 2".to_string(),
126        ));
127    }
128    if radius <= 0.0 {
129        return Err(IgraphError::InvalidArgument(
130            "sample_sphere_volume: radius must be positive".to_string(),
131        ));
132    }
133
134    let mut rng = SplitMix64::new(seed);
135    let inv_dim = 1.0 / dim as f64;
136    let mut result = Vec::with_capacity(n);
137
138    for _ in 0..n {
139        let mut point = Vec::with_capacity(dim);
140        let mut sum_sq = 0.0_f64;
141        for _ in 0..dim {
142            let z = rng.gen_normal();
143            sum_sq += z * z;
144            point.push(z);
145        }
146        let norm = sum_sq.sqrt();
147        let u = rng.gen_unit().powf(inv_dim);
148        for c in &mut point {
149            *c = radius * u * *c / norm;
150            if positive {
151                *c = c.abs();
152            }
153        }
154        result.push(point);
155    }
156
157    Ok(result)
158}
159
160/// Sample points from a Dirichlet distribution.
161///
162/// Generates `n` vectors drawn from a Dirichlet distribution with
163/// concentration parameters `alpha`. Each sample sums to 1.0. Uses
164/// the Gamma-based method: draw independent `Gamma(alpha_i, 1)` values
165/// and normalize.
166///
167/// Returns a `Vec<Vec<f64>>` where each inner `Vec` has `alpha.len()`
168/// elements summing to 1.0.
169///
170/// # Errors
171///
172/// Returns `InvalidArgument` if:
173/// - `alpha.len() < 2`
174/// - Any element of `alpha` is non-positive.
175///
176/// # Examples
177///
178/// ```
179/// use rust_igraph::sample_dirichlet;
180///
181/// let alpha = [1.0, 2.0, 3.0];
182/// let samples = sample_dirichlet(5, &alpha, 42).unwrap();
183/// assert_eq!(samples.len(), 5);
184/// for s in &samples {
185///     assert_eq!(s.len(), 3);
186///     let sum: f64 = s.iter().sum();
187///     assert!((sum - 1.0).abs() < 1e-10);
188///     for &v in s {
189///         assert!(v >= 0.0);
190///     }
191/// }
192/// ```
193pub fn sample_dirichlet(n: usize, alpha: &[f64], seed: u64) -> IgraphResult<Vec<Vec<f64>>> {
194    let dim = alpha.len();
195    if dim < 2 {
196        return Err(IgraphError::InvalidArgument(
197            "sample_dirichlet: alpha must have at least 2 entries".to_string(),
198        ));
199    }
200
201    for (i, &a) in alpha.iter().enumerate() {
202        if a <= 0.0 {
203            return Err(IgraphError::InvalidArgument(format!(
204                "sample_dirichlet: alpha[{i}] = {a}, must be positive"
205            )));
206        }
207    }
208
209    let mut rng = SplitMix64::new(seed);
210    let mut result = Vec::with_capacity(n);
211
212    for _ in 0..n {
213        let mut sample = Vec::with_capacity(dim);
214        let mut sum = 0.0_f64;
215        for &a in alpha {
216            let g = rng.gen_gamma(a);
217            sum += g;
218            sample.push(g);
219        }
220        for v in &mut sample {
221            *v /= sum;
222        }
223        result.push(sample);
224    }
225
226    Ok(result)
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn sphere_surface_on_unit_sphere() {
235        let points = sample_sphere_surface(3, 100, 1.0, false, 42).unwrap();
236        assert_eq!(points.len(), 100);
237        for p in &points {
238            assert_eq!(p.len(), 3);
239            let r2: f64 = p.iter().map(|&x| x * x).sum();
240            assert!(
241                (r2 - 1.0).abs() < 1e-10,
242                "point not on unit sphere: r²={r2}"
243            );
244        }
245    }
246
247    #[test]
248    fn sphere_surface_scaled_radius() {
249        let radius = 5.0;
250        let points = sample_sphere_surface(2, 50, radius, false, 99).unwrap();
251        for p in &points {
252            let r2: f64 = p.iter().map(|&x| x * x).sum();
253            assert!(
254                (r2 - radius * radius).abs() < 1e-8,
255                "point not on sphere of radius {radius}: r²={r2}"
256            );
257        }
258    }
259
260    #[test]
261    fn sphere_surface_positive_orthant() {
262        let points = sample_sphere_surface(3, 100, 1.0, true, 42).unwrap();
263        for p in &points {
264            for &c in p {
265                assert!(c >= 0.0, "expected non-negative, got {c}");
266            }
267        }
268    }
269
270    #[test]
271    fn sphere_surface_dim_1_error() {
272        assert!(sample_sphere_surface(1, 10, 1.0, false, 42).is_err());
273    }
274
275    #[test]
276    fn sphere_surface_negative_radius_error() {
277        assert!(sample_sphere_surface(2, 10, -1.0, false, 42).is_err());
278    }
279
280    #[test]
281    fn sphere_surface_zero_samples() {
282        let points = sample_sphere_surface(3, 0, 1.0, false, 42).unwrap();
283        assert!(points.is_empty());
284    }
285
286    #[test]
287    fn sphere_volume_inside_ball() {
288        let points = sample_sphere_volume(3, 200, 1.0, false, 42).unwrap();
289        assert_eq!(points.len(), 200);
290        for p in &points {
291            let r2: f64 = p.iter().map(|&x| x * x).sum();
292            assert!(r2 <= 1.0 + 1e-10, "point outside unit ball: r²={r2}");
293        }
294    }
295
296    #[test]
297    fn sphere_volume_not_all_on_surface() {
298        let points = sample_sphere_volume(3, 100, 1.0, false, 42).unwrap();
299        let on_surface = points
300            .iter()
301            .filter(|p| {
302                let r2: f64 = p.iter().map(|&x| x * x).sum();
303                (r2 - 1.0).abs() < 0.01
304            })
305            .count();
306        assert!(
307            on_surface < 100,
308            "all points on surface — volume sampling likely broken"
309        );
310    }
311
312    #[test]
313    fn sphere_volume_positive() {
314        let points = sample_sphere_volume(2, 100, 2.0, true, 42).unwrap();
315        for p in &points {
316            for &c in p {
317                assert!(c >= 0.0);
318            }
319        }
320    }
321
322    #[test]
323    fn sphere_volume_scaled() {
324        let radius = 3.0;
325        let points = sample_sphere_volume(2, 200, radius, false, 42).unwrap();
326        for p in &points {
327            let r2: f64 = p.iter().map(|&x| x * x).sum();
328            assert!(r2 <= radius * radius + 1e-8);
329        }
330    }
331
332    #[test]
333    fn dirichlet_sums_to_one() {
334        let alpha = [1.0, 2.0, 3.0];
335        let samples = sample_dirichlet(100, &alpha, 42).unwrap();
336        assert_eq!(samples.len(), 100);
337        for s in &samples {
338            assert_eq!(s.len(), 3);
339            let sum: f64 = s.iter().sum();
340            assert!(
341                (sum - 1.0).abs() < 1e-10,
342                "Dirichlet sample doesn't sum to 1: {sum}"
343            );
344        }
345    }
346
347    #[test]
348    fn dirichlet_all_positive() {
349        let alpha = [0.5, 0.5, 0.5];
350        let samples = sample_dirichlet(200, &alpha, 42).unwrap();
351        for s in &samples {
352            for &v in s {
353                assert!(v >= 0.0, "negative value in Dirichlet sample: {v}");
354            }
355        }
356    }
357
358    #[test]
359    fn dirichlet_mean_matches() {
360        // E[X_i] = alpha_i / sum(alpha)
361        let alpha = [1.0, 2.0, 3.0];
362        let alpha_sum: f64 = alpha.iter().sum();
363        let n = 50_000;
364        let samples = sample_dirichlet(n, &alpha, 42).unwrap();
365        for (j, &a) in alpha.iter().enumerate() {
366            let mean: f64 = samples.iter().map(|s| s[j]).sum::<f64>() / n as f64;
367            let expected = a / alpha_sum;
368            assert!(
369                (mean - expected).abs() < 0.02,
370                "dim {j}: mean={mean}, expected={expected}"
371            );
372        }
373    }
374
375    #[test]
376    fn dirichlet_short_alpha_error() {
377        assert!(sample_dirichlet(10, &[1.0], 42).is_err());
378    }
379
380    #[test]
381    fn dirichlet_non_positive_alpha_error() {
382        assert!(sample_dirichlet(10, &[1.0, -0.5], 42).is_err());
383        assert!(sample_dirichlet(10, &[1.0, 0.0], 42).is_err());
384    }
385
386    #[test]
387    fn dirichlet_zero_samples() {
388        let samples = sample_dirichlet(0, &[1.0, 2.0], 42).unwrap();
389        assert!(samples.is_empty());
390    }
391
392    #[test]
393    fn deterministic_same_seed() {
394        let a = sample_sphere_surface(3, 10, 1.0, false, 42).unwrap();
395        let b = sample_sphere_surface(3, 10, 1.0, false, 42).unwrap();
396        for (pa, pb) in a.iter().zip(b.iter()) {
397            for (ca, cb) in pa.iter().zip(pb.iter()) {
398                assert!((ca - cb).abs() < 1e-15);
399            }
400        }
401    }
402}