1 #include <torch/extension.h>
2 #include <torch/library.h>
3 #include <ATen/Generator.h>
4 #include <ATen/Tensor.h>
5 #include <ATen/native/DistributionTemplates.h>
6 #include <ATen/native/cpu/DistributionTemplates.h>
7 #include <memory>
8
9 using namespace at;
10
11 static size_t instance_count = 0;
12
13 struct TestCPUGenerator : public c10::GeneratorImpl {
TestCPUGeneratorTestCPUGenerator14 TestCPUGenerator(uint64_t value) : c10::GeneratorImpl{Device(DeviceType::CPU), DispatchKeySet(DispatchKey::CustomRNGKeyId)}, value_(value) {
15 ++instance_count;
16 }
~TestCPUGeneratorTestCPUGenerator17 ~TestCPUGenerator() {
18 --instance_count;
19 }
randomTestCPUGenerator20 uint32_t random() { return static_cast<uint32_t>(value_); }
random64TestCPUGenerator21 uint64_t random64() { return value_; }
set_current_seedTestCPUGenerator22 void set_current_seed(uint64_t seed) override { throw std::runtime_error("not implemented"); }
set_offsetTestCPUGenerator23 void set_offset(uint64_t offset) override { throw std::runtime_error("not implemented"); }
get_offsetTestCPUGenerator24 uint64_t get_offset() const override { throw std::runtime_error("not implemented"); }
current_seedTestCPUGenerator25 uint64_t current_seed() const override { throw std::runtime_error("not implemented"); }
seedTestCPUGenerator26 uint64_t seed() override { throw std::runtime_error("not implemented"); }
set_stateTestCPUGenerator27 void set_state(const c10::TensorImpl& new_state) override { throw std::runtime_error("not implemented"); }
get_stateTestCPUGenerator28 c10::intrusive_ptr<c10::TensorImpl> get_state() const override { throw std::runtime_error("not implemented"); }
clone_implTestCPUGenerator29 TestCPUGenerator* clone_impl() const override { throw std::runtime_error("not implemented"); }
30
device_typeTestCPUGenerator31 static DeviceType device_type() { return DeviceType::CPU; }
32
33 uint64_t value_;
34 };
35
random_(Tensor & self,std::optional<Generator> generator)36 Tensor& random_(Tensor& self, std::optional<Generator> generator) {
37 return at::native::templates::random_impl<native::templates::cpu::RandomKernel, TestCPUGenerator>(self, generator);
38 }
39
random_from_to(Tensor & self,int64_t from,optional<int64_t> to,std::optional<Generator> generator)40 Tensor& random_from_to(Tensor& self, int64_t from, optional<int64_t> to, std::optional<Generator> generator) {
41 return at::native::templates::random_from_to_impl<native::templates::cpu::RandomFromToKernel, TestCPUGenerator>(self, from, to, generator);
42 }
43
random_to(Tensor & self,int64_t to,std::optional<Generator> generator)44 Tensor& random_to(Tensor& self, int64_t to, std::optional<Generator> generator) {
45 return random_from_to(self, 0, to, generator);
46 }
47
createTestCPUGenerator(uint64_t value)48 Generator createTestCPUGenerator(uint64_t value) {
49 return at::make_generator<TestCPUGenerator>(value);
50 }
51
identity(Generator g)52 Generator identity(Generator g) {
53 return g;
54 }
55
getInstanceCount()56 size_t getInstanceCount() {
57 return instance_count;
58 }
59
TORCH_LIBRARY_IMPL(aten,CustomRNGKeyId,m)60 TORCH_LIBRARY_IMPL(aten, CustomRNGKeyId, m) {
61 m.impl("aten::random_.from", random_from_to);
62 m.impl("aten::random_.to", random_to);
63 m.impl("aten::random_", random_);
64 }
65
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)66 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
67 m.def("createTestCPUGenerator", &createTestCPUGenerator);
68 m.def("getInstanceCount", &getInstanceCount);
69 m.def("identity", &identity);
70 }
71