xref: /aosp_15_r20/external/pytorch/torch/csrc/profiler/orchestration/python_tracer.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstdint>
4 #include <memory>
5 #include <utility>
6 #include <vector>
7 
8 #include <c10/util/ApproximateClock.h>
9 #include <c10/util/strong_type.h>
10 
11 #include <torch/csrc/profiler/kineto_shim.h>
12 #include <torch/csrc/profiler/util.h>
13 
14 namespace torch {
15 namespace profiler {
16 namespace impl {
17 
18 class RecordQueue;
19 struct Result;
20 namespace python_tracer {
21 
22 using TraceKey = strong::type<
23     uint64_t,
24     struct TraceKey_,
25     strong::regular,
26     strong::hashable,
27     strong::ostreamable>;
28 
29 struct CompressedEvent {
30   TraceKey key_;
31   uint64_t system_tid_{};
32   kineto::DeviceAndResource kineto_info_{};
33   c10::time_t enter_t_{};
34 };
35 
36 /*
37 Libtorch does not depend on Python (e.g. cannot #include <Python.h>); however
38 when we call the profiler from libtorch_python we need the profiler to be able
39 to ingest the data that we collect from the Python tracer. (`PyEval_SetProfile`)
40 
41 In order to solve this dependency issue we define a virtual base and a function
42 to register a getter. The python tracer then implements these functions and
43 exposes itself by calling `registerTracer` from `torch/csrc/autograd/init.cpp`.
44 This pattern of registration for faux python dependencies in libtorch is common
45 in the PyTorch codebase.
46 */
47 struct TORCH_API PythonTracerBase {
48   static std::unique_ptr<PythonTracerBase> make(RecordQueue* queue);
49   virtual ~PythonTracerBase() = default;
50 
51   virtual void stop() = 0;
52   virtual void restart() = 0;
53   virtual std::vector<std::shared_ptr<Result>> getEvents(
54       std::function<c10::time_t(c10::approx_time_t)> time_converter,
55       std::vector<CompressedEvent>& enters,
56       c10::time_t end_time_ns) = 0;
57 };
58 
59 using MakeFn = std::unique_ptr<PythonTracerBase> (*)(RecordQueue*);
60 TORCH_API void registerTracer(MakeFn make_tracer);
61 } // namespace python_tracer
62 } // namespace impl
63 } // namespace profiler
64 } // namespace torch
65