1import time 2 3import torch 4import torch.utils 5from functorch.compile import aot_function, tvm_compile 6 7 8a = torch.randn(2000, 1, 4, requires_grad=True) 9b = torch.randn(1, 2000, 4) 10 11 12def f(a): 13 return (a * b).sum(dim=0) 14 15 16fw_compiler = tvm_compile(target="llvm", tuning_logfile="fw_keops") 17bw_compiler = tvm_compile(target="llvm", tuning_logfile="bw_keops") 18compiled_f = aot_function(f, fw_compiler, bw_compiler) 19 20# fw_compiler = lambda x, _: x 21# bw_compiler = lambda x, _: x 22iters = 10 23out = compiled_f(a) 24out.sum().backward() 25 26 27def bench(func): 28 begin = time.time() 29 for _ in range(iters): 30 out = func(a).sin() 31 out.sum().backward() 32 a.grad = None 33 print(time.time() - begin) 34 35 36def bench_jax(): 37 import jax 38 import jax.numpy as jnp 39 40 jax_a = jnp.array(a.detach().numpy()) 41 jax_b = jnp.array(b.detach().numpy()) 42 43 def f(a): 44 return jnp.sin((a * jax_b).sum(axis=[0])).sum() 45 46 jit_f = jax.jit(jax.grad(f)) 47 jit_f(jax_a) 48 begin = time.time() 49 for _ in range(iters): 50 out = jit_f(jax_a) 51 out.block_until_ready() 52 print(time.time() - begin) 53 # for 54 55 56bench(f) 57bench(compiled_f) 58# bench_jax() 59