1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/core/NamedTensor.h>
3
4 #include <ATen/core/TensorBase.h>
5
6 namespace at {
7
8 thread_local bool NamesMode_enabled = true;
9
is_enabled()10 bool NamesMode::is_enabled() {
11 return NamesMode_enabled;
12 }
13
set_enabled(bool enabled)14 void NamesMode::set_enabled(bool enabled) {
15 NamesMode_enabled = enabled;
16 c10::impl::tls_set_dispatch_key_excluded(DispatchKey::Named, !enabled);
17 }
18
internal_set_names_inplace(const TensorBase & tensor,std::optional<DimnameList> names)19 const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::optional<DimnameList> names) {
20 impl::internal_set_names_inplace(tensor.unsafeGetTensorImpl(), names, /*validate_names=*/true);
21 return tensor;
22 }
23
internal_set_names_inplace(const TensorBase & tensor,std::vector<Dimname> && names,bool validate_names)24 const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector<Dimname>&& names, bool validate_names) {
25 impl::internal_set_names_inplace(tensor.unsafeGetTensorImpl(), std::move(names), validate_names);
26 return tensor;
27 }
28
default_names(size_t len)29 DimnameList default_names(size_t len) {
30 static std::vector<Dimname> all_unnamed(kMaxNamedTensorDim, Dimname::wildcard());
31 TORCH_INTERNAL_ASSERT(
32 len <= kMaxNamedTensorDim,
33 "Only tensors with up to ", kMaxNamedTensorDim, " are supported.");
34 return DimnameList(&all_unnamed.front(), len);
35 }
36
check_unique_names(DimnameList names)37 static void check_unique_names(DimnameList names) {
38 // Strategy: Compare each element with the ones that come after it.
39 // Although this is O(N^2), in practice N is small (no more than 25).
40 for (auto it = names.begin(); it != names.end(); ++it) {
41 if (it->isWildcard()) continue;
42 auto dup = std::find(it + 1, names.end(), *it);
43 while (dup != names.end()) {
44 TORCH_CHECK(false,
45 "Cannot construct a tensor with duplicate names. Got names: ",
46 names, ".");
47 }
48 }
49 }
50
check_names_valid_for(const TensorBase & tensor,DimnameList names)51 void check_names_valid_for(const TensorBase& tensor, DimnameList names) {
52 return impl::check_names_valid_for(tensor.unsafeGetTensorImpl(), names);
53 }
54
check_names_valid_for(size_t tensor_dim,DimnameList names)55 void check_names_valid_for(size_t tensor_dim, DimnameList names) {
56 TORCH_CHECK(
57 tensor_dim <= kMaxNamedTensorDim,
58 "Named tensors only support up to ", kMaxNamedTensorDim, " dims: "
59 "Attempted to create a tensor with dim ", tensor_dim, " with names ", names);
60 TORCH_CHECK(tensor_dim == names.size(),
61 "Number of names (", names.size(), ") and "
62 "number of dimensions in tensor (", tensor_dim, ") ",
63 "do not match. Attempted to create a tensor with names ", names);
64 check_unique_names(names);
65 }
66
67 namespace impl {
68
get_named_tensor_meta(TensorImpl * impl)69 static NamedTensorMeta* get_named_tensor_meta(TensorImpl* impl) {
70 if (!NamesMode::is_enabled()) {
71 return nullptr;
72 }
73 return static_cast<NamedTensorMeta*>(impl->named_tensor_meta());
74 }
75
get_named_tensor_meta(const TensorImpl * impl)76 static const NamedTensorMeta* get_named_tensor_meta(const TensorImpl* impl) {
77 if (!NamesMode::is_enabled()) {
78 return nullptr;
79 }
80 return static_cast<const NamedTensorMeta*>(impl->named_tensor_meta());
81 }
82
check_names_valid_for(TensorImpl * impl,DimnameList names)83 void check_names_valid_for(TensorImpl* impl, DimnameList names) {
84 check_names_valid_for(impl->dim(), names);
85 }
86
internal_set_names_inplace(TensorImpl * impl,std::optional<DimnameList> names,bool validate_names)87 void internal_set_names_inplace(TensorImpl* impl, std::optional<DimnameList> names, bool validate_names) {
88 TORCH_CHECK(impl->layout() == Layout::Strided,
89 "NYI: named tensors only support strided layout");
90 TORCH_CHECK(impl->device().is_cpu() || impl->device().is_cuda() || impl->device().is_xpu() || impl->device().is_privateuseone(),
91 "NYI: named tensors only support CPU, CUDA, XPU or ", c10::get_privateuse1_backend(), " tensors.");
92 if (!names) {
93 impl->set_named_tensor_meta(nullptr);
94 return;
95 }
96 if (validate_names) {
97 check_names_valid_for(impl, *names);
98 }
99 // Do this after validation!
100 if (std::all_of(names->begin(), names->end(), [](const Dimname& n) { return n.isWildcard(); })) {
101 impl->set_named_tensor_meta(nullptr);
102 return;
103 }
104 auto* meta = get_named_tensor_meta(impl);
105 if (meta == nullptr) {
106 // Constructor is private
107 impl->set_named_tensor_meta(std::make_unique<NamedTensorMeta>(NamedTensorMeta::HasNonWildcard, *names));
108 } else {
109 meta->set_names(NamedTensorMeta::HasNonWildcard, *names);
110 }
111 }
112
internal_set_names_inplace(TensorImpl * impl,std::vector<Dimname> && names,bool validate_names)113 void internal_set_names_inplace(TensorImpl* impl, std::vector<Dimname>&& names, bool validate_names) {
114 if (validate_names) {
115 check_names_valid_for(impl, names);
116 }
117 // Do this after validation!
118 if (std::all_of(names.begin(), names.end(), [](const Dimname& n) { return n.isWildcard(); })) {
119 impl->set_named_tensor_meta(nullptr);
120 return;
121 }
122 auto* meta = get_named_tensor_meta(impl);
123 if (meta == nullptr) {
124 impl->set_named_tensor_meta(std::make_unique<NamedTensorMeta>(NamedTensorMeta::HasNonWildcard, std::move(names)));
125 } else {
126 meta->set_names(NamedTensorMeta::HasNonWildcard, std::move(names));
127 }
128 }
129
get_opt_names(const TensorImpl * impl)130 std::optional<DimnameList> get_opt_names(const TensorImpl* impl) {
131 const auto* meta = get_named_tensor_meta(impl);
132 if (meta == nullptr) {
133 return std::nullopt;
134 } else {
135 return meta->names();
136 }
137 }
138
get_names(const TensorImpl * impl)139 DimnameList get_names(const TensorImpl* impl) {
140 auto maybe_names = get_opt_names(impl);
141 if (maybe_names) {
142 return *maybe_names;
143 }
144 return default_names(impl->dim());
145 }
146
has_names(const TensorImpl * impl)147 bool has_names(const TensorImpl* impl) {
148 return impl->has_named_tensor_meta() && NamesMode::is_enabled();
149 }
150
151 } // namespace impl
152
153 } // namespace at
154