1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A library of common shape functions.""" 16import itertools 17 18from tensorflow.python.framework import tensor_shape 19 20 21def _broadcast_shape_helper(shape_x, shape_y): 22 """Helper functions for is_broadcast_compatible and broadcast_shape. 23 24 Args: 25 shape_x: A `TensorShape` 26 shape_y: A `TensorShape` 27 28 Returns: 29 Returns None if the shapes are not broadcast compatible, 30 a list of the broadcast dimensions otherwise. 31 """ 32 # To compute the broadcasted dimensions, we zip together shape_x and shape_y, 33 # and pad with 1 to make them the same length. 34 broadcasted_dims = reversed( 35 list( 36 itertools.zip_longest( 37 reversed(shape_x.dims), 38 reversed(shape_y.dims), 39 fillvalue=tensor_shape.Dimension(1)))) 40 # Next we combine the dimensions according to the numpy broadcasting rules. 41 # http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html 42 return_dims = [] 43 for (dim_x, dim_y) in broadcasted_dims: 44 if dim_x.value is None or dim_y.value is None: 45 # One or both dimensions is unknown. If either dimension is greater than 46 # 1, we assume that the program is correct, and the other dimension will 47 # be broadcast to match it. 48 # TODO(mrry): If we eliminate the shape checks in C++, we must still 49 # assert that the unknown dim is either 1 or the same as the known dim. 50 if dim_x.value is not None and dim_x.value > 1: 51 return_dims.append(dim_x) 52 elif dim_y.value is not None and dim_y.value > 1: 53 return_dims.append(dim_y) 54 else: 55 return_dims.append(None) 56 elif dim_x.value == 1: 57 # We will broadcast dim_x to dim_y. 58 return_dims.append(dim_y) 59 elif dim_y.value == 1: 60 # We will broadcast dim_y to dim_x. 61 return_dims.append(dim_x) 62 elif dim_x.value == dim_y.value: 63 # The dimensions are compatible, so output is the same size in that 64 # dimension. 65 return_dims.append(dim_x.merge_with(dim_y)) 66 else: 67 return None 68 return return_dims 69 70 71def is_broadcast_compatible(shape_x, shape_y): 72 """Returns True if `shape_x` and `shape_y` are broadcast compatible. 73 74 Args: 75 shape_x: A `TensorShape` 76 shape_y: A `TensorShape` 77 78 Returns: 79 True if a shape exists that both `shape_x` and `shape_y` can be broadcasted 80 to. False otherwise. 81 """ 82 if shape_x.ndims is None or shape_y.ndims is None: 83 return False 84 return _broadcast_shape_helper(shape_x, shape_y) is not None 85 86 87def broadcast_shape(shape_x, shape_y): 88 """Returns the broadcasted shape between `shape_x` and `shape_y`. 89 90 Args: 91 shape_x: A `TensorShape` 92 shape_y: A `TensorShape` 93 94 Returns: 95 A `TensorShape` representing the broadcasted shape. 96 97 Raises: 98 ValueError: If the two shapes can not be broadcasted. 99 """ 100 if shape_x.ndims is None or shape_y.ndims is None: 101 return tensor_shape.unknown_shape() 102 return_dims = _broadcast_shape_helper(shape_x, shape_y) 103 if return_dims is None: 104 raise ValueError('Incompatible shapes for broadcasting. Two shapes are ' 105 'compatible if for each dimension pair they are either ' 106 'equal or one of them is 1. ' 107 f'Received: {shape_x} and {shape_y}.') 108 return tensor_shape.TensorShape(return_dims) 109