1use std::cmp::Ordering;
22use std::collections::BinaryHeap;
23
24use crate::algorithms::paths::dijkstra::DijkstraMode;
25use crate::core::graph::EdgeId;
26use crate::core::{Graph, IgraphError, IgraphResult, VertexId};
27
28#[derive(Copy, Clone)]
31struct Frontier(f64, u64, VertexId);
32
33impl PartialEq for Frontier {
34 fn eq(&self, other: &Self) -> bool {
35 self.0 == other.0 && self.1 == other.1 && self.2 == other.2
36 }
37}
38impl Eq for Frontier {}
39impl Ord for Frontier {
40 fn cmp(&self, other: &Self) -> Ordering {
41 other
43 .0
44 .total_cmp(&self.0)
45 .then(other.1.cmp(&self.1))
46 .then(other.2.cmp(&self.2))
47 }
48}
49impl PartialOrd for Frontier {
50 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
51 Some(self.cmp(other))
52 }
53}
54
55fn validate_weights(graph: &Graph, weights: Option<&[f64]>) -> IgraphResult<()> {
56 let Some(w) = weights else {
57 return Ok(());
58 };
59 let m = graph.ecount();
60 if w.len() != m {
61 return Err(IgraphError::InvalidArgument(format!(
62 "weights vector size ({}) differs from edge count ({})",
63 w.len(),
64 m
65 )));
66 }
67 for (e, &v) in w.iter().enumerate() {
68 if v.is_nan() {
69 return Err(IgraphError::InvalidArgument(format!(
70 "weight at edge {e} is NaN"
71 )));
72 }
73 if v < 0.0 {
74 return Err(IgraphError::InvalidArgument(format!(
75 "weight at edge {e} is negative ({v}); A* requires non-negative weights"
76 )));
77 }
78 }
79 Ok(())
80}
81
82fn incident_for_mode(graph: &Graph, v: VertexId, mode: DijkstraMode) -> IgraphResult<Vec<EdgeId>> {
86 if !graph.is_directed() {
87 return graph.incident(v);
88 }
89 match mode {
90 DijkstraMode::Out => graph.incident(v),
91 DijkstraMode::In => graph.incident_in(v),
92 DijkstraMode::All => {
93 let mut out = graph.incident(v)?;
94 out.extend(graph.incident_in(v)?);
95 Ok(out)
96 }
97 }
98}
99
100pub fn a_star_path<H: Fn(VertexId, VertexId) -> f64>(
138 graph: &Graph,
139 from: VertexId,
140 to: VertexId,
141 weights: Option<&[f64]>,
142 mode: DijkstraMode,
143 heuristic: H,
144) -> IgraphResult<Option<(Vec<VertexId>, Vec<EdgeId>)>> {
145 let n = graph.vcount();
146 if from >= n {
147 return Err(IgraphError::VertexOutOfRange { id: from, n });
148 }
149 if to >= n {
150 return Err(IgraphError::VertexOutOfRange { id: to, n });
151 }
152 validate_weights(graph, weights)?;
153
154 if from == to {
155 return Ok(Some((vec![from], Vec::new())));
156 }
157
158 let n_us = n as usize;
159 let mut dist = vec![f64::INFINITY; n_us];
160 let mut parent_eid: Vec<Option<EdgeId>> = vec![None; n_us];
161 let mut closed = vec![false; n_us];
162
163 let mut tiebreaker: u64 = 0;
164 let mut next_tb = || {
165 let t = tiebreaker;
166 tiebreaker += 1;
167 t
168 };
169
170 let mut heap: BinaryHeap<Frontier> = BinaryHeap::new();
171 dist[from as usize] = 0.0;
172 let h0 = heuristic(from, to);
173 if h0.is_nan() || h0 < 0.0 {
174 return Err(IgraphError::InvalidArgument(format!(
175 "heuristic returned invalid estimate ({h0}); must be non-negative and not NaN"
176 )));
177 }
178 heap.push(Frontier(h0, next_tb(), from));
179
180 let mut found = false;
181 while let Some(Frontier(_, _, u)) = heap.pop() {
182 if closed[u as usize] {
183 continue;
184 }
185 closed[u as usize] = true;
186 if u == to {
187 found = true;
188 break;
189 }
190
191 for eid in incident_for_mode(graph, u, mode)? {
192 let w = match weights {
193 None => 1.0,
194 Some(ws) => ws[eid as usize],
195 };
196 if !w.is_finite() {
197 continue;
198 }
199 let v = graph.edge_other(eid as EdgeId, u)?;
200 if closed[v as usize] {
201 continue;
202 }
203 let altdist = dist[u as usize] + w;
204 let curdist = dist[v as usize];
205 if !curdist.is_finite() || altdist < curdist {
206 dist[v as usize] = altdist;
207 parent_eid[v as usize] = Some(eid as EdgeId);
208 let h = heuristic(v, to);
209 if h.is_nan() || h < 0.0 {
210 return Err(IgraphError::InvalidArgument(format!(
211 "heuristic returned invalid estimate ({h}); must be non-negative and not NaN"
212 )));
213 }
214 heap.push(Frontier(altdist + h, next_tb(), v));
215 }
216 }
217 }
218
219 if !found {
220 return Ok(None);
221 }
222
223 let mut vs = Vec::new();
225 let mut es = Vec::new();
226 let mut cur = to;
227 while let Some(eid) = parent_eid[cur as usize] {
228 es.push(eid);
229 vs.push(cur);
230 cur = graph.edge_other(eid, cur)?;
231 }
232 vs.push(cur);
233 vs.reverse();
234 es.reverse();
235 Ok(Some((vs, es)))
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 fn null_h(_: VertexId, _: VertexId) -> f64 {
243 0.0
244 }
245
246 #[test]
247 fn unit_weights_match_bfs_chain() {
248 let mut g = Graph::with_vertices(4);
249 g.add_edge(0, 1).unwrap();
250 g.add_edge(1, 2).unwrap();
251 g.add_edge(2, 3).unwrap();
252 let (vs, es) = a_star_path(&g, 0, 3, None, DijkstraMode::Out, null_h)
253 .unwrap()
254 .unwrap();
255 assert_eq!(vs, vec![0, 1, 2, 3]);
256 assert_eq!(es, vec![0, 1, 2]);
257 }
258
259 #[test]
260 fn weighted_triangle_with_shortcut() {
261 let mut g = Graph::with_vertices(3);
262 g.add_edge(0, 1).unwrap(); g.add_edge(0, 2).unwrap(); g.add_edge(1, 2).unwrap(); let weights = [1.0, 4.0, 2.0];
266 let (vs, es) = a_star_path(&g, 0, 2, Some(&weights), DijkstraMode::Out, null_h)
267 .unwrap()
268 .unwrap();
269 assert_eq!(vs, vec![0, 1, 2]);
271 assert_eq!(es, vec![0, 2]);
272 }
273
274 #[test]
275 fn unreachable_target_returns_none() {
276 let mut g = Graph::with_vertices(3);
277 g.add_edge(0, 1).unwrap();
278 assert_eq!(
279 a_star_path(&g, 0, 2, None, DijkstraMode::Out, null_h).unwrap(),
280 None
281 );
282 }
283
284 #[test]
285 fn from_equals_to_singleton_chain() {
286 let g = Graph::with_vertices(3);
287 let (vs, es) = a_star_path(&g, 1, 1, None, DijkstraMode::Out, null_h)
288 .unwrap()
289 .unwrap();
290 assert_eq!(vs, vec![1]);
291 assert!(es.is_empty());
292 }
293
294 #[test]
295 fn admissible_heuristic_finds_same_path_as_null() {
296 let mut g = Graph::with_vertices(4);
301 g.add_edge(0, 1).unwrap();
302 g.add_edge(1, 2).unwrap();
303 g.add_edge(0, 3).unwrap();
304 g.add_edge(3, 2).unwrap();
305 let h = |v: VertexId, target: VertexId| -> f64 { if v == target { 0.0 } else { 1.0 } };
306 let (vs, _) = a_star_path(&g, 0, 2, None, DijkstraMode::Out, h)
307 .unwrap()
308 .unwrap();
309 assert_eq!(vs.len(), 3);
311 assert_eq!(vs[0], 0);
312 assert_eq!(vs[2], 2);
313 }
314
315 #[test]
316 fn directed_in_mode_walks_reverse_edges() {
317 let mut g = Graph::new(3, true).unwrap();
318 g.add_edge(0, 1).unwrap();
319 g.add_edge(1, 2).unwrap();
320 assert_eq!(
322 a_star_path(&g, 2, 0, None, DijkstraMode::Out, null_h).unwrap(),
323 None
324 );
325 let (vs, es) = a_star_path(&g, 2, 0, None, DijkstraMode::In, null_h)
327 .unwrap()
328 .unwrap();
329 assert_eq!(vs, vec![2, 1, 0]);
330 assert_eq!(es, vec![1, 0]);
331 }
332
333 #[test]
334 fn negative_weight_errors() {
335 let mut g = Graph::with_vertices(2);
336 g.add_edge(0, 1).unwrap();
337 let weights = [-1.0_f64];
338 assert!(a_star_path(&g, 0, 1, Some(&weights), DijkstraMode::Out, null_h).is_err());
339 }
340
341 #[test]
342 fn nan_weight_errors() {
343 let mut g = Graph::with_vertices(2);
344 g.add_edge(0, 1).unwrap();
345 let weights = [f64::NAN];
346 assert!(a_star_path(&g, 0, 1, Some(&weights), DijkstraMode::Out, null_h).is_err());
347 }
348
349 #[test]
350 fn weights_size_mismatch_errors() {
351 let mut g = Graph::with_vertices(2);
352 g.add_edge(0, 1).unwrap();
353 let weights = [1.0_f64, 2.0];
354 assert!(a_star_path(&g, 0, 1, Some(&weights), DijkstraMode::Out, null_h).is_err());
355 }
356
357 #[test]
358 fn out_of_range_source_errors() {
359 let g = Graph::with_vertices(2);
360 assert!(a_star_path(&g, 99, 0, None, DijkstraMode::Out, null_h).is_err());
361 assert!(a_star_path(&g, 0, 99, None, DijkstraMode::Out, null_h).is_err());
362 }
363
364 #[test]
365 fn negative_heuristic_errors() {
366 let mut g = Graph::with_vertices(2);
367 g.add_edge(0, 1).unwrap();
368 let bad_h = |_v: VertexId, _t: VertexId| -1.0_f64;
369 assert!(a_star_path(&g, 0, 1, None, DijkstraMode::Out, bad_h).is_err());
370 }
371
372 #[test]
373 fn infinity_weight_skipped() {
374 let mut g = Graph::with_vertices(3);
376 g.add_edge(0, 1).unwrap(); g.add_edge(0, 2).unwrap(); g.add_edge(2, 1).unwrap(); let weights = [f64::INFINITY, 1.0, 1.0];
380 let (vs, _) = a_star_path(&g, 0, 1, Some(&weights), DijkstraMode::Out, null_h)
381 .unwrap()
382 .unwrap();
383 assert_eq!(vs, vec![0, 2, 1]);
385 }
386}