xref: /aosp_15_r20/external/pytorch/c10/core/WrapDimMinimal.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/WrapDimMinimal.h>
2 
3 namespace c10::detail {
4 
5 template <typename T>
maybe_wrap_dim_slow(T dim,T dim_post_expr,bool wrap_scalar)6 T maybe_wrap_dim_slow(T dim, T dim_post_expr, bool wrap_scalar) {
7   TORCH_CHECK_INDEX(
8       dim_post_expr >= 0, "Rank cannot be negative but got ", dim_post_expr);
9 
10   if (dim_post_expr == 0) {
11     TORCH_CHECK_INDEX(
12         wrap_scalar,
13         "Dimension specified as ",
14         dim,
15         " but tensor has no dimensions");
16     return c10::maybe_wrap_dim(
17         std::move(dim), /*dim_post_expr=*/1, /*wrap_scalar=*/false);
18   }
19 
20   T min = dim_post_expr * -1;
21   T max = dim_post_expr - 1;
22   TORCH_CHECK_INDEX(
23       min <= dim && dim <= max,
24       "Dimension out of range (expected to be in range of [",
25       min,
26       ", ",
27       max,
28       "], but got ",
29       dim,
30       ")");
31 
32   TORCH_INTERNAL_ASSERT(
33       false, "should never reach here as dim should be out-of-bounds");
34 }
35 
36 // Explicitly instantiate the template at the two types it will be used
37 template C10_API int64_t
38 maybe_wrap_dim_slow(int64_t dim, int64_t dim_post_expr, bool wrap_scalar);
39 template C10_API SymInt
40 maybe_wrap_dim_slow(SymInt dim, SymInt dim_post_expr, bool wrap_scalar);
41 
42 } // namespace c10::detail
43