xref: /aosp_15_r20/external/pytorch/c10/core/InferenceMode.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/InferenceMode.h>
2 
3 namespace c10 {
4 // Invariant:
5 //   is_enabled() ==
6 //   !c10::impl::tls_is_dispatch_key_included(DispatchKey::ADInplaceOrView);
7 // InferenceMode::is_enabled() is in perf critical path (TensorImpl constructor)
8 // so it worths a separate TLS to skip the DispatchKeySet check.
is_enabled()9 bool InferenceMode::is_enabled() {
10   return AutogradState::get_tls_state().get_inference_mode();
11 }
12 } // namespace c10
13