xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/MaybeOwned_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/Functions.h>
4 #include <ATen/NativeFunctions.h>
5 #include <ATen/Tensor.h>
6 #include <ATen/core/ivalue.h>
7 #include <c10/util/intrusive_ptr.h>
8 #include <c10/util/MaybeOwned.h>
9 
10 #include <memory>
11 #include <string>
12 
13 namespace {
14 
15 using at::Tensor;
16 using c10::IValue;
17 
18 struct MyString : public c10::intrusive_ptr_target, public std::string {
19   using std::string::string;
20 };
21 
22 template <typename T>
23 class MaybeOwnedTest : public ::testing::Test {
24  public:
25   T borrowFrom;
26   T ownCopy;
27   T ownCopy2;
28   c10::MaybeOwned<T> borrowed;
29   c10::MaybeOwned<T> owned;
30   c10::MaybeOwned<T> owned2;
31 
32  protected:
33   void SetUp() override; // defined below helpers
TearDown()34   void TearDown() override {
35     // Release everything to try to trigger ASAN violations in the
36     // test that broke things.
37     borrowFrom = T();
38     ownCopy = T();
39     ownCopy2 = T();
40 
41     borrowed = c10::MaybeOwned<T>();
42     owned = c10::MaybeOwned<T>();
43     owned2 = c10::MaybeOwned<T>();
44   }
45 
46 };
47 
48 
49 //////////////////// Helpers that differ per tested type. ////////////////////
50 
51 template <typename T>
52 T getSampleValue();
53 
54 template <typename T>
55 T getSampleValue2();
56 
57 template <typename T>
58 void assertBorrow(const c10::MaybeOwned<T>&, const T&);
59 
60 template <typename T>
61 void assertOwn(const c10::MaybeOwned<T>&, const T&, size_t useCount = 2);
62 
63 ////////////////// Helper implementations for intrusive_ptr. //////////////////
64 template<>
getSampleValue()65 c10::intrusive_ptr<MyString> getSampleValue() {
66   return c10::make_intrusive<MyString>("hello");
67 }
68 
69 template<>
getSampleValue2()70 c10::intrusive_ptr<MyString> getSampleValue2() {
71   return c10::make_intrusive<MyString>("goodbye");
72 }
73 
are_equal(const c10::intrusive_ptr<MyString> & lhs,const c10::intrusive_ptr<MyString> & rhs)74 bool are_equal(const c10::intrusive_ptr<MyString>& lhs, const c10::intrusive_ptr<MyString>& rhs) {
75   if (!lhs || !rhs) {
76     return !lhs && !rhs;
77   }
78   return *lhs == *rhs;
79 }
80 
81 template <>
assertBorrow(const c10::MaybeOwned<c10::intrusive_ptr<MyString>> & mo,const c10::intrusive_ptr<MyString> & borrowedFrom)82 void assertBorrow(
83     const c10::MaybeOwned<c10::intrusive_ptr<MyString>>& mo,
84     const c10::intrusive_ptr<MyString>& borrowedFrom) {
85   EXPECT_EQ(*mo, borrowedFrom);
86   EXPECT_EQ(mo->get(), borrowedFrom.get());
87   EXPECT_EQ(borrowedFrom.use_count(), 1);
88 }
89 
90 template <>
assertOwn(const c10::MaybeOwned<c10::intrusive_ptr<MyString>> & mo,const c10::intrusive_ptr<MyString> & original,size_t useCount)91 void assertOwn(
92     const c10::MaybeOwned<c10::intrusive_ptr<MyString>>& mo,
93     const c10::intrusive_ptr<MyString>& original,
94     size_t useCount) {
95   EXPECT_EQ(*mo, original);
96   EXPECT_EQ(mo->get(), original.get());
97   EXPECT_NE(&*mo, &original);
98   EXPECT_EQ(original.use_count(), useCount);
99 }
100 
101 //////////////////// Helper implementations for Tensor. ////////////////////
102 
103 template<>
getSampleValue()104 Tensor getSampleValue() {
105   return at::zeros({2, 2}).to(at::kCPU);
106 }
107 
108 template<>
getSampleValue2()109 Tensor getSampleValue2() {
110   return at::native::ones({2, 2}).to(at::kCPU);
111 }
112 
are_equal(const Tensor & lhs,const Tensor & rhs)113 bool are_equal(const Tensor& lhs, const Tensor& rhs) {
114   if (!lhs.defined() || !rhs.defined()) {
115     return !lhs.defined() && !rhs.defined();
116   }
117   return at::native::cpu_equal(lhs, rhs);
118 }
119 
120 template <>
assertBorrow(const c10::MaybeOwned<Tensor> & mo,const Tensor & borrowedFrom)121 void assertBorrow(
122     const c10::MaybeOwned<Tensor>& mo,
123     const Tensor& borrowedFrom) {
124   EXPECT_TRUE(mo->is_same(borrowedFrom));
125   EXPECT_EQ(borrowedFrom.use_count(), 1);
126 }
127 
128 template <>
assertOwn(const c10::MaybeOwned<Tensor> & mo,const Tensor & original,size_t useCount)129 void assertOwn(
130     const c10::MaybeOwned<Tensor>& mo,
131     const Tensor& original,
132     size_t useCount) {
133   EXPECT_TRUE(mo->is_same(original));
134   EXPECT_EQ(original.use_count(), useCount);
135 }
136 
137 //////////////////// Helper implementations for IValue. ////////////////////
138 
139 template<>
getSampleValue()140 IValue getSampleValue() {
141   return IValue(getSampleValue<Tensor>());
142 }
143 
144 template<>
getSampleValue2()145 IValue getSampleValue2() {
146   return IValue("hello");
147 }
148 
are_equal(const IValue & lhs,const IValue & rhs)149 bool are_equal(const IValue& lhs, const IValue& rhs) {
150   if (lhs.isTensor() != rhs.isTensor()) {
151     return false;
152   }
153   if (lhs.isTensor() && rhs.isTensor()) {
154     return lhs.toTensor().equal(rhs.toTensor());
155   }
156   return lhs == rhs;
157 }
158 
159 template <>
assertBorrow(const c10::MaybeOwned<IValue> & mo,const IValue & borrowedFrom)160 void assertBorrow(
161     const c10::MaybeOwned<IValue>& mo,
162     const IValue& borrowedFrom) {
163   if (!borrowedFrom.isPtrType()) {
164     EXPECT_EQ(*mo, borrowedFrom);
165   } else {
166     EXPECT_EQ(mo->internalToPointer(), borrowedFrom.internalToPointer());
167     EXPECT_EQ(borrowedFrom.use_count(), 1);
168   }
169 }
170 
171 template <>
assertOwn(const c10::MaybeOwned<IValue> & mo,const IValue & original,size_t useCount)172 void assertOwn(
173     const c10::MaybeOwned<IValue>& mo,
174     const IValue& original,
175     size_t useCount) {
176   if (!original.isPtrType()) {
177     EXPECT_EQ(*mo, original);
178   } else {
179     EXPECT_EQ(mo->internalToPointer(), original.internalToPointer());
180     EXPECT_EQ(original.use_count(), useCount);
181   }
182 }
183 
184 template <typename T>
SetUp()185 void MaybeOwnedTest<T>::SetUp() {
186   borrowFrom = getSampleValue<T>();
187   ownCopy = getSampleValue<T>();
188   ownCopy2 = getSampleValue<T>();
189   borrowed = c10::MaybeOwned<T>::borrowed(borrowFrom);
190   owned = c10::MaybeOwned<T>::owned(std::in_place, ownCopy);
191   owned2 = c10::MaybeOwned<T>::owned(T(ownCopy2));
192 }
193 
194 using MaybeOwnedTypes = ::testing::Types<
195   c10::intrusive_ptr<MyString>,
196   at::Tensor,
197   c10::IValue
198   >;
199 
200 TYPED_TEST_SUITE(MaybeOwnedTest, MaybeOwnedTypes);
201 
TYPED_TEST(MaybeOwnedTest,SimpleDereferencingString)202 TYPED_TEST(MaybeOwnedTest, SimpleDereferencingString) {
203   assertBorrow(this->borrowed, this->borrowFrom);
204   assertOwn(this->owned, this->ownCopy);
205   assertOwn(this->owned2, this->ownCopy2);
206 }
207 
TYPED_TEST(MaybeOwnedTest,DefaultCtor)208 TYPED_TEST(MaybeOwnedTest, DefaultCtor) {
209   c10::MaybeOwned<TypeParam> borrowed, owned;
210   // Don't leave the fixture versions around messing up reference counts.
211   this->borrowed = c10::MaybeOwned<TypeParam>();
212   this->owned = c10::MaybeOwned<TypeParam>();
213   borrowed = c10::MaybeOwned<TypeParam>::borrowed(this->borrowFrom);
214   owned = c10::MaybeOwned<TypeParam>::owned(std::in_place, this->ownCopy);
215 
216   assertBorrow(borrowed, this->borrowFrom);
217   assertOwn(owned, this->ownCopy);
218 }
219 
TYPED_TEST(MaybeOwnedTest,CopyConstructor)220 TYPED_TEST(MaybeOwnedTest, CopyConstructor) {
221 
222   auto copiedBorrowed(this->borrowed);
223   auto copiedOwned(this->owned);
224   auto copiedOwned2(this->owned2);
225 
226   assertBorrow(this->borrowed, this->borrowFrom);
227   assertBorrow(copiedBorrowed, this->borrowFrom);
228 
229   assertOwn(this->owned, this->ownCopy, 3);
230   assertOwn(copiedOwned, this->ownCopy, 3);
231   assertOwn(this->owned2, this->ownCopy2, 3);
232   assertOwn(copiedOwned2, this->ownCopy2, 3);
233 }
234 
TYPED_TEST(MaybeOwnedTest,MoveDereferencing)235 TYPED_TEST(MaybeOwnedTest, MoveDereferencing) {
236   // Need a different value.
237   this->owned = c10::MaybeOwned<TypeParam>::owned(std::in_place, getSampleValue2<TypeParam>());
238 
239   EXPECT_TRUE(are_equal(*std::move(this->borrowed), getSampleValue<TypeParam>()));
240   EXPECT_TRUE(are_equal(*std::move(this->owned), getSampleValue2<TypeParam>()));
241 
242   // Borrowed is unaffected.
243   assertBorrow(this->borrowed, this->borrowFrom);
244 
245   // Owned is a null c10::intrusive_ptr / empty Tensor.
246   EXPECT_TRUE(are_equal(*this->owned, TypeParam()));
247 }
248 
TYPED_TEST(MaybeOwnedTest,MoveConstructor)249 TYPED_TEST(MaybeOwnedTest, MoveConstructor) {
250   auto movedBorrowed(std::move(this->borrowed));
251   auto movedOwned(std::move(this->owned));
252   auto movedOwned2(std::move(this->owned2));
253 
254   assertBorrow(movedBorrowed, this->borrowFrom);
255   assertOwn(movedOwned, this->ownCopy);
256   assertOwn(movedOwned2, this->ownCopy2);
257 }
258 
TYPED_TEST(MaybeOwnedTest,CopyAssignmentIntoOwned)259 TYPED_TEST(MaybeOwnedTest, CopyAssignmentIntoOwned) {
260   auto copiedBorrowed = c10::MaybeOwned<TypeParam>::owned(std::in_place);
261   auto copiedOwned = c10::MaybeOwned<TypeParam>::owned(std::in_place);
262   auto copiedOwned2 = c10::MaybeOwned<TypeParam>::owned(std::in_place);
263 
264   copiedBorrowed = this->borrowed;
265   copiedOwned = this->owned;
266   copiedOwned2 = this->owned2;
267 
268   assertBorrow(this->borrowed, this->borrowFrom);
269   assertBorrow(copiedBorrowed, this->borrowFrom);
270   assertOwn(this->owned, this->ownCopy, 3);
271   assertOwn(copiedOwned, this->ownCopy, 3);
272   assertOwn(this->owned2, this->ownCopy2, 3);
273   assertOwn(copiedOwned2, this->ownCopy2, 3);
274 }
275 
TYPED_TEST(MaybeOwnedTest,CopyAssignmentIntoBorrowed)276 TYPED_TEST(MaybeOwnedTest, CopyAssignmentIntoBorrowed) {
277   auto otherBorrowFrom = getSampleValue2<TypeParam>();
278   auto otherOwnCopy = getSampleValue2<TypeParam>();
279   auto copiedBorrowed = c10::MaybeOwned<TypeParam>::borrowed(otherBorrowFrom);
280   auto copiedOwned = c10::MaybeOwned<TypeParam>::borrowed(otherOwnCopy);
281   auto copiedOwned2 = c10::MaybeOwned<TypeParam>::borrowed(otherOwnCopy);
282 
283   copiedBorrowed = this->borrowed;
284   copiedOwned = this->owned;
285   copiedOwned2 = this->owned2;
286 
287   assertBorrow(this->borrowed, this->borrowFrom);
288   assertBorrow(copiedBorrowed, this->borrowFrom);
289 
290   assertOwn(this->owned, this->ownCopy, 3);
291   assertOwn(this->owned2, this->ownCopy2, 3);
292   assertOwn(copiedOwned, this->ownCopy, 3);
293   assertOwn(copiedOwned2, this->ownCopy2, 3);
294 }
295 
296 
TYPED_TEST(MaybeOwnedTest,MoveAssignmentIntoOwned)297 TYPED_TEST(MaybeOwnedTest, MoveAssignmentIntoOwned) {
298 
299   auto movedBorrowed = c10::MaybeOwned<TypeParam>::owned(std::in_place);
300   auto movedOwned = c10::MaybeOwned<TypeParam>::owned(std::in_place);
301   auto movedOwned2 = c10::MaybeOwned<TypeParam>::owned(std::in_place);
302 
303   movedBorrowed = std::move(this->borrowed);
304   movedOwned = std::move(this->owned);
305   movedOwned2 = std::move(this->owned2);
306 
307   assertBorrow(movedBorrowed, this->borrowFrom);
308   assertOwn(movedOwned, this->ownCopy);
309   assertOwn(movedOwned2, this->ownCopy2);
310 }
311 
312 
TYPED_TEST(MaybeOwnedTest,MoveAssignmentIntoBorrowed)313 TYPED_TEST(MaybeOwnedTest, MoveAssignmentIntoBorrowed) {
314   auto y = getSampleValue2<TypeParam>();
315   auto movedBorrowed = c10::MaybeOwned<TypeParam>::borrowed(y);
316   auto movedOwned = c10::MaybeOwned<TypeParam>::borrowed(y);
317   auto movedOwned2 = c10::MaybeOwned<TypeParam>::borrowed(y);
318 
319   movedBorrowed = std::move(this->borrowed);
320   movedOwned = std::move(this->owned);
321   movedOwned2 = std::move(this->owned2);
322 
323   assertBorrow(movedBorrowed, this->borrowFrom);
324   assertOwn(movedOwned, this->ownCopy);
325   assertOwn(movedOwned2, this->ownCopy2);
326 }
327 
TYPED_TEST(MaybeOwnedTest,SelfAssignment)328 TYPED_TEST(MaybeOwnedTest, SelfAssignment) {
329   this->borrowed = this->borrowed;
330   this->owned = this->owned;
331   this->owned2 = this->owned2;
332 
333   assertBorrow(this->borrowed, this->borrowFrom);
334   assertOwn(this->owned, this->ownCopy);
335   assertOwn(this->owned2, this->ownCopy2);
336 }
337 
338 } // namespace
339