1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/pytree/function_ref.h>
10
11 #include <gtest/gtest.h>
12
13 using namespace ::testing;
14
15 using ::executorch::extension::pytree::FunctionRef;
16
17 namespace {
18 class Item {
19 private:
20 int32_t val_;
21 FunctionRef<void(int32_t&)> ref_;
22
23 public:
Item(int32_t val,FunctionRef<void (int32_t &)> ref)24 /* implicit */ Item(int32_t val, FunctionRef<void(int32_t&)> ref)
25 : val_(val), ref_(ref) {}
26
get()27 int32_t get() {
28 ref_(val_);
29 return val_;
30 }
31 };
32
one(int32_t & i)33 void one(int32_t& i) {
34 i = 1;
35 }
36
37 } // namespace
38
TEST(FunctionRefTest,CapturingLambda)39 TEST(FunctionRefTest, CapturingLambda) {
40 auto one = 1;
41 auto f = [&](int32_t& i) { i = one; };
42 Item item(0, FunctionRef<void(int32_t&)>{f});
43 EXPECT_EQ(item.get(), 1);
44 // ERROR:
45 // Item item1(0, f);
46 // Item item2(0, [&](int32_t& i) { i = 2; });
47 // FunctionRef<void(int32_t&)> ref([&](int32_t&){});
48 }
49
TEST(FunctionRefTest,NonCapturingLambda)50 TEST(FunctionRefTest, NonCapturingLambda) {
51 int32_t val = 0;
52 FunctionRef<void(int32_t&)> ref([](int32_t& i) { i = 1; });
53 ref(val);
54 EXPECT_EQ(val, 1);
55
56 val = 0;
57 auto lambda = [](int32_t& i) { i = 1; };
58 FunctionRef<void(int32_t&)> ref1(lambda);
59 ref1(val);
60 EXPECT_EQ(val, 1);
61
62 Item item(0, [](int32_t& i) { i = 1; });
63 EXPECT_EQ(item.get(), 1);
64
65 auto f = [](int32_t& i) { i = 1; };
66 Item item1(0, f);
67 EXPECT_EQ(item1.get(), 1);
68
69 Item item2(0, std::move(f));
70 EXPECT_EQ(item2.get(), 1);
71 }
72
TEST(FunctionRefTest,FunctionPointer)73 TEST(FunctionRefTest, FunctionPointer) {
74 int32_t val = 0;
75 FunctionRef<void(int32_t&)> ref(one);
76 ref(val);
77 EXPECT_EQ(val, 1);
78
79 Item item(0, one);
80 EXPECT_EQ(item.get(), 1);
81
82 Item item1(0, &one);
83 EXPECT_EQ(item1.get(), 1);
84 }
85