Skip to main content

rust_igraph/algorithms/eigen/
general.rs

1//! General (non-symmetric) eigenvalue solver via Arnoldi iteration.
2//!
3//! Computes selected eigenvalues and eigenvectors of a general real
4//! matrix defined implicitly by a matrix-vector product closure.
5//!
6//! For non-symmetric matrices, eigenvalues may be complex. This module
7//! returns eigenvalues as `(real, imag)` pairs and eigenvectors as
8//! pairs of real/imaginary component vectors.
9//!
10//! ```
11//! use rust_igraph::eigen_matrix;
12//! use rust_igraph::GeneralEigenWhich;
13//!
14//! // 2×2 rotation matrix [[0, -1], [1, 0]]: eigenvalues are ±i
15//! let result = eigen_matrix(
16//!     2,
17//!     |x, y| { y[0] = -x[1]; y[1] = x[0]; },
18//!     2,
19//!     GeneralEigenWhich::LargestMagnitude,
20//! ).unwrap();
21//!
22//! assert_eq!(result.eigenvalues.len(), 2);
23//! for &(re, im) in &result.eigenvalues {
24//!     let mag = (re * re + im * im).sqrt();
25//!     assert!((mag - 1.0).abs() < 0.05);
26//! }
27//! ```
28
29// Numerical linear algebra routines use index-heavy loops and many local
30// variables that are standard in the literature (h, q, k, m, n, etc.).
31#![allow(
32    clippy::many_single_char_names,
33    clippy::needless_range_loop,
34    clippy::too_many_lines,
35    clippy::cast_precision_loss,
36    clippy::manual_memcpy,
37    clippy::manual_swap,
38    clippy::similar_names,
39    unknown_lints,
40    clippy::manual_midpoint
41)]
42
43use crate::core::error::{IgraphError, IgraphResult};
44
45/// Which eigenvalues to compute for a general (non-symmetric) matrix.
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum GeneralEigenWhich {
48    /// Largest magnitude (largest |λ|).
49    LargestMagnitude,
50    /// Smallest magnitude (smallest |λ|).
51    SmallestMagnitude,
52    /// Largest real part.
53    LargestReal,
54    /// Smallest real part (most negative).
55    SmallestReal,
56}
57
58/// Result of a general eigenvalue decomposition.
59///
60/// Eigenvalues may be complex, represented as `(real, imaginary)` pairs.
61/// Eigenvectors for real eigenvalues have zero imaginary components; for
62/// complex conjugate pairs, only one pair of `(real_part, imag_part)`
63/// vectors is stored.
64#[derive(Debug, Clone)]
65pub struct GeneralEigenDecomposition {
66    /// Computed eigenvalues as `(real, imaginary)` pairs.
67    pub eigenvalues: Vec<(f64, f64)>,
68    /// Corresponding eigenvectors: each entry is `(real_part, imag_part)`
69    /// where both vectors have length `n`. For a real eigenvalue, the
70    /// imaginary part is all zeros.
71    pub eigenvectors: Vec<(Vec<f64>, Vec<f64>)>,
72}
73
74/// Compute selected eigenpairs of a general real matrix.
75///
76/// The matrix is defined implicitly: `matvec(x, y)` must compute
77/// `y = A * x` where both slices have length `n`.
78///
79/// Uses Arnoldi iteration to build an upper Hessenberg projection,
80/// then QR iteration on the small Hessenberg matrix to extract
81/// eigenvalues.
82///
83/// # Arguments
84///
85/// * `n` — matrix dimension
86/// * `matvec` — closure computing the matrix-vector product `y = A x`
87/// * `nev` — number of eigenvalues to compute (clamped to `n`)
88/// * `which` — which part of the spectrum to target
89///
90/// # Errors
91///
92/// Returns [`IgraphError::InvalidArgument`] if `n == 0`.
93///
94/// # Examples
95///
96/// ```
97/// use rust_igraph::eigen_matrix;
98/// use rust_igraph::GeneralEigenWhich;
99///
100/// // diag(3, 2, 1): top eigenvalue by magnitude is 3.0
101/// let result = eigen_matrix(
102///     3,
103///     |x, y| { y[0] = 3.0*x[0]; y[1] = 2.0*x[1]; y[2] = x[2]; },
104///     2,
105///     GeneralEigenWhich::LargestMagnitude,
106/// ).unwrap();
107///
108/// let (re, im) = result.eigenvalues[0];
109/// assert!((re - 3.0).abs() < 0.05);
110/// assert!(im.abs() < 0.05);
111/// ```
112pub fn eigen_matrix<F>(
113    n: usize,
114    matvec: F,
115    nev: usize,
116    which: GeneralEigenWhich,
117) -> IgraphResult<GeneralEigenDecomposition>
118where
119    F: Fn(&[f64], &mut [f64]),
120{
121    if n == 0 {
122        return Err(IgraphError::InvalidArgument(
123            "eigen_matrix: matrix dimension must be > 0".into(),
124        ));
125    }
126
127    let nev = nev.min(n);
128    if nev == 0 {
129        return Ok(GeneralEigenDecomposition {
130            eigenvalues: Vec::new(),
131            eigenvectors: Vec::new(),
132        });
133    }
134
135    if n == 1 {
136        let mut y = vec![0.0];
137        matvec(&[1.0], &mut y);
138        return Ok(GeneralEigenDecomposition {
139            eigenvalues: vec![(y[0], 0.0)],
140            eigenvectors: vec![(vec![1.0], vec![0.0])],
141        });
142    }
143
144    // Arnoldi subspace dimension: larger than nev to get good convergence
145    let m = n.min(nev.saturating_mul(2).max(20).min(n));
146
147    // Run Arnoldi iteration
148    let (q_basis, h_matrix) = arnoldi_iteration(n, &matvec, m);
149    let actual_m = q_basis.len();
150    if actual_m == 0 {
151        return Ok(GeneralEigenDecomposition {
152            eigenvalues: Vec::new(),
153            eigenvectors: Vec::new(),
154        });
155    }
156
157    // Compute eigenvalues of the upper Hessenberg matrix H via QR
158    let (eig_re, eig_im) = hessenberg_qr_eigenvalues(&h_matrix, actual_m);
159
160    // Sort eigenvalues by the requested criterion, pair with indices
161    let mut indexed: Vec<(usize, f64, f64)> = eig_re
162        .iter()
163        .zip(eig_im.iter())
164        .enumerate()
165        .map(|(i, (&re, &im))| (i, re, im))
166        .collect();
167
168    match which {
169        GeneralEigenWhich::LargestMagnitude => {
170            indexed.sort_by(|a, b| {
171                let ma = (a.1 * a.1 + a.2 * a.2).sqrt();
172                let mb = (b.1 * b.1 + b.2 * b.2).sqrt();
173                mb.partial_cmp(&ma).unwrap_or(std::cmp::Ordering::Equal)
174            });
175        }
176        GeneralEigenWhich::SmallestMagnitude => {
177            indexed.sort_by(|a, b| {
178                let ma = (a.1 * a.1 + a.2 * a.2).sqrt();
179                let mb = (b.1 * b.1 + b.2 * b.2).sqrt();
180                ma.partial_cmp(&mb).unwrap_or(std::cmp::Ordering::Equal)
181            });
182        }
183        GeneralEigenWhich::LargestReal => {
184            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
185        }
186        GeneralEigenWhich::SmallestReal => {
187            indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
188        }
189    }
190
191    // Take top nev
192    let selected: Vec<_> = indexed.into_iter().take(nev).collect();
193
194    // Compute Ritz vectors: for each eigenvalue of H, solve (H - λI)z = 0,
195    // then x = Q * z gives the approximate eigenvector
196    let mut eigenvalues = Vec::with_capacity(nev);
197    let mut eigenvectors = Vec::with_capacity(nev);
198
199    for &(idx, re, im) in &selected {
200        eigenvalues.push((re, im));
201
202        if im.abs() < 1e-14 {
203            // Real eigenvalue: compute real eigenvector of H
204            let z = hessenberg_eigenvector_real(&h_matrix, actual_m, re);
205            let mut x = vec![0.0; n];
206            for (j, zj) in z.iter().enumerate() {
207                if j < q_basis.len() {
208                    for i in 0..n {
209                        x[i] += zj * q_basis[j][i];
210                    }
211                }
212            }
213            normalize_vec(&mut x);
214            eigenvectors.push((x, vec![0.0; n]));
215        } else {
216            // Complex eigenvalue: compute complex eigenvector
217            let (z_re, z_im) = hessenberg_eigenvector_complex(&h_matrix, actual_m, re, im);
218            let mut x_re = vec![0.0; n];
219            let mut x_im = vec![0.0; n];
220            for j in 0..z_re.len().min(q_basis.len()) {
221                for i in 0..n {
222                    x_re[i] += z_re[j] * q_basis[j][i];
223                    x_im[i] += z_im[j] * q_basis[j][i];
224                }
225            }
226            let norm = {
227                let s: f64 = x_re.iter().zip(&x_im).map(|(r, i)| r * r + i * i).sum();
228                s.sqrt()
229            };
230            if norm > 1e-30 {
231                for i in 0..n {
232                    x_re[i] /= norm;
233                    x_im[i] /= norm;
234                }
235            }
236            eigenvectors.push((x_re, x_im));
237        }
238        let _ = idx;
239    }
240
241    Ok(GeneralEigenDecomposition {
242        eigenvalues,
243        eigenvectors,
244    })
245}
246
247// ---------------------------------------------------------------
248// Arnoldi iteration
249// ---------------------------------------------------------------
250
251fn arnoldi_iteration<F>(n: usize, matvec: &F, m: usize) -> (Vec<Vec<f64>>, Vec<Vec<f64>>)
252where
253    F: Fn(&[f64], &mut [f64]),
254{
255    let mut q: Vec<Vec<f64>> = Vec::with_capacity(m + 1);
256    // H is (m+1) x m upper Hessenberg stored as Vec of columns
257    let mut h: Vec<Vec<f64>> = vec![vec![0.0; m + 1]; m];
258
259    // Start with a random-ish vector
260    let mut q0 = vec![0.0; n];
261    for i in 0..n {
262        // Deterministic pseudo-random start via simple hash
263        let seed = (i as u64)
264            .wrapping_mul(6_364_136_223_846_793_005)
265            .wrapping_add(1);
266        q0[i] = ((seed >> 33) as f64) / (1u64 << 31) as f64 - 0.5;
267    }
268    normalize_vec(&mut q0);
269    q.push(q0);
270
271    let mut w = vec![0.0; n];
272
273    for j in 0..m {
274        matvec(&q[j], &mut w);
275
276        // Modified Gram-Schmidt orthogonalization
277        for i in 0..=j {
278            let hij = dot_vecs(&q[i], &w);
279            h[j][i] = hij;
280            for k in 0..n {
281                w[k] -= hij * q[i][k];
282            }
283        }
284
285        // Re-orthogonalize for numerical stability
286        for i in 0..=j {
287            let corr = dot_vecs(&q[i], &w);
288            h[j][i] += corr;
289            for k in 0..n {
290                w[k] -= corr * q[i][k];
291            }
292        }
293
294        let beta = norm_vec(&w);
295        if beta < 1e-14 {
296            // Invariant subspace found
297            return (q, h);
298        }
299
300        h[j][j + 1] = beta;
301        let inv_beta = 1.0 / beta;
302        let q_next: Vec<f64> = w.iter().map(|&v| v * inv_beta).collect();
303        q.push(q_next);
304    }
305
306    (q, h)
307}
308
309// ---------------------------------------------------------------
310// QR iteration for upper Hessenberg eigenvalues
311// ---------------------------------------------------------------
312
313fn hessenberg_qr_eigenvalues(h: &[Vec<f64>], m: usize) -> (Vec<f64>, Vec<f64>) {
314    if m == 0 {
315        return (Vec::new(), Vec::new());
316    }
317    if m == 1 {
318        return (vec![h[0][0]], vec![0.0]);
319    }
320
321    // Copy H into a dense row-major m×m matrix
322    let mut a = vec![0.0; m * m];
323    for j in 0..m {
324        for i in 0..m {
325            a[i * m + j] = h[j][i];
326        }
327    }
328
329    // Implicit double-shift QR (Francis algorithm)
330    let max_iter = m.saturating_mul(100).max(1000);
331    let mut eig_re = vec![0.0; m];
332    let mut eig_im = vec![0.0; m];
333
334    francis_qr(&mut a, m, &mut eig_re, &mut eig_im, max_iter);
335
336    (eig_re, eig_im)
337}
338
339/// Francis double-shift QR iteration on an upper Hessenberg matrix.
340fn francis_qr(a: &mut [f64], n: usize, eig_re: &mut [f64], eig_im: &mut [f64], max_iter: usize) {
341    let mut p = n;
342    let mut iter_count = 0;
343
344    while p > 0 && iter_count < max_iter {
345        // Find the lowest active subdiagonal element that is negligible
346        let mut q = p;
347        while q > 1 {
348            let h_qq = a[(q - 1) * n + (q - 1)];
349            let h_q1q1 = a[(q - 2) * n + (q - 2)];
350            let threshold = 1e-14 * (h_qq.abs() + h_q1q1.abs()).max(1e-30);
351            if a[(q - 1) * n + (q - 2)].abs() <= threshold {
352                a[(q - 1) * n + (q - 2)] = 0.0;
353                break;
354            }
355            q -= 1;
356        }
357
358        if q == p {
359            // 1×1 block: real eigenvalue
360            eig_re[p - 1] = a[(p - 1) * n + (p - 1)];
361            eig_im[p - 1] = 0.0;
362            p -= 1;
363        } else if q == p - 1 {
364            // 2×2 block: extract eigenvalues
365            let a11 = a[(p - 2) * n + (p - 2)];
366            let a12 = a[(p - 2) * n + (p - 1)];
367            let a21 = a[(p - 1) * n + (p - 2)];
368            let a22 = a[(p - 1) * n + (p - 1)];
369            let (e1r, e1i, e2r, e2i) = eigenvalues_2x2(a11, a12, a21, a22);
370            eig_re[p - 2] = e1r;
371            eig_im[p - 2] = e1i;
372            eig_re[p - 1] = e2r;
373            eig_im[p - 1] = e2i;
374            p -= 2;
375        } else {
376            // Perform one implicit double-shift QR step
377            implicit_double_shift_step(a, n, q.saturating_sub(1), p);
378            iter_count += 1;
379        }
380    }
381
382    // If we didn't converge, extract whatever diagonal we have
383    if p > 0 && iter_count >= max_iter {
384        for i in 0..p {
385            eig_re[i] = a[i * n + i];
386            eig_im[i] = 0.0;
387        }
388    }
389}
390
391fn implicit_double_shift_step(a: &mut [f64], n: usize, lo: usize, hi: usize) {
392    if hi <= lo + 1 {
393        return;
394    }
395
396    // Wilkinson shift: eigenvalues of the trailing 2×2 block
397    let s = a[(hi - 2) * n + (hi - 2)] + a[(hi - 1) * n + (hi - 1)];
398    let t = a[(hi - 2) * n + (hi - 2)] * a[(hi - 1) * n + (hi - 1)]
399        - a[(hi - 2) * n + (hi - 1)] * a[(hi - 1) * n + (hi - 2)];
400
401    // First column of (A - σ₁I)(A - σ₂I) = A² - sA + tI
402    let mut x = a[lo * n + lo] * a[lo * n + lo] + a[lo * n + (lo + 1)] * a[(lo + 1) * n + lo]
403        - s * a[lo * n + lo]
404        + t;
405    let mut y = a[(lo + 1) * n + lo] * (a[lo * n + lo] + a[(lo + 1) * n + (lo + 1)] - s);
406    let mut z = if lo + 2 < hi {
407        a[(lo + 1) * n + lo] * a[(lo + 2) * n + (lo + 1)]
408    } else {
409        0.0
410    };
411
412    for k in lo..hi.saturating_sub(1) {
413        // Householder reflector to zero out y, z
414        let (v0, v1, v2, tau) = householder3(x, y, z);
415
416        let r = k.saturating_sub(1).max(lo);
417        let end = hi.min(k + 4);
418
419        // Apply from left: rows k, k+1, k+2
420        for j in r..n.min(hi) {
421            let sum = v0 * a[k * n + j]
422                + v1 * a[(k + 1).min(hi - 1) * n + j]
423                + if k + 2 < hi {
424                    v2 * a[(k + 2) * n + j]
425                } else {
426                    0.0
427                };
428            a[k * n + j] -= tau * v0 * sum;
429            a[(k + 1).min(hi - 1) * n + j] -= tau * v1 * sum;
430            if k + 2 < hi {
431                a[(k + 2) * n + j] -= tau * v2 * sum;
432            }
433        }
434
435        // Apply from right: columns k, k+1, k+2
436        for i in 0..end.min(n) {
437            let sum = v0 * a[i * n + k]
438                + v1 * a[i * n + (k + 1).min(hi - 1)]
439                + if k + 2 < hi {
440                    v2 * a[i * n + k + 2]
441                } else {
442                    0.0
443                };
444            a[i * n + k] -= tau * v0 * sum;
445            a[i * n + (k + 1).min(hi - 1)] -= tau * v1 * sum;
446            if k + 2 < hi {
447                a[i * n + k + 2] -= tau * v2 * sum;
448            }
449        }
450
451        // Prepare for next iteration
452        if k + 2 < hi.saturating_sub(1) {
453            x = a[(k + 1) * n + k];
454            y = a[(k + 2) * n + k];
455            z = if k + 3 < hi { a[(k + 3) * n + k] } else { 0.0 };
456        } else {
457            x = a[(k + 1).min(hi - 1) * n + k];
458            y = 0.0;
459            z = 0.0;
460        }
461    }
462}
463
464fn householder3(x: f64, y: f64, z: f64) -> (f64, f64, f64, f64) {
465    let norm = (x * x + y * y + z * z).sqrt();
466    if norm < 1e-30 {
467        return (1.0, 0.0, 0.0, 0.0);
468    }
469    let sign_x = if x >= 0.0 { 1.0 } else { -1.0 };
470    let v0 = x + sign_x * norm;
471    let v1 = y;
472    let v2 = z;
473    let v_norm_sq = v0 * v0 + v1 * v1 + v2 * v2;
474    if v_norm_sq < 1e-30 {
475        return (1.0, 0.0, 0.0, 0.0);
476    }
477    let tau = 2.0 / v_norm_sq;
478    (v0, v1, v2, tau)
479}
480
481fn eigenvalues_2x2(a: f64, b: f64, c: f64, d: f64) -> (f64, f64, f64, f64) {
482    let trace = a + d;
483    let det = a * d - b * c;
484    let disc = trace * trace - 4.0 * det;
485    if disc >= 0.0 {
486        let sqrt_disc = disc.sqrt();
487        let e1 = (trace + sqrt_disc) / 2.0;
488        let e2 = (trace - sqrt_disc) / 2.0;
489        (e1, 0.0, e2, 0.0)
490    } else {
491        let sqrt_disc = (-disc).sqrt();
492        let re = trace / 2.0;
493        (re, sqrt_disc / 2.0, re, -sqrt_disc / 2.0)
494    }
495}
496
497// ---------------------------------------------------------------
498// Eigenvector computation from Hessenberg Schur form
499// ---------------------------------------------------------------
500
501fn hessenberg_eigenvector_real(h: &[Vec<f64>], m: usize, lambda: f64) -> Vec<f64> {
502    // Inverse iteration: solve (H - λI)z = 0 approximately
503    let mut z = vec![0.0; m];
504
505    // Build (H - λI) in row-major
506    let mut mat = vec![0.0; m * m];
507    for j in 0..m {
508        for i in 0..m {
509            mat[i * m + j] = h[j][i];
510        }
511        mat[j * m + j] -= lambda;
512    }
513
514    // Inverse iteration with a random starting vector
515    let mut x = vec![0.0; m];
516    for i in 0..m {
517        let seed = (i as u64)
518            .wrapping_mul(2_862_933_555_777_941_757)
519            .wrapping_add(3);
520        x[i] = ((seed >> 33) as f64) / (1u64 << 31) as f64 - 0.5;
521    }
522
523    // Shift slightly to avoid singularity
524    let shift = 1e-10 * (1.0 + lambda.abs());
525    for i in 0..m {
526        mat[i * m + i] += shift;
527    }
528
529    // LU factorize (partial pivoting)
530    let mut lu = mat.clone();
531    let mut piv = vec![0usize; m];
532    for i in 0..m {
533        piv[i] = i;
534    }
535    for k in 0..m.saturating_sub(1) {
536        let mut max_val = lu[k * m + k].abs();
537        let mut max_row = k;
538        for i in (k + 1)..m {
539            if lu[i * m + k].abs() > max_val {
540                max_val = lu[i * m + k].abs();
541                max_row = i;
542            }
543        }
544        if max_row != k {
545            piv.swap(k, max_row);
546            for j in 0..m {
547                let tmp = lu[k * m + j];
548                lu[k * m + j] = lu[max_row * m + j];
549                lu[max_row * m + j] = tmp;
550            }
551        }
552        if lu[k * m + k].abs() < 1e-30 {
553            continue;
554        }
555        for i in (k + 1)..m {
556            lu[i * m + k] /= lu[k * m + k];
557            for j in (k + 1)..m {
558                lu[i * m + j] -= lu[i * m + k] * lu[k * m + j];
559            }
560        }
561    }
562
563    // 3 rounds of inverse iteration
564    for _ in 0..3 {
565        // Apply permutation
566        let mut b = vec![0.0; m];
567        for i in 0..m {
568            b[i] = x[piv[i]];
569        }
570        // Forward substitution
571        for i in 1..m {
572            for j in 0..i {
573                b[i] -= lu[i * m + j] * b[j];
574            }
575        }
576        // Back substitution
577        for i in (0..m).rev() {
578            for j in (i + 1)..m {
579                b[i] -= lu[i * m + j] * b[j];
580            }
581            if lu[i * m + i].abs() > 1e-30 {
582                b[i] /= lu[i * m + i];
583            }
584        }
585        x = b;
586        normalize_vec(&mut x);
587    }
588
589    z.copy_from_slice(&x);
590    z
591}
592
593fn hessenberg_eigenvector_complex(
594    h: &[Vec<f64>],
595    m: usize,
596    lambda_re: f64,
597    lambda_im: f64,
598) -> (Vec<f64>, Vec<f64>) {
599    // For complex eigenvalues, we solve (H - λI)z = 0 in complex arithmetic
600    // using inverse iteration with complex vectors
601    let mut z_re = vec![0.0; m];
602    let mut z_im = vec![0.0; m];
603
604    // Build (H - λ_re·I) and store -λ_im for imaginary part
605    let mut mat_re = vec![0.0; m * m];
606    for j in 0..m {
607        for i in 0..m {
608            mat_re[i * m + j] = h[j][i];
609        }
610        mat_re[j * m + j] -= lambda_re;
611    }
612
613    // Start with a random complex vector
614    let mut x_re = vec![0.0; m];
615    let mut x_im = vec![0.0; m];
616    for i in 0..m {
617        let seed1 = (i as u64)
618            .wrapping_mul(6_364_136_223_846_793_005)
619            .wrapping_add(7);
620        let seed2 = (i as u64)
621            .wrapping_mul(2_862_933_555_777_941_757)
622            .wrapping_add(13);
623        x_re[i] = ((seed1 >> 33) as f64) / (1u64 << 31) as f64 - 0.5;
624        x_im[i] = ((seed2 >> 33) as f64) / (1u64 << 31) as f64 - 0.5;
625    }
626
627    // Use real Schur form approach: build the 2m×2m real system
628    // [ (H-re·I)  im·I ] [x_re]   [0]
629    // [ -im·I   (H-re·I)] [x_im] = [0]
630    // and solve via inverse iteration on the real system
631    let shift = 1e-10 * (1.0 + lambda_re.abs() + lambda_im.abs());
632    let n2 = 2 * m;
633    let mut big_mat = vec![0.0; n2 * n2];
634    for i in 0..m {
635        for j in 0..m {
636            big_mat[i * n2 + j] = mat_re[i * m + j];
637            big_mat[(i + m) * n2 + (j + m)] = mat_re[i * m + j];
638        }
639        big_mat[i * n2 + (i + m)] = lambda_im;
640        big_mat[(i + m) * n2 + i] = -lambda_im;
641        big_mat[i * n2 + i] += shift;
642        big_mat[(i + m) * n2 + (i + m)] += shift;
643    }
644
645    // LU factorize big_mat
646    let mut lu = big_mat;
647    let mut piv = vec![0usize; n2];
648    for i in 0..n2 {
649        piv[i] = i;
650    }
651    for k in 0..n2.saturating_sub(1) {
652        let mut max_val = lu[k * n2 + k].abs();
653        let mut max_row = k;
654        for i in (k + 1)..n2 {
655            if lu[i * n2 + k].abs() > max_val {
656                max_val = lu[i * n2 + k].abs();
657                max_row = i;
658            }
659        }
660        if max_row != k {
661            piv.swap(k, max_row);
662            for j in 0..n2 {
663                let tmp = lu[k * n2 + j];
664                lu[k * n2 + j] = lu[max_row * n2 + j];
665                lu[max_row * n2 + j] = tmp;
666            }
667        }
668        if lu[k * n2 + k].abs() < 1e-30 {
669            continue;
670        }
671        for i in (k + 1)..n2 {
672            lu[i * n2 + k] /= lu[k * n2 + k];
673            for j in (k + 1)..n2 {
674                lu[i * n2 + j] -= lu[i * n2 + k] * lu[k * n2 + j];
675            }
676        }
677    }
678
679    // Combined start vector
680    let mut big_x = vec![0.0; n2];
681    for i in 0..m {
682        big_x[i] = x_re[i];
683        big_x[i + m] = x_im[i];
684    }
685
686    // 3 rounds of inverse iteration
687    for _ in 0..3 {
688        let mut b = vec![0.0; n2];
689        for i in 0..n2 {
690            b[i] = big_x[piv[i]];
691        }
692        for i in 1..n2 {
693            for j in 0..i {
694                b[i] -= lu[i * n2 + j] * b[j];
695            }
696        }
697        for i in (0..n2).rev() {
698            for j in (i + 1)..n2 {
699                b[i] -= lu[i * n2 + j] * b[j];
700            }
701            if lu[i * n2 + i].abs() > 1e-30 {
702                b[i] /= lu[i * n2 + i];
703            }
704        }
705        let norm: f64 = b.iter().map(|v| v * v).sum::<f64>().sqrt();
706        if norm > 1e-30 {
707            for v in &mut b {
708                *v /= norm;
709            }
710        }
711        big_x = b;
712    }
713
714    for i in 0..m {
715        z_re[i] = big_x[i];
716        z_im[i] = big_x[i + m];
717    }
718
719    (z_re, z_im)
720}
721
722// ---------------------------------------------------------------
723// Utility functions
724// ---------------------------------------------------------------
725
726fn dot_vecs(a: &[f64], b: &[f64]) -> f64 {
727    a.iter().zip(b).map(|(x, y)| x * y).sum()
728}
729
730fn norm_vec(v: &[f64]) -> f64 {
731    dot_vecs(v, v).sqrt()
732}
733
734fn normalize_vec(v: &mut [f64]) {
735    let n = norm_vec(v);
736    if n > 1e-30 {
737        for x in v.iter_mut() {
738            *x /= n;
739        }
740    }
741}
742
743#[cfg(test)]
744mod tests {
745    use super::*;
746
747    #[test]
748    fn diagonal_real_eigenvalues() {
749        let result = eigen_matrix(
750            3,
751            |x, y| {
752                y[0] = 3.0 * x[0];
753                y[1] = 2.0 * x[1];
754                y[2] = x[2];
755            },
756            2,
757            GeneralEigenWhich::LargestMagnitude,
758        )
759        .unwrap();
760
761        assert_eq!(result.eigenvalues.len(), 2);
762        let (re0, im0) = result.eigenvalues[0];
763        assert!(
764            (re0 - 3.0).abs() < 0.05,
765            "top eigenvalue should be ~3.0, got {re0}"
766        );
767        assert!(im0.abs() < 0.05, "should be real, got im={im0}");
768    }
769
770    #[test]
771    fn rotation_matrix_complex_eigenvalues() {
772        // [[0, -1], [1, 0]] has eigenvalues ±i
773        let result = eigen_matrix(
774            2,
775            |x, y| {
776                y[0] = -x[1];
777                y[1] = x[0];
778            },
779            2,
780            GeneralEigenWhich::LargestMagnitude,
781        )
782        .unwrap();
783
784        assert_eq!(result.eigenvalues.len(), 2);
785        for &(re, im) in &result.eigenvalues {
786            let mag = (re * re + im * im).sqrt();
787            assert!(
788                (mag - 1.0).abs() < 1e-4,
789                "eigenvalue magnitude should be 1.0, got {mag} (re={re}, im={im})"
790            );
791            assert!(re.abs() < 1e-4, "real part should be ~0, got {re}");
792        }
793    }
794
795    #[test]
796    fn upper_triangular_eigenvalues() {
797        // [[2, 1], [0, 5]]: eigenvalues are 2 and 5
798        let result = eigen_matrix(
799            2,
800            |x, y| {
801                y[0] = 2.0 * x[0] + x[1];
802                y[1] = 5.0 * x[1];
803            },
804            2,
805            GeneralEigenWhich::LargestMagnitude,
806        )
807        .unwrap();
808
809        assert_eq!(result.eigenvalues.len(), 2);
810        let (re0, _) = result.eigenvalues[0];
811        let (re1, _) = result.eigenvalues[1];
812        let mut vals = [re0, re1];
813        vals.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
814        assert!(
815            (vals[0] - 5.0).abs() < 0.1,
816            "largest eigenvalue should be ~5.0, got {}",
817            vals[0]
818        );
819        assert!(
820            (vals[1] - 2.0).abs() < 0.1,
821            "second eigenvalue should be ~2.0, got {}",
822            vals[1]
823        );
824    }
825
826    #[test]
827    fn nev_zero_returns_empty() {
828        let result = eigen_matrix(3, |_x, _y| {}, 0, GeneralEigenWhich::LargestMagnitude).unwrap();
829        assert!(result.eigenvalues.is_empty());
830    }
831
832    #[test]
833    fn n_zero_returns_error() {
834        let result = eigen_matrix(0, |_x, _y| {}, 1, GeneralEigenWhich::LargestMagnitude);
835        assert!(result.is_err());
836    }
837
838    #[test]
839    fn n_one_returns_single() {
840        let result = eigen_matrix(
841            1,
842            |x, y| {
843                y[0] = 7.0 * x[0];
844            },
845            1,
846            GeneralEigenWhich::LargestMagnitude,
847        )
848        .unwrap();
849
850        assert_eq!(result.eigenvalues.len(), 1);
851        let (re, im) = result.eigenvalues[0];
852        assert!(
853            (re - 7.0).abs() < 1e-6,
854            "eigenvalue should be 7.0, got {re}"
855        );
856        assert!(im.abs() < 1e-10);
857    }
858
859    #[test]
860    fn smallest_real_selection() {
861        let result = eigen_matrix(
862            3,
863            |x, y| {
864                y[0] = -5.0 * x[0];
865                y[1] = 2.0 * x[1];
866                y[2] = 10.0 * x[2];
867            },
868            1,
869            GeneralEigenWhich::SmallestReal,
870        )
871        .unwrap();
872
873        assert_eq!(result.eigenvalues.len(), 1);
874        let (re, _) = result.eigenvalues[0];
875        assert!(
876            (re - (-5.0)).abs() < 0.1,
877            "smallest real eigenvalue should be ~-5.0, got {re}"
878        );
879    }
880
881    #[test]
882    fn largest_real_selection() {
883        let result = eigen_matrix(
884            3,
885            |x, y| {
886                y[0] = -5.0 * x[0];
887                y[1] = 2.0 * x[1];
888                y[2] = 10.0 * x[2];
889            },
890            1,
891            GeneralEigenWhich::LargestReal,
892        )
893        .unwrap();
894
895        assert_eq!(result.eigenvalues.len(), 1);
896        let (re, _) = result.eigenvalues[0];
897        assert!(
898            (re - 10.0).abs() < 0.1,
899            "largest real eigenvalue should be ~10.0, got {re}"
900        );
901    }
902}