1 #pragma once 2 3 #include <unordered_map> 4 5 #include <oneapi/dnnl/dnnl_graph.hpp> 6 #include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h> 7 #include <torch/csrc/jit/codegen/onednn/graph_helper.h> 8 #include <torch/csrc/jit/ir/ir.h> 9 #include <torch/csrc/jit/runtime/interpreter.h> 10 11 #include <c10/util/CallOnce.h> 12 13 namespace torch { 14 namespace jit { 15 namespace fuser { 16 namespace onednn { 17 18 using ArgSpec = LlgaTensorDesc; 19 using ArgSpecs = std::vector<ArgSpec>; 20 using RunArg = dnnl::graph::tensor; 21 using RunArgs = std::vector<RunArg>; 22 using TensorArgs = std::vector<at::Tensor>; 23 24 class LlgaKernel { 25 public: 26 explicit LlgaKernel(const Node* fusionNode); 27 28 void run(Stack& stack); 29 30 void initialize(const TensorArgs& inputs); 31 debugName()32 const std::string& debugName() const { 33 return debugName_; 34 } 35 36 private: 37 bool useOpaqueLayout(size_t offset) const; 38 39 // PyTorch copy constants inside the subgraph instead of referencing them. 40 // Constants inputs to the partition are no longer in the graph->inputs(). 41 // Need use the tid retrieved from the partition to find the missing 42 // constant inputs. 43 void initializeConstantInputs(); 44 45 ArgSpecs initializeInputSpecs(const TensorArgs& inputs); 46 47 ArgSpecs initializeOutputSpecs() const; 48 49 dnnl::graph::compiled_partition compile( 50 const dnnl::graph::partition& partition); 51 52 std::map<size_t, int64_t> initializeTensorIdToOccurence() const; 53 54 std::tuple<RunArgs, RunArgs> prepareRunArgs( 55 const TensorArgs& inputs, 56 TensorArgs& outputs) const; 57 genDebugName()58 static std::string genDebugName() { 59 static size_t debugId = 0; 60 return "LlgaPartition_" + std::to_string(debugId++); 61 } 62 toLogicalTensor(const ArgSpec & s)63 static dnnl::graph::logical_tensor toLogicalTensor(const ArgSpec& s) { 64 return s.logical_tensor(); 65 } 66 67 at::Device device_ = at::kCPU; 68 const Node* fusionNode_; 69 std::shared_ptr<Graph> graph_; 70 int64_t nGraphInputs_ = 0; // number of inputs to graph_ on the IR 71 int64_t nOutputs_ = 0; 72 std::map<size_t, Value*> tensorIdToValue_; 73 std::vector<int64_t> runArgsIdx_; 74 dnnl::graph::partition partition_; 75 // nPartitionInputs_ is the actual number of inputs to partition_ of graph_ 76 // needed by the backend. 77 // nPartitionInputs_ = nGraphInputs_ + constantInputs_.size() since Constant 78 // inputs are copied to the inside of the subgraph 79 int64_t nPartitionInputs_; 80 dnnl::graph::compiled_partition compilation_; 81 std::set<size_t> initializedInputIds_; 82 std::vector<Value*> constantValues_; 83 TensorArgs constantInputs_; 84 ArgSpecs inputSpecs_; 85 ArgSpecs outputSpecs_; 86 std::vector<dnnl::graph::logical_tensor> constantLogicalTensors_; 87 std::string debugName_; 88 c10::once_flag initialized_flag; 89 bool is_initialized_ = false; 90 }; 91 92 } // namespace onednn 93 } // namespace fuser 94 } // namespace jit 95 } // namespace torch 96