xref: /aosp_15_r20/external/pytorch/test/cpp/api/meta_tensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1  #include <gtest/gtest.h>
2  
3  #include <ATen/MetaFunctions.h>
4  #include <torch/torch.h>
5  
6  #include <vector>
7  
TEST(MetaTensorTest,MetaDeviceApi)8  TEST(MetaTensorTest, MetaDeviceApi) {
9    auto a = at::ones({4}, at::kFloat);
10    auto b = at::ones({3, 4}, at::kFloat);
11    // at::add() will return a meta tensor if its inputs are also meta tensors.
12    auto out_meta = at::add(a.to(c10::kMeta), b.to(c10::kMeta));
13  
14    ASSERT_EQ(a.device(), c10::kCPU);
15    ASSERT_EQ(b.device(), c10::kCPU);
16    ASSERT_EQ(out_meta.device(), c10::kMeta);
17    c10::IntArrayRef sizes_actual = out_meta.sizes();
18    std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4};
19    ASSERT_EQ(sizes_actual, sizes_expected);
20  }
21  
TEST(MetaTensorTest,MetaNamespaceApi)22  TEST(MetaTensorTest, MetaNamespaceApi) {
23    auto a = at::ones({4}, at::kFloat);
24    auto b = at::ones({3, 4}, at::kFloat);
25    // The at::meta:: namespace take in tensors from any backend
26    // and return a meta tensor.
27    auto out_meta = at::meta::add(a, b);
28  
29    ASSERT_EQ(a.device(), c10::kCPU);
30    ASSERT_EQ(b.device(), c10::kCPU);
31    ASSERT_EQ(out_meta.device(), c10::kMeta);
32    c10::IntArrayRef sizes_actual = out_meta.sizes();
33    std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4};
34    ASSERT_EQ(sizes_actual, sizes_expected);
35  }
36