xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/interface.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <oneapi/dnnl/dnnl_graph.hpp>
2 #include <torch/csrc/jit/codegen/onednn/decompose_silu.h>
3 #include <torch/csrc/jit/codegen/onednn/defer_size_check.h>
4 #include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
5 #include <torch/csrc/jit/codegen/onednn/guard_shape.h>
6 #include <torch/csrc/jit/codegen/onednn/interface.h>
7 #include <torch/csrc/jit/codegen/onednn/kernel.h>
8 #include <torch/csrc/jit/codegen/onednn/layout_propagation.h>
9 #include <torch/csrc/jit/codegen/onednn/prepare_binary.h>
10 #include <torch/csrc/jit/jit_log.h>
11 #include <torch/csrc/jit/passes/decompose_ops.h>
12 #include <torch/csrc/jit/passes/pass_manager.h>
13 #include <torch/csrc/jit/passes/remove_mutation.h>
14 #include <torch/csrc/jit/passes/tensorexpr_fuser.h>
15 #include <torch/csrc/jit/runtime/custom_operator.h>
16 #include <torch/csrc/jit/runtime/graph_executor.h>
17 #include <torch/csrc/jit/runtime/operator_options.h>
18 
19 namespace torch {
20 namespace jit {
21 namespace fuser {
22 namespace onednn {
23 
fuseGraph(std::shared_ptr<Graph> & g)24 void fuseGraph(std::shared_ptr<Graph>& g) {
25   // Follow the process of the tensorexpr_fuser in profiling mode:
26   // Remove prim::profile nodes and embed the profile info directly in the
27   // IR in value types to avoid breaking the fusion patterns.
28   // Will add shape guard after LLGA optimization passes and
29   // wipe the tensor type information from the IR, so that it's not
30   // accidentally used by any other pass.
31 
32   // We rely on the shape specialization and shape guard to ensure the validity
33   // of the cached compilation in the kernel, thus only support profiling mode.
34   // TODO: add check on oneDNNFusionGroup to ensure allShapesAreKnown on nodes
35   // to fuse: torch/csrc/jit/passes/tensorexpr_fuser.cpp: allShapesAreKnown
36   if (getProfilingMode()) {
37     GRAPH_DUMP(
38         "Before RemoveProfileNodesAndSpecializeTypes. Beginning of LLGA "
39         "optimization pass",
40         g);
41     RemoveProfileNodesAndSpecializeTypes(g);
42     GRAPH_DUMP(
43         "After RemoveProfileNodesAndSpecializeTypes. Before mutation removal",
44         g);
45 
46     RemoveTensorMutation(g, [](Node* nodeToFunctionalize) {
47       static std::unordered_set<Symbol> supportedOps = {
48           aten::add_,
49           aten::mul_,
50           aten::tanh_,
51           aten::elu_,
52           aten::relu_,
53           aten::relu6_,
54           aten::gelu_,
55           aten::sqrt_,
56           aten::sigmoid_,
57           aten::hardtanh_,
58           aten::abs_,
59           aten::square_,
60           aten::pow_,
61           aten::leaky_relu_,
62           aten::round_,
63           aten::exp_,
64           aten::abs_,
65           aten::hardswish_,
66           aten::silu_};
67       return supportedOps.count(nodeToFunctionalize->kind()) != 0;
68     });
69     RemoveListMutation(g);
70     GRAPH_DUMP("After mutation removal. Before DecomposeSiluForLlga", g);
71     DecomposeSiluForLLGA(g);
72     GRAPH_DUMP("After DecomposeSiluForLlga. Before PrepareBinaryForLLGA", g);
73     PrepareBinaryForLLGA(g);
74     GRAPH_DUMP("After PrepareBinaryForLLGA. Before DeferSizeCheck", g);
75     DeferSizeCheck(g);
76     GRAPH_DUMP("After DeferSizeCheck. Before CreateLlgaSubgraphs", g);
77     dnnl::graph::set_constant_tensor_cache(true);
78     CreateLlgaSubgraphs(g);
79     GRAPH_DUMP("After CreateLlgaSubgraphs. Before PropagateLayout", g);
80     PropagateLayout(g);
81     GRAPH_DUMP(
82         "After PropagateLayout. Before prepareFusionGroupAndGuardOutputs", g);
83 
84     // Add shape guard for profiling mode and wipe the tensor type information
85     // from the IR
86     prepareFusionGroupAndGuardOutputs(g->block());
87     GRAPH_DUMP(
88         "After prepareFusionGroupAndGuardOutputs. Before "
89         "RemoveTensorTypeSpecializations",
90         g);
91     RemoveTensorTypeSpecializations(g);
92     GRAPH_DUMP(
93         "After RemoveTensorTypeSpecializations. End of LLGA optimization pass",
94         g);
95   }
96 }
97 
98 } // namespace onednn
99 } // namespace fuser
100 
createLlgaKernel(const Node * node)101 static Operation createLlgaKernel(const Node* node) {
102   auto kernel = std::make_shared<fuser::onednn::LlgaKernel>(node);
103   return [kernel](Stack& stack) {
104     RECORD_FUNCTION(kernel->debugName(), std::vector<c10::IValue>());
105     kernel->run(stack);
106     return 0;
107   };
108 }
109 
110 RegisterOperators oneDNNFusionGroupOp({
111     torch::jit::Operator(
112         prim::oneDNNFusionGroup,
113         createLlgaKernel,
114         AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
115 });
116 
117 // Currently, we convert some scalar inputs, such as the second argument of
118 // binary ops to a 1D tensor. Other scalar inputs are prim::Constant nodes.
119 // But if we have any scalar inputs to guard in the future, some logic here
120 // would have to be changed.
createLlgaGuardKernel(const Node * node)121 static Operation createLlgaGuardKernel(const Node* node) {
122   return [node](Stack& stack) {
123 #ifdef GRAPH_DEBUG_ENABLED
124     GRAPH_DEBUG("Guarding node: ", node->kind().toQualString());
125 #endif
126     std::vector<TypePtr> types = node->tys(attr::types);
127     const auto num_inputs = types.size();
128 #ifdef GRAPH_DEBUG_ENABLED
129     GRAPH_DEBUG("num_inputs to guard: ", num_inputs);
130 #endif
131     for (size_t i = 0; i < num_inputs; i++) {
132 #ifdef GRAPH_DEBUG_ENABLED
133       GRAPH_DEBUG("checking input ", i);
134 #endif
135       auto& input = peek(stack, i, num_inputs);
136       const c10::TensorTypePtr& guard_tensor_type =
137           types[i]->cast<TensorType>();
138 
139       if (!input.isTensor()) {
140 #ifdef GRAPH_DEBUG_ENABLED
141         GRAPH_DEBUG("input ", i, " is not a tensor, return false");
142 #endif
143         push(stack, IValue(false));
144         return;
145       }
146       const at::Tensor& tensor = input.toTensor();
147 
148       // If input tensor is of mkldnn, it's originated from an upstream
149       // LLGA partition that has passed the check on input shapes.
150       // It is valid to continue here as long as the output shapes from
151       // oneDNN graph partitions are determined by the input shapes.
152       if (tensor.is_mkldnn()) {
153 #ifdef GRAPH_DEBUG_ENABLED
154         GRAPH_DEBUG("input ", i, " is_mkldnn, continue");
155 #endif
156         continue;
157       }
158 
159       if (!guard_tensor_type->matchTensor(tensor)) {
160 #ifdef GRAPH_DEBUG_ENABLED
161         GRAPH_DEBUG("input ", i, " check failed, return false");
162 #endif
163         push(stack, IValue(false));
164         return;
165       }
166     }
167 #ifdef GRAPH_DEBUG_ENABLED
168     GRAPH_DEBUG("all check done, return true");
169 #endif
170     push(stack, IValue(true));
171     return;
172   };
173 }
174 
175 RegisterOperators oneDNNGuardOp({
176     torch::jit::Operator(
177         prim::oneDNNFusionGuard,
178         createLlgaGuardKernel,
179         AliasAnalysisKind::FROM_SCHEMA),
180 });
181 } // namespace jit
182 } // namespace torch
183