1 #pragma once 2 3 #include <torch/csrc/jit/ir/ir.h> 4 5 namespace torch { 6 namespace jit { 7 namespace fuser { 8 namespace onednn { 9 10 // Prepare binary ops for LLGA 11 // 12 // The pass does the following: 13 // 14 // - Convert scalar input of aten::add and aten::mul into Float tensor with 15 // dimension [1] 16 // 17 // - Decompose fused add into aten::mul + aten::add when alpha != 1.0 18 // 19 // - Eliminate identity add/mul, i.e., tensor + 0, tensor * 1 20 // 21 void PrepareBinaryForLLGA(const std::shared_ptr<Graph>& graph); 22 23 } // namespace onednn 24 } // namespace fuser 25 } // namespace jit 26 } // namespace torch 27