xref: /aosp_15_r20/external/pytorch/functorch/examples/compilation/eager_fusion.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import time
2
3import torch
4import torch.utils
5from functorch.compile import aot_function, tvm_compile
6
7
8a = torch.randn(2000, 1, 4, requires_grad=True)
9b = torch.randn(1, 2000, 4)
10
11
12def f(a):
13    return (a * b).sum(dim=0)
14
15
16fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops")
17bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops")
18compiled_f = aot_function(f, fw_compiler, bw_compiler)
19
20# fw_compiler = lambda x, _: x
21# bw_compiler = lambda x, _: x
22iters = 10
23out = compiled_f(a)
24out.sum().backward()
25
26
27def bench(func):
28    begin = time.time()
29    for _ in range(iters):
30        out = func(a).sin()
31        out.sum().backward()
32        a.grad = None
33    print(time.time() - begin)
34
35
36def bench_jax():
37    import jax
38    import jax.numpy as jnp
39
40    jax_a = jnp.array(a.detach().numpy())
41    jax_b = jnp.array(b.detach().numpy())
42
43    def f(a):
44        return jnp.sin((a * jax_b).sum(axis=[0])).sum()
45
46    jit_f = jax.jit(jax.grad(f))
47    jit_f(jax_a)
48    begin = time.time()
49    for _ in range(iters):
50        out = jit_f(jax_a)
51    out.block_until_ready()
52    print(time.time() - begin)
53    # for
54
55
56bench(f)
57bench(compiled_f)
58# bench_jax()
59