xref: /aosp_15_r20/external/pytorch/test/cpp/lazy/test_misc.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <string>
3 
4 #include <c10/util/int128.h>
5 #include <torch/csrc/lazy/core/hash.h>
6 
7 namespace torch {
8 namespace lazy {
9 
10 template <typename T>
test_hash_repeatable_sensitive(const T & example_a,const T & example_b)11 void test_hash_repeatable_sensitive(const T& example_a, const T& example_b) {
12   // repeatable
13   EXPECT_EQ(Hash(example_a), Hash(example_a));
14   EXPECT_EQ(MHash(example_a), MHash(example_a));
15   EXPECT_EQ(MHash(example_a, example_a), MHash(example_a, example_a));
16 
17   // sensitive
18   EXPECT_NE(Hash(example_a), Hash(example_b));
19   EXPECT_NE(MHash(example_a), MHash(example_b));
20   EXPECT_NE(MHash(example_a, example_a), MHash(example_a, example_b));
21 }
22 
TEST(HashTest,Scalar)23 TEST(HashTest, Scalar) {
24   GTEST_SKIP()
25       << "Broken test. See https://github.com/pytorch/pytorch/issues/99883";
26   c10::Scalar a(0);
27   c10::Scalar b(0);
28 
29   // simulate some garbage in the unused bits of the
30   // the tagged union that is c10::Scalar, which is bigger
31   // than the size of the int64_t we're currently using it with
32   *((uint8_t*)&b) = 1;
33   // actual 'value' of the Scalar as a 64 bit int shouldn't have changed
34   EXPECT_EQ(a.toLong(), b.toLong());
35   // and hash should ignore this garbage
36   EXPECT_EQ(Hash(a), Hash(b));
37   EXPECT_EQ(MHash(a), MHash(b));
38   EXPECT_EQ(MHash(a, a), MHash(a, b));
39 }
40 
TEST(HashTest,Sanity)41 TEST(HashTest, Sanity) {
42   // String
43   test_hash_repeatable_sensitive(
44       std::string(
45           "Lorem ipsum dolor sit amet, consectetur adipiscing elit. Ut at suscipit purus."),
46       std::string(
47           "Lorem Jpsum dolor sit amet, consectetur adipiscing elit. Ut at suscipit purus."));
48 
49   // Number types
50   test_hash_repeatable_sensitive(true, false);
51   test_hash_repeatable_sensitive((int8_t)0xfa, (int8_t)0xfb);
52   test_hash_repeatable_sensitive((int16_t)0xface, (int16_t)0xfade);
53   test_hash_repeatable_sensitive((int32_t)0xfaceb000, (int32_t)0xfadeb000);
54   test_hash_repeatable_sensitive((int64_t)0x1faceb000, (int64_t)0x1fadeb000);
55   test_hash_repeatable_sensitive((uint8_t)0xfa, (uint8_t)0xfb);
56   test_hash_repeatable_sensitive((uint16_t)0xface, (uint16_t)0xfade);
57   test_hash_repeatable_sensitive((uint32_t)0xfaceb000, (uint32_t)0xfadeb000);
58   test_hash_repeatable_sensitive((uint64_t)0x1faceb000, (uint64_t)0x1fadeb000);
59 
60   // c10 types
61   test_hash_repeatable_sensitive(c10::ScalarType::Bool, c10::ScalarType::Byte);
62   test_hash_repeatable_sensitive(c10::Scalar(1.334), c10::Scalar(1.335));
63   test_hash_repeatable_sensitive(c10::Scalar(true), c10::Scalar(false));
64   test_hash_repeatable_sensitive(c10::Scalar(12345), c10::Scalar(12354));
65 
66   // std::optional
67   test_hash_repeatable_sensitive(
68       std::optional<std::string>("I have value!"),
69       std::optional<std::string>(std::nullopt));
70 
71   // Containers
72   auto a = std::vector<int32_t>({0, 1, 1, 2, 3, 5, 8});
73   auto b = std::vector<int32_t>({1, 1, 2, 3, 5, 8, 12});
74   test_hash_repeatable_sensitive(a, b);
75   test_hash_repeatable_sensitive(
76       c10::ArrayRef<int32_t>(a), c10::ArrayRef<int32_t>(b));
77 
78   // vector<bool> is a special case bc it is implemented as vector<bit>
79   auto bool_a = std::vector<bool>({true, false, false, true});
80   auto bool_b = std::vector<bool>({true, true, false, true});
81   test_hash_repeatable_sensitive(bool_a, bool_b);
82 }
83 
84 } // namespace lazy
85 } // namespace torch
86