xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/fuse_relu.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/fuse_relu.h>
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/ir/subgraph_matcher.h>
5 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
6 
7 namespace torch::jit {
8 
9 namespace {
fuseAddReluImpl(std::shared_ptr<Graph> & graph)10 void fuseAddReluImpl(std::shared_ptr<Graph>& graph) {
11   SubgraphRewriter rewriter;
12 
13   std::string add_relu_0 = R"(
14     graph(%a, %b, %alpha):
15         %add_res = aten::add(%a, %b, %alpha)
16         %res = aten::relu(%add_res)
17         return (%res))";
18   std::string add_relu_fused = R"(
19     graph(%a, %b, %alpha):
20         %res = aten::_add_relu(%a, %b, %alpha)
21         return (%res))";
22   rewriter.RegisterRewritePattern(add_relu_0, add_relu_fused);
23 
24   std::string add_relu_1 = R"(
25     graph(%a, %b, %alpha):
26         %add_res = aten::add(%a, %b, %alpha)
27         %res = aten::relu_(%add_res)
28         return (%res))";
29   rewriter.RegisterRewritePattern(add_relu_1, add_relu_fused);
30 
31   std::string add_inplace_relu_1 = R"(
32     graph(%a, %b, %alpha):
33         %add_res = aten::add_(%a, %b, %alpha)
34         %res = aten::relu_(%add_res)
35         return (%res))";
36   std::string add_inplace_relu_fused = R"(
37     graph(%a, %b, %alpha):
38         %res = aten::_add_relu_(%a, %b, %alpha)
39         return (%res))";
40   rewriter.RegisterRewritePattern(add_inplace_relu_1, add_inplace_relu_fused);
41 
42   std::string add_out_relu = R"(
43     graph(%a, %b, %alpha, %out):
44         %add_res = aten::add(%a, %b, %alpha, %out)
45         %res = aten::relu_(%add_res)
46         return (%res))";
47   std::string add_out_relu_fused = R"(
48     graph(%a, %b, %alpha, %out):
49         %res = aten::_add_relu(%a, %b, %alpha, %out)
50         return (%res))";
51 
52   rewriter.RegisterRewritePattern(add_out_relu, add_out_relu_fused);
53 
54   rewriter.runOnGraph(graph);
55   // NB: Patterns that are left out are add_ + relu and add_out + relu
56   // This is because inplace mutation of the tensor done by add_ will be lost if
57   // inplace mutation of the same tensor actually does add+relu
58 }
59 } // namespace
60 
FuseAddRelu(script::Module & module)61 void FuseAddRelu(script::Module& module) {
62   auto graph = module.get_method("forward").graph();
63   fuseAddReluImpl(graph);
64 }
65 
FuseAddRelu(std::shared_ptr<Graph> & graph)66 void FuseAddRelu(std::shared_ptr<Graph>& graph) {
67   fuseAddReluImpl(graph);
68 }
69 } // namespace torch::jit
70