xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/rng_extension.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/extension.h>
2 #include <torch/library.h>
3 #include <ATen/Generator.h>
4 #include <ATen/Tensor.h>
5 #include <ATen/native/DistributionTemplates.h>
6 #include <ATen/native/cpu/DistributionTemplates.h>
7 #include <memory>
8 
9 using namespace at;
10 
11 static size_t instance_count = 0;
12 
13 struct TestCPUGenerator : public c10::GeneratorImpl {
TestCPUGeneratorTestCPUGenerator14   TestCPUGenerator(uint64_t value) : c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(DispatchKey::CustomRNGKeyId)}, value_(value) {
15     ++instance_count;
16   }
~TestCPUGeneratorTestCPUGenerator17   ~TestCPUGenerator() {
18     --instance_count;
19   }
randomTestCPUGenerator20   uint32_t random() { return static_cast<uint32_t>(value_); }
random64TestCPUGenerator21   uint64_t random64() { return value_; }
set_current_seedTestCPUGenerator22   void set_current_seed(uint64_t seed) override { throw std::runtime_error("not implemented"); }
set_offsetTestCPUGenerator23   void set_offset(uint64_t offset) override { throw std::runtime_error("not implemented"); }
get_offsetTestCPUGenerator24   uint64_t get_offset() const override { throw std::runtime_error("not implemented"); }
current_seedTestCPUGenerator25   uint64_t current_seed() const override { throw std::runtime_error("not implemented"); }
seedTestCPUGenerator26   uint64_t seed() override { throw std::runtime_error("not implemented"); }
set_stateTestCPUGenerator27   void set_state(const c10::TensorImpl& new_state) override { throw std::runtime_error("not implemented"); }
get_stateTestCPUGenerator28   c10::intrusive_ptr<c10::TensorImpl> get_state() const override { throw std::runtime_error("not implemented"); }
clone_implTestCPUGenerator29   TestCPUGenerator* clone_impl() const override { throw std::runtime_error("not implemented"); }
30 
device_typeTestCPUGenerator31   static DeviceType device_type() { return DeviceType::CPU; }
32 
33   uint64_t value_;
34 };
35 
random_(Tensor & self,std::optional<Generator> generator)36 Tensor& random_(Tensor& self, std::optional<Generator> generator) {
37   return at::native::templates::random_impl<native::templates::cpu::RandomKernel, TestCPUGenerator>(self, generator);
38 }
39 
random_from_to(Tensor & self,int64_t from,optional<int64_t> to,std::optional<Generator> generator)40 Tensor& random_from_to(Tensor& self, int64_t from, optional<int64_t> to, std::optional<Generator> generator) {
41   return at::native::templates::random_from_to_impl<native::templates::cpu::RandomFromToKernel, TestCPUGenerator>(self, from, to, generator);
42 }
43 
random_to(Tensor & self,int64_t to,std::optional<Generator> generator)44 Tensor& random_to(Tensor& self, int64_t to, std::optional<Generator> generator) {
45   return random_from_to(self, 0, to, generator);
46 }
47 
createTestCPUGenerator(uint64_t value)48 Generator createTestCPUGenerator(uint64_t value) {
49   return at::make_generator<TestCPUGenerator>(value);
50 }
51 
identity(Generator g)52 Generator identity(Generator g) {
53   return g;
54 }
55 
getInstanceCount()56 size_t getInstanceCount() {
57   return instance_count;
58 }
59 
TORCH_LIBRARY_IMPL(aten,CustomRNGKeyId,m)60 TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) {
61   m.impl("aten::random_.from",                 random_from_to);
62   m.impl("aten::random_.to",                   random_to);
63   m.impl("aten::random_",                      random_);
64 }
65 
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)66 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
67   m.def("createTestCPUGenerator", &createTestCPUGenerator);
68   m.def("getInstanceCount", &getInstanceCount);
69   m.def("identity", &identity);
70 }
71