1 #pragma once 2 3 #include <ATen/Parallel.h> 4 #include <c10/core/thread_pool.h> 5 6 namespace at { 7 8 class TORCH_API PTThreadPool : public c10::ThreadPool { 9 public: 10 explicit PTThreadPool(int pool_size, int numa_node_id = -1) 11 : c10::ThreadPool(pool_size, numa_node_id, []() { 12 c10::setThreadName("PTThreadPool"); 13 at::init_num_threads(); 14 }) {} 15 }; 16 17 } // namespace at 18