1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: package/deploy"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom torch.package._digraph import DiGraph 4*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workertry: 8*da0073e9SAndroid Build Coastguard Worker from .common import PackageTestCase 9*da0073e9SAndroid Build Coastguard Workerexcept ImportError: 10*da0073e9SAndroid Build Coastguard Worker # Support the case where we run this file directly. 11*da0073e9SAndroid Build Coastguard Worker from common import PackageTestCase 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workerclass TestDiGraph(PackageTestCase): 15*da0073e9SAndroid Build Coastguard Worker """Test the DiGraph structure we use to represent dependencies in PackageExporter""" 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker def test_successors(self): 18*da0073e9SAndroid Build Coastguard Worker g = DiGraph() 19*da0073e9SAndroid Build Coastguard Worker g.add_edge("foo", "bar") 20*da0073e9SAndroid Build Coastguard Worker g.add_edge("foo", "baz") 21*da0073e9SAndroid Build Coastguard Worker g.add_node("qux") 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker self.assertIn("bar", list(g.successors("foo"))) 24*da0073e9SAndroid Build Coastguard Worker self.assertIn("baz", list(g.successors("foo"))) 25*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(g.successors("qux"))), 0) 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker def test_predecessors(self): 28*da0073e9SAndroid Build Coastguard Worker g = DiGraph() 29*da0073e9SAndroid Build Coastguard Worker g.add_edge("foo", "bar") 30*da0073e9SAndroid Build Coastguard Worker g.add_edge("foo", "baz") 31*da0073e9SAndroid Build Coastguard Worker g.add_node("qux") 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Worker self.assertIn("foo", list(g.predecessors("bar"))) 34*da0073e9SAndroid Build Coastguard Worker self.assertIn("foo", list(g.predecessors("baz"))) 35*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(g.predecessors("qux"))), 0) 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker def test_successor_not_in_graph(self): 38*da0073e9SAndroid Build Coastguard Worker g = DiGraph() 39*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 40*da0073e9SAndroid Build Coastguard Worker g.successors("not in graph") 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker def test_predecessor_not_in_graph(self): 43*da0073e9SAndroid Build Coastguard Worker g = DiGraph() 44*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ValueError): 45*da0073e9SAndroid Build Coastguard Worker g.predecessors("not in graph") 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker def test_node_attrs(self): 48*da0073e9SAndroid Build Coastguard Worker g = DiGraph() 49*da0073e9SAndroid Build Coastguard Worker g.add_node("foo", my_attr=1, other_attr=2) 50*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g.nodes["foo"]["my_attr"], 1) 51*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g.nodes["foo"]["other_attr"], 2) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker def test_node_attr_update(self): 54*da0073e9SAndroid Build Coastguard Worker g = DiGraph() 55*da0073e9SAndroid Build Coastguard Worker g.add_node("foo", my_attr=1) 56*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g.nodes["foo"]["my_attr"], 1) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker g.add_node("foo", my_attr="different") 59*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g.nodes["foo"]["my_attr"], "different") 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker def test_edges(self): 62*da0073e9SAndroid Build Coastguard Worker g = DiGraph() 63*da0073e9SAndroid Build Coastguard Worker g.add_edge(1, 2) 64*da0073e9SAndroid Build Coastguard Worker g.add_edge(2, 3) 65*da0073e9SAndroid Build Coastguard Worker g.add_edge(1, 3) 66*da0073e9SAndroid Build Coastguard Worker g.add_edge(4, 5) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker edge_list = list(g.edges) 69*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(edge_list), 4) 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker self.assertIn((1, 2), edge_list) 72*da0073e9SAndroid Build Coastguard Worker self.assertIn((2, 3), edge_list) 73*da0073e9SAndroid Build Coastguard Worker self.assertIn((1, 3), edge_list) 74*da0073e9SAndroid Build Coastguard Worker self.assertIn((4, 5), edge_list) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker def test_iter(self): 77*da0073e9SAndroid Build Coastguard Worker g = DiGraph() 78*da0073e9SAndroid Build Coastguard Worker g.add_node(1) 79*da0073e9SAndroid Build Coastguard Worker g.add_node(2) 80*da0073e9SAndroid Build Coastguard Worker g.add_node(3) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker nodes = set() 83*da0073e9SAndroid Build Coastguard Worker nodes.update(g) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nodes, {1, 2, 3}) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker def test_contains(self): 88*da0073e9SAndroid Build Coastguard Worker g = DiGraph() 89*da0073e9SAndroid Build Coastguard Worker g.add_node("yup") 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker self.assertTrue("yup" in g) 92*da0073e9SAndroid Build Coastguard Worker self.assertFalse("nup" in g) 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker def test_contains_non_hashable(self): 95*da0073e9SAndroid Build Coastguard Worker g = DiGraph() 96*da0073e9SAndroid Build Coastguard Worker self.assertFalse([1, 2, 3] in g) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker def test_forward_closure(self): 99*da0073e9SAndroid Build Coastguard Worker g = DiGraph() 100*da0073e9SAndroid Build Coastguard Worker g.add_edge("1", "2") 101*da0073e9SAndroid Build Coastguard Worker g.add_edge("2", "3") 102*da0073e9SAndroid Build Coastguard Worker g.add_edge("5", "4") 103*da0073e9SAndroid Build Coastguard Worker g.add_edge("4", "3") 104*da0073e9SAndroid Build Coastguard Worker self.assertTrue(g.forward_transitive_closure("1") == {"1", "2", "3"}) 105*da0073e9SAndroid Build Coastguard Worker self.assertTrue(g.forward_transitive_closure("4") == {"4", "3"}) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker def test_all_paths(self): 108*da0073e9SAndroid Build Coastguard Worker g = DiGraph() 109*da0073e9SAndroid Build Coastguard Worker g.add_edge("1", "2") 110*da0073e9SAndroid Build Coastguard Worker g.add_edge("1", "7") 111*da0073e9SAndroid Build Coastguard Worker g.add_edge("7", "8") 112*da0073e9SAndroid Build Coastguard Worker g.add_edge("8", "3") 113*da0073e9SAndroid Build Coastguard Worker g.add_edge("2", "3") 114*da0073e9SAndroid Build Coastguard Worker g.add_edge("5", "4") 115*da0073e9SAndroid Build Coastguard Worker g.add_edge("4", "3") 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker result = g.all_paths("1", "3") 118*da0073e9SAndroid Build Coastguard Worker # to get rid of indeterminism 119*da0073e9SAndroid Build Coastguard Worker actual = {i.strip("\n") for i in result.split(";")[2:-1]} 120*da0073e9SAndroid Build Coastguard Worker expected = { 121*da0073e9SAndroid Build Coastguard Worker '"2" -> "3"', 122*da0073e9SAndroid Build Coastguard Worker '"1" -> "7"', 123*da0073e9SAndroid Build Coastguard Worker '"7" -> "8"', 124*da0073e9SAndroid Build Coastguard Worker '"1" -> "2"', 125*da0073e9SAndroid Build Coastguard Worker '"8" -> "3"', 126*da0073e9SAndroid Build Coastguard Worker } 127*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 131*da0073e9SAndroid Build Coastguard Worker run_tests() 132