1# Owner(s): ["oncall: jit"] 2 3import torch 4import torch._lazy.metrics as metrics 5import torch._lazy.ts_backend 6from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase 7 8 9torch._lazy.ts_backend.init() 10 11 12class LazyGeneratorTest(TestCase): 13 def test_generator(self): 14 """ 15 Test that generators are being inserted into the TorchScript 16 graph by setting different seeds before each call to 17 generate_tensor but the resulting tensor is the same 18 """ 19 20 def generate_tensor(): 21 g1 = torch.Generator() 22 g1.manual_seed(2023) 23 t1 = torch.tensor(1.0) 24 t1.uniform_(generator=g1) 25 26 g2 = torch.Generator() 27 g2.manual_seed(2024) 28 t2 = torch.tensor(1.0) 29 t2.normal_(generator=g2) 30 31 return t1, t2 32 33 torch.manual_seed(1) 34 35 with torch.device("cpu"): 36 cpu_t1, cpu_t2 = generate_tensor() 37 38 torch.manual_seed(2) 39 40 with torch.device("lazy"): 41 lazy_t1, lazy_t2 = generate_tensor() 42 43 torch._lazy.mark_step() 44 45 assert torch.allclose( 46 cpu_t1, lazy_t1.to("cpu") 47 ), f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}" 48 assert torch.allclose( 49 cpu_t2, lazy_t2.to("cpu") 50 ), f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}" 51 52 @skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type") 53 def test_generator_causes_multiple_compiles(self): 54 """ 55 Test that inserting generators with different seed caused recompile 56 """ 57 58 def generate_tensor(seed): 59 t = torch.tensor(1.0) 60 g = torch.Generator() 61 g.manual_seed(seed) 62 t.uniform_(-1, 1, generator=g) 63 return t 64 65 metrics.reset() 66 67 with torch.device("lazy"): 68 t = generate_tensor(1) 69 torch._lazy.mark_step() 70 71 uncached_compile = metrics.counter_value("UncachedCompile") 72 assert ( 73 uncached_compile == 1 74 ), f"Expected 1 uncached compiles, got {uncached_compile}" 75 76 t = generate_tensor(2) 77 torch._lazy.mark_step() 78 79 uncached_compile = metrics.counter_value("UncachedCompile") 80 assert ( 81 uncached_compile == 2 82 ), f"Expected 2 uncached compiles, got {uncached_compile}" 83 84 t = generate_tensor(1) 85 torch._lazy.mark_step() 86 87 uncached_compile = metrics.counter_value("UncachedCompile") 88 assert ( 89 uncached_compile == 2 90 ), f"Expected 2 uncached compiles, got {uncached_compile}" 91 cached_compile = metrics.counter_value("CachedCompile") 92 assert ( 93 cached_compile == 1 94 ), f"Expected 1 cached compile, got {cached_compile}" 95 96 metrics.reset() 97 98 latest_graph = torch._C._lazy_ts_backend._get_latest_computation_graph() 99 assert 'torch.Generator(device="cpu", seed=1)' in latest_graph 100 assert "aten::uniform" in latest_graph 101 102 103if __name__ == "__main__": 104 run_tests() 105