1# Copyright (c) Facebook, Inc. and its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import time 8 9import torch 10from functorch import grad, make_fx 11from functorch.compile import nnc_jit 12 13 14def f(x): 15 return torch.sin(x).sum() 16 17 18inp = torch.randn(100) 19grad_pt = grad(f) 20grad_fx = make_fx(grad_pt)(inp) 21grad_nnc = nnc_jit(grad_pt) 22 23 24def bench(name, f, iters=10000, warmup=3): 25 for _ in range(warmup): 26 f() 27 begin = time.time() 28 for _ in range(iters): 29 f() 30 print(f"{name}: ", time.time() - begin) 31 32 33bench("Pytorch: ", lambda: grad_pt(inp)) 34bench("FX: ", lambda: grad_fx(inp)) 35bench("NNC: ", lambda: grad_nnc(inp)) 36