xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/data/samplers/stream.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1  #include <torch/data/samplers/stream.h>
2  #include <torch/serialize/archive.h>
3  #include <torch/types.h>
4  
5  #include <c10/util/Exception.h>
6  
7  #include <cstddef>
8  
9  namespace torch {
10  namespace data {
11  namespace samplers {
12  
BatchSize(size_t size)13  BatchSize::BatchSize(size_t size) : size_(size) {}
size() const14  size_t BatchSize::size() const noexcept {
15    return size_;
16  }
operator size_t() const17  BatchSize::operator size_t() const noexcept {
18    return size_;
19  }
20  
StreamSampler(size_t epoch_size)21  StreamSampler::StreamSampler(size_t epoch_size) : epoch_size_(epoch_size) {}
22  
reset(std::optional<size_t> new_size)23  void StreamSampler::reset(std::optional<size_t> new_size) {
24    if (new_size.has_value()) {
25      epoch_size_ = *new_size;
26    }
27    examples_retrieved_so_far_ = 0;
28  }
29  
next(size_t batch_size)30  std::optional<BatchSize> StreamSampler::next(size_t batch_size) {
31    AT_ASSERT(examples_retrieved_so_far_ <= epoch_size_);
32    if (examples_retrieved_so_far_ == epoch_size_) {
33      return nullopt;
34    }
35    if (examples_retrieved_so_far_ + batch_size > epoch_size_) {
36      batch_size = epoch_size_ - examples_retrieved_so_far_;
37    }
38    examples_retrieved_so_far_ += batch_size;
39    return BatchSize(batch_size);
40  }
41  
save(serialize::OutputArchive & archive) const42  void StreamSampler::save(serialize::OutputArchive& archive) const {
43    archive.write(
44        "examples_retrieved_so_far",
45        torch::tensor(
46            static_cast<int64_t>(examples_retrieved_so_far_), torch::kInt64),
47        /*is_buffer=*/true);
48  }
49  
load(serialize::InputArchive & archive)50  void StreamSampler::load(serialize::InputArchive& archive) {
51    auto tensor = torch::empty(1, torch::kInt64);
52    archive.read(
53        "examples_retrieved_so_far",
54        tensor,
55        /*is_buffer=*/true);
56    examples_retrieved_so_far_ = tensor.item<int64_t>();
57  }
58  
59  } // namespace samplers
60  } // namespace data
61  } // namespace torch
62