1 #include <oneapi/dnnl/dnnl_graph.hpp>
2 #include <torch/csrc/jit/codegen/onednn/decompose_silu.h>
3 #include <torch/csrc/jit/codegen/onednn/defer_size_check.h>
4 #include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
5 #include <torch/csrc/jit/codegen/onednn/guard_shape.h>
6 #include <torch/csrc/jit/codegen/onednn/interface.h>
7 #include <torch/csrc/jit/codegen/onednn/kernel.h>
8 #include <torch/csrc/jit/codegen/onednn/layout_propagation.h>
9 #include <torch/csrc/jit/codegen/onednn/prepare_binary.h>
10 #include <torch/csrc/jit/jit_log.h>
11 #include <torch/csrc/jit/passes/decompose_ops.h>
12 #include <torch/csrc/jit/passes/pass_manager.h>
13 #include <torch/csrc/jit/passes/remove_mutation.h>
14 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
15 #include <torch/csrc/jit/runtime/custom_operator.h>
16 #include <torch/csrc/jit/runtime/graph_executor.h>
17 #include <torch/csrc/jit/runtime/operator_options.h>
18
19 namespace torch {
20 namespace jit {
21 namespace fuser {
22 namespace onednn {
23
fuseGraph(std::shared_ptr<Graph> & g)24 void fuseGraph(std::shared_ptr<Graph>& g) {
25 // Follow the process of the tensorexpr_fuser in profiling mode:
26 // Remove prim::profile nodes and embed the profile info directly in the
27 // IR in value types to avoid breaking the fusion patterns.
28 // Will add shape guard after LLGA optimization passes and
29 // wipe the tensor type information from the IR, so that it's not
30 // accidentally used by any other pass.
31
32 // We rely on the shape specialization and shape guard to ensure the validity
33 // of the cached compilation in the kernel, thus only support profiling mode.
34 // TODO: add check on oneDNNFusionGroup to ensure allShapesAreKnown on nodes
35 // to fuse: torch/csrc/jit/passes/tensorexpr_fuser.cpp: allShapesAreKnown
36 if (getProfilingMode()) {
37 GRAPH_DUMP(
38 "Before RemoveProfileNodesAndSpecializeTypes. Beginning of LLGA "
39 "optimization pass",
40 g);
41 RemoveProfileNodesAndSpecializeTypes(g);
42 GRAPH_DUMP(
43 "After RemoveProfileNodesAndSpecializeTypes. Before mutation removal",
44 g);
45
46 RemoveTensorMutation(g, [](Node* nodeToFunctionalize) {
47 static std::unordered_set<Symbol> supportedOps = {
48 aten::add_,
49 aten::mul_,
50 aten::tanh_,
51 aten::elu_,
52 aten::relu_,
53 aten::relu6_,
54 aten::gelu_,
55 aten::sqrt_,
56 aten::sigmoid_,
57 aten::hardtanh_,
58 aten::abs_,
59 aten::square_,
60 aten::pow_,
61 aten::leaky_relu_,
62 aten::round_,
63 aten::exp_,
64 aten::abs_,
65 aten::hardswish_,
66 aten::silu_};
67 return supportedOps.count(nodeToFunctionalize->kind()) != 0;
68 });
69 RemoveListMutation(g);
70 GRAPH_DUMP("After mutation removal. Before DecomposeSiluForLlga", g);
71 DecomposeSiluForLLGA(g);
72 GRAPH_DUMP("After DecomposeSiluForLlga. Before PrepareBinaryForLLGA", g);
73 PrepareBinaryForLLGA(g);
74 GRAPH_DUMP("After PrepareBinaryForLLGA. Before DeferSizeCheck", g);
75 DeferSizeCheck(g);
76 GRAPH_DUMP("After DeferSizeCheck. Before CreateLlgaSubgraphs", g);
77 dnnl::graph::set_constant_tensor_cache(true);
78 CreateLlgaSubgraphs(g);
79 GRAPH_DUMP("After CreateLlgaSubgraphs. Before PropagateLayout", g);
80 PropagateLayout(g);
81 GRAPH_DUMP(
82 "After PropagateLayout. Before prepareFusionGroupAndGuardOutputs", g);
83
84 // Add shape guard for profiling mode and wipe the tensor type information
85 // from the IR
86 prepareFusionGroupAndGuardOutputs(g->block());
87 GRAPH_DUMP(
88 "After prepareFusionGroupAndGuardOutputs. Before "
89 "RemoveTensorTypeSpecializations",
90 g);
91 RemoveTensorTypeSpecializations(g);
92 GRAPH_DUMP(
93 "After RemoveTensorTypeSpecializations. End of LLGA optimization pass",
94 g);
95 }
96 }
97
98 } // namespace onednn
99 } // namespace fuser
100
createLlgaKernel(const Node * node)101 static Operation createLlgaKernel(const Node* node) {
102 auto kernel = std::make_shared<fuser::onednn::LlgaKernel>(node);
103 return [kernel](Stack& stack) {
104 RECORD_FUNCTION(kernel->debugName(), std::vector<c10::IValue>());
105 kernel->run(stack);
106 return 0;
107 };
108 }
109
110 RegisterOperators oneDNNFusionGroupOp({
111 torch::jit::Operator(
112 prim::oneDNNFusionGroup,
113 createLlgaKernel,
114 AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
115 });
116
117 // Currently, we convert some scalar inputs, such as the second argument of
118 // binary ops to a 1D tensor. Other scalar inputs are prim::Constant nodes.
119 // But if we have any scalar inputs to guard in the future, some logic here
120 // would have to be changed.
createLlgaGuardKernel(const Node * node)121 static Operation createLlgaGuardKernel(const Node* node) {
122 return [node](Stack& stack) {
123 #ifdef GRAPH_DEBUG_ENABLED
124 GRAPH_DEBUG("Guarding node: ", node->kind().toQualString());
125 #endif
126 std::vector<TypePtr> types = node->tys(attr::types);
127 const auto num_inputs = types.size();
128 #ifdef GRAPH_DEBUG_ENABLED
129 GRAPH_DEBUG("num_inputs to guard: ", num_inputs);
130 #endif
131 for (size_t i = 0; i < num_inputs; i++) {
132 #ifdef GRAPH_DEBUG_ENABLED
133 GRAPH_DEBUG("checking input ", i);
134 #endif
135 auto& input = peek(stack, i, num_inputs);
136 const c10::TensorTypePtr& guard_tensor_type =
137 types[i]->cast<TensorType>();
138
139 if (!input.isTensor()) {
140 #ifdef GRAPH_DEBUG_ENABLED
141 GRAPH_DEBUG("input ", i, " is not a tensor, return false");
142 #endif
143 push(stack, IValue(false));
144 return;
145 }
146 const at::Tensor& tensor = input.toTensor();
147
148 // If input tensor is of mkldnn, it's originated from an upstream
149 // LLGA partition that has passed the check on input shapes.
150 // It is valid to continue here as long as the output shapes from
151 // oneDNN graph partitions are determined by the input shapes.
152 if (tensor.is_mkldnn()) {
153 #ifdef GRAPH_DEBUG_ENABLED
154 GRAPH_DEBUG("input ", i, " is_mkldnn, continue");
155 #endif
156 continue;
157 }
158
159 if (!guard_tensor_type->matchTensor(tensor)) {
160 #ifdef GRAPH_DEBUG_ENABLED
161 GRAPH_DEBUG("input ", i, " check failed, return false");
162 #endif
163 push(stack, IValue(false));
164 return;
165 }
166 }
167 #ifdef GRAPH_DEBUG_ENABLED
168 GRAPH_DEBUG("all check done, return true");
169 #endif
170 push(stack, IValue(true));
171 return;
172 };
173 }
174
175 RegisterOperators oneDNNGuardOp({
176 torch::jit::Operator(
177 prim::oneDNNFusionGuard,
178 createLlgaGuardKernel,
179 AliasAnalysisKind::FROM_SCHEMA),
180 });
181 } // namespace jit
182 } // namespace torch
183