# Owner(s): ["oncall: jit"] import gc import os import sys import unittest from typing import NamedTuple import torch from torch.testing import FileCheck from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import ( NoTest, skipCUDANonDefaultStreamIf, skipIfRocm, TEST_CUDA, ) from torch.testing._internal.jit_utils import JitTestCase # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) # If GPU is not available, then do not run the tests if not TEST_CUDA: print("CUDA not available, skipping tests", file=sys.stderr) JitTestCase = NoTest # noqa: F811 TEST_LARGE_TENSOR = TEST_CUDA # If GPU is available, then initialize the cuda context and check # if there is memory available to allocate for LARGE Tensors. if TEST_CUDA: torch.ones(1).cuda() # initialize cuda context TEST_LARGE_TENSOR = torch.cuda.get_device_properties(0).total_memory >= 5e9 if __name__ == "__main__": raise RuntimeError( "This test file is not meant to be run directly, use:\n\n" "\tpython test/test_jit.py TESTNAME\n\n" "instead." ) class TestCUDA(JitTestCase): """ A suite of tests for the CUDA API in TorchScript. """ def tearDown(self): gc.collect() torch.cuda.empty_cache() super().tearDown() @skipIfRocm @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") def test_cuda_synchronize(self): # Test device synchronization. @torch.jit.script def test_device_synchronize(): prev_current_device_index = torch.cuda.current_device() torch.cuda.synchronize() torch.cuda.synchronize("cuda") torch.cuda.synchronize("cuda:0") torch.cuda.synchronize(0) torch.cuda.synchronize(torch.device("cuda:1")) after_current_device_index = torch.cuda.current_device() # Check if the current device index is same as the device index before # synchronizing the device. return prev_current_device_index == after_current_device_index @torch.jit.script def test_multi_device_synchronize(): torch.cuda.synchronize(torch.device("cuda:0")) prev_current_device_index = torch.cuda.current_device() torch.cuda.synchronize(1) after_current_device_index = torch.cuda.current_device() # Check if the current device index is same as the device index before # synchronizing the device. return prev_current_device_index == after_current_device_index self.assertTrue(test_device_synchronize) FileCheck().check("cuda::synchronize(").run(test_device_synchronize.graph) self.assertTrue(test_multi_device_synchronize) FileCheck().check("cuda::synchronize(").run(test_multi_device_synchronize.graph) def test_stream_args(self): # Test stream creation with default arguments @torch.jit.script def stream_default_args() -> bool: s = torch.cuda.Stream() return s.device_index() == torch.cuda.current_device() @torch.jit.script def stream_default_args_for_device() -> bool: s = torch.cuda.Stream(priority=0) return s.device_index() == torch.cuda.current_device() @torch.jit.script def stream_default_args_for_priority() -> bool: d = torch.device("cuda:1") s = torch.cuda.Stream(d) return s.device_index() == 1 @torch.jit.script def stream_args_all() -> bool: d = torch.device("cuda:0") s = torch.cuda.Stream(d, 0) return s.device_index() == 0 self.assertTrue(stream_default_args) self.assertTrue(stream_default_args_for_device) self.assertTrue(stream_default_args_for_priority) self.assertTrue(stream_args_all) def test_event_args(self): # Test Event creation with default arguments @torch.jit.script def event_default_args() -> bool: e = torch.cuda.Event() return e is not None self.assertTrue(event_default_args) @skipIfRocm @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") def test_current_stream(self): # Test current stream on the device and check if the stream device index # matches with the device ID @torch.jit.script def fn(): device_index = torch.cuda.current_device() device = torch.device("cuda:" + str(device_index)) s0 = torch.cuda.current_stream(device) s1 = torch.cuda.current_stream(torch.device("cuda:1")) s2 = torch.cuda.current_stream(torch.device("cuda:0")) return s0.device_index(), s1.device_index(), s2.device_index() d0, d1, d2 = fn() # By default, the current device ID is 0. self.assertEqual(0, d0) self.assertEqual(1, d1) self.assertEqual(0, d2) self.assertEqual(d0, d2) # Test current_stream API by passing device ID as an argument and # and check if the stream device index matches with the device ID @torch.jit.script def fn_with_device_index_args(): device_index = torch.cuda.current_device() s0 = torch.cuda.current_stream(device_index) s1 = torch.cuda.current_stream(1) s2 = torch.cuda.current_stream(0) return s0.device_index(), s1.device_index(), s2.device_index() d0, d1, d2 = fn_with_device_index_args() # By default, the current device ID is 0. self.assertEqual(0, d0) self.assertEqual(1, d1) self.assertEqual(0, d2) self.assertEqual(d0, d2) @skipIfRocm @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU") @unittest.skipIf(not TEST_LARGE_TENSOR, "not enough memory") @skipCUDANonDefaultStreamIf(True) def test_streams_and_events(self): # Test default_stream API by passing device ID as an argument and # and check if the stream device index matches with the device ID @torch.jit.script def test_default_streams_with_device_index_args(): s0 = torch.cuda.default_stream(0) s1 = torch.cuda.default_stream(1) return s0.device_index(), s1.device_index() d0, d1 = test_default_streams_with_device_index_args() self.assertEqual(d0, 0) self.assertEqual(d1, 1) # This test checks for the default stream ID is set to 0 on the device @torch.jit.script def test_default_streams(): s0 = torch.cuda.default_stream(torch.device("cuda:0")) s1 = torch.cuda.default_stream(torch.device("cuda:1")) d = torch.device("cuda:1") # Check the current stream id and default id are same # on the current device. The current device id by default is 0 s2 = torch.cuda.current_stream(torch.device("cuda:0")) check_s2 = s2.id() == s0.id() check_d0 = torch.cuda.current_device() == s2.device_index() # Set the current device to d1 and check if the stream # has been set to the default stream on d1 with torch.cuda.device(d): s3 = torch.cuda.current_stream(d) check_s3 = s3.id() == s1.id() check_d1 = torch.cuda.current_device() == s3.device_index() # Check if the current device was reset to 0 is_device_d0 = torch.cuda.current_device() == s2.device_index() return ( s0.device_index(), s1.device_index(), check_s2, check_s3, check_d0, check_d1, is_device_d0, ) ( d0, d1, check_s2, check_s3, check_d0, check_d1, is_device_d0, ) = test_default_streams() self.assertEqual(d0, 0) self.assertEqual(d1, 1) self.assertTrue(check_s2) self.assertTrue(check_s3) self.assertTrue(check_d0) self.assertTrue(check_d1) self.assertTrue(is_device_d0) # This test checks if the Stream Context manager is a no op # when the stream is none for `with torch.cuda.stream` @torch.jit.script def test_set_none_stream(): device_index = torch.cuda.current_device() device = torch.device("cuda:" + str(device_index)) current_stream = torch.cuda.current_stream(device) default_stream = torch.cuda.default_stream(device) # When stream is none, check if this operation is a no-op with torch.cuda.stream(None): cur_device_index = torch.cuda.current_device() is_device_index_same = cur_device_index == device_index is_current_stream_same = ( torch.cuda.current_stream(device).id() == current_stream.id() ) is_default_stream_same = ( torch.cuda.default_stream(device).id() == default_stream.id() ) # Check if the device index, current stream and default streams have not changed are_streams_same = ( is_device_index_same and is_current_stream_same and is_default_stream_same ) return are_streams_same self.assertTrue(test_set_none_stream()) # This test checks if the Device Context manager is a no op # when the device is none for `with torch.cuda.device` @torch.jit.script def test_set_device_none(): device_index = torch.cuda.current_device() # When device is none, check if this operation is a no-op with torch.cuda.device(None): # Check if the current device is the same is_device_same = torch.cuda.current_device() == device_index return is_device_same self.assertTrue(test_set_device_none()) # Check if a CUDA JIT stream is created # on the current_device @torch.jit.script def test_simple_stream(): device_index = torch.cuda.current_device() s = torch.cuda.Stream() return device_index == s.device_index() self.assertTrue(test_simple_stream(), "Could not create Stream!") # Class used to store results for the test: test_get_stream. class Result(NamedTuple): t1: torch.Tensor t2: torch.Tensor is_current_and_default_stream_same: bool is_default_and_user_stream_not_same: bool is_stream_set: bool is_stream_reset: bool default_stream_query: bool default_stream_id: int user_stream_id: int # The test aims at checking different stream proporties. @torch.jit.script def test_get_stream(): device_index = torch.cuda.current_device() device = torch.device("cuda:" + str(device_index)) current_stream = torch.cuda.current_stream(device) default_stream = torch.cuda.default_stream(device) user_stream = torch.cuda.Stream() # Check if the current and default streams are the same on the device is_current_and_default_stream_same = ( current_stream.id() == default_stream.id() ) # Check if user stream and default stream are not the same on the device is_default_and_user_stream_not_same = ( default_stream.id() != user_stream.id() ) with torch.cuda.stream(user_stream): is_stream_set = ( torch.cuda.current_stream(device).id() == user_stream.id() ) # Check if the stream was reset to current_stream is_stream_reset = ( torch.cuda.current_stream(device).id() == current_stream.id() ) tensor1 = torch.rand(10000, 10000, device="cuda") tensor2 = torch.mm(tensor1, tensor1).to("cuda") default_stream.synchronize() default_stream_query = default_stream.query() # Capture all the results in the class Result res = Result( tensor1, tensor2, is_current_and_default_stream_same, is_default_and_user_stream_not_same, is_stream_set, is_stream_reset, default_stream_query, default_stream.id(), user_stream.id(), ) return res result = test_get_stream() self.assertEqual(torch.matmul(result.t1, result.t1), result.t2) self.assertTrue(result.is_current_and_default_stream_same) self.assertTrue(result.is_default_and_user_stream_not_same) self.assertTrue(result.is_stream_set) self.assertTrue(result.is_stream_reset) self.assertTrue(result.default_stream_query) self.assertEqual( result.default_stream_id, 0 ) # Check if the default stream ID is always 0 self.assertNotEqual( result.user_stream_id, 0 ) # Check if the user stream is always non zero # Test the stream context manager. This test checks if the stream is switched # to the user stream on using the stream context manager. @torch.jit.script def test_stream_context(): device_index = torch.cuda.current_device() device = torch.device("cuda:" + str(device_index)) current_stream = torch.cuda.current_stream(device) user_stream = torch.cuda.Stream() A = torch.rand(1000, 1000, device="cuda") with torch.cuda.stream(user_stream): check = torch.cuda.current_stream(device).id() == user_stream.id() B = torch.mm(A, A).to("cuda") # Wait for B to be computed user_stream.synchronize() # Check if the stream has been reset on the current device is_stream_reset = ( torch.cuda.current_stream(device).id() == current_stream.id() ) return A, B, check, is_stream_reset A, B, is_stream_set, is_stream_reset = test_stream_context() self.assertEqual(torch.matmul(A, A), B) self.assertTrue( is_stream_set, "Error: Current stream was not set to user stream!" ) self.assertTrue( is_stream_reset, "Error: The stream was not restored to previous stream!" ) # Test multiple nested streams. Check if the operations are computed as expected on the streams # This test has been adapted from the eager mode tests available at test/test_cuda.py @torch.jit.script def test_multiple_stream(): prev_device_index = torch.cuda.current_device() device = torch.device("cuda:" + str(prev_device_index)) prev_current_stream = torch.cuda.current_stream(device) d1 = torch.device("cuda:0") d2 = torch.device("cuda:1") s1 = torch.cuda.Stream(d1, 0) s2 = torch.cuda.Stream(d2, 0) A = torch.rand(1000, 1000, device="cuda") B = torch.rand(1000, 1000, device="cuda") with torch.cuda.stream(s1): C = torch.mm(A, A).to("cuda") # Check if the stream and device have been set to s1 is_stream_s1 = torch.cuda.current_stream(d1).id() == s1.id() is_device_s1 = torch.cuda.current_device() == s1.device_index() with torch.cuda.stream(s2): # Check if the stream and device have been set to s2 is_stream_s2 = torch.cuda.current_stream(d2).id() == s2.id() is_device_s2 = torch.cuda.current_device() == s2.device_index() D = torch.mm(B, B).to("cuda") # Check if the stream and device have been set to s1 is_stream_s1_after = torch.cuda.current_stream(d1).id() == s1.id() is_device_s1_after = torch.cuda.current_device() == s1.device_index() # Wait for D to be computed s2.synchronize() # Wait for C to be computed on S1 s1.synchronize() # Check if the stream and device has been restored to previous stream and device is_device_current = torch.cuda.current_device() == prev_device_index is_stream_current = ( torch.cuda.current_stream(device).id() == prev_current_stream.id() ) check_stream = ( is_stream_s1 and is_stream_s2 and is_stream_s1_after and is_stream_current ) check_device = ( is_device_s1 and is_device_s2 and is_device_s1_after and is_device_current ) return A, B, C, D, check_stream, check_device A, B, C, D, check_stream, check_device = test_multiple_stream() self.assertEqual(torch.matmul(A, A), C) self.assertEqual(torch.matmul(B, B), D) self.assertTrue(check_stream) self.assertTrue(check_device) # Test multiple streams waiting on each other for the operations to be completed. @torch.jit.script def test_data_dependency_between_streams(): device_index = torch.cuda.current_device() device = torch.device("cuda:" + str(device_index)) prev_current_stream = torch.cuda.current_stream(device) d = torch.device("cuda:0") s1 = torch.cuda.Stream(d, 0) s2 = torch.cuda.Stream(d, 0) event = torch.cuda.Event(False, False, False) A = torch.rand(1000, 1000, device="cuda") with torch.cuda.stream(s1): is_stream_s1 = torch.cuda.current_stream(device).id() == s1.id() B = torch.mm(A, A).to("cuda") s1.record_event(event) # Check if the current_stream is reset is_current_stream_1 = ( torch.cuda.current_stream(device).id() == prev_current_stream.id() ) # Wait for ops on s1 to be computed s2.wait_event(event) with torch.cuda.stream(s2): is_stream_s2 = torch.cuda.current_stream(device).id() == s2.id() C = torch.mm(B, B).to("cuda") # Wait for C to be computed s2.synchronize() # Check if the current_stream is reset is_current_stream_2 = ( torch.cuda.current_stream(device).id() == prev_current_stream.id() ) check_stream = ( is_current_stream_1 and is_current_stream_2 and is_stream_s1 and is_stream_s2 ) return A, B, C, check_stream A, B, C, check_stream = test_data_dependency_between_streams() self.assertEqual(torch.matmul(A, A), B) self.assertEqual(torch.matmul(B, B), C) self.assertTrue(check_stream) # Test a simple CUDA event. Test if the CUDA event was created successfully @torch.jit.script def test_simple_event(): e = torch.cuda.Event(True, False, False) return e is not None self.assertTrue(test_simple_event(), "Could not create CUDA Event!") # Record the CUDA event for operation torch.mm on the current stream # and then test if the elapsed time is greater than 0. This test is also # an adaption from eager mdoe CUDA tests available at test/test_cuda.py @torch.jit.script def test_event(): device_index = torch.cuda.current_device() device = torch.device("cuda:" + str(device_index)) stream = torch.cuda.current_stream(device) event = torch.cuda.Event(True, False, False) is_true_event_query = event.query() start_event = torch.cuda.Event(True, False, False) stream.record_event(start_event) tensor1 = torch.rand(1000000000, 1000000000, device="cuda") tensor2 = torch.mm(tensor1, tensor1).to("cuda") stream.record_event(event) event.synchronize() is_again_true_event_query = event.query() if not (is_true_event_query and is_again_true_event_query): return -1.0 return start_event.elapsed_time(event) self.assertGreater(test_event(), 0) # Check for stream synchronization , when a large tensor multiplication is # computed on the stream. The stream.query should be true once the synchroniztion is done @torch.jit.script def test_stream_synchronize() -> float: device_index = torch.cuda.current_device() s = torch.cuda.Stream() e_tik = torch.cuda.Event(True, False, False) e_tok = torch.cuda.Event(True, False, False) e_tik.record(s) tensor1 = torch.rand(1000000000, 1000000000, device="cuda") with torch.cuda.stream(s): tensor2 = torch.mm(tensor1, tensor1).to("cuda") s.synchronize() e_tok.record(s) e_tok.synchronize() if not s.query(): return -1.0 # not necessary to check e_tik and e_tok, as elapsed_time would throw # exception if otherwise. return e_tik.elapsed_time(e_tok) self.assertGreater(test_stream_synchronize(), 0) # Test event synchronization for the event that records a stream doing # a large tensor multiplication. Check if the elapsed time is greater than 0 # and the stream.query evaluates to true. @torch.jit.script def test_event_synchronize() -> float: s = torch.cuda.Stream() e_tik = torch.cuda.Event(True, False, False) e_tok = torch.cuda.Event(True, False, False) e_tik.record(s) tensor1 = torch.rand(1000000000, 1000000000, device="cuda") with torch.cuda.stream(s): tensor = torch.mm(tensor1, tensor1).to("cuda") s.record_event(e_tok) e_tok.synchronize() s.synchronize() if not s.query(): return -1.0 # not necessary to check e_tik and e_tok, as elapsed_time would throw # exception if otherwise. return e_tik.elapsed_time(e_tok) self.assertGreater(test_event_synchronize(), 0) # Test for event wait. Check if event waits for the all the operations on # the stream to be done. Check for synchronizations and query on the streams # and events. This test is adapted from eager mode tests for CUDA. Please refer # test/test_cuda.py @torch.jit.script def test_event_wait() -> float: device_index = torch.cuda.current_device() device = torch.device("cuda:" + str(device_index)) s0 = torch.cuda.current_stream(device) s1 = torch.cuda.Stream() e_tik = torch.cuda.Event(True, True, False) e_tok = torch.cuda.Event(True, True, False) e_tik.record(s0) tensor1 = torch.rand(1000000000, 1000000000, device="cuda") with torch.cuda.stream(s0): tensor2 = torch.mm(tensor1, tensor1).cuda() e_sync = torch.cuda.Event(True, False, False) e_sync.record(torch.cuda.current_stream(device)) e_sync.wait(s1) with torch.cuda.stream(s1): tensor3 = torch.rand(1000000000, 1000000000, device="cuda") tensor4 = torch.mm(tensor3, tensor3).cuda() s1.synchronize() e_tok.record(torch.cuda.current_stream(device)) e_tok.synchronize() s0.synchronize() if not s0.query() or not s1.query() or not e_sync.query(): return -1.0 # not necessary to check e_tik and e_tok, as elapsed_time would throw # exception if otherwise. return e_tik.elapsed_time(e_tok) self.assertGreater(test_event_wait(), 0) # Test for stream wait_event. Checks if the stream waits on the event @torch.jit.script def test_wait_event(): d1 = torch.device("cuda:1") with torch.cuda.device(d1): s0 = torch.cuda.current_stream(d1) tensor1 = torch.rand(1000000000, 1000000000, device="cuda") tensor2 = torch.mm(tensor1, tensor1).to("cuda") e0 = torch.cuda.Event(False, False, False) s0.record_event(e0) s1 = torch.cuda.current_stream(torch.device("cuda:0")) s1.wait_event(e0) s1.synchronize() return e0.query() and s0.query() and s1.query() self.assertTrue(test_wait_event()) # Test if a scripted module with cuda streams can be saved, loaded and executed def test_save_load(self): class Model(torch.nn.Module): def forward(self): s = torch.cuda.Stream() a = torch.rand(3, 4, device="cuda") b = torch.rand(3, 4, device="cuda") with torch.cuda.stream(s): is_stream_s = torch.cuda.current_stream(s.device).id() == s.id() c = torch.cat((a, b), 0).cuda() s.synchronize() return is_stream_s, a, b, c model = Model() # Script the model and save script_model = torch.jit.script(model) is_stream_s, a, b, c = script_model() # Verify if the output is correct self.assertTrue(is_stream_s) self.assertEqual(torch.cat((a, b), 0), c) # Save and load scripted model load_model = self.getExportImportCopy(script_model) is_stream_s, a_load, b_load, c_load = load_model() self.assertTrue(is_stream_s) self.assertEqual(torch.cat((a_load, b_load), 0), c_load) # Make sure that cuda._exchange_device doesn't get DCE'ed @unittest.skipIf(not TEST_CUDA, "Cuda not available") def test__exchange_device_op(self): def fn(device: int, tensor): torch.cuda._exchange_device(device) return tensor.cos().relu() fn_s = torch.jit.script(fn) # Just check the graph, don't run it. Otherwise, we'd need to # run this test on a multi-gpu CI runner, which is overkill. g = fn_s.graph FileCheck().check("cuda::_exchange_device(").run(g) torch._C._jit_pass_inline(g) FileCheck().check("cuda::_exchange_device(").run(g) # Make sure that cuda._maybe_exchange_device doesn't get DCE'ed @unittest.skipIf(not TEST_CUDA, "Cuda not available") def test__maybe_exchange_device_op(self): def fn(device: int, tensor): torch.cuda._maybe_exchange_device(device) return tensor.cos().relu() fn_s = torch.jit.script(fn) # Just check the graph, don't run it. Otherwise, we'd need to # run this test on a multi-gpu CI runner, which is overkill. g = fn_s.graph FileCheck().check("cuda::_maybe_exchange_device(").run(g) torch._C._jit_pass_inline(g) FileCheck().check("cuda::_maybe_exchange_device(").run(g)