xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/ts_backend/ir_builder.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/lazy/core/internal_ops/ltc_ops.h>
4 #include <torch/csrc/lazy/core/ir.h>
5 #include <torch/csrc/lazy/core/ir_builder.h>
6 #include <torch/csrc/lazy/core/shape_inference.h>
7 #include <torch/csrc/lazy/generated/LazyNonNativeIr.h>
8 #include <torch/csrc/lazy/ts_backend/dynamic_ir.h>
9 #include <torch/csrc/lazy/ts_backend/ops/device_data.h>
10 #include <torch/csrc/lazy/ts_backend/ops/generic.h>
11 #include <torch/csrc/lazy/ts_backend/ts_node.h>
12 
13 namespace torch {
14 namespace lazy {
15 
16 struct TorchScriptIrBuilder : IrBuilder {
MakeDeviceDataTorchScriptIrBuilder17   NodePtr MakeDeviceData(
18       const std::shared_ptr<BackendData>& data) const override {
19     return DeviceData::Create(data);
20   }
21   // TODO: Scalar node is not currently used by ts_backend. Enable reusing
22   // Scalar node later if needed.
MakeScalarTorchScriptIrBuilder23   NodePtr MakeScalar(const at::Scalar& value, const at::ScalarType& type)
24       const override {
25     return MakeNode<Scalar>(value, type);
26   }
MakeExpandTorchScriptIrBuilder27   NodePtr MakeExpand(
28       const Value& input0,
29       const std::vector<int64_t>& size,
30       const bool& is_scalar_expand) const override {
31     return ReuseOrMakeNode<Expand>(input0, size, is_scalar_expand);
32   }
33   NodePtr MakeCast(
34       const Value& input0,
35       const at::ScalarType& dtype,
36       const std::optional<at::ScalarType>& stype =
37           std::nullopt) const override {
38     return ReuseOrMakeNode<Cast>(input0, dtype, stype);
39   }
MakeTensorListTorchScriptIrBuilder40   NodePtr MakeTensorList(const OpList& inputs) const override {
41     return ReuseOrMakeNode<TensorList>(inputs);
42   }
43   // Generic needs cleanup
44   NodePtr MakeGeneric(
45       const OpKind& op,
46       const OpList& operands,
47       const Shape& shape,
48       const size_t& num_outputs = 1,
49       const hash_t& hash_seed =
50           static_cast<uint32_t>(0x5a2d296e9)) const override {
51     return MakeNode<Generic>(op, operands, shape, num_outputs, hash_seed);
52   }
53 
54   // dynamic ir nodes
55   // TODO: verify if IR node reusing works for Dynamic shape ops
MakeSizeNodeTorchScriptIrBuilder56   NodePtr MakeSizeNode(const Value& input, size_t dim) const override {
57     return MakeNode<SizeNode>(input, dim);
58   }
MakeSizeAddTorchScriptIrBuilder59   NodePtr MakeSizeAdd(const Value& a, const Value& b) const override {
60     return MakeNode<SizeAdd>(a, b);
61   }
MakeSizeMulTorchScriptIrBuilder62   NodePtr MakeSizeMul(const Value& a, const Value& b) const override {
63     return MakeNode<SizeMul>(a, b);
64   }
MakeSizeDivTorchScriptIrBuilder65   NodePtr MakeSizeDiv(const Value& a, const Value& b) const override {
66     return MakeNode<SizeDiv>(a, b);
67   }
68 };
69 
70 } // namespace lazy
71 } // namespace torch
72