Skip to main content

rust_igraph/algorithms/
matching_lsap.rs

1//! Linear Sum Assignment Problem (ALGO-MA-005) — Hungarian method.
2//!
3//! Counterpart of `igraph_solve_lsap` in
4//! `references/igraph/src/internal/lsap.c:664`.
5//!
6//! Solves the balanced assignment problem: given an n×n cost matrix,
7//! find a permutation p such that Σ C\[i\]\[p\[i\]\] is minimized.
8//!
9//! ## Algorithm
10//!
11//! Classical Hungarian method (Kuhn-Munkres): O(n³).
12//! 1. Subtract row and column minima (preprocessing).
13//! 2. Greedily assign zeros.
14//! 3. Iteratively cover rows/columns and reduce until all rows assigned.
15
16use crate::core::{IgraphError, IgraphResult};
17
18/// Solve a balanced linear sum assignment problem (Hungarian method).
19///
20/// Given an n×n cost matrix (stored row-major as a flat slice of length n²),
21/// find an assignment of each row to exactly one column that minimizes the
22/// total cost.
23///
24/// # Arguments
25///
26/// * `costs` — flat row-major n×n cost matrix (length must equal `n * n`).
27/// * `n` — size of the problem (number of rows = number of columns).
28///
29/// # Returns
30///
31/// A vector `p` of length `n` where `p[i]` is the column assigned to row `i`.
32///
33/// # Errors
34///
35/// Returns an error if `costs.len() != n * n`, or if `n` is zero and costs
36/// is non-empty, or if costs contain NaN.
37///
38/// # Examples
39///
40/// ```
41/// use rust_igraph::solve_lsap;
42///
43/// // 3×3 cost matrix:
44/// // [1, 2, 3]
45/// // [2, 4, 6]
46/// // [3, 6, 9]
47/// // Optimal: row 0→col 2 (3), row 1→col 1 (4), row 2→col 0 (3) = 10
48/// // OR: row 0→col 0 (1), row 1→col 1 (4), row 2→col 2 (9) = 14
49/// // Actually optimal: 0→2(3), 1→0(2), 2→1(6) = 11
50/// // Let's use a simpler example:
51/// let costs = vec![
52///     10.0, 5.0, 13.0,
53///      3.0, 7.0,  2.0,
54///      6.0, 8.0, 12.0,
55/// ];
56/// let p = solve_lsap(&costs, 3).unwrap();
57/// // Verify it's a valid permutation
58/// let mut used = vec![false; 3];
59/// for &col in &p {
60///     assert!(!used[col as usize]);
61///     used[col as usize] = true;
62/// }
63/// ```
64pub fn solve_lsap(costs: &[f64], n: usize) -> IgraphResult<Vec<u32>> {
65    if n == 0 {
66        if costs.is_empty() {
67            return Ok(Vec::new());
68        }
69        return Err(IgraphError::InvalidArgument(
70            "solve_lsap: n=0 but costs is non-empty".into(),
71        ));
72    }
73
74    let expected_len = n
75        .checked_mul(n)
76        .ok_or_else(|| IgraphError::InvalidArgument("solve_lsap: n*n overflows".into()))?;
77    if costs.len() != expected_len {
78        return Err(IgraphError::InvalidArgument(format!(
79            "solve_lsap: costs length {} != n*n = {}",
80            costs.len(),
81            expected_len
82        )));
83    }
84
85    for (i, &v) in costs.iter().enumerate() {
86        if v.is_nan() {
87            return Err(IgraphError::InvalidArgument(format!(
88                "solve_lsap: costs[{i}] is NaN"
89            )));
90        }
91    }
92
93    let assignment = hungarian(costs, n);
94    Ok(assignment)
95}
96
97fn hungarian(costs: &[f64], n: usize) -> Vec<u32> {
98    // Build reduced cost matrix (1-indexed internally for clarity)
99    let mut c = vec![vec![0.0_f64; n + 1]; n + 1];
100    for i in 1..=n {
101        for j in 1..=n {
102            c[i][j] = costs[(i - 1) * n + (j - 1)];
103        }
104    }
105
106    preprocess(&mut c, n);
107
108    // s[i] = column assigned to row i (0 = unassigned)
109    let mut s = vec![0_usize; n + 1];
110    // f[j] = row assigned to column j (0 = unassigned)
111    let mut f = vec![0_usize; n + 1];
112    let mut na = 0_usize;
113
114    preassign(&c, n, &mut s, &mut f, &mut na);
115
116    while na < n {
117        let mut ri = vec![false; n + 1]; // covered rows
118        let mut ci = vec![false; n + 1]; // covered columns
119
120        if cover(&mut c, n, &mut s, &mut f, &mut na, &mut ri, &mut ci) {
121            reduce(&mut c, n, &ri, &ci);
122        }
123    }
124
125    // Convert to 0-based u32
126    (1..=n)
127        .map(|i| u32::try_from(s[i] - 1).unwrap_or(0))
128        .collect()
129}
130
131#[allow(clippy::needless_range_loop)]
132fn preprocess(c: &mut [Vec<f64>], n: usize) {
133    // Subtract row minima
134    for i in 1..=n {
135        let mut min = c[i][1];
136        for j in 2..=n {
137            if c[i][j] < min {
138                min = c[i][j];
139            }
140        }
141        for j in 1..=n {
142            c[i][j] -= min;
143        }
144    }
145
146    // Subtract column minima
147    for j in 1..=n {
148        let mut min = c[1][j];
149        for i in 2..=n {
150            if c[i][j] < min {
151                min = c[i][j];
152            }
153        }
154        for i in 1..=n {
155            c[i][j] -= min;
156        }
157    }
158}
159
160#[allow(clippy::needless_range_loop)]
161fn preassign(c: &[Vec<f64>], n: usize, s: &mut [usize], f: &mut [usize], na: &mut usize) {
162    *na = 0;
163    let mut row_assigned = vec![false; n + 1];
164    let mut col_assigned = vec![false; n + 1];
165
166    // Count zeros in each row and column
167    let mut rz = vec![0_usize; n + 1];
168    let mut cz = vec![0_usize; n + 1];
169
170    for i in 1..=n {
171        for j in 1..=n {
172            if c[i][j] == 0.0 {
173                rz[i] += 1;
174                cz[j] += 1;
175            }
176        }
177    }
178
179    loop {
180        // Find unassigned row with fewest zeros > 0
181        let mut best_row = 0;
182        let mut best_count = usize::MAX;
183        for i in 1..=n {
184            if !row_assigned[i] && rz[i] > 0 && rz[i] < best_count {
185                best_count = rz[i];
186                best_row = i;
187            }
188        }
189        if best_row == 0 {
190            break;
191        }
192
193        // Find unassigned column in that row with fewest zeros
194        let mut best_col = 0;
195        let mut best_col_count = usize::MAX;
196        for j in 1..=n {
197            if c[best_row][j] == 0.0 && !col_assigned[j] && cz[j] < best_col_count {
198                best_col_count = cz[j];
199                best_col = j;
200            }
201        }
202
203        if best_col != 0 {
204            *na += 1;
205            s[best_row] = best_col;
206            f[best_col] = best_row;
207            row_assigned[best_row] = true;
208            col_assigned[best_col] = true;
209
210            // Adjust zero counts for column best_col
211            for i in 1..=n {
212                if c[i][best_col] == 0.0 {
213                    rz[i] = rz[i].saturating_sub(1);
214                }
215            }
216            cz[best_col] = 0;
217        } else {
218            // No available column, mark row as having no usable zeros
219            rz[best_row] = 0;
220        }
221    }
222}
223
224/// Attempt to extend the assignment. Returns true if reduction is needed.
225#[allow(clippy::needless_range_loop, clippy::many_single_char_names)]
226fn cover(
227    c: &mut [Vec<f64>],
228    n: usize,
229    s: &mut [usize],
230    f: &mut [usize],
231    na: &mut usize,
232    ri: &mut [bool],
233    ci: &mut [bool],
234) -> bool {
235    // Reset cover indices
236    let mut mr = vec![false; n + 1]; // marked rows
237    for i in 1..=n {
238        if s[i] == 0 {
239            ri[i] = false; // uncovered
240            mr[i] = true; // marked
241        } else {
242            ri[i] = true; // covered
243        }
244        ci[i] = false; // uncovered
245    }
246
247    loop {
248        // Find a marked row
249        let mut r = 0;
250        for i in 1..=n {
251            if mr[i] {
252                r = i;
253                break;
254            }
255        }
256        if r == 0 {
257            break;
258        }
259
260        // Look for uncovered zero in row r
261        let mut found_augment = false;
262        for j in 1..=n {
263            if c[r][j] == 0.0 && !ci[j] {
264                if f[j] != 0 {
265                    // Column j is assigned to row f[j]: uncover that row
266                    ri[f[j]] = false;
267                    mr[f[j]] = true;
268                    ci[j] = true;
269                } else {
270                    // Augmenting path found
271                    if s[r] == 0 {
272                        *na += 1;
273                    }
274                    // Unassign old column of row r
275                    let old_col = s[r];
276                    if old_col != 0 {
277                        f[old_col] = 0;
278                    }
279                    f[j] = r;
280                    s[r] = j;
281                    found_augment = true;
282                    break;
283                }
284            }
285        }
286
287        if found_augment {
288            return false;
289        }
290        mr[r] = false;
291    }
292
293    true
294}
295
296#[allow(clippy::needless_range_loop)]
297fn reduce(c: &mut [Vec<f64>], n: usize, ri: &[bool], ci: &[bool]) {
298    // Find minimum uncovered element
299    let mut min = f64::MAX;
300    for i in 1..=n {
301        if ri[i] {
302            continue;
303        }
304        for j in 1..=n {
305            if ci[j] {
306                continue;
307            }
308            if c[i][j] < min {
309                min = c[i][j];
310            }
311        }
312    }
313
314    // Subtract min from uncovered, add to doubly-covered
315    for i in 1..=n {
316        for j in 1..=n {
317            if !ri[i] && !ci[j] {
318                c[i][j] -= min;
319            } else if ri[i] && ci[j] {
320                c[i][j] += min;
321            }
322        }
323    }
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    fn is_valid_permutation(p: &[u32], n: usize) -> bool {
331        if p.len() != n {
332            return false;
333        }
334        let mut used = vec![false; n];
335        for &col in p {
336            let c = col as usize;
337            if c >= n || used[c] {
338                return false;
339            }
340            used[c] = true;
341        }
342        true
343    }
344
345    fn assignment_cost(costs: &[f64], n: usize, p: &[u32]) -> f64 {
346        (0..n).map(|i| costs[i * n + p[i] as usize]).sum()
347    }
348
349    #[test]
350    fn lsap_empty() {
351        let p = solve_lsap(&[], 0).unwrap();
352        assert!(p.is_empty());
353    }
354
355    #[test]
356    fn lsap_1x1() {
357        let p = solve_lsap(&[42.0], 1).unwrap();
358        assert_eq!(p, vec![0]);
359    }
360
361    #[test]
362    fn lsap_2x2_identity() {
363        // [1, 100]
364        // [100, 1]
365        // Optimal: (0,0) + (1,1) = 2
366        let costs = vec![1.0, 100.0, 100.0, 1.0];
367        let p = solve_lsap(&costs, 2).unwrap();
368        assert!(is_valid_permutation(&p, 2));
369        let cost = assignment_cost(&costs, 2, &p);
370        assert!((cost - 2.0).abs() < 1e-10);
371    }
372
373    #[test]
374    fn lsap_2x2_swap() {
375        // [100, 1]
376        // [1, 100]
377        // Optimal: (0,1) + (1,0) = 2
378        let costs = vec![100.0, 1.0, 1.0, 100.0];
379        let p = solve_lsap(&costs, 2).unwrap();
380        assert!(is_valid_permutation(&p, 2));
381        let cost = assignment_cost(&costs, 2, &p);
382        assert!((cost - 2.0).abs() < 1e-10);
383    }
384
385    #[test]
386    fn lsap_3x3() {
387        // Classic example:
388        // [82, 83, 69]
389        // [77, 37, 49]
390        // [11, 69, 5]
391        // Optimal: 0→2(69), 1→1(37), 2→0(11) = 117
392        let costs = vec![82.0, 83.0, 69.0, 77.0, 37.0, 49.0, 11.0, 69.0, 5.0];
393        let p = solve_lsap(&costs, 3).unwrap();
394        assert!(is_valid_permutation(&p, 3));
395        let cost = assignment_cost(&costs, 3, &p);
396        // Verify against known optimum
397        assert!((cost - 117.0).abs() < 1e-10);
398    }
399
400    #[test]
401    fn lsap_4x4() {
402        // [10, 5, 13, 15]
403        // [ 3, 9, 18,  3]
404        // [13, 6,  12, 14]
405        // [12, 8, 14,  9]
406        // Optimal: 0→1(5), 1→3(3), 2→2(12), 3→0(12) = 32
407        // OR 0→1(5), 1→0(3), 2→2(12), 3→3(9) = 29
408        let costs = vec![
409            10.0, 5.0, 13.0, 15.0, 3.0, 9.0, 18.0, 3.0, 13.0, 6.0, 12.0, 14.0, 12.0, 8.0, 14.0, 9.0,
410        ];
411        let p = solve_lsap(&costs, 4).unwrap();
412        assert!(is_valid_permutation(&p, 4));
413        let cost = assignment_cost(&costs, 4, &p);
414        // Check all possible assignments to find the minimum
415        let min_cost = brute_force_min_cost(&costs, 4);
416        assert!(
417            (cost - min_cost).abs() < 1e-10,
418            "Hungarian cost {cost} != brute force min {min_cost}"
419        );
420    }
421
422    #[test]
423    fn lsap_uniform() {
424        // All costs equal: any permutation is optimal
425        let costs = vec![5.0; 9];
426        let p = solve_lsap(&costs, 3).unwrap();
427        assert!(is_valid_permutation(&p, 3));
428        let cost = assignment_cost(&costs, 3, &p);
429        assert!((cost - 15.0).abs() < 1e-10);
430    }
431
432    #[test]
433    fn lsap_diagonal() {
434        // Diagonal is cheapest
435        let n = 5;
436        let mut costs = vec![100.0; n * n];
437        for i in 0..n {
438            costs[i * n + i] = 1.0;
439        }
440        let p = solve_lsap(&costs, n).unwrap();
441        assert!(is_valid_permutation(&p, n));
442        let cost = assignment_cost(&costs, n, &p);
443        assert!((cost - 5.0).abs() < 1e-10);
444    }
445
446    #[test]
447    fn lsap_invalid_size() {
448        assert!(solve_lsap(&[1.0, 2.0], 2).is_err());
449    }
450
451    #[test]
452    fn lsap_nan_cost() {
453        assert!(solve_lsap(&[f64::NAN, 1.0, 1.0, 1.0], 2).is_err());
454    }
455
456    fn brute_force_min_cost(costs: &[f64], n: usize) -> f64 {
457        let mut perm: Vec<usize> = (0..n).collect();
458        let mut min_cost = f64::MAX;
459        loop {
460            let cost: f64 = (0..n).map(|i| costs[i * n + perm[i]]).sum();
461            if cost < min_cost {
462                min_cost = cost;
463            }
464            if !next_permutation(&mut perm) {
465                break;
466            }
467        }
468        min_cost
469    }
470
471    fn next_permutation(arr: &mut [usize]) -> bool {
472        let n = arr.len();
473        if n < 2 {
474            return false;
475        }
476        let mut i = n - 1;
477        while i > 0 && arr[i - 1] >= arr[i] {
478            i -= 1;
479        }
480        if i == 0 {
481            return false;
482        }
483        let mut j = n - 1;
484        while arr[j] <= arr[i - 1] {
485            j -= 1;
486        }
487        arr.swap(i - 1, j);
488        arr[i..].reverse();
489        true
490    }
491}