rust_igraph/algorithms/embedding/
dim_select.rs1use crate::core::{IgraphError, IgraphResult};
15
16#[allow(
46 clippy::cast_precision_loss,
47 clippy::needless_range_loop,
48 unknown_lints,
49 clippy::manual_midpoint
50)]
51pub fn dim_select(sv: &[f64]) -> IgraphResult<usize> {
52 let n = sv.len();
53
54 if n == 0 {
55 return Err(IgraphError::InvalidArgument(
56 "Need at least one singular value for dimensionality selection".to_string(),
57 ));
58 }
59
60 if n == 1 {
61 return Ok(1);
62 }
63
64 let nf = n as f64;
65
66 let mut sum1 = 0.0_f64;
67 let mut sum2: f64 = sv.iter().sum();
68 let mut sumsq1 = 0.0_f64;
69 let mut sumsq2 = 0.0_f64;
70 let mut mean1 = 0.0_f64;
71 let mut mean2 = sum2 / nf;
72 let mut varsq1 = 0.0_f64;
73 let mut varsq2 = 0.0_f64;
74
75 for &x in sv {
76 sumsq2 += x * x;
77 varsq2 += (mean2 - x) * (mean2 - x);
78 }
79
80 let mut max = f64::NEG_INFINITY;
81 let mut dim = n; for i in 0..n - 1 {
84 let n1 = (i + 1) as f64;
85 let n2 = (n - i - 1) as f64;
86 let n1m1 = n1 - 1.0;
87 let n2m1 = n2 - 1.0;
88
89 let x = sv[i];
90 let x2 = x * x;
91 sum1 += x;
92 sum2 -= x;
93 sumsq1 += x2;
94 sumsq2 -= x2;
95 let oldmean1 = mean1;
96 let oldmean2 = mean2;
97 mean1 = sum1 / n1;
98 mean2 = sum2 / n2;
99 varsq1 += (x - oldmean1) * (x - mean1);
100 varsq2 -= (x - oldmean2) * (x - mean2);
101 let var1 = if i == 0 { 0.0 } else { varsq1 / n1m1 };
102 let var2 = if i == n - 2 { 0.0 } else { varsq2 / n2m1 };
103 let sd = ((n1m1 * var1 + n2m1 * var2) / (nf - 2.0)).sqrt();
104 let profile = -nf * sd.ln()
105 - ((sumsq1 - 2.0 * mean1 * sum1 + n1 * mean1 * mean1)
106 + (sumsq2 - 2.0 * mean2 * sum2 + n2 * mean2 * mean2))
107 / 2.0
108 / sd
109 / sd;
110 if profile > max {
111 max = profile;
112 dim = i + 1;
113 }
114 }
115
116 let x = sv[n - 1];
118 sum1 += x;
119 let oldmean1 = mean1;
120 mean1 = sum1 / nf;
121 sumsq1 += x * x;
122 varsq1 += (x - oldmean1) * (x - mean1);
123 let var1 = varsq1 / (nf - 1.0);
124 let sd = var1.sqrt();
125 let profile =
126 -nf * sd.ln() - (sumsq1 - 2.0 * mean1 * sum1 + nf * mean1 * mean1) / 2.0 / sd / sd;
127 if profile > max {
128 dim = n;
129 }
130
131 Ok(dim)
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 #[test]
139 fn empty_input_errors() {
140 assert!(dim_select(&[]).is_err());
141 }
142
143 #[test]
144 fn single_value_is_one() {
145 assert_eq!(dim_select(&[42.0]).unwrap(), 1);
146 }
147
148 #[test]
149 fn ascending_ramp_splits_at_midpoint() {
150 let sv: Vec<f64> = (1..=100).map(f64::from).collect();
152 assert_eq!(dim_select(&sv).unwrap(), 50);
153 }
154
155 #[test]
156 fn small_ramp_anchor() {
157 let sv: Vec<f64> = (1..=10).map(f64::from).collect();
159 assert_eq!(dim_select(&sv).unwrap(), 5);
160 }
161
162 #[test]
163 fn clear_gap_is_detected() {
164 let sv = [100.0, 99.0, 98.0, 1.0, 0.9, 0.8, 0.7, 0.6];
167 assert_eq!(dim_select(&sv).unwrap(), 3);
168 }
169
170 #[test]
171 fn two_values_returns_two() {
172 assert_eq!(dim_select(&[2.0, 1.0]).unwrap(), 2);
175 }
176
177 #[test]
178 fn result_within_bounds() {
179 let sv = [5.0, 4.0, 3.0, 2.0, 1.0];
180 let d = dim_select(&sv).unwrap();
181 assert!((1..=sv.len()).contains(&d));
182 }
183}