xref: /aosp_15_r20/external/pytorch/c10/core/AutogradState.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/AutogradState.h>
2 
3 namespace c10 {
4 
5 namespace {
6 // By default, grad mode and multithreading are enabled, inference mode is
7 // disabled,
8 thread_local AutogradState autograd_state_tls = AutogradState(
9     /* grad_mode */ true,
10     /* inference_mode */ false,
11     /* fw_grad_mode */ true,
12     /* multithreading_enabled */ true);
13 } // namespace
14 
get_tls_state()15 AutogradState& AutogradState::get_tls_state() {
16   return autograd_state_tls;
17 }
18 
set_tls_state(AutogradState state)19 void AutogradState::set_tls_state(AutogradState state) {
20   autograd_state_tls = state;
21 }
22 
23 } // namespace c10
24