1 #include <gtest/gtest.h> 2 3 #include <ATen/MetaFunctions.h> 4 #include <torch/torch.h> 5 6 #include <vector> 7 TEST(MetaTensorTest,MetaDeviceApi)8 TEST(MetaTensorTest, MetaDeviceApi) { 9 auto a = at::ones({4}, at::kFloat); 10 auto b = at::ones({3, 4}, at::kFloat); 11 // at::add() will return a meta tensor if its inputs are also meta tensors. 12 auto out_meta = at::add(a.to(c10::kMeta), b.to(c10::kMeta)); 13 14 ASSERT_EQ(a.device(), c10::kCPU); 15 ASSERT_EQ(b.device(), c10::kCPU); 16 ASSERT_EQ(out_meta.device(), c10::kMeta); 17 c10::IntArrayRef sizes_actual = out_meta.sizes(); 18 std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4}; 19 ASSERT_EQ(sizes_actual, sizes_expected); 20 } 21 TEST(MetaTensorTest,MetaNamespaceApi)22 TEST(MetaTensorTest, MetaNamespaceApi) { 23 auto a = at::ones({4}, at::kFloat); 24 auto b = at::ones({3, 4}, at::kFloat); 25 // The at::meta:: namespace take in tensors from any backend 26 // and return a meta tensor. 27 auto out_meta = at::meta::add(a, b); 28 29 ASSERT_EQ(a.device(), c10::kCPU); 30 ASSERT_EQ(b.device(), c10::kCPU); 31 ASSERT_EQ(out_meta.device(), c10::kMeta); 32 c10::IntArrayRef sizes_actual = out_meta.sizes(); 33 std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4}; 34 ASSERT_EQ(sizes_actual, sizes_expected); 35 } 36