xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/NamedTensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/core/NamedTensor.h>
3 
4 #include <ATen/core/TensorBase.h>
5 
6 namespace at {
7 
8 thread_local bool NamesMode_enabled = true;
9 
is_enabled()10 bool NamesMode::is_enabled() {
11   return NamesMode_enabled;
12 }
13 
set_enabled(bool enabled)14 void NamesMode::set_enabled(bool enabled) {
15   NamesMode_enabled = enabled;
16   c10::impl::tls_set_dispatch_key_excluded(DispatchKey::Named, !enabled);
17 }
18 
internal_set_names_inplace(const TensorBase & tensor,std::optional<DimnameList> names)19 const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::optional<DimnameList> names) {
20   impl::internal_set_names_inplace(tensor.unsafeGetTensorImpl(), names, /*validate_names=*/true);
21   return tensor;
22 }
23 
internal_set_names_inplace(const TensorBase & tensor,std::vector<Dimname> && names,bool validate_names)24 const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector<Dimname>&& names, bool validate_names) {
25   impl::internal_set_names_inplace(tensor.unsafeGetTensorImpl(), std::move(names), validate_names);
26   return tensor;
27 }
28 
default_names(size_t len)29 DimnameList default_names(size_t len) {
30   static std::vector<Dimname> all_unnamed(kMaxNamedTensorDim, Dimname::wildcard());
31     TORCH_INTERNAL_ASSERT(
32         len <= kMaxNamedTensorDim,
33         "Only tensors with up to ", kMaxNamedTensorDim, " are supported.");
34   return DimnameList(&all_unnamed.front(), len);
35 }
36 
check_unique_names(DimnameList names)37 static void check_unique_names(DimnameList names) {
38   // Strategy: Compare each element with the ones that come after it.
39   // Although this is O(N^2), in practice N is small (no more than 25).
40   for (auto it = names.begin(); it != names.end(); ++it) {
41     if (it->isWildcard()) continue;
42     auto dup = std::find(it + 1, names.end(), *it);
43     while (dup != names.end()) {
44       TORCH_CHECK(false,
45           "Cannot construct a tensor with duplicate names. Got names: ",
46           names, ".");
47     }
48   }
49 }
50 
check_names_valid_for(const TensorBase & tensor,DimnameList names)51 void check_names_valid_for(const TensorBase& tensor, DimnameList names) {
52   return impl::check_names_valid_for(tensor.unsafeGetTensorImpl(), names);
53 }
54 
check_names_valid_for(size_t tensor_dim,DimnameList names)55 void check_names_valid_for(size_t tensor_dim, DimnameList names) {
56   TORCH_CHECK(
57       tensor_dim <= kMaxNamedTensorDim,
58       "Named tensors only support up to ", kMaxNamedTensorDim, " dims: "
59       "Attempted to create a tensor with dim ", tensor_dim, " with names ", names);
60   TORCH_CHECK(tensor_dim == names.size(),
61       "Number of names (", names.size(), ") and "
62       "number of dimensions in tensor (", tensor_dim, ") ",
63       "do not match. Attempted to create a tensor with names ", names);
64   check_unique_names(names);
65 }
66 
67 namespace impl {
68 
get_named_tensor_meta(TensorImpl * impl)69 static NamedTensorMeta* get_named_tensor_meta(TensorImpl* impl) {
70   if (!NamesMode::is_enabled()) {
71     return nullptr;
72   }
73   return static_cast<NamedTensorMeta*>(impl->named_tensor_meta());
74 }
75 
get_named_tensor_meta(const TensorImpl * impl)76 static const NamedTensorMeta* get_named_tensor_meta(const TensorImpl* impl) {
77   if (!NamesMode::is_enabled()) {
78     return nullptr;
79   }
80   return static_cast<const NamedTensorMeta*>(impl->named_tensor_meta());
81 }
82 
check_names_valid_for(TensorImpl * impl,DimnameList names)83 void check_names_valid_for(TensorImpl* impl, DimnameList names) {
84   check_names_valid_for(impl->dim(), names);
85 }
86 
internal_set_names_inplace(TensorImpl * impl,std::optional<DimnameList> names,bool validate_names)87 void internal_set_names_inplace(TensorImpl* impl, std::optional<DimnameList> names, bool validate_names) {
88   TORCH_CHECK(impl->layout() == Layout::Strided,
89       "NYI: named tensors only support strided layout");
90   TORCH_CHECK(impl->device().is_cpu() || impl->device().is_cuda() || impl->device().is_xpu() || impl->device().is_privateuseone(),
91       "NYI: named tensors only support CPU, CUDA, XPU or ", c10::get_privateuse1_backend(), " tensors.");
92   if (!names) {
93     impl->set_named_tensor_meta(nullptr);
94     return;
95   }
96   if (validate_names) {
97     check_names_valid_for(impl, *names);
98   }
99   // Do this after validation!
100   if (std::all_of(names->begin(), names->end(), [](const Dimname& n) { return n.isWildcard(); })) {
101     impl->set_named_tensor_meta(nullptr);
102     return;
103   }
104   auto* meta = get_named_tensor_meta(impl);
105   if (meta == nullptr) {
106     // Constructor is private
107     impl->set_named_tensor_meta(std::make_unique<NamedTensorMeta>(NamedTensorMeta::HasNonWildcard, *names));
108   } else {
109     meta->set_names(NamedTensorMeta::HasNonWildcard, *names);
110   }
111 }
112 
internal_set_names_inplace(TensorImpl * impl,std::vector<Dimname> && names,bool validate_names)113 void internal_set_names_inplace(TensorImpl* impl, std::vector<Dimname>&& names, bool validate_names) {
114   if (validate_names) {
115     check_names_valid_for(impl, names);
116   }
117   // Do this after validation!
118   if (std::all_of(names.begin(), names.end(), [](const Dimname& n) { return n.isWildcard(); })) {
119     impl->set_named_tensor_meta(nullptr);
120     return;
121   }
122   auto* meta = get_named_tensor_meta(impl);
123   if (meta == nullptr) {
124     impl->set_named_tensor_meta(std::make_unique<NamedTensorMeta>(NamedTensorMeta::HasNonWildcard, std::move(names)));
125   } else {
126     meta->set_names(NamedTensorMeta::HasNonWildcard, std::move(names));
127   }
128 }
129 
get_opt_names(const TensorImpl * impl)130 std::optional<DimnameList> get_opt_names(const TensorImpl* impl) {
131   const auto* meta = get_named_tensor_meta(impl);
132   if (meta == nullptr) {
133     return std::nullopt;
134   } else {
135     return meta->names();
136   }
137 }
138 
get_names(const TensorImpl * impl)139 DimnameList get_names(const TensorImpl* impl) {
140   auto maybe_names = get_opt_names(impl);
141   if (maybe_names) {
142     return *maybe_names;
143   }
144   return default_names(impl->dim());
145 }
146 
has_names(const TensorImpl * impl)147 bool has_names(const TensorImpl* impl) {
148   return impl->has_named_tensor_meta() && NamesMode::is_enabled();
149 }
150 
151 } // namespace impl
152 
153 } // namespace at
154