xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/common_shapes.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A library of common shape functions."""
16import itertools
17
18from tensorflow.python.framework import tensor_shape
19
20
21def _broadcast_shape_helper(shape_x, shape_y):
22  """Helper functions for is_broadcast_compatible and broadcast_shape.
23
24  Args:
25    shape_x: A `TensorShape`
26    shape_y: A `TensorShape`
27
28  Returns:
29    Returns None if the shapes are not broadcast compatible,
30    a list of the broadcast dimensions otherwise.
31  """
32  # To compute the broadcasted dimensions, we zip together shape_x and shape_y,
33  # and pad with 1 to make them the same length.
34  broadcasted_dims = reversed(
35      list(
36          itertools.zip_longest(
37              reversed(shape_x.dims),
38              reversed(shape_y.dims),
39              fillvalue=tensor_shape.Dimension(1))))
40  # Next we combine the dimensions according to the numpy broadcasting rules.
41  # http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
42  return_dims = []
43  for (dim_x, dim_y) in broadcasted_dims:
44    if dim_x.value is None or dim_y.value is None:
45      # One or both dimensions is unknown. If either dimension is greater than
46      # 1, we assume that the program is correct, and the other dimension will
47      # be broadcast to match it.
48      # TODO(mrry): If we eliminate the shape checks in C++, we must still
49      # assert that the unknown dim is either 1 or the same as the known dim.
50      if dim_x.value is not None and dim_x.value > 1:
51        return_dims.append(dim_x)
52      elif dim_y.value is not None and dim_y.value > 1:
53        return_dims.append(dim_y)
54      else:
55        return_dims.append(None)
56    elif dim_x.value == 1:
57      # We will broadcast dim_x to dim_y.
58      return_dims.append(dim_y)
59    elif dim_y.value == 1:
60      # We will broadcast dim_y to dim_x.
61      return_dims.append(dim_x)
62    elif dim_x.value == dim_y.value:
63      # The dimensions are compatible, so output is the same size in that
64      # dimension.
65      return_dims.append(dim_x.merge_with(dim_y))
66    else:
67      return None
68  return return_dims
69
70
71def is_broadcast_compatible(shape_x, shape_y):
72  """Returns True if `shape_x` and `shape_y` are broadcast compatible.
73
74  Args:
75    shape_x: A `TensorShape`
76    shape_y: A `TensorShape`
77
78  Returns:
79    True if a shape exists that both `shape_x` and `shape_y` can be broadcasted
80    to.  False otherwise.
81  """
82  if shape_x.ndims is None or shape_y.ndims is None:
83    return False
84  return _broadcast_shape_helper(shape_x, shape_y) is not None
85
86
87def broadcast_shape(shape_x, shape_y):
88  """Returns the broadcasted shape between `shape_x` and `shape_y`.
89
90  Args:
91    shape_x: A `TensorShape`
92    shape_y: A `TensorShape`
93
94  Returns:
95    A `TensorShape` representing the broadcasted shape.
96
97  Raises:
98    ValueError: If the two shapes can not be broadcasted.
99  """
100  if shape_x.ndims is None or shape_y.ndims is None:
101    return tensor_shape.unknown_shape()
102  return_dims = _broadcast_shape_helper(shape_x, shape_y)
103  if return_dims is None:
104    raise ValueError('Incompatible shapes for broadcasting. Two shapes are '
105                     'compatible if for each dimension pair they are either '
106                     'equal or one of them is 1. '
107                     f'Received: {shape_x} and {shape_y}.')
108  return tensor_shape.TensorShape(return_dims)
109