1 #include <torch/csrc/jit/passes/fuse_relu.h>
2
3 #include <torch/csrc/jit/ir/ir.h>
4 #include <torch/csrc/jit/ir/subgraph_matcher.h>
5 #include <torch/csrc/jit/passes/subgraph_rewrite.h>
6
7 namespace torch::jit {
8
9 namespace {
fuseAddReluImpl(std::shared_ptr<Graph> & graph)10 void fuseAddReluImpl(std::shared_ptr<Graph>& graph) {
11 SubgraphRewriter rewriter;
12
13 std::string add_relu_0 = R"(
14 graph(%a, %b, %alpha):
15 %add_res = aten::add(%a, %b, %alpha)
16 %res = aten::relu(%add_res)
17 return (%res))";
18 std::string add_relu_fused = R"(
19 graph(%a, %b, %alpha):
20 %res = aten::_add_relu(%a, %b, %alpha)
21 return (%res))";
22 rewriter.RegisterRewritePattern(add_relu_0, add_relu_fused);
23
24 std::string add_relu_1 = R"(
25 graph(%a, %b, %alpha):
26 %add_res = aten::add(%a, %b, %alpha)
27 %res = aten::relu_(%add_res)
28 return (%res))";
29 rewriter.RegisterRewritePattern(add_relu_1, add_relu_fused);
30
31 std::string add_inplace_relu_1 = R"(
32 graph(%a, %b, %alpha):
33 %add_res = aten::add_(%a, %b, %alpha)
34 %res = aten::relu_(%add_res)
35 return (%res))";
36 std::string add_inplace_relu_fused = R"(
37 graph(%a, %b, %alpha):
38 %res = aten::_add_relu_(%a, %b, %alpha)
39 return (%res))";
40 rewriter.RegisterRewritePattern(add_inplace_relu_1, add_inplace_relu_fused);
41
42 std::string add_out_relu = R"(
43 graph(%a, %b, %alpha, %out):
44 %add_res = aten::add(%a, %b, %alpha, %out)
45 %res = aten::relu_(%add_res)
46 return (%res))";
47 std::string add_out_relu_fused = R"(
48 graph(%a, %b, %alpha, %out):
49 %res = aten::_add_relu(%a, %b, %alpha, %out)
50 return (%res))";
51
52 rewriter.RegisterRewritePattern(add_out_relu, add_out_relu_fused);
53
54 rewriter.runOnGraph(graph);
55 // NB: Patterns that are left out are add_ + relu and add_out + relu
56 // This is because inplace mutation of the tensor done by add_ will be lost if
57 // inplace mutation of the same tensor actually does add+relu
58 }
59 } // namespace
60
FuseAddRelu(script::Module & module)61 void FuseAddRelu(script::Module& module) {
62 auto graph = module.get_method("forward").graph();
63 fuseAddReluImpl(graph);
64 }
65
FuseAddRelu(std::shared_ptr<Graph> & graph)66 void FuseAddRelu(std::shared_ptr<Graph>& graph) {
67 fuseAddReluImpl(graph);
68 }
69 } // namespace torch::jit
70