xref: /aosp_15_r20/external/pytorch/tools/test/test_cmake.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import contextlib
4import os
5import typing
6import unittest
7import unittest.mock
8from typing import Iterator, Sequence
9
10import tools.setup_helpers.cmake
11import tools.setup_helpers.env  # noqa: F401 unused but resolves circular import
12
13
14T = typing.TypeVar("T")
15
16
17class TestCMake(unittest.TestCase):
18    @unittest.mock.patch("multiprocessing.cpu_count")
19    def test_build_jobs(self, mock_cpu_count: unittest.mock.MagicMock) -> None:
20        """Tests that the number of build jobs comes out correctly."""
21        mock_cpu_count.return_value = 13
22        cases = [
23            # MAX_JOBS, USE_NINJA, IS_WINDOWS,         want
24            (("8", True, False), ["-j", "8"]),  # noqa: E201,E241
25            ((None, True, False), None),  # noqa: E201,E241
26            (("7", False, False), ["-j", "7"]),  # noqa: E201,E241
27            ((None, False, False), ["-j", "13"]),  # noqa: E201,E241
28            (("6", True, True), ["-j", "6"]),  # noqa: E201,E241
29            ((None, True, True), None),  # noqa: E201,E241
30            (("11", False, True), ["/p:CL_MPCount=11"]),  # noqa: E201,E241
31            ((None, False, True), ["/p:CL_MPCount=13"]),  # noqa: E201,E241
32        ]
33        for (max_jobs, use_ninja, is_windows), want in cases:
34            with self.subTest(
35                MAX_JOBS=max_jobs, USE_NINJA=use_ninja, IS_WINDOWS=is_windows
36            ):
37                with contextlib.ExitStack() as stack:
38                    stack.enter_context(env_var("MAX_JOBS", max_jobs))
39                    stack.enter_context(
40                        unittest.mock.patch.object(
41                            tools.setup_helpers.cmake, "USE_NINJA", use_ninja
42                        )
43                    )
44                    stack.enter_context(
45                        unittest.mock.patch.object(
46                            tools.setup_helpers.cmake, "IS_WINDOWS", is_windows
47                        )
48                    )
49
50                    cmake = tools.setup_helpers.cmake.CMake()
51
52                    with unittest.mock.patch.object(cmake, "run") as cmake_run:
53                        cmake.build({})
54
55                    cmake_run.assert_called_once()
56                    (call,) = cmake_run.mock_calls
57                    build_args, _ = call.args
58
59                if want is None:
60                    self.assertNotIn("-j", build_args)
61                else:
62                    self.assert_contains_sequence(build_args, want)
63
64    @staticmethod
65    def assert_contains_sequence(
66        sequence: Sequence[T], subsequence: Sequence[T]
67    ) -> None:
68        """Raises an assertion if the subsequence is not contained in the sequence."""
69        if len(subsequence) == 0:
70            return  # all sequences contain the empty subsequence
71
72        # Iterate over all windows of len(subsequence). Stop if the
73        # window matches.
74        for i in range(len(sequence) - len(subsequence) + 1):
75            candidate = sequence[i : i + len(subsequence)]
76            assert len(candidate) == len(subsequence)  # sanity check
77            if candidate == subsequence:
78                return  # found it
79        raise AssertionError(f"{subsequence} not found in {sequence}")
80
81
82@contextlib.contextmanager
83def env_var(key: str, value: str | None) -> Iterator[None]:
84    """Sets/clears an environment variable within a Python context."""
85    # Get the previous value and then override it.
86    previous_value = os.environ.get(key)
87    set_env_var(key, value)
88    try:
89        yield
90    finally:
91        # Restore to previous value.
92        set_env_var(key, previous_value)
93
94
95def set_env_var(key: str, value: str | None) -> None:
96    """Sets/clears an environment variable."""
97    if value is None:
98        os.environ.pop(key, None)
99    else:
100        os.environ[key] = value
101
102
103if __name__ == "__main__":
104    unittest.main()
105