xref: /aosp_15_r20/external/executorch/exir/operator/util.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from torchgen.model import FunctionSchema, SchemaKind
10from torchgen.native_function_generation import (
11    functional_to_out_signature,
12    mutable_to_out_signature,
13    self_to_out_signature,
14)
15from torchgen.utils import NamespaceHelper
16
17
18def gen_out_variant_schema(func_op_schema: str) -> str:
19    """
20    Generate schema for the out= variant of a given functional operator schema.
21    """
22    # Parse the operator schema
23    namespace_helper = NamespaceHelper.from_namespaced_entity(
24        namespaced_entity=func_op_schema, max_level=1
25    )
26    func = FunctionSchema.parse(namespace_helper.entity_name)
27
28    namespace = namespace_helper.get_cpp_namespace(default="")
29    # Convert it to out variant schema
30    if func.kind() == SchemaKind.inplace:
31        schema = str(self_to_out_signature(func))
32    elif func.kind() == SchemaKind.functional:
33        schema = str(functional_to_out_signature(func))
34    elif func.kind() == SchemaKind.mutable:
35        schema = str(mutable_to_out_signature(func))
36    elif func.kind() == SchemaKind.out:
37        schema = str(func)
38    else:
39        raise RuntimeError(f"SchemaKind: {func.kind()} is not supported")
40
41    return f"{namespace}::{schema}" if namespace else schema
42