xref: /aosp_15_r20/external/pytorch/torch/csrc/dynamo/cpp_shim.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1  #include <ATen/record_function.h>
2  #include <torch/csrc/dynamo/cpp_shim.h>
3  
4  struct _PytorchRecordFunctionState {
5    at::RecordFunction guard;
6  
_PytorchRecordFunctionState_PytorchRecordFunctionState7    _PytorchRecordFunctionState() : guard(at::RecordScope::FUNCTION) {}
8  };
9  
_pytorch_record_function_enter(const char * name)10  _PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name) {
11    _PytorchRecordFunctionState* state = new _PytorchRecordFunctionState();
12    state->guard.before(name);
13    return state;
14  }
15  
16  static inline _PytorchRecordFunctionState*
_pytorch_record_function_enter_with_kwinputs(const char * name,const std::unordered_map<std::string,c10::IValue> * kwargs)17  _pytorch_record_function_enter_with_kwinputs(
18      const char* name,
19      const std::unordered_map<std::string, c10::IValue>* kwargs) {
20    _PytorchRecordFunctionState* state = new _PytorchRecordFunctionState();
21    std::vector<c10::IValue> args;
22    state->guard.before(name, &args, kwargs);
23    return state;
24  }
25  
_pytorch_record_function_enter_with_context(const char * name,const char * context)26  _PytorchRecordFunctionState* _pytorch_record_function_enter_with_context(
27      const char* name,
28      const char* context) {
29    auto map = std::unordered_map<std::string, c10::IValue>();
30    map.insert({"context", c10::IValue(context)});
31    return _pytorch_record_function_enter_with_kwinputs(name, &map);
32  }
33  
_pytorch_record_function_exit(_PytorchRecordFunctionState * state)34  void _pytorch_record_function_exit(_PytorchRecordFunctionState* state) {
35    if (state == nullptr) {
36      return;
37    }
38    delete state;
39  }
40