1use crate::core::{IgraphError, IgraphResult};
17
18pub fn solve_lsap(costs: &[f64], n: usize) -> IgraphResult<Vec<u32>> {
65 if n == 0 {
66 if costs.is_empty() {
67 return Ok(Vec::new());
68 }
69 return Err(IgraphError::InvalidArgument(
70 "solve_lsap: n=0 but costs is non-empty".into(),
71 ));
72 }
73
74 let expected_len = n
75 .checked_mul(n)
76 .ok_or_else(|| IgraphError::InvalidArgument("solve_lsap: n*n overflows".into()))?;
77 if costs.len() != expected_len {
78 return Err(IgraphError::InvalidArgument(format!(
79 "solve_lsap: costs length {} != n*n = {}",
80 costs.len(),
81 expected_len
82 )));
83 }
84
85 for (i, &v) in costs.iter().enumerate() {
86 if v.is_nan() {
87 return Err(IgraphError::InvalidArgument(format!(
88 "solve_lsap: costs[{i}] is NaN"
89 )));
90 }
91 }
92
93 let assignment = hungarian(costs, n);
94 Ok(assignment)
95}
96
97fn hungarian(costs: &[f64], n: usize) -> Vec<u32> {
98 let mut c = vec![vec![0.0_f64; n + 1]; n + 1];
100 for i in 1..=n {
101 for j in 1..=n {
102 c[i][j] = costs[(i - 1) * n + (j - 1)];
103 }
104 }
105
106 preprocess(&mut c, n);
107
108 let mut s = vec![0_usize; n + 1];
110 let mut f = vec![0_usize; n + 1];
112 let mut na = 0_usize;
113
114 preassign(&c, n, &mut s, &mut f, &mut na);
115
116 while na < n {
117 let mut ri = vec![false; n + 1]; let mut ci = vec![false; n + 1]; if cover(&mut c, n, &mut s, &mut f, &mut na, &mut ri, &mut ci) {
121 reduce(&mut c, n, &ri, &ci);
122 }
123 }
124
125 (1..=n)
127 .map(|i| u32::try_from(s[i] - 1).unwrap_or(0))
128 .collect()
129}
130
131#[allow(clippy::needless_range_loop)]
132fn preprocess(c: &mut [Vec<f64>], n: usize) {
133 for i in 1..=n {
135 let mut min = c[i][1];
136 for j in 2..=n {
137 if c[i][j] < min {
138 min = c[i][j];
139 }
140 }
141 for j in 1..=n {
142 c[i][j] -= min;
143 }
144 }
145
146 for j in 1..=n {
148 let mut min = c[1][j];
149 for i in 2..=n {
150 if c[i][j] < min {
151 min = c[i][j];
152 }
153 }
154 for i in 1..=n {
155 c[i][j] -= min;
156 }
157 }
158}
159
160#[allow(clippy::needless_range_loop)]
161fn preassign(c: &[Vec<f64>], n: usize, s: &mut [usize], f: &mut [usize], na: &mut usize) {
162 *na = 0;
163 let mut row_assigned = vec![false; n + 1];
164 let mut col_assigned = vec![false; n + 1];
165
166 let mut rz = vec![0_usize; n + 1];
168 let mut cz = vec![0_usize; n + 1];
169
170 for i in 1..=n {
171 for j in 1..=n {
172 if c[i][j] == 0.0 {
173 rz[i] += 1;
174 cz[j] += 1;
175 }
176 }
177 }
178
179 loop {
180 let mut best_row = 0;
182 let mut best_count = usize::MAX;
183 for i in 1..=n {
184 if !row_assigned[i] && rz[i] > 0 && rz[i] < best_count {
185 best_count = rz[i];
186 best_row = i;
187 }
188 }
189 if best_row == 0 {
190 break;
191 }
192
193 let mut best_col = 0;
195 let mut best_col_count = usize::MAX;
196 for j in 1..=n {
197 if c[best_row][j] == 0.0 && !col_assigned[j] && cz[j] < best_col_count {
198 best_col_count = cz[j];
199 best_col = j;
200 }
201 }
202
203 if best_col != 0 {
204 *na += 1;
205 s[best_row] = best_col;
206 f[best_col] = best_row;
207 row_assigned[best_row] = true;
208 col_assigned[best_col] = true;
209
210 for i in 1..=n {
212 if c[i][best_col] == 0.0 {
213 rz[i] = rz[i].saturating_sub(1);
214 }
215 }
216 cz[best_col] = 0;
217 } else {
218 rz[best_row] = 0;
220 }
221 }
222}
223
224#[allow(clippy::needless_range_loop, clippy::many_single_char_names)]
226fn cover(
227 c: &mut [Vec<f64>],
228 n: usize,
229 s: &mut [usize],
230 f: &mut [usize],
231 na: &mut usize,
232 ri: &mut [bool],
233 ci: &mut [bool],
234) -> bool {
235 let mut mr = vec![false; n + 1]; for i in 1..=n {
238 if s[i] == 0 {
239 ri[i] = false; mr[i] = true; } else {
242 ri[i] = true; }
244 ci[i] = false; }
246
247 loop {
248 let mut r = 0;
250 for i in 1..=n {
251 if mr[i] {
252 r = i;
253 break;
254 }
255 }
256 if r == 0 {
257 break;
258 }
259
260 let mut found_augment = false;
262 for j in 1..=n {
263 if c[r][j] == 0.0 && !ci[j] {
264 if f[j] != 0 {
265 ri[f[j]] = false;
267 mr[f[j]] = true;
268 ci[j] = true;
269 } else {
270 if s[r] == 0 {
272 *na += 1;
273 }
274 let old_col = s[r];
276 if old_col != 0 {
277 f[old_col] = 0;
278 }
279 f[j] = r;
280 s[r] = j;
281 found_augment = true;
282 break;
283 }
284 }
285 }
286
287 if found_augment {
288 return false;
289 }
290 mr[r] = false;
291 }
292
293 true
294}
295
296#[allow(clippy::needless_range_loop)]
297fn reduce(c: &mut [Vec<f64>], n: usize, ri: &[bool], ci: &[bool]) {
298 let mut min = f64::MAX;
300 for i in 1..=n {
301 if ri[i] {
302 continue;
303 }
304 for j in 1..=n {
305 if ci[j] {
306 continue;
307 }
308 if c[i][j] < min {
309 min = c[i][j];
310 }
311 }
312 }
313
314 for i in 1..=n {
316 for j in 1..=n {
317 if !ri[i] && !ci[j] {
318 c[i][j] -= min;
319 } else if ri[i] && ci[j] {
320 c[i][j] += min;
321 }
322 }
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 fn is_valid_permutation(p: &[u32], n: usize) -> bool {
331 if p.len() != n {
332 return false;
333 }
334 let mut used = vec![false; n];
335 for &col in p {
336 let c = col as usize;
337 if c >= n || used[c] {
338 return false;
339 }
340 used[c] = true;
341 }
342 true
343 }
344
345 fn assignment_cost(costs: &[f64], n: usize, p: &[u32]) -> f64 {
346 (0..n).map(|i| costs[i * n + p[i] as usize]).sum()
347 }
348
349 #[test]
350 fn lsap_empty() {
351 let p = solve_lsap(&[], 0).unwrap();
352 assert!(p.is_empty());
353 }
354
355 #[test]
356 fn lsap_1x1() {
357 let p = solve_lsap(&[42.0], 1).unwrap();
358 assert_eq!(p, vec![0]);
359 }
360
361 #[test]
362 fn lsap_2x2_identity() {
363 let costs = vec![1.0, 100.0, 100.0, 1.0];
367 let p = solve_lsap(&costs, 2).unwrap();
368 assert!(is_valid_permutation(&p, 2));
369 let cost = assignment_cost(&costs, 2, &p);
370 assert!((cost - 2.0).abs() < 1e-10);
371 }
372
373 #[test]
374 fn lsap_2x2_swap() {
375 let costs = vec![100.0, 1.0, 1.0, 100.0];
379 let p = solve_lsap(&costs, 2).unwrap();
380 assert!(is_valid_permutation(&p, 2));
381 let cost = assignment_cost(&costs, 2, &p);
382 assert!((cost - 2.0).abs() < 1e-10);
383 }
384
385 #[test]
386 fn lsap_3x3() {
387 let costs = vec![82.0, 83.0, 69.0, 77.0, 37.0, 49.0, 11.0, 69.0, 5.0];
393 let p = solve_lsap(&costs, 3).unwrap();
394 assert!(is_valid_permutation(&p, 3));
395 let cost = assignment_cost(&costs, 3, &p);
396 assert!((cost - 117.0).abs() < 1e-10);
398 }
399
400 #[test]
401 fn lsap_4x4() {
402 let costs = vec![
409 10.0, 5.0, 13.0, 15.0, 3.0, 9.0, 18.0, 3.0, 13.0, 6.0, 12.0, 14.0, 12.0, 8.0, 14.0, 9.0,
410 ];
411 let p = solve_lsap(&costs, 4).unwrap();
412 assert!(is_valid_permutation(&p, 4));
413 let cost = assignment_cost(&costs, 4, &p);
414 let min_cost = brute_force_min_cost(&costs, 4);
416 assert!(
417 (cost - min_cost).abs() < 1e-10,
418 "Hungarian cost {cost} != brute force min {min_cost}"
419 );
420 }
421
422 #[test]
423 fn lsap_uniform() {
424 let costs = vec![5.0; 9];
426 let p = solve_lsap(&costs, 3).unwrap();
427 assert!(is_valid_permutation(&p, 3));
428 let cost = assignment_cost(&costs, 3, &p);
429 assert!((cost - 15.0).abs() < 1e-10);
430 }
431
432 #[test]
433 fn lsap_diagonal() {
434 let n = 5;
436 let mut costs = vec![100.0; n * n];
437 for i in 0..n {
438 costs[i * n + i] = 1.0;
439 }
440 let p = solve_lsap(&costs, n).unwrap();
441 assert!(is_valid_permutation(&p, n));
442 let cost = assignment_cost(&costs, n, &p);
443 assert!((cost - 5.0).abs() < 1e-10);
444 }
445
446 #[test]
447 fn lsap_invalid_size() {
448 assert!(solve_lsap(&[1.0, 2.0], 2).is_err());
449 }
450
451 #[test]
452 fn lsap_nan_cost() {
453 assert!(solve_lsap(&[f64::NAN, 1.0, 1.0, 1.0], 2).is_err());
454 }
455
456 fn brute_force_min_cost(costs: &[f64], n: usize) -> f64 {
457 let mut perm: Vec<usize> = (0..n).collect();
458 let mut min_cost = f64::MAX;
459 loop {
460 let cost: f64 = (0..n).map(|i| costs[i * n + perm[i]]).sum();
461 if cost < min_cost {
462 min_cost = cost;
463 }
464 if !next_permutation(&mut perm) {
465 break;
466 }
467 }
468 min_cost
469 }
470
471 fn next_permutation(arr: &mut [usize]) -> bool {
472 let n = arr.len();
473 if n < 2 {
474 return false;
475 }
476 let mut i = n - 1;
477 while i > 0 && arr[i - 1] >= arr[i] {
478 i -= 1;
479 }
480 if i == 0 {
481 return false;
482 }
483 let mut j = n - 1;
484 while arr[j] <= arr[i - 1] {
485 j -= 1;
486 }
487 arr.swap(i - 1, j);
488 arr[i..].reverse();
489 true
490 }
491}