1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from torchgen.model import FunctionSchema, SchemaKind 10from torchgen.native_function_generation import ( 11 functional_to_out_signature, 12 mutable_to_out_signature, 13 self_to_out_signature, 14) 15from torchgen.utils import NamespaceHelper 16 17 18def gen_out_variant_schema(func_op_schema: str) -> str: 19 """ 20 Generate schema for the out= variant of a given functional operator schema. 21 """ 22 # Parse the operator schema 23 namespace_helper = NamespaceHelper.from_namespaced_entity( 24 namespaced_entity=func_op_schema, max_level=1 25 ) 26 func = FunctionSchema.parse(namespace_helper.entity_name) 27 28 namespace = namespace_helper.get_cpp_namespace(default="") 29 # Convert it to out variant schema 30 if func.kind() == SchemaKind.inplace: 31 schema = str(self_to_out_signature(func)) 32 elif func.kind() == SchemaKind.functional: 33 schema = str(functional_to_out_signature(func)) 34 elif func.kind() == SchemaKind.mutable: 35 schema = str(mutable_to_out_signature(func)) 36 elif func.kind() == SchemaKind.out: 37 schema = str(func) 38 else: 39 raise RuntimeError(f"SchemaKind: {func.kind()} is not supported") 40 41 return f"{namespace}::{schema}" if namespace else schema 42