xref: /aosp_15_r20/external/pytorch/c10/util/Load.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1  #pragma once
2  #include <c10/macros/Macros.h>
3  #include <cstring>
4  
5  namespace c10 {
6  namespace detail {
7  
8  template <typename T>
9  struct LoadImpl {
applyLoadImpl10    C10_HOST_DEVICE static T apply(const void* src) {
11      return *reinterpret_cast<const T*>(src);
12    }
13  };
14  
15  template <>
16  struct LoadImpl<bool> {
17    C10_HOST_DEVICE static bool apply(const void* src) {
18      static_assert(sizeof(bool) == sizeof(char));
19      // NOTE: [Loading boolean values]
20      // Protect against invalid boolean values by loading as a byte
21      // first, then converting to bool (see gh-54789).
22      return *reinterpret_cast<const unsigned char*>(src);
23    }
24  };
25  
26  } // namespace detail
27  
28  template <typename T>
29  C10_HOST_DEVICE T load(const void* src) {
30    return c10::detail::LoadImpl<T>::apply(src);
31  }
32  
33  template <typename scalar_t>
34  C10_HOST_DEVICE scalar_t load(const scalar_t* src) {
35    return c10::detail::LoadImpl<scalar_t>::apply(src);
36  }
37  
38  } // namespace c10
39