Skip to main content

rust_igraph/algorithms/embedding/
dim_select.rs

1//! Dimensionality selection via profile likelihood (ALGO-EM-001).
2//!
3//! Counterpart of `igraph_dim_select()` in
4//! `references/igraph/src/misc/embedding.c:1130-1195`.
5//!
6//! Given an ordered vector of "importance" measures (for spectral
7//! embedding, the singular values of the adjacency matrix), the input is
8//! modelled as a two-component Gaussian mixture with different means and a
9//! shared variance. The returned dimension `d` is the split point — the
10//! first `d` values assigned to one component and the remaining values to
11//! the other — that maximises the profile log-likelihood. This is the
12//! Zhu & Ghodsi (2006) "elbow of the scree plot" criterion.
13
14use crate::core::{IgraphError, IgraphResult};
15
16/// Select the number of significant values by profile likelihood.
17///
18/// The slice `sv` holds the ordered values (e.g. singular values, largest
19/// first). The returned `d` (a count in `1..=sv.len()`) is the split that
20/// maximises the two-component equal-variance Gaussian-mixture profile
21/// log-likelihood. The values are used in the given order; this routine
22/// does not sort them.
23///
24/// # Errors
25///
26/// * [`IgraphError::InvalidArgument`] — if `sv` is empty (at least one
27///   value is required).
28///
29/// # Examples
30///
31/// ```
32/// use rust_igraph::dim_select;
33///
34/// // A clean two-level ramp: the elbow sits at the midpoint.
35/// let sv: Vec<f64> = (1..=100).map(|i| i as f64).collect();
36/// assert_eq!(dim_select(&sv).unwrap(), 50);
37///
38/// // A single value is trivially one-dimensional.
39/// assert_eq!(dim_select(&[3.5]).unwrap(), 1);
40/// ```
41// This is a faithful running-sum translation of the C reference: the loop
42// index `i` is needed for the group-size arithmetic (not just indexing), the
43// `usize`→`f64` casts are over small counts (well within mantissa range), and
44// the `(... ) / 2.0` likelihood term is not a midpoint despite clippy's guess.
45#[allow(
46    clippy::cast_precision_loss,
47    clippy::needless_range_loop,
48    unknown_lints,
49    clippy::manual_midpoint
50)]
51pub fn dim_select(sv: &[f64]) -> IgraphResult<usize> {
52    let n = sv.len();
53
54    if n == 0 {
55        return Err(IgraphError::InvalidArgument(
56            "Need at least one singular value for dimensionality selection".to_string(),
57        ));
58    }
59
60    if n == 1 {
61        return Ok(1);
62    }
63
64    let nf = n as f64;
65
66    let mut sum1 = 0.0_f64;
67    let mut sum2: f64 = sv.iter().sum();
68    let mut sumsq1 = 0.0_f64;
69    let mut sumsq2 = 0.0_f64;
70    let mut mean1 = 0.0_f64;
71    let mut mean2 = sum2 / nf;
72    let mut varsq1 = 0.0_f64;
73    let mut varsq2 = 0.0_f64;
74
75    for &x in sv {
76        sumsq2 += x * x;
77        varsq2 += (mean2 - x) * (mean2 - x);
78    }
79
80    let mut max = f64::NEG_INFINITY;
81    let mut dim = n; // matches the C "all in one group" fallback
82
83    for i in 0..n - 1 {
84        let n1 = (i + 1) as f64;
85        let n2 = (n - i - 1) as f64;
86        let n1m1 = n1 - 1.0;
87        let n2m1 = n2 - 1.0;
88
89        let x = sv[i];
90        let x2 = x * x;
91        sum1 += x;
92        sum2 -= x;
93        sumsq1 += x2;
94        sumsq2 -= x2;
95        let oldmean1 = mean1;
96        let oldmean2 = mean2;
97        mean1 = sum1 / n1;
98        mean2 = sum2 / n2;
99        varsq1 += (x - oldmean1) * (x - mean1);
100        varsq2 -= (x - oldmean2) * (x - mean2);
101        let var1 = if i == 0 { 0.0 } else { varsq1 / n1m1 };
102        let var2 = if i == n - 2 { 0.0 } else { varsq2 / n2m1 };
103        let sd = ((n1m1 * var1 + n2m1 * var2) / (nf - 2.0)).sqrt();
104        let profile = -nf * sd.ln()
105            - ((sumsq1 - 2.0 * mean1 * sum1 + n1 * mean1 * mean1)
106                + (sumsq2 - 2.0 * mean2 * sum2 + n2 * mean2 * mean2))
107                / 2.0
108                / sd
109                / sd;
110        if profile > max {
111            max = profile;
112            dim = i + 1;
113        }
114    }
115
116    // The remaining case: every element in a single group.
117    let x = sv[n - 1];
118    sum1 += x;
119    let oldmean1 = mean1;
120    mean1 = sum1 / nf;
121    sumsq1 += x * x;
122    varsq1 += (x - oldmean1) * (x - mean1);
123    let var1 = varsq1 / (nf - 1.0);
124    let sd = var1.sqrt();
125    let profile =
126        -nf * sd.ln() - (sumsq1 - 2.0 * mean1 * sum1 + nf * mean1 * mean1) / 2.0 / sd / sd;
127    if profile > max {
128        dim = n;
129    }
130
131    Ok(dim)
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    #[test]
139    fn empty_input_errors() {
140        assert!(dim_select(&[]).is_err());
141    }
142
143    #[test]
144    fn single_value_is_one() {
145        assert_eq!(dim_select(&[42.0]).unwrap(), 1);
146    }
147
148    #[test]
149    fn ascending_ramp_splits_at_midpoint() {
150        // Authentic rigraph anchor: dim_select(1:100) == 50.
151        let sv: Vec<f64> = (1..=100).map(f64::from).collect();
152        assert_eq!(dim_select(&sv).unwrap(), 50);
153    }
154
155    #[test]
156    fn small_ramp_anchor() {
157        // dim_select(1:10): the symmetric ramp elbow sits at the midpoint.
158        let sv: Vec<f64> = (1..=10).map(f64::from).collect();
159        assert_eq!(dim_select(&sv).unwrap(), 5);
160    }
161
162    #[test]
163    fn clear_gap_is_detected() {
164        // Three large leading values, then a cluster of tiny ones: the
165        // dimension should be the count of the dominant block.
166        let sv = [100.0, 99.0, 98.0, 1.0, 0.9, 0.8, 0.7, 0.6];
167        assert_eq!(dim_select(&sv).unwrap(), 3);
168    }
169
170    #[test]
171    fn two_values_returns_two() {
172        // n == 2: the loop's combined SD is degenerate, so the single-group
173        // fallback decides and returns the full dimension.
174        assert_eq!(dim_select(&[2.0, 1.0]).unwrap(), 2);
175    }
176
177    #[test]
178    fn result_within_bounds() {
179        let sv = [5.0, 4.0, 3.0, 2.0, 1.0];
180        let d = dim_select(&sv).unwrap();
181        assert!((1..=sv.len()).contains(&d));
182    }
183}