1 #pragma once 2 3 #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h> 4 #include <torch/csrc/lazy/core/ir.h> 5 #include <torch/csrc/lazy/core/ir_builder.h> 6 #include <torch/csrc/lazy/core/shape_inference.h> 7 #include <torch/csrc/lazy/generated/LazyNonNativeIr.h> 8 #include <torch/csrc/lazy/ts_backend/dynamic_ir.h> 9 #include <torch/csrc/lazy/ts_backend/ops/device_data.h> 10 #include <torch/csrc/lazy/ts_backend/ops/generic.h> 11 #include <torch/csrc/lazy/ts_backend/ts_node.h> 12 13 namespace torch { 14 namespace lazy { 15 16 struct TorchScriptIrBuilder : IrBuilder { MakeDeviceDataTorchScriptIrBuilder17 NodePtr MakeDeviceData( 18 const std::shared_ptr<BackendData>& data) const override { 19 return DeviceData::Create(data); 20 } 21 // TODO: Scalar node is not currently used by ts_backend. Enable reusing 22 // Scalar node later if needed. MakeScalarTorchScriptIrBuilder23 NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type) 24 const override { 25 return MakeNode<Scalar>(value, type); 26 } MakeExpandTorchScriptIrBuilder27 NodePtr MakeExpand( 28 const Value& input0, 29 const std::vector<int64_t>& size, 30 const bool& is_scalar_expand) const override { 31 return ReuseOrMakeNode<Expand>(input0, size, is_scalar_expand); 32 } 33 NodePtr MakeCast( 34 const Value& input0, 35 const at::ScalarType& dtype, 36 const std::optional<at::ScalarType>& stype = 37 std::nullopt) const override { 38 return ReuseOrMakeNode<Cast>(input0, dtype, stype); 39 } MakeTensorListTorchScriptIrBuilder40 NodePtr MakeTensorList(const OpList& inputs) const override { 41 return ReuseOrMakeNode<TensorList>(inputs); 42 } 43 // Generic needs cleanup 44 NodePtr MakeGeneric( 45 const OpKind& op, 46 const OpList& operands, 47 const Shape& shape, 48 const size_t& num_outputs = 1, 49 const hash_t& hash_seed = 50 static_cast<uint32_t>(0x5a2d296e9)) const override { 51 return MakeNode<Generic>(op, operands, shape, num_outputs, hash_seed); 52 } 53 54 // dynamic ir nodes 55 // TODO: verify if IR node reusing works for Dynamic shape ops MakeSizeNodeTorchScriptIrBuilder56 NodePtr MakeSizeNode(const Value& input, size_t dim) const override { 57 return MakeNode<SizeNode>(input, dim); 58 } MakeSizeAddTorchScriptIrBuilder59 NodePtr MakeSizeAdd(const Value& a, const Value& b) const override { 60 return MakeNode<SizeAdd>(a, b); 61 } MakeSizeMulTorchScriptIrBuilder62 NodePtr MakeSizeMul(const Value& a, const Value& b) const override { 63 return MakeNode<SizeMul>(a, b); 64 } MakeSizeDivTorchScriptIrBuilder65 NodePtr MakeSizeDiv(const Value& a, const Value& b) const override { 66 return MakeNode<SizeDiv>(a, b); 67 } 68 }; 69 70 } // namespace lazy 71 } // namespace torch 72