1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for StructuredTensor.""" 16 17from absl.testing import parameterized 18 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import tensor_shape 22from tensorflow.python.framework import tensor_spec 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops.structured import structured_tensor 26from tensorflow.python.platform import googletest 27 28 29# TODO(edloper): Move this to a common util package (forked from ragged). 30class _SliceBuilder: 31 """Helper to construct arguments for __getitem__. 32 33 Usage: _SliceBuilder()[<expr>] slice_spec Python generates for <expr>. 34 """ 35 36 def __getitem__(self, slice_spec): 37 return slice_spec 38 39 40# TODO(edloper): Move this to a common util package (forked from ragged). 41SLICE_BUILDER = _SliceBuilder() 42 43 44# TODO(edloper): Move this to a common util package (forked from ragged). 45def _make_tensor_slice_spec(slice_spec, use_constant=True): 46 """Wraps all integers in an extended slice spec w/ a tensor. 47 48 This function is used to help test slicing when the slice spec contains 49 tensors, rather than integers. 50 51 Args: 52 slice_spec: The extended slice spec. 53 use_constant: If true, then wrap each integer with a tf.constant. If false, 54 then wrap each integer with a tf.placeholder. 55 56 Returns: 57 A copy of slice_spec, but with each integer i replaced with tf.constant(i). 58 """ 59 60 def make_piece_scalar(piece): 61 if isinstance(piece, int): 62 scalar = constant_op.constant(piece) 63 if use_constant: 64 return scalar 65 else: 66 return array_ops.placeholder_with_default(scalar, []) 67 elif isinstance(piece, slice): 68 return slice( 69 make_piece_scalar(piece.start), make_piece_scalar(piece.stop), 70 make_piece_scalar(piece.step)) 71 else: 72 return piece 73 74 if isinstance(slice_spec, tuple): 75 return tuple(make_piece_scalar(piece) for piece in slice_spec) 76 else: 77 return make_piece_scalar(slice_spec) 78 79 80EXAMPLE_STRUCT = { 81 # f1: scalar value field 82 "f1": 1, 83 # f2: matrix field 84 "f2": [[1, 2], [3, 4]], 85 # f3: scalar structure field 86 "f3": {"f3_1": 1}, 87 # f4: vector structure field 88 "f4": [{"f4_1": 1, "f4_2": b"a"}, {"f4_1": 2, "f4_2": b"b"}], 89 # f5: matrix structure field 90 "f5": [[{"f5_1": 1}, {"f5_1": 2}], [{"f5_1": 3}, {"f5_1": 4}]], 91} 92 93EXAMPLE_STRUCT_2 = { 94 # f1: scalar value field 95 "f1": 5, 96 # f2: matrix field 97 "f2": [[6, 7], [8, 9]], 98 # f3: scalar structure field 99 "f3": {"f3_1": 9}, 100 # f4: vector structure field 101 "f4": [{"f4_1": 5, "f4_2": b"A"}, {"f4_1": 6, "f4_2": b"B"}], 102 # f5: matrix structure field 103 "f5": [[{"f5_1": 6}, {"f5_1": 7}], [{"f5_1": 8}, {"f5_1": 9}]], 104} 105 106EXAMPLE_STRUCT_VECTOR = [EXAMPLE_STRUCT] * 5 + [EXAMPLE_STRUCT_2] 107 108EXAMPLE_STRUCT_SPEC1 = structured_tensor.StructuredTensorSpec([], { 109 "f1": tensor_spec.TensorSpec([], dtypes.int32), 110 "f2": tensor_spec.TensorSpec([2, 2], dtypes.int32), 111 "f3": structured_tensor.StructuredTensorSpec( 112 [], {"f3_1": tensor_spec.TensorSpec([], dtypes.int32)}), 113 "f4": structured_tensor.StructuredTensorSpec( 114 [2], {"f4_1": tensor_spec.TensorSpec([2], dtypes.int32), 115 "f4_2": tensor_spec.TensorSpec([2], dtypes.string)}), 116 "f5": structured_tensor.StructuredTensorSpec( 117 [2, 2], {"f5_1": tensor_spec.TensorSpec([2, 2], dtypes.int32)}), 118}) 119 120 121@test_util.run_all_in_graph_and_eager_modes 122class StructuredTensorSliceTest(test_util.TensorFlowTestCase, 123 parameterized.TestCase): 124 125 def assertAllEqual(self, a, b, msg=None): 126 if not (isinstance(a, structured_tensor.StructuredTensor) or 127 isinstance(b, structured_tensor.StructuredTensor)): 128 super(StructuredTensorSliceTest, self).assertAllEqual(a, b, msg) 129 elif (isinstance(a, structured_tensor.StructuredTensor) and 130 isinstance(b, structured_tensor.StructuredTensor)): 131 a_shape = tensor_shape.as_shape(a.shape) 132 b_shape = tensor_shape.as_shape(b.shape) 133 a_shape.assert_is_compatible_with(b_shape) 134 self.assertEqual(set(a.field_names()), set(b.field_names())) 135 for field in a.field_names(): 136 self.assertAllEqual(a.field_value(field), b.field_value(field)) 137 elif isinstance(b, structured_tensor.StructuredTensor): 138 self.assertAllEqual(b, a, msg) 139 else: 140 if a.rank == 0: 141 self.assertIsInstance(b, dict) 142 self.assertEqual(set(a.field_names()), set(b)) 143 for (key, b_val) in b.items(): 144 a_val = a.field_value(key) 145 self.assertAllEqual(a_val, b_val) 146 else: 147 self.assertIsInstance(b, (list, tuple)) 148 a.shape[:1].assert_is_compatible_with([len(b)]) 149 for i in range(len(b)): 150 self.assertAllEqual(a[i], b[i]) 151 152 def _TestGetItem(self, struct, slice_spec, expected): 153 """Helper function for testing StructuredTensor.__getitem__. 154 155 Checks that calling `struct.__getitem__(slice_spec) returns the expected 156 value. Checks three different configurations for each slice spec: 157 158 * Call __getitem__ with the slice spec as-is (with int values) 159 * Call __getitem__ with int values in the slice spec wrapped in 160 `tf.constant()`. 161 * Call __getitem__ with int values in the slice spec wrapped in 162 `tf.compat.v1.placeholder()` (so value is not known at graph 163 construction time). 164 165 Args: 166 struct: The StructuredTensor to test. 167 slice_spec: The slice spec. 168 expected: The expected value of struct.__getitem__(slice_spec), as a 169 python list. 170 """ 171 tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True) 172 tensor_slice_spec2 = _make_tensor_slice_spec(slice_spec, False) 173 value1 = struct.__getitem__(slice_spec) 174 value2 = struct.__getitem__(tensor_slice_spec1) 175 value3 = struct.__getitem__(tensor_slice_spec2) 176 self.assertAllEqual(value1, expected, "slice_spec=%s" % (slice_spec,)) 177 self.assertAllEqual(value2, expected, "slice_spec=%s" % (slice_spec,)) 178 self.assertAllEqual(value3, expected, "slice_spec=%s" % (slice_spec,)) 179 180 @parameterized.parameters([ 181 # Simple indexing 182 (SLICE_BUILDER["f1"], EXAMPLE_STRUCT["f1"]), 183 (SLICE_BUILDER["f2"], EXAMPLE_STRUCT["f2"]), 184 (SLICE_BUILDER["f3"], EXAMPLE_STRUCT["f3"]), 185 (SLICE_BUILDER["f4"], EXAMPLE_STRUCT["f4"]), 186 (SLICE_BUILDER["f5"], EXAMPLE_STRUCT["f5"]), 187 # Multidimensional indexing 188 (SLICE_BUILDER["f2", 1], EXAMPLE_STRUCT["f2"][1]), 189 (SLICE_BUILDER["f3", "f3_1"], EXAMPLE_STRUCT["f3"]["f3_1"]), 190 (SLICE_BUILDER["f4", 1], EXAMPLE_STRUCT["f4"][1]), 191 (SLICE_BUILDER["f4", 1, "f4_2"], EXAMPLE_STRUCT["f4"][1]["f4_2"]), 192 (SLICE_BUILDER["f5", 0, 1], EXAMPLE_STRUCT["f5"][0][1]), 193 (SLICE_BUILDER["f5", 0, 1, "f5_1"], EXAMPLE_STRUCT["f5"][0][1]["f5_1"]), 194 # Multidimensional slicing 195 (SLICE_BUILDER["f2", 1:], EXAMPLE_STRUCT["f2"][1:]), 196 (SLICE_BUILDER["f4", :1], EXAMPLE_STRUCT["f4"][:1]), 197 (SLICE_BUILDER["f4", 1:, "f4_2"], [b"b"]), 198 (SLICE_BUILDER["f4", :, "f4_2"], [b"a", b"b"]), 199 (SLICE_BUILDER["f5", :, :, "f5_1"], [[1, 2], [3, 4]]), 200 # Slicing over multiple keys 201 (SLICE_BUILDER[:], EXAMPLE_STRUCT), 202 # List-valued key. 203 (["f2", 1], EXAMPLE_STRUCT["f2"][1]), 204 ]) 205 def testGetitemFromScalarStruct(self, slice_spec, expected): 206 # By default, lists are converted to RaggedTensors. 207 struct = structured_tensor.StructuredTensor.from_pyval(EXAMPLE_STRUCT) 208 self._TestGetItem(struct, slice_spec, expected) 209 210 # Using an explicit TypeSpec, we can convert them to Tensors instead. 211 struct2 = structured_tensor.StructuredTensor.from_pyval( 212 EXAMPLE_STRUCT, EXAMPLE_STRUCT_SPEC1) 213 self._TestGetItem(struct2, slice_spec, expected) 214 215 @parameterized.parameters([ 216 (SLICE_BUILDER[2], EXAMPLE_STRUCT_VECTOR[2]), 217 (SLICE_BUILDER[5], EXAMPLE_STRUCT_VECTOR[5]), 218 (SLICE_BUILDER[-2], EXAMPLE_STRUCT_VECTOR[-2]), 219 (SLICE_BUILDER[-1], EXAMPLE_STRUCT_VECTOR[-1]), 220 (SLICE_BUILDER[2, "f1"], EXAMPLE_STRUCT_VECTOR[2]["f1"]), 221 (SLICE_BUILDER[-1, "f1"], EXAMPLE_STRUCT_VECTOR[-1]["f1"]), 222 (SLICE_BUILDER[5:], EXAMPLE_STRUCT_VECTOR[5:]), 223 (SLICE_BUILDER[3:, "f1"], [1, 1, 5]), 224 (SLICE_BUILDER[::2, "f1"], [1, 1, 1]), 225 (SLICE_BUILDER[1::2, "f1"], [1, 1, 5]), 226 (SLICE_BUILDER[4:, "f5", 0, 1, "f5_1"], [2, 7], True), 227 (SLICE_BUILDER[4:, "f5", :, :, "f5_1"], 228 [[[1, 2], [3, 4]], [[6, 7], [8, 9]]]), 229 ]) # pyformat: disable 230 def testGetitemFromVectorStruct(self, slice_spec, expected, 231 test_requires_typespec=False): 232 # By default, lists are converted to RaggedTensors. 233 if not test_requires_typespec: 234 struct_vector = structured_tensor.StructuredTensor.from_pyval( 235 EXAMPLE_STRUCT_VECTOR) 236 self._TestGetItem(struct_vector, slice_spec, expected) 237 238 # Using an explicit TypeSpec, we can convert them to Tensors instead. 239 struct_vector2 = structured_tensor.StructuredTensor.from_pyval( 240 EXAMPLE_STRUCT_VECTOR, EXAMPLE_STRUCT_SPEC1._batch(6)) 241 self._TestGetItem(struct_vector2, slice_spec, expected) 242 243 # TODO(edloper): Add tests for slicing from matrix StructuredTensors. 244 245 @parameterized.parameters([ 246 (SLICE_BUILDER[:2], r"Key for indexing a StructuredTensor must be " 247 r"a string or a full slice \(':'\)"), 248 (SLICE_BUILDER["f4", ...], r"Slicing not supported for Ellipsis"), 249 (SLICE_BUILDER["f4", None], r"Slicing not supported for tf.newaxis"), 250 (SLICE_BUILDER["f4", :, 0], 251 r"Key for indexing a StructuredTensor must be a string"), 252 ]) 253 def testGetItemError(self, slice_spec, error, exception=ValueError): 254 struct = structured_tensor.StructuredTensor.from_pyval(EXAMPLE_STRUCT) 255 with self.assertRaisesRegex(exception, error): 256 struct.__getitem__(slice_spec) 257 258 @parameterized.parameters([ 259 (SLICE_BUILDER[:, 1], 260 r"Key for indexing a StructuredTensor must be a string"), 261 ]) 262 def testGetItemFromVectorError(self, slice_spec, error, exception=ValueError): 263 struct = structured_tensor.StructuredTensor.from_pyval( 264 EXAMPLE_STRUCT_VECTOR) 265 with self.assertRaisesRegex(exception, error): 266 struct.__getitem__(slice_spec) 267 268 269if __name__ == "__main__": 270 googletest.main() 271