1import time 2import timeit 3 4import numpy as np 5 6import torch 7 8 9def add1(x): 10 return x + 1 11 12 13def bench(name, fn, requires_grad): 14 torch._dynamo.reset() 15 x = torch.randn(1, requires_grad=requires_grad) 16 start = time.perf_counter() 17 for _ in range(3): 18 fn(x) 19 end = time.perf_counter() 20 21 results = timeit.repeat(lambda: fn(x), number=1000, repeat=1000) 22 print(f"{name} {np.median(results)*1000:.1f}us (warmup={end-start:.1f}s)") 23 24 25def main(): 26 print("requires_grad=False") 27 bench("eager ", add1, False) 28 bench("compiled", torch.compile(add1), False) 29 print() 30 print("requires_grad=True") 31 bench("eager ", add1, True) 32 bench("compiled", torch.compile(add1), True) 33 print() 34 print("inference_mode()") 35 with torch.inference_mode(): 36 bench("eager ", add1, False) 37 bench("compiled", torch.compile(add1), False) 38 39 40if __name__ == "__main__": 41 main() 42