xref: /aosp_15_r20/external/pytorch/test/edge/operator_registry.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/Exception.h>
2 #include <operator_registry.h>
3 
4 namespace torch {
5 namespace executor {
6 
getKernelRegistry()7 KernelRegistry& getKernelRegistry() {
8   static KernelRegistry kernel_registry;
9   return kernel_registry;
10 }
11 
register_kernels(const ArrayRef<Kernel> & kernels)12 bool register_kernels(const ArrayRef<Kernel>& kernels) {
13   return getKernelRegistry().register_kernels(kernels);
14 }
15 
register_kernels(const ArrayRef<Kernel> & kernels)16 bool KernelRegistry::register_kernels(
17     const ArrayRef<Kernel>& kernels) {
18   for (const auto& kernel : kernels) {
19     this->kernels_map_[kernel.name_] = kernel.kernel_;
20   }
21   return true;
22 }
23 
hasKernelFn(const char * name)24 bool hasKernelFn(const char* name) {
25   return getKernelRegistry().hasKernelFn(name);
26 }
27 
hasKernelFn(const char * name)28 bool KernelRegistry::hasKernelFn(const char* name) {
29   auto kernel = this->kernels_map_.find(name);
30   return kernel != this->kernels_map_.end();
31 }
32 
getKernelFn(const char * name)33 KernelFunction& getKernelFn(const char* name) {
34   return getKernelRegistry().getKernelFn(name);
35 }
36 
getKernelFn(const char * name)37 KernelFunction& KernelRegistry::getKernelFn(const char* name) {
38   auto kernel = this->kernels_map_.find(name);
39   TORCH_CHECK_MSG(kernel != this->kernels_map_.end(), "Kernel not found!");
40   return kernel->second;
41 }
42 
43 
44 } // namespace executor
45 } // namespace torch
46