xref: /aosp_15_r20/external/pytorch/test/package/test_digraph.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: package/deploy"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom torch.package._digraph import DiGraph
4*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import run_tests
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workertry:
8*da0073e9SAndroid Build Coastguard Worker    from .common import PackageTestCase
9*da0073e9SAndroid Build Coastguard Workerexcept ImportError:
10*da0073e9SAndroid Build Coastguard Worker    # Support the case where we run this file directly.
11*da0073e9SAndroid Build Coastguard Worker    from common import PackageTestCase
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerclass TestDiGraph(PackageTestCase):
15*da0073e9SAndroid Build Coastguard Worker    """Test the DiGraph structure we use to represent dependencies in PackageExporter"""
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker    def test_successors(self):
18*da0073e9SAndroid Build Coastguard Worker        g = DiGraph()
19*da0073e9SAndroid Build Coastguard Worker        g.add_edge("foo", "bar")
20*da0073e9SAndroid Build Coastguard Worker        g.add_edge("foo", "baz")
21*da0073e9SAndroid Build Coastguard Worker        g.add_node("qux")
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker        self.assertIn("bar", list(g.successors("foo")))
24*da0073e9SAndroid Build Coastguard Worker        self.assertIn("baz", list(g.successors("foo")))
25*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(g.successors("qux"))), 0)
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    def test_predecessors(self):
28*da0073e9SAndroid Build Coastguard Worker        g = DiGraph()
29*da0073e9SAndroid Build Coastguard Worker        g.add_edge("foo", "bar")
30*da0073e9SAndroid Build Coastguard Worker        g.add_edge("foo", "baz")
31*da0073e9SAndroid Build Coastguard Worker        g.add_node("qux")
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker        self.assertIn("foo", list(g.predecessors("bar")))
34*da0073e9SAndroid Build Coastguard Worker        self.assertIn("foo", list(g.predecessors("baz")))
35*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(g.predecessors("qux"))), 0)
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker    def test_successor_not_in_graph(self):
38*da0073e9SAndroid Build Coastguard Worker        g = DiGraph()
39*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
40*da0073e9SAndroid Build Coastguard Worker            g.successors("not in graph")
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    def test_predecessor_not_in_graph(self):
43*da0073e9SAndroid Build Coastguard Worker        g = DiGraph()
44*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ValueError):
45*da0073e9SAndroid Build Coastguard Worker            g.predecessors("not in graph")
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    def test_node_attrs(self):
48*da0073e9SAndroid Build Coastguard Worker        g = DiGraph()
49*da0073e9SAndroid Build Coastguard Worker        g.add_node("foo", my_attr=1, other_attr=2)
50*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g.nodes["foo"]["my_attr"], 1)
51*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g.nodes["foo"]["other_attr"], 2)
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    def test_node_attr_update(self):
54*da0073e9SAndroid Build Coastguard Worker        g = DiGraph()
55*da0073e9SAndroid Build Coastguard Worker        g.add_node("foo", my_attr=1)
56*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g.nodes["foo"]["my_attr"], 1)
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker        g.add_node("foo", my_attr="different")
59*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g.nodes["foo"]["my_attr"], "different")
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    def test_edges(self):
62*da0073e9SAndroid Build Coastguard Worker        g = DiGraph()
63*da0073e9SAndroid Build Coastguard Worker        g.add_edge(1, 2)
64*da0073e9SAndroid Build Coastguard Worker        g.add_edge(2, 3)
65*da0073e9SAndroid Build Coastguard Worker        g.add_edge(1, 3)
66*da0073e9SAndroid Build Coastguard Worker        g.add_edge(4, 5)
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker        edge_list = list(g.edges)
69*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(edge_list), 4)
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker        self.assertIn((1, 2), edge_list)
72*da0073e9SAndroid Build Coastguard Worker        self.assertIn((2, 3), edge_list)
73*da0073e9SAndroid Build Coastguard Worker        self.assertIn((1, 3), edge_list)
74*da0073e9SAndroid Build Coastguard Worker        self.assertIn((4, 5), edge_list)
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    def test_iter(self):
77*da0073e9SAndroid Build Coastguard Worker        g = DiGraph()
78*da0073e9SAndroid Build Coastguard Worker        g.add_node(1)
79*da0073e9SAndroid Build Coastguard Worker        g.add_node(2)
80*da0073e9SAndroid Build Coastguard Worker        g.add_node(3)
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker        nodes = set()
83*da0073e9SAndroid Build Coastguard Worker        nodes.update(g)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nodes, {1, 2, 3})
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    def test_contains(self):
88*da0073e9SAndroid Build Coastguard Worker        g = DiGraph()
89*da0073e9SAndroid Build Coastguard Worker        g.add_node("yup")
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("yup" in g)
92*da0073e9SAndroid Build Coastguard Worker        self.assertFalse("nup" in g)
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    def test_contains_non_hashable(self):
95*da0073e9SAndroid Build Coastguard Worker        g = DiGraph()
96*da0073e9SAndroid Build Coastguard Worker        self.assertFalse([1, 2, 3] in g)
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker    def test_forward_closure(self):
99*da0073e9SAndroid Build Coastguard Worker        g = DiGraph()
100*da0073e9SAndroid Build Coastguard Worker        g.add_edge("1", "2")
101*da0073e9SAndroid Build Coastguard Worker        g.add_edge("2", "3")
102*da0073e9SAndroid Build Coastguard Worker        g.add_edge("5", "4")
103*da0073e9SAndroid Build Coastguard Worker        g.add_edge("4", "3")
104*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(g.forward_transitive_closure("1") == {"1", "2", "3"})
105*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(g.forward_transitive_closure("4") == {"4", "3"})
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    def test_all_paths(self):
108*da0073e9SAndroid Build Coastguard Worker        g = DiGraph()
109*da0073e9SAndroid Build Coastguard Worker        g.add_edge("1", "2")
110*da0073e9SAndroid Build Coastguard Worker        g.add_edge("1", "7")
111*da0073e9SAndroid Build Coastguard Worker        g.add_edge("7", "8")
112*da0073e9SAndroid Build Coastguard Worker        g.add_edge("8", "3")
113*da0073e9SAndroid Build Coastguard Worker        g.add_edge("2", "3")
114*da0073e9SAndroid Build Coastguard Worker        g.add_edge("5", "4")
115*da0073e9SAndroid Build Coastguard Worker        g.add_edge("4", "3")
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker        result = g.all_paths("1", "3")
118*da0073e9SAndroid Build Coastguard Worker        # to get rid of indeterminism
119*da0073e9SAndroid Build Coastguard Worker        actual = {i.strip("\n") for i in result.split(";")[2:-1]}
120*da0073e9SAndroid Build Coastguard Worker        expected = {
121*da0073e9SAndroid Build Coastguard Worker            '"2" -> "3"',
122*da0073e9SAndroid Build Coastguard Worker            '"1" -> "7"',
123*da0073e9SAndroid Build Coastguard Worker            '"7" -> "8"',
124*da0073e9SAndroid Build Coastguard Worker            '"1" -> "2"',
125*da0073e9SAndroid Build Coastguard Worker            '"8" -> "3"',
126*da0073e9SAndroid Build Coastguard Worker        }
127*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(actual, expected)
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
131*da0073e9SAndroid Build Coastguard Worker    run_tests()
132