xref: /aosp_15_r20/external/pytorch/functorch/examples/compilation/simple_function.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Copyright (c) Facebook, Inc. and its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import time
8
9import torch
10from functorch import grad, make_fx
11from functorch.compile import nnc_jit
12
13
14def f(x):
15    return torch.sin(x).sum()
16
17
18inp = torch.randn(100)
19grad_pt = grad(f)
20grad_fx = make_fx(grad_pt)(inp)
21grad_nnc = nnc_jit(grad_pt)
22
23
24def bench(name, f, iters=10000, warmup=3):
25    for _ in range(warmup):
26        f()
27    begin = time.time()
28    for _ in range(iters):
29        f()
30    print(f"{name}: ", time.time() - begin)
31
32
33bench("Pytorch: ", lambda: grad_pt(inp))
34bench("FX: ", lambda: grad_fx(inp))
35bench("NNC: ", lambda: grad_nnc(inp))
36