xref: /aosp_15_r20/external/pytorch/test/lazy/test_generator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import torch
4import torch._lazy.metrics as metrics
5import torch._lazy.ts_backend
6from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
7
8
9torch._lazy.ts_backend.init()
10
11
12class LazyGeneratorTest(TestCase):
13    def test_generator(self):
14        """
15        Test that generators are being inserted into the TorchScript
16        graph by setting different seeds before each call to
17        generate_tensor but the resulting tensor is the same
18        """
19
20        def generate_tensor():
21            g1 = torch.Generator()
22            g1.manual_seed(2023)
23            t1 = torch.tensor(1.0)
24            t1.uniform_(generator=g1)
25
26            g2 = torch.Generator()
27            g2.manual_seed(2024)
28            t2 = torch.tensor(1.0)
29            t2.normal_(generator=g2)
30
31            return t1, t2
32
33        torch.manual_seed(1)
34
35        with torch.device("cpu"):
36            cpu_t1, cpu_t2 = generate_tensor()
37
38        torch.manual_seed(2)
39
40        with torch.device("lazy"):
41            lazy_t1, lazy_t2 = generate_tensor()
42
43        torch._lazy.mark_step()
44
45        assert torch.allclose(
46            cpu_t1, lazy_t1.to("cpu")
47        ), f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}"
48        assert torch.allclose(
49            cpu_t2, lazy_t2.to("cpu")
50        ), f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}"
51
52    @skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type")
53    def test_generator_causes_multiple_compiles(self):
54        """
55        Test that inserting generators with different seed caused recompile
56        """
57
58        def generate_tensor(seed):
59            t = torch.tensor(1.0)
60            g = torch.Generator()
61            g.manual_seed(seed)
62            t.uniform_(-1, 1, generator=g)
63            return t
64
65        metrics.reset()
66
67        with torch.device("lazy"):
68            t = generate_tensor(1)
69            torch._lazy.mark_step()
70
71            uncached_compile = metrics.counter_value("UncachedCompile")
72            assert (
73                uncached_compile == 1
74            ), f"Expected 1 uncached compiles, got {uncached_compile}"
75
76            t = generate_tensor(2)
77            torch._lazy.mark_step()
78
79            uncached_compile = metrics.counter_value("UncachedCompile")
80            assert (
81                uncached_compile == 2
82            ), f"Expected 2 uncached compiles, got {uncached_compile}"
83
84            t = generate_tensor(1)
85            torch._lazy.mark_step()
86
87            uncached_compile = metrics.counter_value("UncachedCompile")
88            assert (
89                uncached_compile == 2
90            ), f"Expected 2 uncached compiles, got {uncached_compile}"
91            cached_compile = metrics.counter_value("CachedCompile")
92            assert (
93                cached_compile == 1
94            ), f"Expected 1 cached compile, got {cached_compile}"
95
96        metrics.reset()
97
98        latest_graph = torch._C._lazy_ts_backend._get_latest_computation_graph()
99        assert 'torch.Generator(device="cpu", seed=1)' in latest_graph
100        assert "aten::uniform" in latest_graph
101
102
103if __name__ == "__main__":
104    run_tests()
105