1 #include <torch/data/datasets/mnist.h>
2
3 #include <torch/data/example.h>
4 #include <torch/types.h>
5
6 #include <c10/util/Exception.h>
7
8 #include <cstddef>
9 #include <fstream>
10 #include <string>
11
12 namespace torch {
13 namespace data {
14 namespace datasets {
15 namespace {
16 constexpr uint32_t kTrainSize = 60000;
17 constexpr uint32_t kTestSize = 10000;
18 constexpr uint32_t kImageMagicNumber = 2051;
19 constexpr uint32_t kTargetMagicNumber = 2049;
20 constexpr uint32_t kImageRows = 28;
21 constexpr uint32_t kImageColumns = 28;
22 constexpr const char* kTrainImagesFilename = "train-images-idx3-ubyte";
23 constexpr const char* kTrainTargetsFilename = "train-labels-idx1-ubyte";
24 constexpr const char* kTestImagesFilename = "t10k-images-idx3-ubyte";
25 constexpr const char* kTestTargetsFilename = "t10k-labels-idx1-ubyte";
26
check_is_little_endian()27 bool check_is_little_endian() {
28 const uint32_t word = 1;
29 return reinterpret_cast<const uint8_t*>(&word)[0] == 1;
30 }
31
flip_endianness(uint32_t value)32 constexpr uint32_t flip_endianness(uint32_t value) {
33 return ((value & 0xffu) << 24u) | ((value & 0xff00u) << 8u) |
34 ((value & 0xff0000u) >> 8u) | ((value & 0xff000000u) >> 24u);
35 }
36
read_int32(std::ifstream & stream)37 uint32_t read_int32(std::ifstream& stream) {
38 static const bool is_little_endian = check_is_little_endian();
39 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
40 uint32_t value;
41 AT_ASSERT(stream.read(reinterpret_cast<char*>(&value), sizeof value));
42 return is_little_endian ? flip_endianness(value) : value;
43 }
44
expect_int32(std::ifstream & stream,uint32_t expected)45 uint32_t expect_int32(std::ifstream& stream, uint32_t expected) {
46 const auto value = read_int32(stream);
47 // clang-format off
48 TORCH_CHECK(value == expected,
49 "Expected to read number ", expected, " but found ", value, " instead");
50 // clang-format on
51 return value;
52 }
53
join_paths(std::string head,const std::string & tail)54 std::string join_paths(std::string head, const std::string& tail) {
55 if (head.back() != '/') {
56 head.push_back('/');
57 }
58 head += tail;
59 return head;
60 }
61
read_images(const std::string & root,bool train)62 Tensor read_images(const std::string& root, bool train) {
63 const auto path =
64 join_paths(root, train ? kTrainImagesFilename : kTestImagesFilename);
65 std::ifstream images(path, std::ios::binary);
66 TORCH_CHECK(images, "Error opening images file at ", path);
67
68 const auto count = train ? kTrainSize : kTestSize;
69
70 // From http://yann.lecun.com/exdb/mnist/
71 expect_int32(images, kImageMagicNumber);
72 expect_int32(images, count);
73 expect_int32(images, kImageRows);
74 expect_int32(images, kImageColumns);
75
76 auto tensor =
77 torch::empty({count, 1, kImageRows, kImageColumns}, torch::kByte);
78 images.read(reinterpret_cast<char*>(tensor.data_ptr()), tensor.numel());
79 return tensor.to(torch::kFloat32).div_(255);
80 }
81
read_targets(const std::string & root,bool train)82 Tensor read_targets(const std::string& root, bool train) {
83 const auto path =
84 join_paths(root, train ? kTrainTargetsFilename : kTestTargetsFilename);
85 std::ifstream targets(path, std::ios::binary);
86 TORCH_CHECK(targets, "Error opening targets file at ", path);
87
88 const auto count = train ? kTrainSize : kTestSize;
89
90 expect_int32(targets, kTargetMagicNumber);
91 expect_int32(targets, count);
92
93 auto tensor = torch::empty(count, torch::kByte);
94 targets.read(reinterpret_cast<char*>(tensor.data_ptr()), count);
95 return tensor.to(torch::kInt64);
96 }
97 } // namespace
98
MNIST(const std::string & root,Mode mode)99 MNIST::MNIST(const std::string& root, Mode mode)
100 : images_(read_images(root, mode == Mode::kTrain)),
101 targets_(read_targets(root, mode == Mode::kTrain)) {}
102
get(size_t index)103 Example<> MNIST::get(size_t index) {
104 return {images_[index], targets_[index]};
105 }
106
size() const107 std::optional<size_t> MNIST::size() const {
108 return images_.size(0);
109 }
110
111 // NOLINTNEXTLINE(bugprone-exception-escape)
is_train() const112 bool MNIST::is_train() const noexcept {
113 return images_.size(0) == kTrainSize;
114 }
115
images() const116 const Tensor& MNIST::images() const {
117 return images_;
118 }
119
targets() const120 const Tensor& MNIST::targets() const {
121 return targets_;
122 }
123
124 } // namespace datasets
125 } // namespace data
126 } // namespace torch
127