1 #pragma once 2 3 #include <cstring> 4 #include <functional> 5 #include <map> 6 7 #include "Evalue.h" 8 #include "kernel_runtime_context.h" 9 10 #include <c10/util/ArrayRef.h> 11 12 namespace torch { 13 namespace executor { 14 15 using KernelFunction = std::function<void(KernelRuntimeContext&, EValue**)>; 16 17 template<typename T> 18 using ArrayRef = at::ArrayRef<T>; 19 20 #define EXECUTORCH_SCOPE_PROF(x) 21 22 struct Kernel { 23 const char* name_; 24 KernelFunction kernel_; 25 26 Kernel() = default; 27 28 /** 29 * We are doing a copy of the string pointer instead of duplicating the string 30 * itself, we require the lifetime of the kernel name to be at least as long 31 * as the kernel registry. 32 */ KernelKernel33 explicit Kernel(const char* name, KernelFunction func) 34 : name_(name), kernel_(func) {} 35 }; 36 37 /** 38 * See KernelRegistry::hasKernelFn() 39 */ 40 bool hasKernelFn(const char* name); 41 42 /** 43 * See KernelRegistry::getKernelFn() 44 */ 45 KernelFunction& getKernelFn(const char* name); 46 47 48 [[nodiscard]] bool register_kernels(const ArrayRef<Kernel>&); 49 50 struct KernelRegistry { 51 public: KernelRegistryKernelRegistry52 KernelRegistry() : kernelRegSize_(0) {} 53 54 bool register_kernels(const ArrayRef<Kernel>&); 55 56 /** 57 * Checks whether an kernel with a given name is registered 58 */ 59 bool hasKernelFn(const char* name); 60 61 /** 62 * Checks whether an kernel with a given name is registered 63 */ 64 KernelFunction& getKernelFn(const char* name); 65 66 private: 67 std::map<const char*, KernelFunction> kernels_map_; 68 uint32_t kernelRegSize_; 69 }; 70 71 } // namespace executor 72 } // namespace torch 73