xref: /aosp_15_r20/external/pytorch/tools/lldb/pytorch_lldb.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Any
2
3import lldb  # type: ignore[import]
4
5
6def get_target() -> Any:
7    target = lldb.debugger.GetSelectedTarget()
8    if not target:
9        print("[-] error: no target available. please add a target to lldb.")
10        return None
11    return target
12
13
14class DisableBreakpoints:
15    """
16    Context-manager to temporarily disable all lldb breakpoints, useful if
17    there is a risk to hit one during the evaluation of one of our custom
18    commands
19    """
20
21    def __enter__(self) -> None:
22        target = get_target()
23
24        if target.DisableAllBreakpoints() is False:
25            print("[-] error: failed to disable all breakpoints.")
26
27    def __exit__(self, etype: Any, evalue: Any, tb: Any) -> None:
28        target = get_target()
29
30        if target.EnableAllBreakpoints() is False:
31            print("[-] error: failed to enable all breakpoints.")
32
33
34def IntArrayRef_summary(valobj: Any, internal_dict: Any, options: Any) -> str:
35    """Print human readable representation of c10::IntArrayRef"""
36    with DisableBreakpoints():
37        target = get_target()
38        tensor = valobj.GetName()
39        result = target.EvaluateExpression(
40            f"torch::gdb::int_array_ref_string({tensor})"
41        )
42        str_result = str(result)
43        str_result = str_result[str_result.find('"') + 1 : -1]
44        return str_result
45
46
47def DispatchKeyset_summary(valobj: Any, internal_dict: Any, options: Any) -> str:
48    """Print human readable representation of c10::DispatchKeyset"""
49    with DisableBreakpoints():
50        target = get_target()
51        keyset = valobj.GetName()
52        result = target.EvaluateExpression(
53            f"torch::gdb::dispatch_keyset_string({keyset})"
54        )
55        str_result = str(result)
56        str_result = str_result[str_result.find('"') + 1 : -1]
57        return str_result
58
59
60def Tensor_summary(valobj: Any, internal_dict: Any, options: Any) -> str:
61    """Print a human readable representation of the given at::Tensor.
62
63    at::Tensor instances do not have a C++ implementation of a repr method: in
64    pytorch, this is done by pure-Python code. As such, print <tensor>
65    internally creates a Python wrapper for the given tensor and call repr()
66    on it.
67    Usage:
68        print self
69    """
70    with DisableBreakpoints():
71        target = get_target()
72        tensor = valobj.GetName()
73        result = target.EvaluateExpression(f"torch::gdb::tensor_repr({tensor})")
74        str_result = str(result)
75        target.EvaluateExpression(f"(void)free({result.GetValue()})")
76        str_result = "\n" + str_result[str_result.find("tensor") : -1]
77        return str_result
78
79
80# And the initialization code to add your commands
81def __lldb_init_module(debugger: Any, internal_dict: Any) -> Any:
82    debugger.HandleCommand(
83        "type summary add c10::IntArrayRef -F pytorch_lldb.IntArrayRef_summary -w torch"
84    )
85    debugger.HandleCommand(
86        "type summary add c10::DispatchKeySet -F pytorch_lldb.DispatchKeyset_summary -w torch"
87    )
88    debugger.HandleCommand(
89        "type summary add at::Tensor -F pytorch_lldb.Tensor_summary -w torch"
90    )
91    print(
92        "Pretty Printing lldb summary for PyTorch AT types has been installed and is ready for use. "
93        "This category is enabled by default. To disable run: `type category disable torch`"
94    )
95    print(
96        "Usage:\n\tprint <at::tensor>\n\tprint <c10::IntArrayRef>\n\tprint <c10::DispatchKeySet>"
97    )
98