1 *da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2 *da0073e9SAndroid Build Coastguard Worker
3 *da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
4 *da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/utils/variadic.h>
5 *da0073e9SAndroid Build Coastguard Worker #include <torch/detail/static.h>
6 *da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
7 *da0073e9SAndroid Build Coastguard Worker
8 *da0073e9SAndroid Build Coastguard Worker #include <string>
9 *da0073e9SAndroid Build Coastguard Worker #include <type_traits>
10 *da0073e9SAndroid Build Coastguard Worker #include <vector>
11 *da0073e9SAndroid Build Coastguard Worker
12 *da0073e9SAndroid Build Coastguard Worker template <
13 *da0073e9SAndroid Build Coastguard Worker typename T,
14 *da0073e9SAndroid Build Coastguard Worker typename = std::enable_if_t<!torch::detail::is_module<T>::value>>
f(T && m)15 *da0073e9SAndroid Build Coastguard Worker bool f(T&& m) {
16 *da0073e9SAndroid Build Coastguard Worker return false;
17 *da0073e9SAndroid Build Coastguard Worker }
18 *da0073e9SAndroid Build Coastguard Worker
19 *da0073e9SAndroid Build Coastguard Worker template <typename T>
f(T && m)20 *da0073e9SAndroid Build Coastguard Worker torch::detail::enable_if_module_t<T, bool> f(T&& m) {
21 *da0073e9SAndroid Build Coastguard Worker return true;
22 *da0073e9SAndroid Build Coastguard Worker }
23 *da0073e9SAndroid Build Coastguard Worker
TEST(TestStatic,EnableIfModule)24 *da0073e9SAndroid Build Coastguard Worker TEST(TestStatic, EnableIfModule) {
25 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(f(torch::nn::LinearImpl(1, 2)));
26 *da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(f(5));
27 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::detail::check_not_lvalue_references<int>());
28 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE((torch::detail::check_not_lvalue_references<float, int, char>()));
29 *da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(
30 *da0073e9SAndroid Build Coastguard Worker (torch::detail::check_not_lvalue_references<float, int&, char>()));
31 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(torch::detail::check_not_lvalue_references<std::string>());
32 *da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(torch::detail::check_not_lvalue_references<std::string&>());
33 *da0073e9SAndroid Build Coastguard Worker }
34 *da0073e9SAndroid Build Coastguard Worker
35 *da0073e9SAndroid Build Coastguard Worker namespace {
36 *da0073e9SAndroid Build Coastguard Worker
37 *da0073e9SAndroid Build Coastguard Worker struct A : torch::nn::Module {
forward__anon24ea8ac20111::A38 *da0073e9SAndroid Build Coastguard Worker int forward() {
39 *da0073e9SAndroid Build Coastguard Worker return 5;
40 *da0073e9SAndroid Build Coastguard Worker }
41 *da0073e9SAndroid Build Coastguard Worker };
42 *da0073e9SAndroid Build Coastguard Worker
43 *da0073e9SAndroid Build Coastguard Worker struct B : torch::nn::Module {
forward__anon24ea8ac20111::B44 *da0073e9SAndroid Build Coastguard Worker std::string forward(torch::Tensor tensor) {
45 *da0073e9SAndroid Build Coastguard Worker return "";
46 *da0073e9SAndroid Build Coastguard Worker }
47 *da0073e9SAndroid Build Coastguard Worker };
48 *da0073e9SAndroid Build Coastguard Worker
49 *da0073e9SAndroid Build Coastguard Worker struct C : torch::nn::Module {
forward__anon24ea8ac20111::C50 *da0073e9SAndroid Build Coastguard Worker float forward(torch::Tensor& tensor) {
51 *da0073e9SAndroid Build Coastguard Worker return 5.0;
52 *da0073e9SAndroid Build Coastguard Worker }
53 *da0073e9SAndroid Build Coastguard Worker };
54 *da0073e9SAndroid Build Coastguard Worker
55 *da0073e9SAndroid Build Coastguard Worker struct D : torch::nn::Module {
forward__anon24ea8ac20111::D56 *da0073e9SAndroid Build Coastguard Worker char forward(torch::Tensor&& tensor) {
57 *da0073e9SAndroid Build Coastguard Worker return 'x';
58 *da0073e9SAndroid Build Coastguard Worker }
59 *da0073e9SAndroid Build Coastguard Worker };
60 *da0073e9SAndroid Build Coastguard Worker
61 *da0073e9SAndroid Build Coastguard Worker struct E : torch::nn::Module {};
62 *da0073e9SAndroid Build Coastguard Worker
63 *da0073e9SAndroid Build Coastguard Worker } // anonymous namespace
64 *da0073e9SAndroid Build Coastguard Worker
65 *da0073e9SAndroid Build Coastguard Worker // Put in a function because macros don't handle the comma between arguments to
66 *da0073e9SAndroid Build Coastguard Worker // is_same well ...
67 *da0073e9SAndroid Build Coastguard Worker template <typename Module, typename ExpectedType, typename... Args>
assert_has_expected_type()68 *da0073e9SAndroid Build Coastguard Worker void assert_has_expected_type() {
69 *da0073e9SAndroid Build Coastguard Worker using ReturnType =
70 *da0073e9SAndroid Build Coastguard Worker typename torch::detail::return_type_of_forward<Module, Args...>::type;
71 *da0073e9SAndroid Build Coastguard Worker constexpr bool is_expected_type =
72 *da0073e9SAndroid Build Coastguard Worker std::is_same<ReturnType, ExpectedType>::value;
73 *da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(is_expected_type) << Module().name();
74 *da0073e9SAndroid Build Coastguard Worker }
75 *da0073e9SAndroid Build Coastguard Worker
TEST(TestStatic,ReturnTypeOfForward)76 *da0073e9SAndroid Build Coastguard Worker TEST(TestStatic, ReturnTypeOfForward) {
77 *da0073e9SAndroid Build Coastguard Worker assert_has_expected_type<A, int>();
78 *da0073e9SAndroid Build Coastguard Worker assert_has_expected_type<B, std::string, torch::Tensor>();
79 *da0073e9SAndroid Build Coastguard Worker assert_has_expected_type<C, float, torch::Tensor&>();
80 *da0073e9SAndroid Build Coastguard Worker assert_has_expected_type<D, char, torch::Tensor&&>();
81 *da0073e9SAndroid Build Coastguard Worker assert_has_expected_type<E, void>();
82 *da0073e9SAndroid Build Coastguard Worker }
83 *da0073e9SAndroid Build Coastguard Worker
TEST(TestStatic,Apply)84 *da0073e9SAndroid Build Coastguard Worker TEST(TestStatic, Apply) {
85 *da0073e9SAndroid Build Coastguard Worker std::vector<int> v;
86 *da0073e9SAndroid Build Coastguard Worker torch::apply([&v](int x) { v.push_back(x); }, 1, 2, 3, 4, 5);
87 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(v.size(), 5);
88 *da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(v.size())) {
89 *da0073e9SAndroid Build Coastguard Worker ASSERT_EQ(v.at(i), i + 1);
90 *da0073e9SAndroid Build Coastguard Worker }
91 *da0073e9SAndroid Build Coastguard Worker }
92