xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/graph_to_function_def.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Utility to convert a Graph to a FunctionDef."""
16
17import re
18
19from tensorflow.core.framework import function_pb2
20from tensorflow.core.framework import op_def_pb2
21from tensorflow.python.framework import op_def_registry
22
23
24def _make_argname_from_tensor_name(name):
25  return re.sub(":0$", "", name).replace(":", "_o")
26
27
28def _tensor_to_argdef(t, name=None, used_names=None):
29  """Convert tensor t to an argdef, with a specified name or a unique name."""
30  arg = op_def_pb2.OpDef.ArgDef()
31  if name is None:
32    arg.name = _make_argname_from_tensor_name(t.name)
33    if used_names is not None:
34      if arg.name in used_names:
35        i = 0
36        while True:
37          new_name = "%s_U%d" % (arg.name, i)
38          if new_name not in used_names:
39            arg.name = new_name
40            break
41          i += 1
42      used_names.add(arg.name)
43  else:
44    arg.name = name
45  arg.type = t.dtype.as_datatype_enum
46  return arg
47
48
49def _is_in_placeholders(op, func_arg_placeholders):
50  """Checks whether any output of this op is in func_arg_placeholders."""
51  return op.values() and any(x.name in func_arg_placeholders
52                             for x in op.values())
53
54
55def _get_node_def(op):
56  return op.node_def  # pylint: disable=protected-access
57
58
59def _get_op_def(op):
60  return op.op_def or op_def_registry.get(op.type)
61
62
63def _create_input_dict(function_graph,
64                       func_arg_placeholders,
65                       initial_value=None):
66  """Create a mapping from graph tensor names to function tensor names."""
67  if initial_value is None:
68    input_dict = {}
69  else:
70    input_dict = dict(initial_value)
71  for op in function_graph.get_operations():
72    if _is_in_placeholders(op, func_arg_placeholders):
73      input_dict[op.name] = op.name
74    else:
75      op_def = _get_op_def(op)
76      attrs = _get_node_def(op).attr
77      o = 0
78      for arg_def in op_def.output_arg:
79        if arg_def.number_attr:
80          num = attrs[arg_def.number_attr].i
81        elif arg_def.type_list_attr:
82          num = len(attrs[arg_def.type_list_attr].list.type)
83        else:
84          num = 1
85        for i in range(num):
86          result = "%s:%s:%d" % (op.name, arg_def.name, i)
87          input_dict[op.values()[o].name] = result
88          if o == 0:
89            input_dict[op.name] = result
90          o += 1
91  return input_dict
92
93
94def _add_op_node(op, func, input_dict):
95  """Converts an op to a function def node and add it to `func`."""
96  # Add an entry in func.node_def
97
98  # Note that extend() makes a copy in this case, see:
99  # https://developers.google.com/protocol-buffers/docs/reference/python-generated#repeated-message-fields
100  func.node_def.extend([_get_node_def(op)])
101  node_def = func.node_def[-1]
102  for i in range(len(node_def.input)):
103    if not node_def.input[i].startswith("^"):
104      assert node_def.input[i] in input_dict, ("%s missing from %s" %
105                                               (node_def.input[i],
106                                                input_dict.items()))
107      node_def.input[i] = input_dict[node_def.input[i]]
108  # The function is stateful if any of its operations are stateful.
109  # NOTE(mrry): The "Const" node typically does not have an `OpDef` associated
110  # with it, so we assume any nodes without an `OpDef` are stateless.
111  # TODO(skyewm): Remove the `is not None` test after we transition to the C
112  # API.
113  if op.op_def is not None and op.op_def.is_stateful:
114    func.signature.is_stateful = True
115
116
117def graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
118  """Returns `graph` as a `FunctionDef` protocol buffer.
119
120  This method creates a [`FunctionDef`](
121  https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
122  protocol buffer that contains all the ops in `operations`.  The
123  operations become the body of the function.
124
125  The arguments `inputs` and `outputs` will be listed as the inputs
126  and outputs tensors of the function.  They must be lists of
127  tensors present in the graph.  The lists can optionally be empty.
128
129  Args:
130    graph: Graph.
131    operations: the operations to put in the function. Must be a subset of
132     the operations in the graph.
133    inputs: List of tensors. Inputs to the function.
134    outputs: List of tensors. Outputs of the function.
135    out_names: Optional list of string names for the outputs.
136
137  Returns:
138    A FunctionDef protocol buffer.
139
140  Raises:
141    ValueError: if out_names is specified and the wrong length.
142  """
143  func = function_pb2.FunctionDef()
144  func.signature.name = "_"
145  used_names = set()
146  func.signature.input_arg.extend(
147      [_tensor_to_argdef(i, used_names=used_names) for i in inputs])
148  # Initializes the input map with all placeholder input tensors.
149  initial_dict = {}
150  for o, m in zip(inputs, func.signature.input_arg):
151    initial_dict[o.name] = m.name
152  if out_names is None:
153    used_names = set()
154    func.signature.output_arg.extend(
155        [_tensor_to_argdef(o, used_names=used_names) for o in outputs])
156  elif len(outputs) != len(out_names):
157    raise ValueError(
158        f"out_names must be either empty or equal in size to outputs. "
159        f"len(out_names) = {len(out_names)} len(outputs) = {len(outputs)}")
160  elif len(out_names) != len(set(out_names)):
161    raise ValueError(
162        f"Must not have duplicates in out_names. Received: {out_names}")
163  else:
164    func.signature.output_arg.extend(
165        [_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
166  func_arg_placeholders = set(i.name for i in inputs)
167  input_dict = _create_input_dict(graph, func_arg_placeholders,
168                                  initial_value=initial_dict)
169
170  for op in operations:
171    if _is_in_placeholders(op, func_arg_placeholders):
172      continue
173    _add_op_node(op, func, input_dict)
174
175  if out_names is None:
176    for index, o in enumerate(outputs):
177      k = func.signature.output_arg[index].name
178      func.ret[k] = input_dict[o.name]
179  else:
180    for o, n in zip(outputs, out_names):
181      func.ret[n] = input_dict[o.name]
182
183  return func
184