xref: /aosp_15_r20/external/pytorch/tools/testing/discover_tests.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport glob
4*da0073e9SAndroid Build Coastguard Workerimport os
5*da0073e9SAndroid Build Coastguard Workerimport sys
6*da0073e9SAndroid Build Coastguard Workerfrom pathlib import Path
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard WorkerCPP_TEST_PREFIX = "cpp"
10*da0073e9SAndroid Build Coastguard WorkerCPP_TEST_PATH = "build/bin"
11*da0073e9SAndroid Build Coastguard WorkerCPP_TESTS_DIR = os.path.abspath(os.getenv("CPP_TESTS_DIR", default=CPP_TEST_PATH))
12*da0073e9SAndroid Build Coastguard WorkerREPO_ROOT = Path(__file__).resolve().parent.parent.parent
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerdef parse_test_module(test: str) -> str:
16*da0073e9SAndroid Build Coastguard Worker    return test.split(".")[0]
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerdef discover_tests(
20*da0073e9SAndroid Build Coastguard Worker    base_dir: Path = REPO_ROOT / "test",
21*da0073e9SAndroid Build Coastguard Worker    cpp_tests_dir: str | Path | None = None,
22*da0073e9SAndroid Build Coastguard Worker    blocklisted_patterns: list[str] | None = None,
23*da0073e9SAndroid Build Coastguard Worker    blocklisted_tests: list[str] | None = None,
24*da0073e9SAndroid Build Coastguard Worker    extra_tests: list[str] | None = None,
25*da0073e9SAndroid Build Coastguard Worker) -> list[str]:
26*da0073e9SAndroid Build Coastguard Worker    """
27*da0073e9SAndroid Build Coastguard Worker    Searches for all python files starting with test_ excluding one specified by patterns.
28*da0073e9SAndroid Build Coastguard Worker    If cpp_tests_dir is provided, also scan for all C++ tests under that directory. They
29*da0073e9SAndroid Build Coastguard Worker    are usually found in build/bin
30*da0073e9SAndroid Build Coastguard Worker    """
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker    def skip_test_p(name: str) -> bool:
33*da0073e9SAndroid Build Coastguard Worker        rc = False
34*da0073e9SAndroid Build Coastguard Worker        if blocklisted_patterns is not None:
35*da0073e9SAndroid Build Coastguard Worker            rc |= any(name.startswith(pattern) for pattern in blocklisted_patterns)
36*da0073e9SAndroid Build Coastguard Worker        if blocklisted_tests is not None:
37*da0073e9SAndroid Build Coastguard Worker            rc |= name in blocklisted_tests
38*da0073e9SAndroid Build Coastguard Worker        return rc
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker    # This supports symlinks, so we can link domain library tests to PyTorch test directory
41*da0073e9SAndroid Build Coastguard Worker    all_py_files = [
42*da0073e9SAndroid Build Coastguard Worker        Path(p) for p in glob.glob(f"{base_dir}/**/test_*.py", recursive=True)
43*da0073e9SAndroid Build Coastguard Worker    ]
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker    cpp_tests_dir = (
46*da0073e9SAndroid Build Coastguard Worker        f"{base_dir.parent}/{CPP_TEST_PATH}" if cpp_tests_dir is None else cpp_tests_dir
47*da0073e9SAndroid Build Coastguard Worker    )
48*da0073e9SAndroid Build Coastguard Worker    # CPP test files are located under pytorch/build/bin. Unlike Python test, C++ tests
49*da0073e9SAndroid Build Coastguard Worker    # are just binaries and could have any name, i.e. basic or atest
50*da0073e9SAndroid Build Coastguard Worker    all_cpp_files = [
51*da0073e9SAndroid Build Coastguard Worker        Path(p) for p in glob.glob(f"{cpp_tests_dir}/**/*", recursive=True)
52*da0073e9SAndroid Build Coastguard Worker    ]
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker    rc = [str(fname.relative_to(base_dir))[:-3] for fname in all_py_files]
55*da0073e9SAndroid Build Coastguard Worker    # Add the cpp prefix for C++ tests so that we can tell them apart
56*da0073e9SAndroid Build Coastguard Worker    rc.extend(
57*da0073e9SAndroid Build Coastguard Worker        [
58*da0073e9SAndroid Build Coastguard Worker            parse_test_module(f"{CPP_TEST_PREFIX}/{fname.relative_to(cpp_tests_dir)}")
59*da0073e9SAndroid Build Coastguard Worker            for fname in all_cpp_files
60*da0073e9SAndroid Build Coastguard Worker        ]
61*da0073e9SAndroid Build Coastguard Worker    )
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    # Invert slashes on Windows
64*da0073e9SAndroid Build Coastguard Worker    if sys.platform == "win32":
65*da0073e9SAndroid Build Coastguard Worker        rc = [name.replace("\\", "/") for name in rc]
66*da0073e9SAndroid Build Coastguard Worker    rc = [test for test in rc if not skip_test_p(test)]
67*da0073e9SAndroid Build Coastguard Worker    if extra_tests is not None:
68*da0073e9SAndroid Build Coastguard Worker        rc += extra_tests
69*da0073e9SAndroid Build Coastguard Worker    return sorted(rc)
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard WorkerTESTS = discover_tests(
73*da0073e9SAndroid Build Coastguard Worker    cpp_tests_dir=CPP_TESTS_DIR,
74*da0073e9SAndroid Build Coastguard Worker    blocklisted_patterns=[
75*da0073e9SAndroid Build Coastguard Worker        "ao",
76*da0073e9SAndroid Build Coastguard Worker        "bottleneck_test",
77*da0073e9SAndroid Build Coastguard Worker        "custom_backend",
78*da0073e9SAndroid Build Coastguard Worker        "custom_operator",
79*da0073e9SAndroid Build Coastguard Worker        "fx",  # executed by test_fx.py
80*da0073e9SAndroid Build Coastguard Worker        "jit",  # executed by test_jit.py
81*da0073e9SAndroid Build Coastguard Worker        "mobile",
82*da0073e9SAndroid Build Coastguard Worker        "onnx_caffe2",
83*da0073e9SAndroid Build Coastguard Worker        "package",  # executed by test_package.py
84*da0073e9SAndroid Build Coastguard Worker        "quantization",  # executed by test_quantization.py
85*da0073e9SAndroid Build Coastguard Worker        "autograd",  # executed by test_autograd.py
86*da0073e9SAndroid Build Coastguard Worker    ],
87*da0073e9SAndroid Build Coastguard Worker    blocklisted_tests=[
88*da0073e9SAndroid Build Coastguard Worker        "test_bundled_images",
89*da0073e9SAndroid Build Coastguard Worker        "test_cpp_extensions_aot",
90*da0073e9SAndroid Build Coastguard Worker        "test_determination",
91*da0073e9SAndroid Build Coastguard Worker        "test_jit_fuser",
92*da0073e9SAndroid Build Coastguard Worker        "test_jit_simple",
93*da0073e9SAndroid Build Coastguard Worker        "test_jit_string",
94*da0073e9SAndroid Build Coastguard Worker        "test_kernel_launch_checks",
95*da0073e9SAndroid Build Coastguard Worker        "test_nnapi",
96*da0073e9SAndroid Build Coastguard Worker        "test_static_runtime",
97*da0073e9SAndroid Build Coastguard Worker        "test_throughput_benchmark",
98*da0073e9SAndroid Build Coastguard Worker        "distributed/bin/test_script",
99*da0073e9SAndroid Build Coastguard Worker        "distributed/elastic/multiprocessing/bin/test_script",
100*da0073e9SAndroid Build Coastguard Worker        "distributed/launcher/bin/test_script",
101*da0073e9SAndroid Build Coastguard Worker        "distributed/launcher/bin/test_script_init_method",
102*da0073e9SAndroid Build Coastguard Worker        "distributed/launcher/bin/test_script_is_torchelastic_launched",
103*da0073e9SAndroid Build Coastguard Worker        "distributed/launcher/bin/test_script_local_rank",
104*da0073e9SAndroid Build Coastguard Worker        "distributed/test_c10d_spawn",
105*da0073e9SAndroid Build Coastguard Worker        "distributions/test_transforms",
106*da0073e9SAndroid Build Coastguard Worker        "distributions/test_utils",
107*da0073e9SAndroid Build Coastguard Worker        "test/inductor/test_aot_inductor_utils",
108*da0073e9SAndroid Build Coastguard Worker        "onnx/test_pytorch_onnx_onnxruntime_cuda",
109*da0073e9SAndroid Build Coastguard Worker        "onnx/test_models",
110*da0073e9SAndroid Build Coastguard Worker        # These are not C++ tests
111*da0073e9SAndroid Build Coastguard Worker        f"{CPP_TEST_PREFIX}/CMakeFiles",
112*da0073e9SAndroid Build Coastguard Worker        f"{CPP_TEST_PREFIX}/CTestTestfile.cmake",
113*da0073e9SAndroid Build Coastguard Worker        f"{CPP_TEST_PREFIX}/Makefile",
114*da0073e9SAndroid Build Coastguard Worker        f"{CPP_TEST_PREFIX}/cmake_install.cmake",
115*da0073e9SAndroid Build Coastguard Worker        f"{CPP_TEST_PREFIX}/c10_intrusive_ptr_benchmark",
116*da0073e9SAndroid Build Coastguard Worker        f"{CPP_TEST_PREFIX}/example_allreduce",
117*da0073e9SAndroid Build Coastguard Worker        f"{CPP_TEST_PREFIX}/parallel_benchmark",
118*da0073e9SAndroid Build Coastguard Worker        f"{CPP_TEST_PREFIX}/protoc",
119*da0073e9SAndroid Build Coastguard Worker        f"{CPP_TEST_PREFIX}/protoc-3.13.0.0",
120*da0073e9SAndroid Build Coastguard Worker        f"{CPP_TEST_PREFIX}/torch_shm_manager",
121*da0073e9SAndroid Build Coastguard Worker        f"{CPP_TEST_PREFIX}/tutorial_tensorexpr",
122*da0073e9SAndroid Build Coastguard Worker    ],
123*da0073e9SAndroid Build Coastguard Worker    extra_tests=[
124*da0073e9SAndroid Build Coastguard Worker        "test_cpp_extensions_aot_ninja",
125*da0073e9SAndroid Build Coastguard Worker        "test_cpp_extensions_aot_no_ninja",
126*da0073e9SAndroid Build Coastguard Worker        "distributed/elastic/timer/api_test",
127*da0073e9SAndroid Build Coastguard Worker        "distributed/elastic/timer/local_timer_example",
128*da0073e9SAndroid Build Coastguard Worker        "distributed/elastic/timer/local_timer_test",
129*da0073e9SAndroid Build Coastguard Worker        "distributed/elastic/events/lib_test",
130*da0073e9SAndroid Build Coastguard Worker        "distributed/elastic/metrics/api_test",
131*da0073e9SAndroid Build Coastguard Worker        "distributed/elastic/utils/logging_test",
132*da0073e9SAndroid Build Coastguard Worker        "distributed/elastic/utils/util_test",
133*da0073e9SAndroid Build Coastguard Worker        "distributed/elastic/utils/distributed_test",
134*da0073e9SAndroid Build Coastguard Worker        "distributed/elastic/multiprocessing/api_test",
135*da0073e9SAndroid Build Coastguard Worker        "doctests",
136*da0073e9SAndroid Build Coastguard Worker        "test_autoload_enable",
137*da0073e9SAndroid Build Coastguard Worker        "test_autoload_disable",
138*da0073e9SAndroid Build Coastguard Worker    ],
139*da0073e9SAndroid Build Coastguard Worker)
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
143*da0073e9SAndroid Build Coastguard Worker    print(TESTS)
144