xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/blob.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <type_traits>
4 
5 #include <c10/util/intrusive_ptr.h>
6 #include <c10/util/typeid.h>
7 #include <c10/macros/Macros.h>
8 
9 namespace caffe2 {
10 
11 class Tensor;
12 
13 /**
14  * @brief Blob is a general container that hosts a typed pointer.
15  *
16  * A Blob hosts a pointer as well as its type, and takes charge of deleting it
17  * properly when the blob is deallocated or re-allocated with a new type. A blob
18  * could contain anything, although the most common case is to contain a Tensor.
19  */
20 class TORCH_API Blob final : public c10::intrusive_ptr_target {
21  public:
22   /**
23    * Initializes an empty Blob.
24    */
Blob()25   Blob() noexcept : meta_() {}
~Blob()26   ~Blob() override {
27     Reset();
28   }
29 
Blob(Blob && other)30   Blob(Blob&& other) noexcept : Blob() {
31     swap(other);
32   }
33 
34   Blob& operator=(Blob&& other) noexcept {
35     Blob(std::move(other)).swap(*this);
36     return *this;
37   }
38 
39   /**
40    * Checks if the content stored in the blob is of type T.
41    */
42   template <class T>
IsType()43   bool IsType() const noexcept {
44     return meta_.Match<T>();
45   }
46 
47   /**
48    * Returns the meta info of the blob.
49    */
meta()50   const TypeMeta meta() const noexcept {
51     return meta_;
52   }
53 
54   /**
55    * Returns a printable typename of the blob.
56    */
TypeName()57   c10::string_view TypeName() const noexcept {
58     return meta_.name();
59   }
60 
61   /**
62    * @brief Gets the const reference of the stored object. The code checks if
63    * the stored object is of the desired type.
64    */
65   // TODO(jerryzh): add a Get(c10::DeviceType) function?
66   template <class T>
Get()67   const T& Get() const {
68     TORCH_INTERNAL_ASSERT(
69         IsType<T>(),
70         "wrong type for the Blob instance. Blob contains ",
71         meta_.name(),
72         " while caller expects ",
73         TypeMeta::TypeName<T>());
74     // TODO: after we add Get<Tensor>(c10::DeviceType)
75     // and changed all the callsites, we can add
76     // a static assert here to enforce T != Tensor
77     return *static_cast<const T*>(pointer_);
78   }
79 
GetRaw()80   const void* GetRaw() const noexcept {
81     return pointer_;
82   }
GetRaw()83   void* GetRaw() noexcept {
84     return pointer_;
85   }
86 
87   /**
88    * @brief Gets a mutable pointer to the stored object.
89    *
90    * If the current object is not of the right type, a new object is created
91    * and the old object is freed. Note that type T should have a default
92    * constructor. Otherwise, create the object yourself first, and use
93    * Reset().
94    */
95   template <class T>
GetMutable()96   T* GetMutable() {
97     static_assert(
98         std::is_default_constructible<T>::value,
99         "GetMutable can't be called with non-default-constructible types. "
100         "Try using specialized methods");
101     if (IsType<T>()) {
102       return static_cast<T*>(pointer_);
103     } else {
104       // TODO Re-enable logging
105       // VLOG(1) << "Create new mutable object " << TypeMeta::TypeName<T>();
106       return Reset<T>(new T());
107     }
108   }
109 
110   template <class T>
GetMutableOrNull()111   T* GetMutableOrNull() {
112     if (IsType<T>()) {
113       return static_cast<T*>(pointer_);
114     } else {
115       return nullptr;
116     }
117   }
118 
119   /**
120    * Sets the underlying object to the allocated one. The Blob then takes over
121    * the ownership of the passed in pointer. If there is already an object in
122    * the Blob, the old object is freed.
123    *
124    * This is used when the underlying class T does not have a default ctor, or
125    * complex initializations needs to be done outside the blob.
126    */
127   template <class T>
Reset(T * allocated)128   T* Reset(T* allocated) {
129     free_();
130     meta_ = TypeMeta::Make<T>();
131     pointer_ = static_cast<void*>(allocated);
132     has_ownership_ = true;
133     return allocated;
134   }
135 
136   /**
137    * Sets the underlying object to the allocated one, but does not take over
138    * the ownership of the passed in pointer. If there is already an object in
139    * the Blob, the old object is freed.
140    *
141    * Unlike Reset, this does not take over the ownership of the pointer and the
142    * caller is responsible for making sure that the lifetime of the allocated
143    * blob outlasts the lifetime of any access to this blob, until another Reset
144    * call is made or the blob is destructed.
145    */
146   template <class T>
ShareExternal(std::remove_const_t<T> * allocated)147   std::remove_const_t<T>* ShareExternal(
148       std::remove_const_t<T>* allocated) {
149     return static_cast<T*>(ShareExternal(
150         static_cast<void*>(allocated),
151         TypeMeta::Make<std::remove_const_t<T>>()));
152   }
153 
ShareExternal(void * allocated,const TypeMeta meta)154   void* ShareExternal(void* allocated, const TypeMeta meta) {
155     free_();
156     meta_ = meta;
157     pointer_ = allocated;
158     has_ownership_ = false;
159     return allocated;
160   }
161 
162   /**
163    * Resets the Blob to an empty one.
164    */
Reset()165   void Reset() {
166     free_();
167     pointer_ = nullptr;
168     meta_ = TypeMeta();
169     has_ownership_ = false;
170   }
171 
172   /**
173    * @brief Swaps the underlying storage of two blobs.
174    */
swap(Blob & rhs)175   void swap(Blob& rhs)  noexcept {
176     using std::swap;
177     swap(meta_, rhs.meta_);
178     swap(pointer_, rhs.pointer_);
179     swap(has_ownership_, rhs.has_ownership_);
180   }
181 
182  private:
free_()183   void free_() {
184     if (has_ownership_ && pointer_ != nullptr) {
185       (*meta_.deleteFn())(pointer_);
186     }
187   }
188 
189   TypeMeta meta_;
190   void* pointer_{nullptr};
191   bool has_ownership_{false};
192 
193   C10_DISABLE_COPY_AND_ASSIGN(Blob);
194 };
195 
swap(Blob & lhs,Blob & rhs)196 inline void swap(Blob& lhs, Blob& rhs)  noexcept {
197   lhs.swap(rhs);
198 }
199 
200 inline std::ostream& operator<<(std::ostream& out, const Blob& v) {
201   return out << "Blob[" << v.TypeName() << "]";
202 }
203 
204 } // namespace caffe2
205