1 #include <gtest/gtest.h>
2
3 #include <ATen/Functions.h>
4 #include <ATen/NativeFunctions.h>
5 #include <ATen/Tensor.h>
6 #include <ATen/core/ivalue.h>
7 #include <c10/util/intrusive_ptr.h>
8 #include <c10/util/MaybeOwned.h>
9
10 #include <memory>
11 #include <string>
12
13 namespace {
14
15 using at::Tensor;
16 using c10::IValue;
17
18 struct MyString : public c10::intrusive_ptr_target, public std::string {
19 using std::string::string;
20 };
21
22 template <typename T>
23 class MaybeOwnedTest : public ::testing::Test {
24 public:
25 T borrowFrom;
26 T ownCopy;
27 T ownCopy2;
28 c10::MaybeOwned<T> borrowed;
29 c10::MaybeOwned<T> owned;
30 c10::MaybeOwned<T> owned2;
31
32 protected:
33 void SetUp() override; // defined below helpers
TearDown()34 void TearDown() override {
35 // Release everything to try to trigger ASAN violations in the
36 // test that broke things.
37 borrowFrom = T();
38 ownCopy = T();
39 ownCopy2 = T();
40
41 borrowed = c10::MaybeOwned<T>();
42 owned = c10::MaybeOwned<T>();
43 owned2 = c10::MaybeOwned<T>();
44 }
45
46 };
47
48
49 //////////////////// Helpers that differ per tested type. ////////////////////
50
51 template <typename T>
52 T getSampleValue();
53
54 template <typename T>
55 T getSampleValue2();
56
57 template <typename T>
58 void assertBorrow(const c10::MaybeOwned<T>&, const T&);
59
60 template <typename T>
61 void assertOwn(const c10::MaybeOwned<T>&, const T&, size_t useCount = 2);
62
63 ////////////////// Helper implementations for intrusive_ptr. //////////////////
64 template<>
getSampleValue()65 c10::intrusive_ptr<MyString> getSampleValue() {
66 return c10::make_intrusive<MyString>("hello");
67 }
68
69 template<>
getSampleValue2()70 c10::intrusive_ptr<MyString> getSampleValue2() {
71 return c10::make_intrusive<MyString>("goodbye");
72 }
73
are_equal(const c10::intrusive_ptr<MyString> & lhs,const c10::intrusive_ptr<MyString> & rhs)74 bool are_equal(const c10::intrusive_ptr<MyString>& lhs, const c10::intrusive_ptr<MyString>& rhs) {
75 if (!lhs || !rhs) {
76 return !lhs && !rhs;
77 }
78 return *lhs == *rhs;
79 }
80
81 template <>
assertBorrow(const c10::MaybeOwned<c10::intrusive_ptr<MyString>> & mo,const c10::intrusive_ptr<MyString> & borrowedFrom)82 void assertBorrow(
83 const c10::MaybeOwned<c10::intrusive_ptr<MyString>>& mo,
84 const c10::intrusive_ptr<MyString>& borrowedFrom) {
85 EXPECT_EQ(*mo, borrowedFrom);
86 EXPECT_EQ(mo->get(), borrowedFrom.get());
87 EXPECT_EQ(borrowedFrom.use_count(), 1);
88 }
89
90 template <>
assertOwn(const c10::MaybeOwned<c10::intrusive_ptr<MyString>> & mo,const c10::intrusive_ptr<MyString> & original,size_t useCount)91 void assertOwn(
92 const c10::MaybeOwned<c10::intrusive_ptr<MyString>>& mo,
93 const c10::intrusive_ptr<MyString>& original,
94 size_t useCount) {
95 EXPECT_EQ(*mo, original);
96 EXPECT_EQ(mo->get(), original.get());
97 EXPECT_NE(&*mo, &original);
98 EXPECT_EQ(original.use_count(), useCount);
99 }
100
101 //////////////////// Helper implementations for Tensor. ////////////////////
102
103 template<>
getSampleValue()104 Tensor getSampleValue() {
105 return at::zeros({2, 2}).to(at::kCPU);
106 }
107
108 template<>
getSampleValue2()109 Tensor getSampleValue2() {
110 return at::native::ones({2, 2}).to(at::kCPU);
111 }
112
are_equal(const Tensor & lhs,const Tensor & rhs)113 bool are_equal(const Tensor& lhs, const Tensor& rhs) {
114 if (!lhs.defined() || !rhs.defined()) {
115 return !lhs.defined() && !rhs.defined();
116 }
117 return at::native::cpu_equal(lhs, rhs);
118 }
119
120 template <>
assertBorrow(const c10::MaybeOwned<Tensor> & mo,const Tensor & borrowedFrom)121 void assertBorrow(
122 const c10::MaybeOwned<Tensor>& mo,
123 const Tensor& borrowedFrom) {
124 EXPECT_TRUE(mo->is_same(borrowedFrom));
125 EXPECT_EQ(borrowedFrom.use_count(), 1);
126 }
127
128 template <>
assertOwn(const c10::MaybeOwned<Tensor> & mo,const Tensor & original,size_t useCount)129 void assertOwn(
130 const c10::MaybeOwned<Tensor>& mo,
131 const Tensor& original,
132 size_t useCount) {
133 EXPECT_TRUE(mo->is_same(original));
134 EXPECT_EQ(original.use_count(), useCount);
135 }
136
137 //////////////////// Helper implementations for IValue. ////////////////////
138
139 template<>
getSampleValue()140 IValue getSampleValue() {
141 return IValue(getSampleValue<Tensor>());
142 }
143
144 template<>
getSampleValue2()145 IValue getSampleValue2() {
146 return IValue("hello");
147 }
148
are_equal(const IValue & lhs,const IValue & rhs)149 bool are_equal(const IValue& lhs, const IValue& rhs) {
150 if (lhs.isTensor() != rhs.isTensor()) {
151 return false;
152 }
153 if (lhs.isTensor() && rhs.isTensor()) {
154 return lhs.toTensor().equal(rhs.toTensor());
155 }
156 return lhs == rhs;
157 }
158
159 template <>
assertBorrow(const c10::MaybeOwned<IValue> & mo,const IValue & borrowedFrom)160 void assertBorrow(
161 const c10::MaybeOwned<IValue>& mo,
162 const IValue& borrowedFrom) {
163 if (!borrowedFrom.isPtrType()) {
164 EXPECT_EQ(*mo, borrowedFrom);
165 } else {
166 EXPECT_EQ(mo->internalToPointer(), borrowedFrom.internalToPointer());
167 EXPECT_EQ(borrowedFrom.use_count(), 1);
168 }
169 }
170
171 template <>
assertOwn(const c10::MaybeOwned<IValue> & mo,const IValue & original,size_t useCount)172 void assertOwn(
173 const c10::MaybeOwned<IValue>& mo,
174 const IValue& original,
175 size_t useCount) {
176 if (!original.isPtrType()) {
177 EXPECT_EQ(*mo, original);
178 } else {
179 EXPECT_EQ(mo->internalToPointer(), original.internalToPointer());
180 EXPECT_EQ(original.use_count(), useCount);
181 }
182 }
183
184 template <typename T>
SetUp()185 void MaybeOwnedTest<T>::SetUp() {
186 borrowFrom = getSampleValue<T>();
187 ownCopy = getSampleValue<T>();
188 ownCopy2 = getSampleValue<T>();
189 borrowed = c10::MaybeOwned<T>::borrowed(borrowFrom);
190 owned = c10::MaybeOwned<T>::owned(std::in_place, ownCopy);
191 owned2 = c10::MaybeOwned<T>::owned(T(ownCopy2));
192 }
193
194 using MaybeOwnedTypes = ::testing::Types<
195 c10::intrusive_ptr<MyString>,
196 at::Tensor,
197 c10::IValue
198 >;
199
200 TYPED_TEST_SUITE(MaybeOwnedTest, MaybeOwnedTypes);
201
TYPED_TEST(MaybeOwnedTest,SimpleDereferencingString)202 TYPED_TEST(MaybeOwnedTest, SimpleDereferencingString) {
203 assertBorrow(this->borrowed, this->borrowFrom);
204 assertOwn(this->owned, this->ownCopy);
205 assertOwn(this->owned2, this->ownCopy2);
206 }
207
TYPED_TEST(MaybeOwnedTest,DefaultCtor)208 TYPED_TEST(MaybeOwnedTest, DefaultCtor) {
209 c10::MaybeOwned<TypeParam> borrowed, owned;
210 // Don't leave the fixture versions around messing up reference counts.
211 this->borrowed = c10::MaybeOwned<TypeParam>();
212 this->owned = c10::MaybeOwned<TypeParam>();
213 borrowed = c10::MaybeOwned<TypeParam>::borrowed(this->borrowFrom);
214 owned = c10::MaybeOwned<TypeParam>::owned(std::in_place, this->ownCopy);
215
216 assertBorrow(borrowed, this->borrowFrom);
217 assertOwn(owned, this->ownCopy);
218 }
219
TYPED_TEST(MaybeOwnedTest,CopyConstructor)220 TYPED_TEST(MaybeOwnedTest, CopyConstructor) {
221
222 auto copiedBorrowed(this->borrowed);
223 auto copiedOwned(this->owned);
224 auto copiedOwned2(this->owned2);
225
226 assertBorrow(this->borrowed, this->borrowFrom);
227 assertBorrow(copiedBorrowed, this->borrowFrom);
228
229 assertOwn(this->owned, this->ownCopy, 3);
230 assertOwn(copiedOwned, this->ownCopy, 3);
231 assertOwn(this->owned2, this->ownCopy2, 3);
232 assertOwn(copiedOwned2, this->ownCopy2, 3);
233 }
234
TYPED_TEST(MaybeOwnedTest,MoveDereferencing)235 TYPED_TEST(MaybeOwnedTest, MoveDereferencing) {
236 // Need a different value.
237 this->owned = c10::MaybeOwned<TypeParam>::owned(std::in_place, getSampleValue2<TypeParam>());
238
239 EXPECT_TRUE(are_equal(*std::move(this->borrowed), getSampleValue<TypeParam>()));
240 EXPECT_TRUE(are_equal(*std::move(this->owned), getSampleValue2<TypeParam>()));
241
242 // Borrowed is unaffected.
243 assertBorrow(this->borrowed, this->borrowFrom);
244
245 // Owned is a null c10::intrusive_ptr / empty Tensor.
246 EXPECT_TRUE(are_equal(*this->owned, TypeParam()));
247 }
248
TYPED_TEST(MaybeOwnedTest,MoveConstructor)249 TYPED_TEST(MaybeOwnedTest, MoveConstructor) {
250 auto movedBorrowed(std::move(this->borrowed));
251 auto movedOwned(std::move(this->owned));
252 auto movedOwned2(std::move(this->owned2));
253
254 assertBorrow(movedBorrowed, this->borrowFrom);
255 assertOwn(movedOwned, this->ownCopy);
256 assertOwn(movedOwned2, this->ownCopy2);
257 }
258
TYPED_TEST(MaybeOwnedTest,CopyAssignmentIntoOwned)259 TYPED_TEST(MaybeOwnedTest, CopyAssignmentIntoOwned) {
260 auto copiedBorrowed = c10::MaybeOwned<TypeParam>::owned(std::in_place);
261 auto copiedOwned = c10::MaybeOwned<TypeParam>::owned(std::in_place);
262 auto copiedOwned2 = c10::MaybeOwned<TypeParam>::owned(std::in_place);
263
264 copiedBorrowed = this->borrowed;
265 copiedOwned = this->owned;
266 copiedOwned2 = this->owned2;
267
268 assertBorrow(this->borrowed, this->borrowFrom);
269 assertBorrow(copiedBorrowed, this->borrowFrom);
270 assertOwn(this->owned, this->ownCopy, 3);
271 assertOwn(copiedOwned, this->ownCopy, 3);
272 assertOwn(this->owned2, this->ownCopy2, 3);
273 assertOwn(copiedOwned2, this->ownCopy2, 3);
274 }
275
TYPED_TEST(MaybeOwnedTest,CopyAssignmentIntoBorrowed)276 TYPED_TEST(MaybeOwnedTest, CopyAssignmentIntoBorrowed) {
277 auto otherBorrowFrom = getSampleValue2<TypeParam>();
278 auto otherOwnCopy = getSampleValue2<TypeParam>();
279 auto copiedBorrowed = c10::MaybeOwned<TypeParam>::borrowed(otherBorrowFrom);
280 auto copiedOwned = c10::MaybeOwned<TypeParam>::borrowed(otherOwnCopy);
281 auto copiedOwned2 = c10::MaybeOwned<TypeParam>::borrowed(otherOwnCopy);
282
283 copiedBorrowed = this->borrowed;
284 copiedOwned = this->owned;
285 copiedOwned2 = this->owned2;
286
287 assertBorrow(this->borrowed, this->borrowFrom);
288 assertBorrow(copiedBorrowed, this->borrowFrom);
289
290 assertOwn(this->owned, this->ownCopy, 3);
291 assertOwn(this->owned2, this->ownCopy2, 3);
292 assertOwn(copiedOwned, this->ownCopy, 3);
293 assertOwn(copiedOwned2, this->ownCopy2, 3);
294 }
295
296
TYPED_TEST(MaybeOwnedTest,MoveAssignmentIntoOwned)297 TYPED_TEST(MaybeOwnedTest, MoveAssignmentIntoOwned) {
298
299 auto movedBorrowed = c10::MaybeOwned<TypeParam>::owned(std::in_place);
300 auto movedOwned = c10::MaybeOwned<TypeParam>::owned(std::in_place);
301 auto movedOwned2 = c10::MaybeOwned<TypeParam>::owned(std::in_place);
302
303 movedBorrowed = std::move(this->borrowed);
304 movedOwned = std::move(this->owned);
305 movedOwned2 = std::move(this->owned2);
306
307 assertBorrow(movedBorrowed, this->borrowFrom);
308 assertOwn(movedOwned, this->ownCopy);
309 assertOwn(movedOwned2, this->ownCopy2);
310 }
311
312
TYPED_TEST(MaybeOwnedTest,MoveAssignmentIntoBorrowed)313 TYPED_TEST(MaybeOwnedTest, MoveAssignmentIntoBorrowed) {
314 auto y = getSampleValue2<TypeParam>();
315 auto movedBorrowed = c10::MaybeOwned<TypeParam>::borrowed(y);
316 auto movedOwned = c10::MaybeOwned<TypeParam>::borrowed(y);
317 auto movedOwned2 = c10::MaybeOwned<TypeParam>::borrowed(y);
318
319 movedBorrowed = std::move(this->borrowed);
320 movedOwned = std::move(this->owned);
321 movedOwned2 = std::move(this->owned2);
322
323 assertBorrow(movedBorrowed, this->borrowFrom);
324 assertOwn(movedOwned, this->ownCopy);
325 assertOwn(movedOwned2, this->ownCopy2);
326 }
327
TYPED_TEST(MaybeOwnedTest,SelfAssignment)328 TYPED_TEST(MaybeOwnedTest, SelfAssignment) {
329 this->borrowed = this->borrowed;
330 this->owned = this->owned;
331 this->owned2 = this->owned2;
332
333 assertBorrow(this->borrowed, this->borrowFrom);
334 assertOwn(this->owned, this->ownCopy);
335 assertOwn(this->owned2, this->ownCopy2);
336 }
337
338 } // namespace
339