1 #include <ATen/core/dispatch/Dispatcher.h> 2 #include <ATen/core/PythonOpRegistrationTrampoline.h> 3 #include <chrono> 4 #include <list> 5 #include <sstream> 6 #include <utility> 7 8 #ifdef FBCODE_CAFFE2 9 #include <c10/util/static_tracepoint.h> 10 #endif 11 12 namespace c10 { 13 14 #ifdef FBCODE_CAFFE2 15 TORCH_SDT_DEFINE_SEMAPHORE(operator_start) TORCH_SDT_DEFINE_SEMAPHORE(operator_end)16 TORCH_SDT_DEFINE_SEMAPHORE(operator_end) 17 #endif 18 19 bool show_dispatch_trace() { 20 static char const* temp = getenv("TORCH_SHOW_DISPATCH_TRACE"); 21 return temp != nullptr; 22 } 23 24 static thread_local int64_t dispatch_trace_nesting_value_; 25 dispatch_trace_nesting_incr()26 void dispatch_trace_nesting_incr() { ++dispatch_trace_nesting_value_; } dispatch_trace_nesting_decr()27 void dispatch_trace_nesting_decr() { --dispatch_trace_nesting_value_; } dispatch_trace_nesting_value()28 int64_t dispatch_trace_nesting_value() { return dispatch_trace_nesting_value_; } 29 30 namespace detail { 31 32 class RegistrationListenerList final { 33 public: addListener(std::unique_ptr<OpRegistrationListener> listener)34 std::function<void()> addListener(std::unique_ptr<OpRegistrationListener> listener) { 35 listeners_.push_back(std::move(listener)); 36 auto delete_it = --listeners_.end(); 37 return [this, delete_it] { 38 listeners_.erase(delete_it); 39 }; 40 } 41 callOnOperatorRegistered(const OperatorHandle & op)42 void callOnOperatorRegistered(const OperatorHandle& op) { 43 for (auto& listener : listeners_) { 44 listener->onOperatorRegistered(op); 45 } 46 } 47 callOnOperatorDeregistered(const OperatorHandle & op)48 void callOnOperatorDeregistered(const OperatorHandle& op) { 49 for (auto& listener : listeners_) { 50 listener->onOperatorDeregistered(op); 51 } 52 } 53 private: 54 std::list<std::unique_ptr<OpRegistrationListener>> listeners_; 55 }; 56 _print_dispatch_trace(const std::string & label,const std::string & op_name,const DispatchKeySet & dispatchKeySet)57 void _print_dispatch_trace(const std::string& label, const std::string& op_name, const DispatchKeySet& dispatchKeySet) { 58 auto nesting_value = dispatch_trace_nesting_value(); 59 for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " "; 60 std::cerr << label << " op=[" << op_name << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl; 61 } 62 } // namespace detail 63 64 OpRegistrationListener::~OpRegistrationListener()= default; 65 Dispatcher()66 Dispatcher::Dispatcher() 67 : operators_() 68 , operatorLookupTable_() 69 , backendFallbackKernels_() 70 , listeners_(std::make_unique<detail::RegistrationListenerList>()) 71 , cond_var_() 72 , guard_(std::make_shared<Guard>()) 73 {} 74 ~Dispatcher()75 Dispatcher::~Dispatcher() { 76 std::lock_guard<std::mutex> lock(guard_->mutex); 77 guard_->alive.store(false); 78 } 79 realSingleton()80 C10_EXPORT Dispatcher& Dispatcher::realSingleton() { 81 static Dispatcher _singleton; 82 return _singleton; 83 } 84 findOp(const OperatorName & overload_name)85 std::optional<OperatorHandle> Dispatcher::findOp(const OperatorName& overload_name) { 86 return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::optional<OperatorHandle> { 87 auto found = operatorLookupTable.find(overload_name); 88 if (found == operatorLookupTable.end()) { 89 return std::nullopt; 90 } 91 return found->second; 92 }); 93 } 94 95 // NB: If you add more waitFor* implementations, you also have to add 96 // appropriate notify_all() calls to the relevant register calls 97 waitForDef(const FunctionSchema & schema)98 void Dispatcher::waitForDef(const FunctionSchema& schema) { 99 using namespace std::chrono_literals; 100 std::unique_lock<std::mutex> lock(guard_->mutex); 101 bool r = cond_var_.wait_for(lock, 2s, [&]{ 102 return findOp(schema.operator_name()) != std::nullopt; 103 }); 104 TORCH_INTERNAL_ASSERT(r, 105 "Expected main interpreter to define ", schema.operator_name(), 106 ", but this didn't happen within timeout. Are you trying to load " 107 "different models in the same torchdeploy/multipy instance? You " 108 "must warmup each interpreter identically, e.g., import all " 109 "the same dependencies."); 110 } 111 waitForImpl(const OperatorName & op_name,std::optional<c10::DispatchKey> maybe_dk)112 void Dispatcher::waitForImpl(const OperatorName& op_name, std::optional<c10::DispatchKey> maybe_dk) { 113 using namespace std::chrono_literals; 114 std::unique_lock<std::mutex> lock(guard_->mutex); 115 auto dk = maybe_dk.value_or(DispatchKey::CompositeImplicitAutograd); 116 auto op = findOrRegisterName_(op_name); 117 bool r = cond_var_.wait_for(lock, 2s, [&]{ 118 // NB: this is slightly unsound for overrides, but overrides are 119 // funny business anyway 120 return op.hasKernelForDispatchKey(dk); 121 }); 122 TORCH_INTERNAL_ASSERT(r, 123 "Expected main interpreter to implement ", dk, " for ", op_name, 124 ", but this didn't happen within timeout. Are you trying to load " 125 "different models in the same torchdeploy/multipy instance? You " 126 "must warmup each interpreter identically, e.g., import all " 127 "the same dependencies."); 128 } 129 findSchema(const OperatorName & overload_name)130 std::optional<OperatorHandle> Dispatcher::findSchema(const OperatorName& overload_name) { 131 auto it = findOp(overload_name); 132 if (it.has_value()) { 133 if (it->hasSchema()) { 134 return it; 135 } else { 136 return std::nullopt; 137 } 138 } else { 139 return it; 140 } 141 } 142 findSchemaOrThrow(const char * name,const char * overload_name)143 OperatorHandle Dispatcher::findSchemaOrThrow(const char* name, const char* overload_name) { 144 auto it = findSchema({name, overload_name}); 145 if (!it.has_value()) { 146 // Check if we have ANYTHING; if that's the case, that means you're 147 // missing schema 148 auto it2 = findOp({name, overload_name}); 149 if (!it2.has_value()) { 150 TORCH_CHECK(false, "Could not find schema for ", name, ".", overload_name); 151 } else { 152 TORCH_CHECK(false, "Could not find schema for ", name, ".", overload_name, 153 " but we found an implementation; did you forget to def() the operator?"); 154 } 155 } 156 return it.value(); 157 } 158 getAllOpNames()159 const std::vector<OperatorName> Dispatcher::getAllOpNames() { 160 return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::vector<OperatorName> { 161 std::vector<OperatorName> allOpNames; 162 for (const auto& op : operatorLookupTable) { 163 allOpNames.push_back(op.first); 164 } 165 return allOpNames; 166 }); 167 } 168 169 // Postcondition: caller is responsible for disposing of registration when they 170 // are done findOrRegisterName_(const OperatorName & op_name)171 OperatorHandle Dispatcher::findOrRegisterName_(const OperatorName& op_name) { 172 const auto found = findOp(op_name); 173 if (found != std::nullopt) { 174 return *found; 175 } 176 177 operators_.emplace_back(OperatorName(op_name)); 178 OperatorHandle handle(--operators_.end()); 179 operatorLookupTable_.write([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) { 180 operatorLookupTable.emplace(op_name, handle); 181 }); 182 183 return handle; 184 } 185 186 187 // Adding explicit destructor definition in the cpp to over linker error in Windows builds. 188 // Windows build doesn't produce the destructor symbol in PyTorch libs 189 // causing a linker failure in downstream projects. 190 // x-ref https://github.com/pytorch/pytorch/issues/70032 191 OperatorHandle::~OperatorHandle() = default; 192 registerLibrary(std::string ns,std::string debug)193 RegistrationHandleRAII Dispatcher::registerLibrary(std::string ns, std::string debug) { 194 std::lock_guard<std::mutex> lock(guard_->mutex); 195 auto found = libraries_.find(ns); 196 TORCH_CHECK( 197 found == libraries_.end(), 198 "Only a single TORCH_LIBRARY can be used to register the namespace ", ns, 199 "; please put all of your definitions in a single TORCH_LIBRARY block. " 200 "If you were trying to specify implementations, consider using TORCH_LIBRARY_IMPL " 201 "(which can be duplicated). If you really intended to define operators for a " 202 "single namespace in a distributed way, you can use TORCH_LIBRARY_FRAGMENT to " 203 "explicitly indicate this. " 204 "Previous registration of TORCH_LIBRARY was ", 205 found->second, "; latest registration was ", debug 206 ); 207 libraries_.emplace(ns, std::move(debug)); 208 return RegistrationHandleRAII([guard = this->guard_, this, ns] { 209 std::lock_guard<std::mutex> lock(guard->mutex); 210 if (!guard->alive.load()) { 211 return; 212 } 213 deregisterLibrary_(ns); 214 }); 215 } 216 deregisterLibrary_(const std::string & ns)217 void Dispatcher::deregisterLibrary_(const std::string& ns) { 218 // we need a lock to avoid concurrent writes 219 libraries_.erase(ns); 220 } 221 registerDef(FunctionSchema schema,std::string debug,std::vector<at::Tag> tags)222 RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags) { 223 // we need a lock to avoid concurrent writes 224 std::lock_guard<std::mutex> lock(guard_->mutex); 225 226 OperatorName op_name = schema.operator_name(); 227 auto op = findOrRegisterName_(op_name); 228 229 TORCH_CHECK(op.operatorDef_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.", 230 " Each overload's schema should only be registered with a single call to def().", 231 " Duplicate registration: ", debug, ". Original registration: ", op.operatorDef_->op.debug()); 232 op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug), std::move(tags)); 233 listeners_->callOnOperatorRegistered(op); 234 235 // NB: do not increment the counts until AFTER error checking 236 ++op.operatorDef_->def_count; 237 ++op.operatorDef_->def_and_impl_count; 238 239 cond_var_.notify_all(); 240 241 return RegistrationHandleRAII([guard = this->guard_, this, op, op_name] { 242 // we need a lock to avoid concurrent writes 243 std::lock_guard<std::mutex> lock(guard->mutex); 244 if (!guard->alive.load()) { 245 return; 246 } 247 deregisterDef_(op, op_name); 248 }); 249 } 250 deregisterDef_(const OperatorHandle & op,const OperatorName & op_name)251 void Dispatcher::deregisterDef_( 252 const OperatorHandle& op, 253 const OperatorName& op_name) { 254 TORCH_INTERNAL_ASSERT(op.schema().operator_name() == op_name); 255 256 // reduce def_count and actually deregister if no references left 257 TORCH_INTERNAL_ASSERT(op.operatorDef_->def_count > 0); 258 TORCH_INTERNAL_ASSERT(op.operatorDef_->def_and_impl_count > 0); 259 260 --op.operatorDef_->def_count; 261 --op.operatorDef_->def_and_impl_count; 262 if (0 == op.operatorDef_->def_count) { 263 // note: call listeners *before* operator is removed, i.e. dispatcher is still valid for removed op 264 // TODO: check that listeners are not relying on prepareForDeregistration() 265 // invariant 266 listeners_->callOnOperatorDeregistered(op); 267 op.operatorDef_->op.deregisterSchema(); 268 } 269 270 cleanup(op, op_name); 271 } 272 273 namespace { 274 275 // Maps OperatorName to (python module name, description) tuple. 276 using PythonModuleMapType = std::unordered_map<at::OperatorName, std::pair<const char*, const char*>>; pythonModulesSingleton()277 PythonModuleMapType& pythonModulesSingleton() { 278 static PythonModuleMapType _data; 279 return _data; 280 } 281 282 } 283 getPyStub(OperatorName op_name)284 std::optional<std::pair<const char*, const char*>> Dispatcher::getPyStub(OperatorName op_name) { 285 std::lock_guard<std::mutex> lock(guard_->mutex); 286 auto found = pythonModulesSingleton().find(op_name); 287 if (found == pythonModulesSingleton().end()) { 288 return std::nullopt; 289 } 290 return found->second; 291 } 292 registerPythonModule(const OperatorName & op_name,const char * pymodule,const char * context)293 RegistrationHandleRAII Dispatcher::registerPythonModule( 294 const OperatorName& op_name, 295 const char* pymodule, 296 const char* context 297 ) { 298 std::lock_guard<std::mutex> lock(guard_->mutex); 299 // If there are duplicates, we just let it through and warn about it. 300 // Throwing an error during static initialization causes a crash that 301 // doesn't give any sign of what happened. 302 auto found = pythonModulesSingleton().find(op_name); 303 if (found != pythonModulesSingleton().end()) { 304 TORCH_WARN( 305 "Tried to register an python registration stub (pystub) for ", op_name, " ", 306 "that specifies the Python module ", pymodule, " " 307 "but there already was a pystub that specifies the Python module ", 308 found->second.first, ". We will override the existing pystub."); 309 } 310 pythonModulesSingleton()[op_name] = std::make_pair(pymodule, context); 311 return RegistrationHandleRAII([guard = this->guard_, op_name] { 312 std::lock_guard<std::mutex> lock(guard->mutex); 313 if (!guard->alive.load()) { 314 return; 315 } 316 pythonModulesSingleton().erase(op_name); 317 }); 318 } 319 throwIfHasPythonModule(OperatorName op_name)320 void Dispatcher::throwIfHasPythonModule(OperatorName op_name) { 321 std::lock_guard<std::mutex> lock(guard_->mutex); 322 auto elt = pythonModulesSingleton().find(op_name); 323 if (elt == pythonModulesSingleton().end()) { 324 return; 325 } 326 const char* pymodule = elt->second.first; 327 const char* context = elt->second.second; 328 auto* interpreter = at::impl::PythonOpRegistrationTrampoline::getInterpreter(); 329 TORCH_CHECK( 330 interpreter != nullptr, 331 op_name, 332 ": while attempting to run this operator with Meta Tensors: " 333 "Either there is no meta kernel for this operator, or it is located " 334 "in the python module ", pymodule, " which is not available " 335 "because Python isn't available.") 336 (*interpreter)->throw_abstract_impl_not_imported_error(toString(op_name), pymodule, context); 337 } 338 registerImpl(OperatorName op_name,std::optional<DispatchKey> dispatch_key,KernelFunction kernel,std::optional<impl::CppSignature> cpp_signature,std::unique_ptr<FunctionSchema> inferred_function_schema,std::string debug)339 RegistrationHandleRAII Dispatcher::registerImpl( 340 OperatorName op_name, 341 std::optional<DispatchKey> dispatch_key, 342 KernelFunction kernel, 343 std::optional<impl::CppSignature> cpp_signature, 344 std::unique_ptr<FunctionSchema> inferred_function_schema, 345 std::string debug 346 ) { 347 std::lock_guard<std::mutex> lock(guard_->mutex); 348 349 auto op = findOrRegisterName_(op_name); 350 351 auto handle = op.operatorDef_->op.registerKernel( 352 *this, 353 dispatch_key, 354 std::move(kernel), 355 std::move(cpp_signature), 356 std::move(inferred_function_schema), 357 std::move(debug) 358 ); 359 360 ++op.operatorDef_->def_and_impl_count; 361 362 cond_var_.notify_all(); 363 364 return RegistrationHandleRAII([guard = this->guard_, this, op, op_name, dispatch_key, handle] { 365 std::lock_guard<std::mutex> lock(guard->mutex); 366 if (!guard->alive.load()) { 367 return; 368 } 369 deregisterImpl_(op, op_name, dispatch_key, handle); 370 }); 371 } 372 deregisterImpl_(const OperatorHandle & op,const OperatorName & op_name,std::optional<DispatchKey> dispatch_key,impl::OperatorEntry::AnnotatedKernelContainerIterator handle)373 void Dispatcher::deregisterImpl_(const OperatorHandle& op, const OperatorName& op_name, std::optional<DispatchKey> dispatch_key, impl::OperatorEntry::AnnotatedKernelContainerIterator handle) { 374 op.operatorDef_->op.deregisterKernel_(*this, dispatch_key, handle); 375 376 TORCH_INTERNAL_ASSERT(op.operator_name() == op_name); 377 378 TORCH_INTERNAL_ASSERT(op.operatorDef_->def_and_impl_count > 0); 379 --op.operatorDef_->def_and_impl_count; 380 381 cleanup(op, op_name); 382 } 383 registerName(OperatorName op_name)384 RegistrationHandleRAII Dispatcher::registerName(OperatorName op_name) { 385 std::lock_guard<std::mutex> lock(guard_->mutex); 386 auto op = findOrRegisterName_(op_name); 387 ++op.operatorDef_->def_and_impl_count; 388 389 return RegistrationHandleRAII( 390 [guard = this->guard_, this, op, op_name] { 391 std::lock_guard<std::mutex> lock(guard->mutex); 392 if (!guard->alive.load()) { 393 return; 394 } 395 deregisterName_(op, op_name); 396 } 397 ); 398 } 399 deregisterName_(const OperatorHandle & op,const OperatorName & op_name)400 void Dispatcher::deregisterName_( 401 const OperatorHandle& op, 402 const OperatorName& op_name) { 403 TORCH_INTERNAL_ASSERT(op.operator_name() == op_name); 404 TORCH_INTERNAL_ASSERT(op.operatorDef_->def_and_impl_count > 0); 405 --op.operatorDef_->def_and_impl_count; 406 cleanup(op, op_name); 407 } 408 409 // Test if the operator entry is completely dead, and if so remove it completely cleanup(const OperatorHandle & op,const OperatorName & op_name)410 void Dispatcher::cleanup(const OperatorHandle& op, const OperatorName& op_name) { 411 if (0 == op.operatorDef_->def_and_impl_count) { 412 // NOTE: Making this call fast is the only reason OperatorHandle 413 // stores operatorIterator_! 414 operators_.erase(op.operatorIterator_); 415 operatorLookupTable_.write([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) { 416 operatorLookupTable.erase(op_name); 417 }); 418 } 419 } 420 registerFallback(DispatchKey dispatchKey,KernelFunction kernel,std::string debug)421 RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, KernelFunction kernel, std::string debug) { 422 std::lock_guard<std::mutex> lock(guard_->mutex); 423 424 auto idx = getDispatchTableIndexForDispatchKey(dispatchKey); 425 TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx); 426 TORCH_CHECK( 427 !backendFallbackKernels_[idx].kernel.isValid(), 428 "Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ", 429 backendFallbackKernels_[idx].debug, ", new registration ", debug 430 ); 431 // NB: inferred function schema is always nullptr for fallbacks, as fallbacks 432 // cannot be unboxed 433 backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug)); 434 435 for (auto& op : operators_) { 436 op.op.updateFallback(*this, dispatchKey); 437 } 438 439 return RegistrationHandleRAII([guard = this->guard_, this, dispatchKey] { 440 std::lock_guard<std::mutex> lock(guard->mutex); 441 if (!guard->alive.load()) { 442 return; 443 } 444 deregisterFallback_(dispatchKey); 445 }); 446 } 447 deregisterFallback_(DispatchKey dispatchKey)448 void Dispatcher::deregisterFallback_(DispatchKey dispatchKey) { 449 auto idx = getDispatchTableIndexForDispatchKey(dispatchKey); 450 backendFallbackKernels_[idx] = {}; 451 452 for (auto& op : operators_) { 453 op.op.updateFallback(*this, dispatchKey); 454 } 455 } 456 457 addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener)458 RegistrationHandleRAII Dispatcher::addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener) { 459 std::lock_guard<std::mutex> lock(guard_->mutex); 460 461 for (auto iter = operators_.begin(); iter != operators_.end(); ++iter) { 462 if (iter->def_count > 0) { 463 listener->onOperatorRegistered(OperatorHandle(iter)); 464 } 465 } 466 467 auto removeListener = listeners_->addListener(std::move(listener)); 468 return RegistrationHandleRAII([guard = this->guard_, this, removeListener] { 469 std::lock_guard<std::mutex> lock(guard_->mutex); 470 if (!guard->alive.load()) { 471 return; 472 } 473 removeListener(); 474 }); 475 } 476 checkInvariants() const477 void Dispatcher::checkInvariants() const { 478 for (const auto& op : operators_) { 479 op.op.checkInvariants(); 480 } 481 } 482 findDanglingImpls() const483 std::vector<OperatorHandle> Dispatcher::findDanglingImpls() const { 484 return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::vector<OperatorHandle> { 485 std::vector<OperatorHandle> opsWithDanglingImpls; 486 for (const auto& op : operatorLookupTable) { 487 if (!op.second.hasSchema()) { 488 opsWithDanglingImpls.push_back(op.second); 489 } 490 } 491 return opsWithDanglingImpls; 492 }); 493 } 494 getRegistrationsForDispatchKey(std::optional<DispatchKey> k) const495 std::vector<OperatorName> Dispatcher::getRegistrationsForDispatchKey(std::optional<DispatchKey> k) const { 496 return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::vector<OperatorName> { 497 std::vector<OperatorName> op_names; 498 for (const auto& op : operatorLookupTable) { 499 // If no DispatchKey is specified, print all of the operators. 500 if (!k || op.second.hasKernelForDispatchKey(*k)) { 501 op_names.push_back(op.first); 502 } 503 } 504 return op_names; 505 }); 506 } 507 sequenceNumberForRunningRecordFunction(DispatchKey dispatchKey,DispatchKeySet dispatchKeySet)508 int64_t Dispatcher::sequenceNumberForRunningRecordFunction(DispatchKey dispatchKey, DispatchKeySet dispatchKeySet) { 509 int64_t seq_num = -1; 510 // Setting sequence number in the Autograd case to associate 511 // the forward range with the corresponding Autograd's node 512 513 // Note: this records a sequence number for both Autograd keys, and for 514 // non-Autograd keys where the dispatchKeySet still contains an autograd key. 515 // This means that we might collect the same sequence nubmer two different 516 // events if they all occurred above Autograd and still had the Autograd 517 // dispatch key in the dispatch key set. 518 // However, this usually doesn't happen: normally the first call will 519 // go through the call() or callBoxed() path in the dispatcher, while 520 // subsequent redispatches go through redispatch() or redispatchBoxed(). 521 // `call` has profiler instrumentation, whereas `redispatch` doesn't. 522 // So usually, we'll collect a sequence number on the first call() if the 523 // dispatch keys contain autograd, and not on subsequent redispatches. 524 bool dispatchHasAutograd = !(dispatchKeySet & autograd_dispatch_keyset).empty(); 525 526 if (dispatchHasAutograd && at::GradMode::is_enabled()) { 527 seq_num = at::sequence_number::peek(); 528 } 529 return seq_num; 530 } 531 runRecordFunction(at::RecordFunction & guard,at::RecordFunction::schema_ref_t schema_ref,DispatchKey dispatchKey,DispatchKeySet dispatchKeySet,c10::ArrayRef<const c10::IValue> args)532 void Dispatcher::runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey, DispatchKeySet dispatchKeySet, c10::ArrayRef<const c10::IValue> args) { 533 guard.before(schema_ref, args, sequenceNumberForRunningRecordFunction(dispatchKey, dispatchKeySet)); 534 } 535 runRecordFunction(at::RecordFunction & guard,at::RecordFunction::schema_ref_t schema_ref,DispatchKey dispatchKey,DispatchKeySet dispatchKeySet)536 void Dispatcher::runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey, DispatchKeySet dispatchKeySet) { 537 // Setting sequence number in the Autograd case to associate 538 // the forward range with the corresponding Autograd's node 539 guard.before(schema_ref, sequenceNumberForRunningRecordFunction(dispatchKey, dispatchKeySet)); 540 } 541 #ifdef FBCODE_CAFFE2 profilingOperatorEvents()542 bool Dispatcher::profilingOperatorEvents() { 543 return TORCH_SDT_IS_ENABLED(operator_start) || TORCH_SDT_IS_ENABLED(operator_end); 544 } 545 fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref)546 C10_NOINLINE void Dispatcher::fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref) { 547 if (TORCH_SDT_IS_ENABLED(operator_start)) { 548 TORCH_SDT_WITH_SEMAPHORE(operator_start, schema_ref.get().name().c_str()); 549 } 550 } 551 fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref)552 C10_NOINLINE void Dispatcher::fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref) { 553 if (TORCH_SDT_IS_ENABLED(operator_end)) { 554 TORCH_SDT_WITH_SEMAPHORE(operator_end, schema_ref.get().name().c_str()); 555 } 556 } 557 #endif 558 559 } 560