1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS 2 #include <ATen/native/IndexingUtils.h> 3 4 namespace at::native { 5 canUse32BitIndexMath(const TensorBase & t,int64_t max_elem)6bool canUse32BitIndexMath(const TensorBase& t, int64_t max_elem) { 7 auto elements = t.sym_numel(); 8 if (elements >= max_elem) { 9 return false; 10 } 11 if (elements == 0) { 12 return max_elem > 0; 13 } 14 15 c10::SymInt offset = 0; 16 auto linearId = elements - 1; 17 18 // NOTE: Assumes all strides are positive, which is true for now 19 // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) 20 for (int i = t.dim() - 1; i >= 0; --i) { 21 auto curDimIndex = linearId % t.sym_size(i); 22 auto curDimOffset = curDimIndex * t.sym_stride(i); 23 offset += curDimOffset; 24 linearId /= t.sym_size(i); 25 } 26 27 if (offset >= max_elem) { 28 return false; 29 } 30 31 return true; 32 } 33 34 } // namespace at::native 35