xref: /aosp_15_r20/external/pytorch/torch/fx/annotate.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.proxy import Proxy
3*da0073e9SAndroid Build Coastguard Workerfrom ._compatibility import compatibility
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
6*da0073e9SAndroid Build Coastguard Workerdef annotate(val, type):
7*da0073e9SAndroid Build Coastguard Worker    """
8*da0073e9SAndroid Build Coastguard Worker    Annotates a Proxy object with a given type.
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker    This function annotates a val with a given type if a type of the val is a torch.fx.Proxy object
11*da0073e9SAndroid Build Coastguard Worker    Args:
12*da0073e9SAndroid Build Coastguard Worker        val (object): An object to be annotated if its type is torch.fx.Proxy.
13*da0073e9SAndroid Build Coastguard Worker        type (object): A type to be assigned to a given proxy object as val.
14*da0073e9SAndroid Build Coastguard Worker    Returns:
15*da0073e9SAndroid Build Coastguard Worker        The given val.
16*da0073e9SAndroid Build Coastguard Worker    Raises:
17*da0073e9SAndroid Build Coastguard Worker        RuntimeError: If a val already has a type in its node.
18*da0073e9SAndroid Build Coastguard Worker    """
19*da0073e9SAndroid Build Coastguard Worker    if isinstance(val, Proxy):
20*da0073e9SAndroid Build Coastguard Worker        if val.node.type:
21*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"Tried to annotate a value that already had a type on it!"
22*da0073e9SAndroid Build Coastguard Worker                               f" Existing type is {val.node.type} "
23*da0073e9SAndroid Build Coastguard Worker                               f"and new type is {type}. "
24*da0073e9SAndroid Build Coastguard Worker                               f"This could happen if you tried to annotate a function parameter "
25*da0073e9SAndroid Build Coastguard Worker                               f"value (in which case you should use the type slot "
26*da0073e9SAndroid Build Coastguard Worker                               f"on the function signature) or you called "
27*da0073e9SAndroid Build Coastguard Worker                               f"annotate on the same value twice")
28*da0073e9SAndroid Build Coastguard Worker        else:
29*da0073e9SAndroid Build Coastguard Worker            val.node.type = type
30*da0073e9SAndroid Build Coastguard Worker        return val
31*da0073e9SAndroid Build Coastguard Worker    else:
32*da0073e9SAndroid Build Coastguard Worker        return val
33