xref: /aosp_15_r20/external/pytorch/tools/test/test_test_run.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 import sys
2 import unittest
3 from pathlib import Path
4 
5 
6 REPO_ROOT = Path(__file__).resolve().parent.parent.parent
7 try:
8     # using tools/ to optimize test run.
9     sys.path.append(str(REPO_ROOT))
10     from tools.testing.test_run import ShardedTest, TestRun
11 except ModuleNotFoundError:
12     print("Can't import required modules, exiting")
13     sys.exit(1)
14 
15 
16 class TestTestRun(unittest.TestCase):
17     def test_union_with_full_run(self) -> None:
18         run1 = TestRun("foo")
19         run2 = TestRun("foo::bar")
20 
21         self.assertEqual(run1 | run2, run1)
22         self.assertEqual(run2 | run1, run1)
23 
24     def test_union_with_inclusions(self) -> None:
25         run1 = TestRun("foo::bar")
26         run2 = TestRun("foo::baz")
27 
28         expected = TestRun("foo", included=["bar", "baz"])
29 
30         self.assertEqual(run1 | run2, expected)
31         self.assertEqual(run2 | run1, expected)
32 
33     def test_union_with_non_overlapping_exclusions(self) -> None:
34         run1 = TestRun("foo", excluded=["bar"])
35         run2 = TestRun("foo", excluded=["baz"])
36 
37         expected = TestRun("foo")
38 
39         self.assertEqual(run1 | run2, expected)
40         self.assertEqual(run2 | run1, expected)
41 
42     def test_union_with_overlapping_exclusions(self) -> None:
43         run1 = TestRun("foo", excluded=["bar", "car"])
44         run2 = TestRun("foo", excluded=["bar", "caz"])
45 
46         expected = TestRun("foo", excluded=["bar"])
47 
48         self.assertEqual(run1 | run2, expected)
49         self.assertEqual(run2 | run1, expected)
50 
51     def test_union_with_mixed_inclusion_exclusions(self) -> None:
52         run1 = TestRun("foo", excluded=["baz", "car"])
53         run2 = TestRun("foo", included=["baz"])
54 
55         expected = TestRun("foo", excluded=["car"])
56 
57         self.assertEqual(run1 | run2, expected)
58         self.assertEqual(run2 | run1, expected)
59 
60     def test_union_with_mixed_files_fails(self) -> None:
61         run1 = TestRun("foo")
62         run2 = TestRun("bar")
63 
64         with self.assertRaises(AssertionError):
65             run1 | run2
66 
67     def test_union_with_empty_file_yields_orig_file(self) -> None:
68         run1 = TestRun("foo")
69         run2 = TestRun.empty()
70 
71         self.assertEqual(run1 | run2, run1)
72         self.assertEqual(run2 | run1, run1)
73 
74     def test_subtracting_full_run_fails(self) -> None:
75         run1 = TestRun("foo::bar")
76         run2 = TestRun("foo")
77 
78         self.assertEqual(run1 - run2, TestRun.empty())
79 
80     def test_subtracting_empty_file_yields_orig_file(self) -> None:
81         run1 = TestRun("foo")
82         run2 = TestRun.empty()
83 
84         self.assertEqual(run1 - run2, run1)
85         self.assertEqual(run2 - run1, TestRun.empty())
86 
87     def test_empty_is_falsey(self) -> None:
88         self.assertFalse(TestRun.empty())
89 
90     def test_subtracting_inclusion_from_full_run(self) -> None:
91         run1 = TestRun("foo")
92         run2 = TestRun("foo::bar")
93 
94         expected = TestRun("foo", excluded=["bar"])
95 
96         self.assertEqual(run1 - run2, expected)
97 
98     def test_subtracting_inclusion_from_overlapping_inclusion(self) -> None:
99         run1 = TestRun("foo", included=["bar", "baz"])
100         run2 = TestRun("foo::baz")
101 
102         self.assertEqual(run1 - run2, TestRun("foo", included=["bar"]))
103 
104     def test_subtracting_inclusion_from_nonoverlapping_inclusion(self) -> None:
105         run1 = TestRun("foo", included=["bar", "baz"])
106         run2 = TestRun("foo", included=["car"])
107 
108         self.assertEqual(run1 - run2, TestRun("foo", included=["bar", "baz"]))
109 
110     def test_subtracting_exclusion_from_full_run(self) -> None:
111         run1 = TestRun("foo")
112         run2 = TestRun("foo", excluded=["bar"])
113 
114         self.assertEqual(run1 - run2, TestRun("foo", included=["bar"]))
115 
116     def test_subtracting_exclusion_from_superset_exclusion(self) -> None:
117         run1 = TestRun("foo", excluded=["bar", "baz"])
118         run2 = TestRun("foo", excluded=["baz"])
119 
120         self.assertEqual(run1 - run2, TestRun.empty())
121         self.assertEqual(run2 - run1, TestRun("foo", included=["bar"]))
122 
123     def test_subtracting_exclusion_from_nonoverlapping_exclusion(self) -> None:
124         run1 = TestRun("foo", excluded=["bar", "baz"])
125         run2 = TestRun("foo", excluded=["car"])
126 
127         self.assertEqual(run1 - run2, TestRun("foo", included=["car"]))
128         self.assertEqual(run2 - run1, TestRun("foo", included=["bar", "baz"]))
129 
130     def test_subtracting_inclusion_from_exclusion_without_overlaps(self) -> None:
131         run1 = TestRun("foo", excluded=["bar", "baz"])
132         run2 = TestRun("foo", included=["bar"])
133 
134         self.assertEqual(run1 - run2, run1)
135         self.assertEqual(run2 - run1, run2)
136 
137     def test_subtracting_inclusion_from_exclusion_with_overlaps(self) -> None:
138         run1 = TestRun("foo", excluded=["bar", "baz"])
139         run2 = TestRun("foo", included=["bar", "car"])
140 
141         self.assertEqual(run1 - run2, TestRun("foo", excluded=["bar", "baz", "car"]))
142         self.assertEqual(run2 - run1, TestRun("foo", included=["bar"]))
143 
144     def test_and(self) -> None:
145         run1 = TestRun("foo", included=["bar", "baz"])
146         run2 = TestRun("foo", included=["bar", "car"])
147 
148         self.assertEqual(run1 & run2, TestRun("foo", included=["bar"]))
149 
150     def test_and_exclusions(self) -> None:
151         run1 = TestRun("foo", excluded=["bar", "baz"])
152         run2 = TestRun("foo", excluded=["bar", "car"])
153 
154         self.assertEqual(run1 & run2, TestRun("foo", excluded=["bar", "baz", "car"]))
155 
156 
157 class TestShardedTest(unittest.TestCase):
158     def test_get_pytest_args(self) -> None:
159         test = TestRun("foo", included=["bar", "baz"])
160         sharded_test = ShardedTest(test, 1, 1)
161 
162         expected_args = ["-k", "bar or baz"]
163 
164         self.assertListEqual(sharded_test.get_pytest_args(), expected_args)
165 
166 
167 if __name__ == "__main__":
168     unittest.main()
169