xref: /aosp_15_r20/external/pytorch/tools/testing/modulefinder_determinator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import modulefinder
4import os
5import sys
6import warnings
7from pathlib import Path
8from typing import Any
9
10
11REPO_ROOT = Path(__file__).resolve().parent.parent.parent
12
13# These tests are slow enough that it's worth calculating whether the patch
14# touched any related files first. This list was manually generated, but for every
15# run with --determine-from, we use another generated list based on this one and the
16# previous test stats.
17TARGET_DET_LIST = [
18    # test_autograd.py is not slow, so it does not belong here. But
19    # note that if you try to add it back it will run into
20    # https://bugs.python.org/issue40350 because it imports files
21    # under test/autograd/.
22    "test_binary_ufuncs",
23    "test_cpp_extensions_aot_ninja",
24    "test_cpp_extensions_aot_no_ninja",
25    "test_cpp_extensions_jit",
26    "test_cpp_extensions_open_device_registration",
27    "test_cpp_extensions_stream_and_event",
28    "test_cpp_extensions_mtia_backend",
29    "test_cuda",
30    "test_cuda_primary_ctx",
31    "test_dataloader",
32    "test_determination",
33    "test_futures",
34    "test_jit",
35    "test_jit_legacy",
36    "test_jit_profiling",
37    "test_linalg",
38    "test_multiprocessing",
39    "test_nn",
40    "test_numpy_interop",
41    "test_optim",
42    "test_overrides",
43    "test_pruning_op",
44    "test_quantization",
45    "test_reductions",
46    "test_serialization",
47    "test_shape_ops",
48    "test_sort_and_select",
49    "test_tensorboard",
50    "test_testing",
51    "test_torch",
52    "test_utils",
53    "test_view_ops",
54]
55
56
57_DEP_MODULES_CACHE: dict[str, set[str]] = {}
58
59
60def should_run_test(
61    target_det_list: list[str], test: str, touched_files: list[str], options: Any
62) -> bool:
63    test = parse_test_module(test)
64    # Some tests are faster to execute than to determine.
65    if test not in target_det_list:
66        if options.verbose:
67            print_to_stderr(f"Running {test} without determination")
68        return True
69    # HACK: "no_ninja" is not a real module
70    if test.endswith("_no_ninja"):
71        test = test[: (-1 * len("_no_ninja"))]
72    if test.endswith("_ninja"):
73        test = test[: (-1 * len("_ninja"))]
74
75    dep_modules = get_dep_modules(test)
76
77    for touched_file in touched_files:
78        file_type = test_impact_of_file(touched_file)
79        if file_type == "NONE":
80            continue
81        elif file_type == "CI":
82            # Force all tests to run if any change is made to the CI
83            # configurations.
84            log_test_reason(file_type, touched_file, test, options)
85            return True
86        elif file_type == "UNKNOWN":
87            # Assume uncategorized source files can affect every test.
88            log_test_reason(file_type, touched_file, test, options)
89            return True
90        elif file_type in ["TORCH", "CAFFE2", "TEST"]:
91            parts = os.path.splitext(touched_file)[0].split(os.sep)
92            touched_module = ".".join(parts)
93            # test/ path does not have a "test." namespace
94            if touched_module.startswith("test."):
95                touched_module = touched_module.split("test.")[1]
96            if touched_module in dep_modules or touched_module == test.replace(
97                "/", "."
98            ):
99                log_test_reason(file_type, touched_file, test, options)
100                return True
101
102    # If nothing has determined the test has run, don't run the test.
103    if options.verbose:
104        print_to_stderr(f"Determination is skipping {test}")
105
106    return False
107
108
109def test_impact_of_file(filename: str) -> str:
110    """Determine what class of impact this file has on test runs.
111
112    Possible values:
113        TORCH - torch python code
114        CAFFE2 - caffe2 python code
115        TEST - torch test code
116        UNKNOWN - may affect all tests
117        NONE - known to have no effect on test outcome
118        CI - CI configuration files
119    """
120    parts = filename.split(os.sep)
121    if parts[0] in [".jenkins", ".circleci", ".ci"]:
122        return "CI"
123    if parts[0] in ["docs", "scripts", "CODEOWNERS", "README.md"]:
124        return "NONE"
125    elif parts[0] == "torch":
126        if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"):
127            return "TORCH"
128    elif parts[0] == "caffe2":
129        if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"):
130            return "CAFFE2"
131    elif parts[0] == "test":
132        if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"):
133            return "TEST"
134
135    return "UNKNOWN"
136
137
138def log_test_reason(file_type: str, filename: str, test: str, options: Any) -> None:
139    if options.verbose:
140        print_to_stderr(
141            f"Determination found {file_type} file {filename} -- running {test}"
142        )
143
144
145def get_dep_modules(test: str) -> set[str]:
146    # Cache results in case of repetition
147    if test in _DEP_MODULES_CACHE:
148        return _DEP_MODULES_CACHE[test]
149
150    test_location = REPO_ROOT / "test" / f"{test}.py"
151
152    # HACK: some platforms default to ascii, so we can't just run_script :(
153    finder = modulefinder.ModuleFinder(
154        # Ideally exclude all third party modules, to speed up calculation.
155        excludes=[
156            "scipy",
157            "numpy",
158            "numba",
159            "multiprocessing",
160            "sklearn",
161            "setuptools",
162            "hypothesis",
163            "llvmlite",
164            "joblib",
165            "email",
166            "importlib",
167            "unittest",
168            "urllib",
169            "json",
170            "collections",
171            # Modules below are excluded because they are hitting https://bugs.python.org/issue40350
172            # Trigger AttributeError: 'NoneType' object has no attribute 'is_package'
173            "mpl_toolkits",
174            "google",
175            "onnx",
176            # Triggers RecursionError
177            "mypy",
178        ],
179    )
180
181    with warnings.catch_warnings():
182        warnings.simplefilter("ignore")
183        finder.run_script(str(test_location))
184    dep_modules = set(finder.modules.keys())
185    _DEP_MODULES_CACHE[test] = dep_modules
186    return dep_modules
187
188
189def parse_test_module(test: str) -> str:
190    return test.split(".")[0]
191
192
193def print_to_stderr(message: str) -> None:
194    print(message, file=sys.stderr)
195