xref: /aosp_15_r20/external/pytorch/test/test_cpp_extensions_aot.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 # Owner(s): ["module: cpp-extensions"]
2 
3 import os
4 import re
5 import unittest
6 from itertools import repeat
7 from typing import get_args, get_origin, Union
8 
9 import torch
10 import torch.backends.cudnn
11 import torch.testing._internal.common_utils as common
12 import torch.utils.cpp_extension
13 from torch.testing._internal.common_cuda import TEST_CUDA
14 from torch.testing._internal.common_utils import IS_WINDOWS, skipIfTorchDynamo
15 
16 
17 try:
18     import pytest
19 
20     HAS_PYTEST = True
21 except ImportError as e:
22     HAS_PYTEST = False
23 
24 # TODO: Rewrite these tests so that they can be collected via pytest without
25 # using run_test.py
26 try:
27     if HAS_PYTEST:
28         cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp")
29         maia_extension = pytest.importorskip("torch_test_cpp_extension.maia")
30         rng_extension = pytest.importorskip("torch_test_cpp_extension.rng")
31     else:
32         import torch_test_cpp_extension.cpp as cpp_extension
33         import torch_test_cpp_extension.maia as maia_extension
34         import torch_test_cpp_extension.rng as rng_extension
35 except ImportError as e:
36     raise RuntimeError(
37         "test_cpp_extensions_aot.py cannot be invoked directly. Run "
38         "`python run_test.py -i test_cpp_extensions_aot_ninja` instead."
39     ) from e
40 
41 
42 @torch.testing._internal.common_utils.markDynamoStrictTest
43 class TestCppExtensionAOT(common.TestCase):
44     """Tests ahead-of-time cpp extensions
45 
46     NOTE: run_test.py's test_cpp_extensions_aot_ninja target
47     also runs this test case, but with ninja enabled. If you are debugging
48     a test failure here from the CI, check the logs for which target
49     (test_cpp_extensions_aot_no_ninja vs test_cpp_extensions_aot_ninja)
50     failed.
51     """
52 
53     def test_extension_function(self):
54         x = torch.randn(4, 4)
55         y = torch.randn(4, 4)
56         z = cpp_extension.sigmoid_add(x, y)
57         self.assertEqual(z, x.sigmoid() + y.sigmoid())
58         # test pybind support torch.dtype cast.
59         self.assertEqual(
60             str(torch.float32), str(cpp_extension.get_math_type(torch.half))
61         )
62 
63     def test_extension_module(self):
64         mm = cpp_extension.MatrixMultiplier(4, 8)
65         weights = torch.rand(8, 4, dtype=torch.double)
66         expected = mm.get().mm(weights)
67         result = mm.forward(weights)
68         self.assertEqual(expected, result)
69 
70     def test_backward(self):
71         mm = cpp_extension.MatrixMultiplier(4, 8)
72         weights = torch.rand(8, 4, dtype=torch.double, requires_grad=True)
73         result = mm.forward(weights)
74         result.sum().backward()
75         tensor = mm.get()
76 
77         expected_weights_grad = tensor.t().mm(torch.ones([4, 4], dtype=torch.double))
78         self.assertEqual(weights.grad, expected_weights_grad)
79 
80         expected_tensor_grad = torch.ones([4, 4], dtype=torch.double).mm(weights.t())
81         self.assertEqual(tensor.grad, expected_tensor_grad)
82 
83     @unittest.skipIf(not TEST_CUDA, "CUDA not found")
84     def test_cuda_extension(self):
85         import torch_test_cpp_extension.cuda as cuda_extension
86 
87         x = torch.zeros(100, device="cuda", dtype=torch.float32)
88         y = torch.zeros(100, device="cuda", dtype=torch.float32)
89 
90         z = cuda_extension.sigmoid_add(x, y).cpu()
91 
92         # 2 * sigmoid(0) = 2 * 0.5 = 1
93         self.assertEqual(z, torch.ones_like(z))
94 
95     @unittest.skipIf(not torch.backends.mps.is_available(), "MPS not found")
96     def test_mps_extension(self):
97         import torch_test_cpp_extension.mps as mps_extension
98 
99         tensor_length = 100000
100         x = torch.randn(tensor_length, device="cpu", dtype=torch.float32)
101         y = torch.randn(tensor_length, device="cpu", dtype=torch.float32)
102 
103         cpu_output = mps_extension.get_cpu_add_output(x, y)
104         mps_output = mps_extension.get_mps_add_output(x.to("mps"), y.to("mps"))
105 
106         self.assertEqual(cpu_output, mps_output.to("cpu"))
107 
108     @common.skipIfRocm
109     @unittest.skipIf(common.IS_WINDOWS, "Windows not supported")
110     @unittest.skipIf(not TEST_CUDA, "CUDA not found")
111     def test_cublas_extension(self):
112         from torch_test_cpp_extension import cublas_extension
113 
114         x = torch.zeros(100, device="cuda", dtype=torch.float32)
115         z = cublas_extension.noop_cublas_function(x)
116         self.assertEqual(z, x)
117 
118     @common.skipIfRocm
119     @unittest.skipIf(common.IS_WINDOWS, "Windows not supported")
120     @unittest.skipIf(not TEST_CUDA, "CUDA not found")
121     def test_cusolver_extension(self):
122         from torch_test_cpp_extension import cusolver_extension
123 
124         x = torch.zeros(100, device="cuda", dtype=torch.float32)
125         z = cusolver_extension.noop_cusolver_function(x)
126         self.assertEqual(z, x)
127 
128     @unittest.skipIf(IS_WINDOWS, "Not available on Windows")
129     def test_no_python_abi_suffix_sets_the_correct_library_name(self):
130         # For this test, run_test.py will call `python setup.py install` in the
131         # cpp_extensions/no_python_abi_suffix_test folder, where the
132         # `BuildExtension` class has a `no_python_abi_suffix` option set to
133         # `True`. This *should* mean that on Python 3, the produced shared
134         # library does not have an ABI suffix like
135         # "cpython-37m-x86_64-linux-gnu" before the library suffix, e.g. "so".
136         root = os.path.join("cpp_extensions", "no_python_abi_suffix_test", "build")
137         matches = [f for _, _, fs in os.walk(root) for f in fs if f.endswith("so")]
138         self.assertEqual(len(matches), 1, msg=str(matches))
139         self.assertEqual(matches[0], "no_python_abi_suffix_test.so", msg=str(matches))
140 
141     def test_optional(self):
142         has_value = cpp_extension.function_taking_optional(torch.ones(5))
143         self.assertTrue(has_value)
144         has_value = cpp_extension.function_taking_optional(None)
145         self.assertFalse(has_value)
146 
147     @common.skipIfRocm
148     @unittest.skipIf(common.IS_WINDOWS, "Windows not supported")
149     @unittest.skipIf(not TEST_CUDA, "CUDA not found")
150     @unittest.skipIf(
151         os.getenv("USE_NINJA", "0") == "0",
152         "cuda extension with dlink requires ninja to build",
153     )
154     def test_cuda_dlink_libs(self):
155         from torch_test_cpp_extension import cuda_dlink
156 
157         a = torch.randn(8, dtype=torch.float, device="cuda")
158         b = torch.randn(8, dtype=torch.float, device="cuda")
159         ref = a + b
160         test = cuda_dlink.add(a, b)
161         self.assertEqual(test, ref)
162 
163 
164 @torch.testing._internal.common_utils.markDynamoStrictTest
165 class TestPybindTypeCasters(common.TestCase):
166     """Pybind tests for ahead-of-time cpp extensions
167 
168     These tests verify the types returned from cpp code using custom type
169     casters. By exercising pybind, we also verify that the type casters work
170     properly.
171 
172     For each type caster in `torch/csrc/utils/pybind.h` we create a pybind
173     function that takes no arguments and returns the type_caster type. The
174     second argument to `PYBIND11_TYPE_CASTER` should be the type we expect to
175     receive in python, in these tests we verify this at run-time.
176     """
177 
178     @staticmethod
179     def expected_return_type(func):
180         """
181         Our Pybind functions have a signature of the form `() -> return_type`.
182         """
183         # Imports needed for the `eval` below.
184         from typing import List, Tuple  # noqa: F401
185 
186         return eval(re.search("-> (.*)\n", func.__doc__).group(1))
187 
188     def check(self, func):
189         val = func()
190         expected = self.expected_return_type(func)
191         origin = get_origin(expected)
192         if origin is list:
193             self.check_list(val, expected)
194         elif origin is tuple:
195             self.check_tuple(val, expected)
196         else:
197             self.assertIsInstance(val, expected)
198 
199     def check_list(self, vals, expected):
200         self.assertIsInstance(vals, list)
201         list_type = get_args(expected)[0]
202         for val in vals:
203             self.assertIsInstance(val, list_type)
204 
205     def check_tuple(self, vals, expected):
206         self.assertIsInstance(vals, tuple)
207         tuple_types = get_args(expected)
208         if tuple_types[1] is ...:
209             tuple_types = repeat(tuple_types[0])
210         for val, tuple_type in zip(vals, tuple_types):
211             self.assertIsInstance(val, tuple_type)
212 
213     def check_union(self, funcs):
214         """Special handling for Union type casters.
215 
216         A single cpp type can sometimes be cast to different types in python.
217         In these cases we expect to get exactly one function per python type.
218         """
219         # Verify that all functions have the same return type.
220         union_type = {self.expected_return_type(f) for f in funcs}
221         assert len(union_type) == 1
222         union_type = union_type.pop()
223         self.assertIs(Union, get_origin(union_type))
224         # SymInt is inconvenient to test, so don't require it
225         expected_types = set(get_args(union_type)) - {torch.SymInt}
226         for func in funcs:
227             val = func()
228             for tp in expected_types:
229                 if isinstance(val, tp):
230                     expected_types.remove(tp)
231                     break
232             else:
233                 raise AssertionError(f"{val} is not an instance of {expected_types}")
234         self.assertFalse(
235             expected_types, f"Missing functions for types {expected_types}"
236         )
237 
238     def test_pybind_return_types(self):
239         functions = [
240             cpp_extension.get_complex,
241             cpp_extension.get_device,
242             cpp_extension.get_generator,
243             cpp_extension.get_intarrayref,
244             cpp_extension.get_memory_format,
245             cpp_extension.get_storage,
246             cpp_extension.get_symfloat,
247             cpp_extension.get_symintarrayref,
248             cpp_extension.get_tensor,
249         ]
250         union_functions = [
251             [cpp_extension.get_symint],
252         ]
253         for func in functions:
254             with self.subTest(msg=f"check {func.__name__}"):
255                 self.check(func)
256         for funcs in union_functions:
257             with self.subTest(msg=f"check {[f.__name__ for f in funcs]}"):
258                 self.check_union(funcs)
259 
260 
261 @torch.testing._internal.common_utils.markDynamoStrictTest
262 class TestMAIATensor(common.TestCase):
263     def test_unregistered(self):
264         a = torch.arange(0, 10, device="cpu")
265         with self.assertRaisesRegex(RuntimeError, "Could not run"):
266             b = torch.arange(0, 10, device="maia")
267 
268     @skipIfTorchDynamo("dynamo cannot model maia device")
269     def test_zeros(self):
270         a = torch.empty(5, 5, device="cpu")
271         self.assertEqual(a.device, torch.device("cpu"))
272 
273         b = torch.empty(5, 5, device="maia")
274         self.assertEqual(b.device, torch.device("maia", 0))
275         self.assertEqual(maia_extension.get_test_int(), 0)
276         self.assertEqual(torch.get_default_dtype(), b.dtype)
277 
278         c = torch.empty((5, 5), dtype=torch.int64, device="maia")
279         self.assertEqual(maia_extension.get_test_int(), 0)
280         self.assertEqual(torch.int64, c.dtype)
281 
282     def test_add(self):
283         a = torch.empty(5, 5, device="maia", requires_grad=True)
284         self.assertEqual(maia_extension.get_test_int(), 0)
285 
286         b = torch.empty(5, 5, device="maia")
287         self.assertEqual(maia_extension.get_test_int(), 0)
288 
289         c = a + b
290         self.assertEqual(maia_extension.get_test_int(), 1)
291 
292     def test_conv_backend_override(self):
293         # To simplify tests, we use 4d input here to avoid doing view4d( which
294         # needs more overrides) in _convolution.
295         input = torch.empty(2, 4, 10, 2, device="maia", requires_grad=True)
296         weight = torch.empty(6, 4, 2, 2, device="maia", requires_grad=True)
297         bias = torch.empty(6, device="maia")
298 
299         # Make sure forward is overriden
300         out = torch.nn.functional.conv2d(input, weight, bias, 2, 0, 1, 1)
301         self.assertEqual(maia_extension.get_test_int(), 2)
302         self.assertEqual(out.shape[0], input.shape[0])
303         self.assertEqual(out.shape[1], weight.shape[0])
304 
305         # Make sure backward is overriden
306         # Double backward is dispatched to _convolution_double_backward.
307         # It is not tested here as it involves more computation/overrides.
308         grad = torch.autograd.grad(out, input, out, create_graph=True)
309         self.assertEqual(maia_extension.get_test_int(), 3)
310         self.assertEqual(grad[0].shape, input.shape)
311 
312 
313 @torch.testing._internal.common_utils.markDynamoStrictTest
314 class TestRNGExtension(common.TestCase):
315     def setUp(self):
316         super().setUp()
317 
318     @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
319     def test_rng(self):
320         fourty_two = torch.full((10,), 42, dtype=torch.int64)
321 
322         t = torch.empty(10, dtype=torch.int64).random_()
323         self.assertNotEqual(t, fourty_two)
324 
325         gen = torch.Generator(device="cpu")
326         t = torch.empty(10, dtype=torch.int64).random_(generator=gen)
327         self.assertNotEqual(t, fourty_two)
328 
329         self.assertEqual(rng_extension.getInstanceCount(), 0)
330         gen = rng_extension.createTestCPUGenerator(42)
331         self.assertEqual(rng_extension.getInstanceCount(), 1)
332         copy = gen
333         self.assertEqual(rng_extension.getInstanceCount(), 1)
334         self.assertEqual(gen, copy)
335         copy2 = rng_extension.identity(copy)
336         self.assertEqual(rng_extension.getInstanceCount(), 1)
337         self.assertEqual(gen, copy2)
338         t = torch.empty(10, dtype=torch.int64).random_(generator=gen)
339         self.assertEqual(rng_extension.getInstanceCount(), 1)
340         self.assertEqual(t, fourty_two)
341         del gen
342         self.assertEqual(rng_extension.getInstanceCount(), 1)
343         del copy
344         self.assertEqual(rng_extension.getInstanceCount(), 1)
345         del copy2
346         self.assertEqual(rng_extension.getInstanceCount(), 0)
347 
348 
349 @torch.testing._internal.common_utils.markDynamoStrictTest
350 @unittest.skipIf(not TEST_CUDA, "CUDA not found")
351 class TestTorchLibrary(common.TestCase):
352     def test_torch_library(self):
353         import torch_test_cpp_extension.torch_library  # noqa: F401
354 
355         def f(a: bool, b: bool):
356             return torch.ops.torch_library.logical_and(a, b)
357 
358         self.assertTrue(f(True, True))
359         self.assertFalse(f(True, False))
360         self.assertFalse(f(False, True))
361         self.assertFalse(f(False, False))
362         s = torch.jit.script(f)
363         self.assertTrue(s(True, True))
364         self.assertFalse(s(True, False))
365         self.assertFalse(s(False, True))
366         self.assertFalse(s(False, False))
367         self.assertIn("torch_library::logical_and", str(s.graph))
368 
369 
370 if __name__ == "__main__":
371     common.run_tests()
372