1 use petgraph::algo::floyd_warshall;
2 use petgraph::{prelude::*, Directed, Graph, Undirected};
3 use std::collections::HashMap;
4
5 #[test]
floyd_warshall_uniform_weight()6 fn floyd_warshall_uniform_weight() {
7 let mut graph: Graph<(), (), Directed> = Graph::new();
8 let a = graph.add_node(());
9 let b = graph.add_node(());
10 let c = graph.add_node(());
11 let d = graph.add_node(());
12 let e = graph.add_node(());
13 let f = graph.add_node(());
14 let g = graph.add_node(());
15 let h = graph.add_node(());
16
17 graph.extend_with_edges(&[
18 (a, b),
19 (b, c),
20 (c, d),
21 (d, a),
22 (e, f),
23 (b, e),
24 (f, g),
25 (g, h),
26 (h, e),
27 ]);
28 // a ----> b ----> e ----> f
29 // ^ | ^ |
30 // | v | v
31 // d <---- c h <---- g
32
33 let inf = std::i32::MAX;
34 let expected_res: HashMap<(NodeIndex, NodeIndex), i32> = [
35 ((a, a), 0),
36 ((a, b), 1),
37 ((a, c), 2),
38 ((a, d), 3),
39 ((a, e), 2),
40 ((a, f), 3),
41 ((a, g), 4),
42 ((a, h), 5),
43 ((b, a), 3),
44 ((b, b), 0),
45 ((b, c), 1),
46 ((b, d), 2),
47 ((b, e), 1),
48 ((b, f), 2),
49 ((b, g), 3),
50 ((b, h), 4),
51 ((c, a), 2),
52 ((c, b), 3),
53 ((c, c), 0),
54 ((c, d), 1),
55 ((c, e), 4),
56 ((c, f), 5),
57 ((c, g), 6),
58 ((c, h), 7),
59 ((d, a), 1),
60 ((d, b), 2),
61 ((d, c), 3),
62 ((d, d), 0),
63 ((d, e), 3),
64 ((d, f), 4),
65 ((d, g), 5),
66 ((d, h), 6),
67 ((e, a), inf),
68 ((e, b), inf),
69 ((e, c), inf),
70 ((e, d), inf),
71 ((e, e), 0),
72 ((e, f), 1),
73 ((e, g), 2),
74 ((e, h), 3),
75 ((f, a), inf),
76 ((f, b), inf),
77 ((f, c), inf),
78 ((f, d), inf),
79 ((f, e), 3),
80 ((f, f), 0),
81 ((f, g), 1),
82 ((f, h), 2),
83 ((g, a), inf),
84 ((g, b), inf),
85 ((g, c), inf),
86 ((g, d), inf),
87 ((g, e), 2),
88 ((g, f), 3),
89 ((g, g), 0),
90 ((g, h), 1),
91 ((h, a), inf),
92 ((h, b), inf),
93 ((h, c), inf),
94 ((h, d), inf),
95 ((h, e), 1),
96 ((h, f), 2),
97 ((h, g), 3),
98 ((h, h), 0),
99 ]
100 .iter()
101 .cloned()
102 .collect();
103 let res = floyd_warshall(&graph, |_| 1_i32).unwrap();
104
105 let nodes = [a, b, c, d, e, f, g, h];
106 for node1 in &nodes {
107 for node2 in &nodes {
108 assert_eq!(
109 res.get(&(*node1, *node2)).unwrap(),
110 expected_res.get(&(*node1, *node2)).unwrap()
111 );
112 }
113 }
114 }
115
116 #[test]
floyd_warshall_weighted()117 fn floyd_warshall_weighted() {
118 let mut graph: Graph<(), (), Directed> = Graph::new();
119 let a = graph.add_node(());
120 let b = graph.add_node(());
121 let c = graph.add_node(());
122 let d = graph.add_node(());
123
124 graph.extend_with_edges(&[(a, b), (a, c), (a, d), (b, c), (b, d), (c, d)]);
125
126 let inf = std::i32::MAX;
127 let expected_res: HashMap<(NodeIndex, NodeIndex), i32> = [
128 ((a, a), 0),
129 ((a, b), 1),
130 ((a, c), 3),
131 ((a, d), 3),
132 ((b, a), inf),
133 ((b, b), 0),
134 ((b, c), 2),
135 ((b, d), 2),
136 ((c, a), inf),
137 ((c, b), inf),
138 ((c, c), 0),
139 ((c, d), 2),
140 ((d, a), inf),
141 ((d, b), inf),
142 ((d, c), inf),
143 ((d, d), 0),
144 ]
145 .iter()
146 .cloned()
147 .collect();
148
149 let weight_map: HashMap<(NodeIndex, NodeIndex), i32> = [
150 ((a, a), 0),
151 ((a, b), 1),
152 ((a, c), 4),
153 ((a, d), 10),
154 ((b, b), 0),
155 ((b, c), 2),
156 ((b, d), 2),
157 ((c, c), 0),
158 ((c, d), 2),
159 ]
160 .iter()
161 .cloned()
162 .collect();
163
164 let res = floyd_warshall(&graph, |edge| {
165 if let Some(weight) = weight_map.get(&(edge.source(), edge.target())) {
166 *weight
167 } else {
168 inf
169 }
170 })
171 .unwrap();
172
173 let nodes = [a, b, c, d];
174 for node1 in &nodes {
175 for node2 in &nodes {
176 assert_eq!(
177 res.get(&(*node1, *node2)).unwrap(),
178 expected_res.get(&(*node1, *node2)).unwrap()
179 );
180 }
181 }
182 }
183
184 #[test]
floyd_warshall_weighted_undirected()185 fn floyd_warshall_weighted_undirected() {
186 let mut graph: Graph<(), (), Undirected> = Graph::new_undirected();
187 let a = graph.add_node(());
188 let b = graph.add_node(());
189 let c = graph.add_node(());
190 let d = graph.add_node(());
191
192 graph.extend_with_edges(&[(a, b), (a, c), (a, d), (b, d), (c, b), (c, d)]);
193
194 let inf = std::i32::MAX;
195 let expected_res: HashMap<(NodeIndex, NodeIndex), i32> = [
196 ((a, a), 0),
197 ((a, b), 1),
198 ((a, c), 3),
199 ((a, d), 3),
200 ((b, a), 1),
201 ((b, b), 0),
202 ((b, c), 2),
203 ((b, d), 2),
204 ((c, a), 3),
205 ((c, b), 2),
206 ((c, c), 0),
207 ((c, d), 2),
208 ((d, a), 3),
209 ((d, b), 2),
210 ((d, c), 2),
211 ((d, d), 0),
212 ]
213 .iter()
214 .cloned()
215 .collect();
216
217 let weight_map: HashMap<(NodeIndex, NodeIndex), i32> = [
218 ((a, a), 0),
219 ((a, b), 1),
220 ((a, c), 4),
221 ((a, d), 10),
222 ((b, b), 0),
223 ((b, d), 2),
224 ((c, b), 2),
225 ((c, c), 0),
226 ((c, d), 2),
227 ]
228 .iter()
229 .cloned()
230 .collect();
231
232 let res = floyd_warshall(&graph, |edge| {
233 if let Some(weight) = weight_map.get(&(edge.source(), edge.target())) {
234 *weight
235 } else {
236 inf
237 }
238 })
239 .unwrap();
240
241 let nodes = [a, b, c, d];
242 for node1 in &nodes {
243 for node2 in &nodes {
244 assert_eq!(
245 res.get(&(*node1, *node2)).unwrap(),
246 expected_res.get(&(*node1, *node2)).unwrap()
247 );
248 }
249 }
250 }
251
252 #[test]
floyd_warshall_negative_cycle()253 fn floyd_warshall_negative_cycle() {
254 let mut graph: Graph<(), (), Directed> = Graph::new();
255 let a = graph.add_node(());
256 let b = graph.add_node(());
257 let c = graph.add_node(());
258
259 graph.extend_with_edges(&[(a, b), (b, c), (c, a)]);
260
261 let inf = std::i32::MAX;
262
263 let weight_map: HashMap<(NodeIndex, NodeIndex), i32> = [
264 ((a, a), 0),
265 ((a, b), 1),
266 ((b, b), 0),
267 ((b, c), -3),
268 ((c, c), 0),
269 ((c, a), 1),
270 ]
271 .iter()
272 .cloned()
273 .collect();
274
275 let res = floyd_warshall(&graph, |edge| {
276 if let Some(weight) = weight_map.get(&(edge.source(), edge.target())) {
277 *weight
278 } else {
279 inf
280 }
281 });
282
283 assert!(res.is_err());
284 }
285