1 use petgraph::algo::floyd_warshall;
2 use petgraph::{prelude::*, Directed, Graph, Undirected};
3 use std::collections::HashMap;
4 
5 #[test]
floyd_warshall_uniform_weight()6 fn floyd_warshall_uniform_weight() {
7     let mut graph: Graph<(), (), Directed> = Graph::new();
8     let a = graph.add_node(());
9     let b = graph.add_node(());
10     let c = graph.add_node(());
11     let d = graph.add_node(());
12     let e = graph.add_node(());
13     let f = graph.add_node(());
14     let g = graph.add_node(());
15     let h = graph.add_node(());
16 
17     graph.extend_with_edges(&[
18         (a, b),
19         (b, c),
20         (c, d),
21         (d, a),
22         (e, f),
23         (b, e),
24         (f, g),
25         (g, h),
26         (h, e),
27     ]);
28     // a ----> b ----> e ----> f
29     // ^       |       ^       |
30     // |       v       |       v
31     // d <---- c       h <---- g
32 
33     let inf = std::i32::MAX;
34     let expected_res: HashMap<(NodeIndex, NodeIndex), i32> = [
35         ((a, a), 0),
36         ((a, b), 1),
37         ((a, c), 2),
38         ((a, d), 3),
39         ((a, e), 2),
40         ((a, f), 3),
41         ((a, g), 4),
42         ((a, h), 5),
43         ((b, a), 3),
44         ((b, b), 0),
45         ((b, c), 1),
46         ((b, d), 2),
47         ((b, e), 1),
48         ((b, f), 2),
49         ((b, g), 3),
50         ((b, h), 4),
51         ((c, a), 2),
52         ((c, b), 3),
53         ((c, c), 0),
54         ((c, d), 1),
55         ((c, e), 4),
56         ((c, f), 5),
57         ((c, g), 6),
58         ((c, h), 7),
59         ((d, a), 1),
60         ((d, b), 2),
61         ((d, c), 3),
62         ((d, d), 0),
63         ((d, e), 3),
64         ((d, f), 4),
65         ((d, g), 5),
66         ((d, h), 6),
67         ((e, a), inf),
68         ((e, b), inf),
69         ((e, c), inf),
70         ((e, d), inf),
71         ((e, e), 0),
72         ((e, f), 1),
73         ((e, g), 2),
74         ((e, h), 3),
75         ((f, a), inf),
76         ((f, b), inf),
77         ((f, c), inf),
78         ((f, d), inf),
79         ((f, e), 3),
80         ((f, f), 0),
81         ((f, g), 1),
82         ((f, h), 2),
83         ((g, a), inf),
84         ((g, b), inf),
85         ((g, c), inf),
86         ((g, d), inf),
87         ((g, e), 2),
88         ((g, f), 3),
89         ((g, g), 0),
90         ((g, h), 1),
91         ((h, a), inf),
92         ((h, b), inf),
93         ((h, c), inf),
94         ((h, d), inf),
95         ((h, e), 1),
96         ((h, f), 2),
97         ((h, g), 3),
98         ((h, h), 0),
99     ]
100     .iter()
101     .cloned()
102     .collect();
103     let res = floyd_warshall(&graph, |_| 1_i32).unwrap();
104 
105     let nodes = [a, b, c, d, e, f, g, h];
106     for node1 in &nodes {
107         for node2 in &nodes {
108             assert_eq!(
109                 res.get(&(*node1, *node2)).unwrap(),
110                 expected_res.get(&(*node1, *node2)).unwrap()
111             );
112         }
113     }
114 }
115 
116 #[test]
floyd_warshall_weighted()117 fn floyd_warshall_weighted() {
118     let mut graph: Graph<(), (), Directed> = Graph::new();
119     let a = graph.add_node(());
120     let b = graph.add_node(());
121     let c = graph.add_node(());
122     let d = graph.add_node(());
123 
124     graph.extend_with_edges(&[(a, b), (a, c), (a, d), (b, c), (b, d), (c, d)]);
125 
126     let inf = std::i32::MAX;
127     let expected_res: HashMap<(NodeIndex, NodeIndex), i32> = [
128         ((a, a), 0),
129         ((a, b), 1),
130         ((a, c), 3),
131         ((a, d), 3),
132         ((b, a), inf),
133         ((b, b), 0),
134         ((b, c), 2),
135         ((b, d), 2),
136         ((c, a), inf),
137         ((c, b), inf),
138         ((c, c), 0),
139         ((c, d), 2),
140         ((d, a), inf),
141         ((d, b), inf),
142         ((d, c), inf),
143         ((d, d), 0),
144     ]
145     .iter()
146     .cloned()
147     .collect();
148 
149     let weight_map: HashMap<(NodeIndex, NodeIndex), i32> = [
150         ((a, a), 0),
151         ((a, b), 1),
152         ((a, c), 4),
153         ((a, d), 10),
154         ((b, b), 0),
155         ((b, c), 2),
156         ((b, d), 2),
157         ((c, c), 0),
158         ((c, d), 2),
159     ]
160     .iter()
161     .cloned()
162     .collect();
163 
164     let res = floyd_warshall(&graph, |edge| {
165         if let Some(weight) = weight_map.get(&(edge.source(), edge.target())) {
166             *weight
167         } else {
168             inf
169         }
170     })
171     .unwrap();
172 
173     let nodes = [a, b, c, d];
174     for node1 in &nodes {
175         for node2 in &nodes {
176             assert_eq!(
177                 res.get(&(*node1, *node2)).unwrap(),
178                 expected_res.get(&(*node1, *node2)).unwrap()
179             );
180         }
181     }
182 }
183 
184 #[test]
floyd_warshall_weighted_undirected()185 fn floyd_warshall_weighted_undirected() {
186     let mut graph: Graph<(), (), Undirected> = Graph::new_undirected();
187     let a = graph.add_node(());
188     let b = graph.add_node(());
189     let c = graph.add_node(());
190     let d = graph.add_node(());
191 
192     graph.extend_with_edges(&[(a, b), (a, c), (a, d), (b, d), (c, b), (c, d)]);
193 
194     let inf = std::i32::MAX;
195     let expected_res: HashMap<(NodeIndex, NodeIndex), i32> = [
196         ((a, a), 0),
197         ((a, b), 1),
198         ((a, c), 3),
199         ((a, d), 3),
200         ((b, a), 1),
201         ((b, b), 0),
202         ((b, c), 2),
203         ((b, d), 2),
204         ((c, a), 3),
205         ((c, b), 2),
206         ((c, c), 0),
207         ((c, d), 2),
208         ((d, a), 3),
209         ((d, b), 2),
210         ((d, c), 2),
211         ((d, d), 0),
212     ]
213     .iter()
214     .cloned()
215     .collect();
216 
217     let weight_map: HashMap<(NodeIndex, NodeIndex), i32> = [
218         ((a, a), 0),
219         ((a, b), 1),
220         ((a, c), 4),
221         ((a, d), 10),
222         ((b, b), 0),
223         ((b, d), 2),
224         ((c, b), 2),
225         ((c, c), 0),
226         ((c, d), 2),
227     ]
228     .iter()
229     .cloned()
230     .collect();
231 
232     let res = floyd_warshall(&graph, |edge| {
233         if let Some(weight) = weight_map.get(&(edge.source(), edge.target())) {
234             *weight
235         } else {
236             inf
237         }
238     })
239     .unwrap();
240 
241     let nodes = [a, b, c, d];
242     for node1 in &nodes {
243         for node2 in &nodes {
244             assert_eq!(
245                 res.get(&(*node1, *node2)).unwrap(),
246                 expected_res.get(&(*node1, *node2)).unwrap()
247             );
248         }
249     }
250 }
251 
252 #[test]
floyd_warshall_negative_cycle()253 fn floyd_warshall_negative_cycle() {
254     let mut graph: Graph<(), (), Directed> = Graph::new();
255     let a = graph.add_node(());
256     let b = graph.add_node(());
257     let c = graph.add_node(());
258 
259     graph.extend_with_edges(&[(a, b), (b, c), (c, a)]);
260 
261     let inf = std::i32::MAX;
262 
263     let weight_map: HashMap<(NodeIndex, NodeIndex), i32> = [
264         ((a, a), 0),
265         ((a, b), 1),
266         ((b, b), 0),
267         ((b, c), -3),
268         ((c, c), 0),
269         ((c, a), 1),
270     ]
271     .iter()
272     .cloned()
273     .collect();
274 
275     let res = floyd_warshall(&graph, |edge| {
276         if let Some(weight) = weight_map.get(&(edge.source(), edge.target())) {
277             *weight
278         } else {
279             inf
280         }
281     });
282 
283     assert!(res.is_err());
284 }
285