xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/IndexingUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/IndexingUtils.h>
3 
4 namespace at::native {
5 
canUse32BitIndexMath(const TensorBase & t,int64_t max_elem)6 bool canUse32BitIndexMath(const TensorBase& t, int64_t max_elem) {
7   auto elements = t.sym_numel();
8   if (elements >= max_elem) {
9     return false;
10   }
11   if (elements == 0) {
12     return max_elem > 0;
13   }
14 
15   c10::SymInt offset = 0;
16   auto linearId = elements - 1;
17 
18   // NOTE: Assumes all strides are positive, which is true for now
19   // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
20   for (int i = t.dim() - 1; i >= 0; --i) {
21     auto curDimIndex = linearId % t.sym_size(i);
22     auto curDimOffset = curDimIndex * t.sym_stride(i);
23     offset += curDimOffset;
24     linearId /= t.sym_size(i);
25   }
26 
27   if (offset >= max_elem) {
28     return false;
29   }
30 
31   return true;
32 }
33 
34 } // namespace at::native
35