xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/data/datasets/mnist.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/data/datasets/mnist.h>
2 
3 #include <torch/data/example.h>
4 #include <torch/types.h>
5 
6 #include <c10/util/Exception.h>
7 
8 #include <cstddef>
9 #include <fstream>
10 #include <string>
11 
12 namespace torch {
13 namespace data {
14 namespace datasets {
15 namespace {
16 constexpr uint32_t kTrainSize = 60000;
17 constexpr uint32_t kTestSize = 10000;
18 constexpr uint32_t kImageMagicNumber = 2051;
19 constexpr uint32_t kTargetMagicNumber = 2049;
20 constexpr uint32_t kImageRows = 28;
21 constexpr uint32_t kImageColumns = 28;
22 constexpr const char* kTrainImagesFilename = "train-images-idx3-ubyte";
23 constexpr const char* kTrainTargetsFilename = "train-labels-idx1-ubyte";
24 constexpr const char* kTestImagesFilename = "t10k-images-idx3-ubyte";
25 constexpr const char* kTestTargetsFilename = "t10k-labels-idx1-ubyte";
26 
check_is_little_endian()27 bool check_is_little_endian() {
28   const uint32_t word = 1;
29   return reinterpret_cast<const uint8_t*>(&word)[0] == 1;
30 }
31 
flip_endianness(uint32_t value)32 constexpr uint32_t flip_endianness(uint32_t value) {
33   return ((value & 0xffu) << 24u) | ((value & 0xff00u) << 8u) |
34       ((value & 0xff0000u) >> 8u) | ((value & 0xff000000u) >> 24u);
35 }
36 
read_int32(std::ifstream & stream)37 uint32_t read_int32(std::ifstream& stream) {
38   static const bool is_little_endian = check_is_little_endian();
39   // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
40   uint32_t value;
41   AT_ASSERT(stream.read(reinterpret_cast<char*>(&value), sizeof value));
42   return is_little_endian ? flip_endianness(value) : value;
43 }
44 
expect_int32(std::ifstream & stream,uint32_t expected)45 uint32_t expect_int32(std::ifstream& stream, uint32_t expected) {
46   const auto value = read_int32(stream);
47   // clang-format off
48   TORCH_CHECK(value == expected,
49       "Expected to read number ", expected, " but found ", value, " instead");
50   // clang-format on
51   return value;
52 }
53 
join_paths(std::string head,const std::string & tail)54 std::string join_paths(std::string head, const std::string& tail) {
55   if (head.back() != '/') {
56     head.push_back('/');
57   }
58   head += tail;
59   return head;
60 }
61 
read_images(const std::string & root,bool train)62 Tensor read_images(const std::string& root, bool train) {
63   const auto path =
64       join_paths(root, train ? kTrainImagesFilename : kTestImagesFilename);
65   std::ifstream images(path, std::ios::binary);
66   TORCH_CHECK(images, "Error opening images file at ", path);
67 
68   const auto count = train ? kTrainSize : kTestSize;
69 
70   // From http://yann.lecun.com/exdb/mnist/
71   expect_int32(images, kImageMagicNumber);
72   expect_int32(images, count);
73   expect_int32(images, kImageRows);
74   expect_int32(images, kImageColumns);
75 
76   auto tensor =
77       torch::empty({count, 1, kImageRows, kImageColumns}, torch::kByte);
78   images.read(reinterpret_cast<char*>(tensor.data_ptr()), tensor.numel());
79   return tensor.to(torch::kFloat32).div_(255);
80 }
81 
read_targets(const std::string & root,bool train)82 Tensor read_targets(const std::string& root, bool train) {
83   const auto path =
84       join_paths(root, train ? kTrainTargetsFilename : kTestTargetsFilename);
85   std::ifstream targets(path, std::ios::binary);
86   TORCH_CHECK(targets, "Error opening targets file at ", path);
87 
88   const auto count = train ? kTrainSize : kTestSize;
89 
90   expect_int32(targets, kTargetMagicNumber);
91   expect_int32(targets, count);
92 
93   auto tensor = torch::empty(count, torch::kByte);
94   targets.read(reinterpret_cast<char*>(tensor.data_ptr()), count);
95   return tensor.to(torch::kInt64);
96 }
97 } // namespace
98 
MNIST(const std::string & root,Mode mode)99 MNIST::MNIST(const std::string& root, Mode mode)
100     : images_(read_images(root, mode == Mode::kTrain)),
101       targets_(read_targets(root, mode == Mode::kTrain)) {}
102 
get(size_t index)103 Example<> MNIST::get(size_t index) {
104   return {images_[index], targets_[index]};
105 }
106 
size() const107 std::optional<size_t> MNIST::size() const {
108   return images_.size(0);
109 }
110 
111 // NOLINTNEXTLINE(bugprone-exception-escape)
is_train() const112 bool MNIST::is_train() const noexcept {
113   return images_.size(0) == kTrainSize;
114 }
115 
images() const116 const Tensor& MNIST::images() const {
117   return images_;
118 }
119 
targets() const120 const Tensor& MNIST::targets() const {
121   return targets_;
122 }
123 
124 } // namespace datasets
125 } // namespace data
126 } // namespace torch
127