xref: /aosp_15_r20/external/pytorch/test/edge/operator_registry.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstring>
4 #include <functional>
5 #include <map>
6 
7 #include "Evalue.h"
8 #include "kernel_runtime_context.h"
9 
10 #include <c10/util/ArrayRef.h>
11 
12 namespace torch {
13 namespace executor {
14 
15 using KernelFunction = std::function<void(KernelRuntimeContext&, EValue**)>;
16 
17 template<typename T>
18 using ArrayRef = at::ArrayRef<T>;
19 
20 #define EXECUTORCH_SCOPE_PROF(x)
21 
22 struct Kernel {
23   const char* name_;
24   KernelFunction kernel_;
25 
26   Kernel() = default;
27 
28   /**
29    * We are doing a copy of the string pointer instead of duplicating the string
30    * itself, we require the lifetime of the kernel name to be at least as long
31    * as the kernel registry.
32    */
KernelKernel33   explicit Kernel(const char* name, KernelFunction func)
34       : name_(name), kernel_(func) {}
35 };
36 
37 /**
38  * See KernelRegistry::hasKernelFn()
39  */
40 bool hasKernelFn(const char* name);
41 
42 /**
43  * See KernelRegistry::getKernelFn()
44  */
45 KernelFunction& getKernelFn(const char* name);
46 
47 
48 [[nodiscard]] bool register_kernels(const ArrayRef<Kernel>&);
49 
50 struct KernelRegistry {
51  public:
KernelRegistryKernelRegistry52   KernelRegistry() : kernelRegSize_(0) {}
53 
54   bool register_kernels(const ArrayRef<Kernel>&);
55 
56   /**
57    * Checks whether an kernel with a given name is registered
58    */
59   bool hasKernelFn(const char* name);
60 
61   /**
62    * Checks whether an kernel with a given name is registered
63    */
64   KernelFunction& getKernelFn(const char* name);
65 
66  private:
67   std::map<const char*, KernelFunction> kernels_map_;
68   uint32_t kernelRegSize_;
69 };
70 
71 } // namespace executor
72 } // namespace torch
73