xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/prepare_binary.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/ir/ir.h>
4 
5 namespace torch {
6 namespace jit {
7 namespace fuser {
8 namespace onednn {
9 
10 // Prepare binary ops for LLGA
11 //
12 // The pass does the following:
13 //
14 // - Convert scalar input of aten::add and aten::mul into Float tensor with
15 //   dimension [1]
16 //
17 // - Decompose fused add into aten::mul + aten::add when alpha != 1.0
18 //
19 // - Eliminate identity add/mul, i.e., tensor + 0, tensor * 1
20 //
21 void PrepareBinaryForLLGA(const std::shared_ptr<Graph>& graph);
22 
23 } // namespace onednn
24 } // namespace fuser
25 } // namespace jit
26 } // namespace torch
27