1from __future__ import annotations 2 3import modulefinder 4import os 5import sys 6import warnings 7from pathlib import Path 8from typing import Any 9 10 11REPO_ROOT = Path(__file__).resolve().parent.parent.parent 12 13# These tests are slow enough that it's worth calculating whether the patch 14# touched any related files first. This list was manually generated, but for every 15# run with --determine-from, we use another generated list based on this one and the 16# previous test stats. 17TARGET_DET_LIST = [ 18 # test_autograd.py is not slow, so it does not belong here. But 19 # note that if you try to add it back it will run into 20 # https://bugs.python.org/issue40350 because it imports files 21 # under test/autograd/. 22 "test_binary_ufuncs", 23 "test_cpp_extensions_aot_ninja", 24 "test_cpp_extensions_aot_no_ninja", 25 "test_cpp_extensions_jit", 26 "test_cpp_extensions_open_device_registration", 27 "test_cpp_extensions_stream_and_event", 28 "test_cpp_extensions_mtia_backend", 29 "test_cuda", 30 "test_cuda_primary_ctx", 31 "test_dataloader", 32 "test_determination", 33 "test_futures", 34 "test_jit", 35 "test_jit_legacy", 36 "test_jit_profiling", 37 "test_linalg", 38 "test_multiprocessing", 39 "test_nn", 40 "test_numpy_interop", 41 "test_optim", 42 "test_overrides", 43 "test_pruning_op", 44 "test_quantization", 45 "test_reductions", 46 "test_serialization", 47 "test_shape_ops", 48 "test_sort_and_select", 49 "test_tensorboard", 50 "test_testing", 51 "test_torch", 52 "test_utils", 53 "test_view_ops", 54] 55 56 57_DEP_MODULES_CACHE: dict[str, set[str]] = {} 58 59 60def should_run_test( 61 target_det_list: list[str], test: str, touched_files: list[str], options: Any 62) -> bool: 63 test = parse_test_module(test) 64 # Some tests are faster to execute than to determine. 65 if test not in target_det_list: 66 if options.verbose: 67 print_to_stderr(f"Running {test} without determination") 68 return True 69 # HACK: "no_ninja" is not a real module 70 if test.endswith("_no_ninja"): 71 test = test[: (-1 * len("_no_ninja"))] 72 if test.endswith("_ninja"): 73 test = test[: (-1 * len("_ninja"))] 74 75 dep_modules = get_dep_modules(test) 76 77 for touched_file in touched_files: 78 file_type = test_impact_of_file(touched_file) 79 if file_type == "NONE": 80 continue 81 elif file_type == "CI": 82 # Force all tests to run if any change is made to the CI 83 # configurations. 84 log_test_reason(file_type, touched_file, test, options) 85 return True 86 elif file_type == "UNKNOWN": 87 # Assume uncategorized source files can affect every test. 88 log_test_reason(file_type, touched_file, test, options) 89 return True 90 elif file_type in ["TORCH", "CAFFE2", "TEST"]: 91 parts = os.path.splitext(touched_file)[0].split(os.sep) 92 touched_module = ".".join(parts) 93 # test/ path does not have a "test." namespace 94 if touched_module.startswith("test."): 95 touched_module = touched_module.split("test.")[1] 96 if touched_module in dep_modules or touched_module == test.replace( 97 "/", "." 98 ): 99 log_test_reason(file_type, touched_file, test, options) 100 return True 101 102 # If nothing has determined the test has run, don't run the test. 103 if options.verbose: 104 print_to_stderr(f"Determination is skipping {test}") 105 106 return False 107 108 109def test_impact_of_file(filename: str) -> str: 110 """Determine what class of impact this file has on test runs. 111 112 Possible values: 113 TORCH - torch python code 114 CAFFE2 - caffe2 python code 115 TEST - torch test code 116 UNKNOWN - may affect all tests 117 NONE - known to have no effect on test outcome 118 CI - CI configuration files 119 """ 120 parts = filename.split(os.sep) 121 if parts[0] in [".jenkins", ".circleci", ".ci"]: 122 return "CI" 123 if parts[0] in ["docs", "scripts", "CODEOWNERS", "README.md"]: 124 return "NONE" 125 elif parts[0] == "torch": 126 if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"): 127 return "TORCH" 128 elif parts[0] == "caffe2": 129 if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"): 130 return "CAFFE2" 131 elif parts[0] == "test": 132 if parts[-1].endswith(".py") or parts[-1].endswith(".pyi"): 133 return "TEST" 134 135 return "UNKNOWN" 136 137 138def log_test_reason(file_type: str, filename: str, test: str, options: Any) -> None: 139 if options.verbose: 140 print_to_stderr( 141 f"Determination found {file_type} file {filename} -- running {test}" 142 ) 143 144 145def get_dep_modules(test: str) -> set[str]: 146 # Cache results in case of repetition 147 if test in _DEP_MODULES_CACHE: 148 return _DEP_MODULES_CACHE[test] 149 150 test_location = REPO_ROOT / "test" / f"{test}.py" 151 152 # HACK: some platforms default to ascii, so we can't just run_script :( 153 finder = modulefinder.ModuleFinder( 154 # Ideally exclude all third party modules, to speed up calculation. 155 excludes=[ 156 "scipy", 157 "numpy", 158 "numba", 159 "multiprocessing", 160 "sklearn", 161 "setuptools", 162 "hypothesis", 163 "llvmlite", 164 "joblib", 165 "email", 166 "importlib", 167 "unittest", 168 "urllib", 169 "json", 170 "collections", 171 # Modules below are excluded because they are hitting https://bugs.python.org/issue40350 172 # Trigger AttributeError: 'NoneType' object has no attribute 'is_package' 173 "mpl_toolkits", 174 "google", 175 "onnx", 176 # Triggers RecursionError 177 "mypy", 178 ], 179 ) 180 181 with warnings.catch_warnings(): 182 warnings.simplefilter("ignore") 183 finder.run_script(str(test_location)) 184 dep_modules = set(finder.modules.keys()) 185 _DEP_MODULES_CACHE[test] = dep_modules 186 return dep_modules 187 188 189def parse_test_module(test: str) -> str: 190 return test.split(".")[0] 191 192 193def print_to_stderr(message: str) -> None: 194 print(message, file=sys.stderr) 195