1 #include <ATen/record_function.h> 2 #include <torch/csrc/dynamo/cpp_shim.h> 3 4 struct _PytorchRecordFunctionState { 5 at::RecordFunction guard; 6 _PytorchRecordFunctionState_PytorchRecordFunctionState7 _PytorchRecordFunctionState() : guard(at::RecordScope::FUNCTION) {} 8 }; 9 _pytorch_record_function_enter(const char * name)10 _PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name) { 11 _PytorchRecordFunctionState* state = new _PytorchRecordFunctionState(); 12 state->guard.before(name); 13 return state; 14 } 15 16 static inline _PytorchRecordFunctionState* _pytorch_record_function_enter_with_kwinputs(const char * name,const std::unordered_map<std::string,c10::IValue> * kwargs)17 _pytorch_record_function_enter_with_kwinputs( 18 const char* name, 19 const std::unordered_map<std::string, c10::IValue>* kwargs) { 20 _PytorchRecordFunctionState* state = new _PytorchRecordFunctionState(); 21 std::vector<c10::IValue> args; 22 state->guard.before(name, &args, kwargs); 23 return state; 24 } 25 _pytorch_record_function_enter_with_context(const char * name,const char * context)26 _PytorchRecordFunctionState* _pytorch_record_function_enter_with_context( 27 const char* name, 28 const char* context) { 29 auto map = std::unordered_map<std::string, c10::IValue>(); 30 map.insert({"context", c10::IValue(context)}); 31 return _pytorch_record_function_enter_with_kwinputs(name, &map); 32 } 33 _pytorch_record_function_exit(_PytorchRecordFunctionState * state)34 void _pytorch_record_function_exit(_PytorchRecordFunctionState* state) { 35 if (state == nullptr) { 36 return; 37 } 38 delete state; 39 } 40