xref: /aosp_15_r20/external/pytorch/benchmarks/fastrnns/fuser.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport torch
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerdef set_fuser(fuser_name, executor_name):
5*da0073e9SAndroid Build Coastguard Worker    assert fuser_name in ["te", "old", "none", "default"]
6*da0073e9SAndroid Build Coastguard Worker    if fuser_name == "te":
7*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_profiling_executor(True)
8*da0073e9SAndroid Build Coastguard Worker        torch._C._get_graph_executor_optimize(True)
9*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_override_can_fuse_on_cpu(False)
10*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_override_can_fuse_on_gpu(True)
11*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_texpr_fuser_enabled(True)
12*da0073e9SAndroid Build Coastguard Worker    elif fuser_name == "old":
13*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_profiling_executor(False)
14*da0073e9SAndroid Build Coastguard Worker        torch._C._get_graph_executor_optimize(False)
15*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_override_can_fuse_on_gpu(True)
16*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_texpr_fuser_enabled(False)
17*da0073e9SAndroid Build Coastguard Worker    elif fuser_name == "none":
18*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_profiling_executor(False)
19*da0073e9SAndroid Build Coastguard Worker        torch._C._get_graph_executor_optimize(False)
20*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_override_can_fuse_on_gpu(False)
21*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_override_can_fuse_on_cpu(False)
22*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_texpr_fuser_enabled(False)
23*da0073e9SAndroid Build Coastguard Worker    elif fuser_name == "default":
24*da0073e9SAndroid Build Coastguard Worker        pass
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker    # --executor overrides settings of --fuser
27*da0073e9SAndroid Build Coastguard Worker    if executor_name == "profiling":
28*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_profiling_executor(True)
29*da0073e9SAndroid Build Coastguard Worker        torch._C._get_graph_executor_optimize(True)
30*da0073e9SAndroid Build Coastguard Worker    elif executor_name == "simple":
31*da0073e9SAndroid Build Coastguard Worker        torch._C._get_graph_executor_optimize(False)
32*da0073e9SAndroid Build Coastguard Worker    elif executor_name == "legacy":
33*da0073e9SAndroid Build Coastguard Worker        torch._C._jit_set_profiling_executor(False)
34*da0073e9SAndroid Build Coastguard Worker        torch._C._get_graph_executor_optimize(True)
35*da0073e9SAndroid Build Coastguard Worker    elif executor_name == "default":
36*da0073e9SAndroid Build Coastguard Worker        pass
37