xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/dispatch/Dispatcher.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1  #include <ATen/core/dispatch/Dispatcher.h>
2  #include <ATen/core/PythonOpRegistrationTrampoline.h>
3  #include <chrono>
4  #include <list>
5  #include <sstream>
6  #include <utility>
7  
8  #ifdef FBCODE_CAFFE2
9  #include <c10/util/static_tracepoint.h>
10  #endif
11  
12  namespace c10 {
13  
14  #ifdef FBCODE_CAFFE2
15  TORCH_SDT_DEFINE_SEMAPHORE(operator_start)
TORCH_SDT_DEFINE_SEMAPHORE(operator_end)16  TORCH_SDT_DEFINE_SEMAPHORE(operator_end)
17  #endif
18  
19  bool show_dispatch_trace() {
20      static char const* temp = getenv("TORCH_SHOW_DISPATCH_TRACE");
21      return temp != nullptr;
22  }
23  
24  static thread_local int64_t dispatch_trace_nesting_value_;
25  
dispatch_trace_nesting_incr()26  void dispatch_trace_nesting_incr() { ++dispatch_trace_nesting_value_; }
dispatch_trace_nesting_decr()27  void dispatch_trace_nesting_decr() { --dispatch_trace_nesting_value_; }
dispatch_trace_nesting_value()28  int64_t dispatch_trace_nesting_value() { return dispatch_trace_nesting_value_; }
29  
30  namespace detail {
31  
32  class RegistrationListenerList final {
33  public:
addListener(std::unique_ptr<OpRegistrationListener> listener)34    std::function<void()> addListener(std::unique_ptr<OpRegistrationListener> listener) {
35      listeners_.push_back(std::move(listener));
36      auto delete_it = --listeners_.end();
37      return [this, delete_it] {
38          listeners_.erase(delete_it);
39      };
40    }
41  
callOnOperatorRegistered(const OperatorHandle & op)42    void callOnOperatorRegistered(const OperatorHandle& op) {
43      for (auto& listener : listeners_) {
44        listener->onOperatorRegistered(op);
45      }
46    }
47  
callOnOperatorDeregistered(const OperatorHandle & op)48    void callOnOperatorDeregistered(const OperatorHandle& op) {
49      for (auto& listener : listeners_) {
50        listener->onOperatorDeregistered(op);
51      }
52    }
53  private:
54    std::list<std::unique_ptr<OpRegistrationListener>> listeners_;
55  };
56  
_print_dispatch_trace(const std::string & label,const std::string & op_name,const DispatchKeySet & dispatchKeySet)57  void _print_dispatch_trace(const std::string& label, const std::string& op_name, const DispatchKeySet& dispatchKeySet) {
58    auto nesting_value = dispatch_trace_nesting_value();
59    for (int64_t i = 0; i < nesting_value; ++i) std::cerr << " ";
60    std::cerr << label << " op=[" << op_name << "], key=[" << toString(dispatchKeySet.highestPriorityTypeId()) << "]" << std::endl;
61  }
62  } // namespace detail
63  
64  OpRegistrationListener::~OpRegistrationListener()= default;
65  
Dispatcher()66  Dispatcher::Dispatcher()
67  : operators_()
68  , operatorLookupTable_()
69  , backendFallbackKernels_()
70  , listeners_(std::make_unique<detail::RegistrationListenerList>())
71  , cond_var_()
72  , guard_(std::make_shared<Guard>())
73  {}
74  
~Dispatcher()75  Dispatcher::~Dispatcher() {
76    std::lock_guard<std::mutex> lock(guard_->mutex);
77    guard_->alive.store(false);
78  }
79  
realSingleton()80  C10_EXPORT Dispatcher& Dispatcher::realSingleton() {
81    static Dispatcher _singleton;
82    return _singleton;
83  }
84  
findOp(const OperatorName & overload_name)85  std::optional<OperatorHandle> Dispatcher::findOp(const OperatorName& overload_name) {
86    return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::optional<OperatorHandle> {
87      auto found = operatorLookupTable.find(overload_name);
88      if (found == operatorLookupTable.end()) {
89        return std::nullopt;
90      }
91      return found->second;
92    });
93  }
94  
95  // NB: If you add more waitFor* implementations, you also have to add
96  // appropriate notify_all() calls to the relevant register calls
97  
waitForDef(const FunctionSchema & schema)98  void Dispatcher::waitForDef(const FunctionSchema& schema) {
99    using namespace std::chrono_literals;
100    std::unique_lock<std::mutex> lock(guard_->mutex);
101    bool r = cond_var_.wait_for(lock, 2s, [&]{
102      return findOp(schema.operator_name()) != std::nullopt;
103    });
104    TORCH_INTERNAL_ASSERT(r,
105      "Expected main interpreter to define ", schema.operator_name(),
106      ", but this didn't happen within timeout.  Are you trying to load "
107      "different models in the same torchdeploy/multipy instance?  You "
108      "must warmup each interpreter identically, e.g., import all "
109      "the same dependencies.");
110  }
111  
waitForImpl(const OperatorName & op_name,std::optional<c10::DispatchKey> maybe_dk)112  void Dispatcher::waitForImpl(const OperatorName& op_name, std::optional<c10::DispatchKey> maybe_dk) {
113    using namespace std::chrono_literals;
114    std::unique_lock<std::mutex> lock(guard_->mutex);
115    auto dk = maybe_dk.value_or(DispatchKey::CompositeImplicitAutograd);
116    auto op = findOrRegisterName_(op_name);
117    bool r = cond_var_.wait_for(lock, 2s, [&]{
118      // NB: this is slightly unsound for overrides, but overrides are
119      // funny business anyway
120      return op.hasKernelForDispatchKey(dk);
121    });
122    TORCH_INTERNAL_ASSERT(r,
123      "Expected main interpreter to implement ", dk, " for ", op_name,
124      ", but this didn't happen within timeout.  Are you trying to load "
125      "different models in the same torchdeploy/multipy instance?  You "
126      "must warmup each interpreter identically, e.g., import all "
127      "the same dependencies.");
128  }
129  
findSchema(const OperatorName & overload_name)130  std::optional<OperatorHandle> Dispatcher::findSchema(const OperatorName& overload_name) {
131    auto it = findOp(overload_name);
132    if (it.has_value()) {
133      if (it->hasSchema()) {
134        return it;
135      } else {
136        return std::nullopt;
137      }
138    } else {
139      return it;
140    }
141  }
142  
findSchemaOrThrow(const char * name,const char * overload_name)143  OperatorHandle Dispatcher::findSchemaOrThrow(const char* name, const char* overload_name) {
144    auto it = findSchema({name, overload_name});
145    if (!it.has_value()) {
146      // Check if we have ANYTHING; if that's the case, that means you're
147      // missing schema
148      auto it2 = findOp({name, overload_name});
149      if (!it2.has_value()) {
150        TORCH_CHECK(false, "Could not find schema for ", name, ".", overload_name);
151      } else {
152        TORCH_CHECK(false, "Could not find schema for ", name, ".", overload_name,
153          " but we found an implementation; did you forget to def() the operator?");
154      }
155    }
156    return it.value();
157  }
158  
getAllOpNames()159  const std::vector<OperatorName> Dispatcher::getAllOpNames() {
160    return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::vector<OperatorName> {
161      std::vector<OperatorName> allOpNames;
162      for (const auto& op : operatorLookupTable) {
163          allOpNames.push_back(op.first);
164      }
165      return allOpNames;
166    });
167  }
168  
169  // Postcondition: caller is responsible for disposing of registration when they
170  // are done
findOrRegisterName_(const OperatorName & op_name)171  OperatorHandle Dispatcher::findOrRegisterName_(const OperatorName& op_name) {
172    const auto found = findOp(op_name);
173    if (found != std::nullopt) {
174      return *found;
175    }
176  
177    operators_.emplace_back(OperatorName(op_name));
178    OperatorHandle handle(--operators_.end());
179    operatorLookupTable_.write([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) {
180      operatorLookupTable.emplace(op_name, handle);
181    });
182  
183    return handle;
184  }
185  
186  
187  // Adding explicit destructor definition in the cpp to over linker error in Windows builds.
188  // Windows build doesn't produce the destructor symbol in PyTorch libs
189  // causing a linker failure in downstream projects.
190  // x-ref https://github.com/pytorch/pytorch/issues/70032
191  OperatorHandle::~OperatorHandle() = default;
192  
registerLibrary(std::string ns,std::string debug)193  RegistrationHandleRAII Dispatcher::registerLibrary(std::string ns, std::string debug) {
194    std::lock_guard<std::mutex> lock(guard_->mutex);
195    auto found = libraries_.find(ns);
196    TORCH_CHECK(
197      found == libraries_.end(),
198      "Only a single TORCH_LIBRARY can be used to register the namespace ", ns,
199      "; please put all of your definitions in a single TORCH_LIBRARY block.  "
200      "If you were trying to specify implementations, consider using TORCH_LIBRARY_IMPL "
201      "(which can be duplicated).  If you really intended to define operators for a "
202      "single namespace in a distributed way, you can use TORCH_LIBRARY_FRAGMENT to "
203      "explicitly indicate this.  "
204      "Previous registration of TORCH_LIBRARY was ",
205      found->second, "; latest registration was ", debug
206    );
207    libraries_.emplace(ns, std::move(debug));
208    return RegistrationHandleRAII([guard = this->guard_, this, ns] {
209      std::lock_guard<std::mutex> lock(guard->mutex);
210      if (!guard->alive.load()) {
211        return;
212      }
213      deregisterLibrary_(ns);
214    });
215  }
216  
deregisterLibrary_(const std::string & ns)217  void Dispatcher::deregisterLibrary_(const std::string& ns) {
218    // we need a lock to avoid concurrent writes
219    libraries_.erase(ns);
220  }
221  
registerDef(FunctionSchema schema,std::string debug,std::vector<at::Tag> tags)222  RegistrationHandleRAII Dispatcher::registerDef(FunctionSchema schema, std::string debug, std::vector<at::Tag> tags) {
223    // we need a lock to avoid concurrent writes
224    std::lock_guard<std::mutex> lock(guard_->mutex);
225  
226    OperatorName op_name = schema.operator_name();
227    auto op = findOrRegisterName_(op_name);
228  
229    TORCH_CHECK(op.operatorDef_->def_count == 0, "Tried to register an operator (", schema, ") with the same name and overload name multiple times.",
230                                                      " Each overload's schema should only be registered with a single call to def().",
231                                                      " Duplicate registration: ", debug, ". Original registration: ", op.operatorDef_->op.debug());
232    op.operatorDef_->op.registerSchema(std::move(schema), std::move(debug), std::move(tags));
233    listeners_->callOnOperatorRegistered(op);
234  
235    // NB: do not increment the counts until AFTER error checking
236    ++op.operatorDef_->def_count;
237    ++op.operatorDef_->def_and_impl_count;
238  
239    cond_var_.notify_all();
240  
241    return RegistrationHandleRAII([guard = this->guard_, this, op, op_name] {
242      // we need a lock to avoid concurrent writes
243      std::lock_guard<std::mutex> lock(guard->mutex);
244      if (!guard->alive.load()) {
245        return;
246      }
247      deregisterDef_(op, op_name);
248    });
249  }
250  
deregisterDef_(const OperatorHandle & op,const OperatorName & op_name)251  void Dispatcher::deregisterDef_(
252      const OperatorHandle& op,
253      const OperatorName& op_name) {
254    TORCH_INTERNAL_ASSERT(op.schema().operator_name() == op_name);
255  
256    // reduce def_count and actually deregister if no references left
257    TORCH_INTERNAL_ASSERT(op.operatorDef_->def_count > 0);
258    TORCH_INTERNAL_ASSERT(op.operatorDef_->def_and_impl_count > 0);
259  
260    --op.operatorDef_->def_count;
261    --op.operatorDef_->def_and_impl_count;
262    if (0 == op.operatorDef_->def_count) {
263      // note: call listeners *before* operator is removed, i.e. dispatcher is still valid for removed op
264      // TODO: check that listeners are not relying on prepareForDeregistration()
265      // invariant
266      listeners_->callOnOperatorDeregistered(op);
267      op.operatorDef_->op.deregisterSchema();
268    }
269  
270    cleanup(op, op_name);
271  }
272  
273  namespace {
274  
275  // Maps OperatorName to (python module name, description) tuple.
276  using PythonModuleMapType = std::unordered_map<at::OperatorName, std::pair<const char*, const char*>>;
pythonModulesSingleton()277  PythonModuleMapType& pythonModulesSingleton() {
278    static PythonModuleMapType _data;
279    return _data;
280  }
281  
282  }
283  
getPyStub(OperatorName op_name)284  std::optional<std::pair<const char*, const char*>> Dispatcher::getPyStub(OperatorName op_name) {
285    std::lock_guard<std::mutex> lock(guard_->mutex);
286    auto found = pythonModulesSingleton().find(op_name);
287    if (found == pythonModulesSingleton().end()) {
288      return std::nullopt;
289    }
290    return found->second;
291  }
292  
registerPythonModule(const OperatorName & op_name,const char * pymodule,const char * context)293  RegistrationHandleRAII Dispatcher::registerPythonModule(
294    const OperatorName& op_name,
295    const char* pymodule,
296    const char* context
297  ) {
298    std::lock_guard<std::mutex> lock(guard_->mutex);
299    // If there are duplicates, we just let it through and warn about it.
300    // Throwing an error during static initialization causes a crash that
301    // doesn't give any sign of what happened.
302    auto found = pythonModulesSingleton().find(op_name);
303    if (found != pythonModulesSingleton().end()) {
304      TORCH_WARN(
305          "Tried to register an python registration stub (pystub) for ", op_name, " ",
306          "that specifies the Python module ", pymodule, " "
307          "but there already was a pystub that specifies the Python module ",
308          found->second.first, ". We will override the existing pystub.");
309    }
310    pythonModulesSingleton()[op_name] = std::make_pair(pymodule, context);
311    return RegistrationHandleRAII([guard = this->guard_, op_name] {
312      std::lock_guard<std::mutex> lock(guard->mutex);
313      if (!guard->alive.load()) {
314        return;
315      }
316      pythonModulesSingleton().erase(op_name);
317    });
318  }
319  
throwIfHasPythonModule(OperatorName op_name)320  void Dispatcher::throwIfHasPythonModule(OperatorName op_name) {
321    std::lock_guard<std::mutex> lock(guard_->mutex);
322    auto elt = pythonModulesSingleton().find(op_name);
323    if (elt == pythonModulesSingleton().end()) {
324      return;
325    }
326    const char* pymodule = elt->second.first;
327    const char* context = elt->second.second;
328    auto* interpreter = at::impl::PythonOpRegistrationTrampoline::getInterpreter();
329    TORCH_CHECK(
330        interpreter != nullptr,
331        op_name,
332        ": while attempting to run this operator with Meta Tensors: "
333        "Either there is no meta kernel for this operator, or it is located "
334        "in the python module ", pymodule, " which is not available "
335        "because Python isn't available.")
336    (*interpreter)->throw_abstract_impl_not_imported_error(toString(op_name), pymodule, context);
337  }
338  
registerImpl(OperatorName op_name,std::optional<DispatchKey> dispatch_key,KernelFunction kernel,std::optional<impl::CppSignature> cpp_signature,std::unique_ptr<FunctionSchema> inferred_function_schema,std::string debug)339  RegistrationHandleRAII Dispatcher::registerImpl(
340    OperatorName op_name,
341    std::optional<DispatchKey> dispatch_key,
342    KernelFunction kernel,
343    std::optional<impl::CppSignature> cpp_signature,
344    std::unique_ptr<FunctionSchema> inferred_function_schema,
345    std::string debug
346  ) {
347    std::lock_guard<std::mutex> lock(guard_->mutex);
348  
349    auto op = findOrRegisterName_(op_name);
350  
351    auto handle = op.operatorDef_->op.registerKernel(
352      *this,
353      dispatch_key,
354      std::move(kernel),
355      std::move(cpp_signature),
356      std::move(inferred_function_schema),
357      std::move(debug)
358    );
359  
360    ++op.operatorDef_->def_and_impl_count;
361  
362    cond_var_.notify_all();
363  
364    return RegistrationHandleRAII([guard = this->guard_, this, op, op_name, dispatch_key, handle] {
365      std::lock_guard<std::mutex> lock(guard->mutex);
366      if (!guard->alive.load()) {
367        return;
368      }
369      deregisterImpl_(op, op_name, dispatch_key, handle);
370    });
371  }
372  
deregisterImpl_(const OperatorHandle & op,const OperatorName & op_name,std::optional<DispatchKey> dispatch_key,impl::OperatorEntry::AnnotatedKernelContainerIterator handle)373  void Dispatcher::deregisterImpl_(const OperatorHandle& op, const OperatorName& op_name, std::optional<DispatchKey> dispatch_key, impl::OperatorEntry::AnnotatedKernelContainerIterator handle) {
374    op.operatorDef_->op.deregisterKernel_(*this, dispatch_key, handle);
375  
376    TORCH_INTERNAL_ASSERT(op.operator_name() == op_name);
377  
378    TORCH_INTERNAL_ASSERT(op.operatorDef_->def_and_impl_count > 0);
379    --op.operatorDef_->def_and_impl_count;
380  
381    cleanup(op, op_name);
382  }
383  
registerName(OperatorName op_name)384  RegistrationHandleRAII Dispatcher::registerName(OperatorName op_name) {
385    std::lock_guard<std::mutex> lock(guard_->mutex);
386    auto op = findOrRegisterName_(op_name);
387    ++op.operatorDef_->def_and_impl_count;
388  
389    return RegistrationHandleRAII(
390        [guard = this->guard_, this, op, op_name] {
391          std::lock_guard<std::mutex> lock(guard->mutex);
392          if (!guard->alive.load()) {
393            return;
394          }
395          deregisterName_(op, op_name);
396        }
397    );
398  }
399  
deregisterName_(const OperatorHandle & op,const OperatorName & op_name)400  void Dispatcher::deregisterName_(
401      const OperatorHandle& op,
402      const OperatorName& op_name) {
403    TORCH_INTERNAL_ASSERT(op.operator_name() == op_name);
404    TORCH_INTERNAL_ASSERT(op.operatorDef_->def_and_impl_count > 0);
405    --op.operatorDef_->def_and_impl_count;
406    cleanup(op, op_name);
407  }
408  
409  // Test if the operator entry is completely dead, and if so remove it completely
cleanup(const OperatorHandle & op,const OperatorName & op_name)410  void Dispatcher::cleanup(const OperatorHandle& op, const OperatorName& op_name) {
411    if (0 == op.operatorDef_->def_and_impl_count) {
412      // NOTE: Making this call fast is the only reason OperatorHandle
413      // stores operatorIterator_!
414      operators_.erase(op.operatorIterator_);
415      operatorLookupTable_.write([&] (ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) {
416        operatorLookupTable.erase(op_name);
417      });
418    }
419  }
420  
registerFallback(DispatchKey dispatchKey,KernelFunction kernel,std::string debug)421  RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, KernelFunction kernel, std::string debug) {
422    std::lock_guard<std::mutex> lock(guard_->mutex);
423  
424    auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
425    TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
426    TORCH_CHECK(
427      !backendFallbackKernels_[idx].kernel.isValid(),
428      "Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
429      backendFallbackKernels_[idx].debug, ", new registration ", debug
430    );
431    // NB: inferred function schema is always nullptr for fallbacks, as fallbacks
432    // cannot be unboxed
433    backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
434  
435    for (auto& op : operators_) {
436      op.op.updateFallback(*this, dispatchKey);
437    }
438  
439    return RegistrationHandleRAII([guard = this->guard_, this, dispatchKey] {
440      std::lock_guard<std::mutex> lock(guard->mutex);
441      if (!guard->alive.load()) {
442        return;
443      }
444      deregisterFallback_(dispatchKey);
445    });
446  }
447  
deregisterFallback_(DispatchKey dispatchKey)448  void Dispatcher::deregisterFallback_(DispatchKey dispatchKey) {
449    auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
450    backendFallbackKernels_[idx] = {};
451  
452    for (auto& op : operators_) {
453      op.op.updateFallback(*this, dispatchKey);
454    }
455  }
456  
457  
addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener)458  RegistrationHandleRAII Dispatcher::addRegistrationListener(std::unique_ptr<OpRegistrationListener> listener) {
459    std::lock_guard<std::mutex> lock(guard_->mutex);
460  
461    for (auto iter = operators_.begin(); iter != operators_.end(); ++iter) {
462      if (iter->def_count > 0) {
463        listener->onOperatorRegistered(OperatorHandle(iter));
464      }
465    }
466  
467    auto removeListener = listeners_->addListener(std::move(listener));
468    return RegistrationHandleRAII([guard = this->guard_, this, removeListener] {
469        std::lock_guard<std::mutex> lock(guard_->mutex);
470        if (!guard->alive.load()) {
471          return;
472        }
473        removeListener();
474    });
475  }
476  
checkInvariants() const477  void Dispatcher::checkInvariants() const {
478    for (const auto& op : operators_) {
479      op.op.checkInvariants();
480    }
481  }
482  
findDanglingImpls() const483  std::vector<OperatorHandle> Dispatcher::findDanglingImpls() const {
484    return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::vector<OperatorHandle> {
485      std::vector<OperatorHandle> opsWithDanglingImpls;
486      for (const auto& op : operatorLookupTable) {
487        if (!op.second.hasSchema()) {
488          opsWithDanglingImpls.push_back(op.second);
489        }
490      }
491      return opsWithDanglingImpls;
492    });
493  }
494  
getRegistrationsForDispatchKey(std::optional<DispatchKey> k) const495  std::vector<OperatorName> Dispatcher::getRegistrationsForDispatchKey(std::optional<DispatchKey> k) const {
496    return operatorLookupTable_.read([&] (const ska::flat_hash_map<OperatorName, OperatorHandle>& operatorLookupTable) -> std::vector<OperatorName> {
497      std::vector<OperatorName> op_names;
498      for (const auto& op : operatorLookupTable) {
499        // If no DispatchKey is specified, print all of the operators.
500        if (!k || op.second.hasKernelForDispatchKey(*k)) {
501            op_names.push_back(op.first);
502        }
503      }
504      return op_names;
505    });
506  }
507  
sequenceNumberForRunningRecordFunction(DispatchKey dispatchKey,DispatchKeySet dispatchKeySet)508  int64_t Dispatcher::sequenceNumberForRunningRecordFunction(DispatchKey dispatchKey, DispatchKeySet dispatchKeySet) {
509    int64_t seq_num = -1;
510    // Setting sequence number in the Autograd case to associate
511    // the forward range with the corresponding Autograd's node
512  
513    // Note: this records a sequence number for both Autograd keys, and for
514    // non-Autograd keys where the dispatchKeySet still contains an autograd key.
515    // This means that we might collect the same sequence nubmer two different
516    // events if they all occurred above Autograd and still had the Autograd
517    // dispatch key in the dispatch key set.
518    // However, this usually doesn't happen: normally the first call will
519    // go through the call() or callBoxed() path in the dispatcher, while
520    // subsequent redispatches go through redispatch() or redispatchBoxed().
521    // `call` has profiler instrumentation, whereas `redispatch` doesn't.
522    // So usually, we'll collect a sequence number on the first call() if the
523    // dispatch keys contain autograd, and not on subsequent redispatches.
524    bool dispatchHasAutograd = !(dispatchKeySet & autograd_dispatch_keyset).empty();
525  
526    if (dispatchHasAutograd && at::GradMode::is_enabled()) {
527      seq_num = at::sequence_number::peek();
528    }
529    return seq_num;
530  }
531  
runRecordFunction(at::RecordFunction & guard,at::RecordFunction::schema_ref_t schema_ref,DispatchKey dispatchKey,DispatchKeySet dispatchKeySet,c10::ArrayRef<const c10::IValue> args)532  void Dispatcher::runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey, DispatchKeySet dispatchKeySet, c10::ArrayRef<const c10::IValue> args) {
533    guard.before(schema_ref, args, sequenceNumberForRunningRecordFunction(dispatchKey, dispatchKeySet));
534  }
535  
runRecordFunction(at::RecordFunction & guard,at::RecordFunction::schema_ref_t schema_ref,DispatchKey dispatchKey,DispatchKeySet dispatchKeySet)536  void Dispatcher::runRecordFunction(at::RecordFunction& guard, at::RecordFunction::schema_ref_t schema_ref, DispatchKey dispatchKey, DispatchKeySet dispatchKeySet) {
537    // Setting sequence number in the Autograd case to associate
538    // the forward range with the corresponding Autograd's node
539    guard.before(schema_ref, sequenceNumberForRunningRecordFunction(dispatchKey, dispatchKeySet));
540  }
541  #ifdef FBCODE_CAFFE2
profilingOperatorEvents()542  bool Dispatcher::profilingOperatorEvents() {
543    return TORCH_SDT_IS_ENABLED(operator_start) || TORCH_SDT_IS_ENABLED(operator_end);
544  }
545  
fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref)546  C10_NOINLINE void Dispatcher::fireOpStartUSDT(at::RecordFunction::schema_ref_t schema_ref) {
547    if (TORCH_SDT_IS_ENABLED(operator_start)) {
548      TORCH_SDT_WITH_SEMAPHORE(operator_start, schema_ref.get().name().c_str());
549    }
550  }
551  
fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref)552  C10_NOINLINE void Dispatcher::fireOpEndUSDT(at::RecordFunction::schema_ref_t schema_ref) {
553    if (TORCH_SDT_IS_ENABLED(operator_end)) {
554      TORCH_SDT_WITH_SEMAPHORE(operator_end, schema_ref.get().name().c_str());
555    }
556  }
557  #endif
558  
559  }
560