1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for StructuredTensor.Spec.""" 16 17from absl.testing import parameterized 18import numpy as np 19 20 21from tensorflow.python.data.ops import dataset_ops 22from tensorflow.python.eager import context 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import tensor_shape 25from tensorflow.python.framework import tensor_spec 26from tensorflow.python.framework import test_util 27from tensorflow.python.framework import type_spec 28from tensorflow.python.ops.ragged import ragged_factory_ops 29from tensorflow.python.ops.ragged import ragged_tensor 30from tensorflow.python.ops.ragged import row_partition 31from tensorflow.python.ops.ragged.dynamic_ragged_shape import DynamicRaggedShape 32from tensorflow.python.ops.structured import structured_tensor 33from tensorflow.python.ops.structured.structured_tensor import StructuredTensor 34from tensorflow.python.platform import googletest 35 36 37# TypeSpecs consts for fields types. 38T_3 = tensor_spec.TensorSpec([3]) 39T_1_2 = tensor_spec.TensorSpec([1, 2]) 40T_1_2_8 = tensor_spec.TensorSpec([1, 2, 8]) 41T_1_2_3_4 = tensor_spec.TensorSpec([1, 2, 3, 4]) 42T_2_3 = tensor_spec.TensorSpec([2, 3]) 43R_1_N = ragged_tensor.RaggedTensorSpec([1, None]) 44R_2_N = ragged_tensor.RaggedTensorSpec([2, None]) 45R_1_N_N = ragged_tensor.RaggedTensorSpec([1, None, None]) 46R_2_1_N = ragged_tensor.RaggedTensorSpec([2, 1, None]) 47 48# TensorSpecs for nrows & row_splits in the _to_components encoding. 49NROWS_SPEC = tensor_spec.TensorSpec([], dtypes.int64) 50PARTITION_SPEC = row_partition.RowPartitionSpec() 51 52 53# pylint: disable=g-long-lambda 54@test_util.run_all_in_graph_and_eager_modes 55class StructuredTensorSpecTest(test_util.TensorFlowTestCase, 56 parameterized.TestCase): 57 58 # TODO(edloper): Add a subclass of TensorFlowTestCase that overrides 59 # assertAllEqual etc to work with StructuredTensors. 60 def assertAllEqual(self, a, b, msg=None): 61 if not (isinstance(a, structured_tensor.StructuredTensor) or 62 isinstance(b, structured_tensor.StructuredTensor)): 63 return super(StructuredTensorSpecTest, self).assertAllEqual(a, b, msg) 64 if not (isinstance(a, structured_tensor.StructuredTensor) and 65 isinstance(b, structured_tensor.StructuredTensor)): 66 # TODO(edloper) Add support for this once structured_factory_ops is added. 67 raise ValueError('Not supported yet') 68 69 self.assertEqual(repr(a.shape), repr(b.shape)) 70 self.assertEqual(set(a.field_names()), set(b.field_names())) 71 for field in a.field_names(): 72 self.assertAllEqual(a.field_value(field), b.field_value(field)) 73 74 def assertAllTensorsEqual(self, x, y): 75 assert isinstance(x, dict) and isinstance(y, dict) 76 self.assertEqual(set(x), set(y)) 77 for key in x: 78 self.assertAllEqual(x[key], y[key]) 79 80 def testConstruction(self): 81 spec1_fields = dict(a=T_1_2_3_4) 82 spec1 = StructuredTensor.Spec( 83 _ragged_shape=DynamicRaggedShape.Spec( 84 row_partitions=[], 85 static_inner_shape=tensor_shape.TensorShape([1, 2, 3]), 86 dtype=dtypes.int64), 87 _fields=spec1_fields) 88 self.assertEqual(spec1._shape, (1, 2, 3)) 89 self.assertEqual(spec1._field_specs, spec1_fields) 90 91 spec2_fields = dict(a=T_1_2, b=T_1_2_8, c=R_1_N, d=R_1_N_N, s=spec1) 92 spec2 = StructuredTensor.Spec( 93 _ragged_shape=DynamicRaggedShape.Spec( 94 row_partitions=[], 95 static_inner_shape=tensor_shape.TensorShape([1, 2]), 96 dtype=dtypes.int64), 97 _fields=spec2_fields) 98 self.assertEqual(spec2._shape, (1, 2)) 99 self.assertEqual(spec2._field_specs, spec2_fields) 100 101 # Note that there is no error for creating a spec without known rank. 102 @parameterized.parameters([ 103 (None, r'fields: expected mapping, got None'), 104 ({1: tensor_spec.TensorSpec(None)}, 105 r'expected str, got 1'), 106 ({'x': 0}, 107 r'got 0'), 108 ]) 109 def testConstructionErrors(self, field_specs, error): 110 with self.assertRaisesRegex(TypeError, error): 111 structured_tensor.StructuredTensor.Spec( 112 _ragged_shape=DynamicRaggedShape.Spec( 113 row_partitions=[], 114 static_inner_shape=[], 115 dtype=dtypes.int64), 116 _fields=field_specs) 117 118 def testValueType(self): 119 spec1 = StructuredTensor.Spec( 120 _ragged_shape=DynamicRaggedShape.Spec( 121 row_partitions=[], 122 static_inner_shape=[1, 2], 123 dtype=dtypes.int64), 124 _fields=dict(a=T_1_2)) 125 self.assertEqual(spec1.value_type, StructuredTensor) 126 127 @parameterized.parameters([ 128 { 129 'shape': [], 130 'fields': dict(x=[[1.0, 2.0]]), 131 'field_specs': dict(x=T_1_2), 132 }, 133 { 134 'shape': [2], 135 'fields': dict( 136 a=ragged_factory_ops.constant_value([[1.0], [2.0, 3.0]]), 137 b=[[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]), 138 'field_specs': dict(a=R_2_N, b=T_2_3), 139 }, 140 ]) # pyformat: disable 141 def testToFromComponents(self, shape, fields, field_specs): 142 struct = StructuredTensor.from_fields(fields, shape) 143 spec = StructuredTensor.Spec(_ragged_shape=DynamicRaggedShape.Spec( 144 row_partitions=[], 145 static_inner_shape=shape, 146 dtype=dtypes.int64), _fields=field_specs) 147 actual_components = spec._to_components(struct) 148 rt_reconstructed = spec._from_components(actual_components) 149 self.assertAllEqual(struct, rt_reconstructed) 150 151 def testToFromComponentsEmptyScalar(self): 152 struct = StructuredTensor.from_fields(fields={}, shape=[]) 153 spec = struct._type_spec 154 components = spec._to_components(struct) 155 rt_reconstructed = spec._from_components(components) 156 self.assertAllEqual(struct, rt_reconstructed) 157 158 def testToFromComponentsEmptyTensor(self): 159 struct = StructuredTensor.from_fields(fields={}, shape=[1, 2, 3]) 160 spec = struct._type_spec 161 components = spec._to_components(struct) 162 rt_reconstructed = spec._from_components(components) 163 self.assertAllEqual(struct, rt_reconstructed) 164 165 @parameterized.parameters([ 166 { 167 'unbatched': lambda: [ 168 StructuredTensor.from_fields({'a': 1, 'b': [5, 6]}), 169 StructuredTensor.from_fields({'a': 2, 'b': [7, 8]})], 170 'batch_size': 2, 171 'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={ 172 'a': [1, 2], 173 'b': [[5, 6], [7, 8]]}), 174 }, 175 { 176 'unbatched': lambda: [ 177 StructuredTensor.from_fields(shape=[3], fields={ 178 'a': [1, 2, 3], 179 'b': [[5, 6], [6, 7], [7, 8]]}), 180 StructuredTensor.from_fields(shape=[3], fields={ 181 'a': [2, 3, 4], 182 'b': [[2, 2], [3, 3], [4, 4]]})], 183 'batch_size': 2, 184 'batched': lambda: StructuredTensor.from_fields(shape=[2, 3], fields={ 185 'a': [[1, 2, 3], [2, 3, 4]], 186 'b': [[[5, 6], [6, 7], [7, 8]], 187 [[2, 2], [3, 3], [4, 4]]]}), 188 }, 189 { 190 'unbatched': lambda: [ 191 StructuredTensor.from_fields(shape=[], fields={ 192 'a': 1, 193 'b': StructuredTensor.from_fields({'x': [5]})}), 194 StructuredTensor.from_fields(shape=[], fields={ 195 'a': 2, 196 'b': StructuredTensor.from_fields({'x': [6]})})], 197 'batch_size': 2, 198 'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={ 199 'a': [1, 2], 200 'b': StructuredTensor.from_fields(shape=[2], fields={ 201 'x': [[5], [6]]})}), 202 }, 203 { 204 'unbatched': lambda: [ 205 StructuredTensor.from_fields(shape=[], fields={ 206 'Ragged3d': ragged_factory_ops.constant_value([[1, 2], [3]]), 207 'Ragged2d': ragged_factory_ops.constant_value([1]), 208 }), 209 StructuredTensor.from_fields(shape=[], fields={ 210 'Ragged3d': ragged_factory_ops.constant_value([[1]]), 211 'Ragged2d': ragged_factory_ops.constant_value([2, 3]), 212 })], 213 'batch_size': 2, 214 'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={ 215 'Ragged3d': ragged_factory_ops.constant_value( 216 [[[1, 2], [3]], [[1]]]), 217 'Ragged2d': ragged_factory_ops.constant_value([[1], [2, 3]]), 218 }), 219 'use_only_batched_spec': True, 220 }, 221 ]) # pyformat: disable 222 def testBatchUnbatchValues(self, 223 unbatched, 224 batch_size, 225 batched, 226 use_only_batched_spec=False): 227 batched = batched() # Deferred init because it creates tensors. 228 unbatched = unbatched() # Deferred init because it creates tensors. 229 230 def unbatch_gen(): 231 for i in unbatched: 232 yield i 233 234 ds = dataset_ops.Dataset.from_tensors(batched) 235 ds2 = ds.unbatch() 236 if context.executing_eagerly(): 237 v = list(ds2.batch(2)) 238 self.assertAllEqual(v[0], batched) 239 240 if not use_only_batched_spec: 241 unbatched_spec = type_spec.type_spec_from_value(unbatched[0]) 242 243 dsu = dataset_ops.Dataset.from_generator( 244 unbatch_gen, output_signature=unbatched_spec) 245 dsu2 = dsu.batch(2) 246 if context.executing_eagerly(): 247 v = list(dsu2) 248 self.assertAllEqual(v[0], batched) 249 250 def _lambda_for_fields(self): 251 return lambda: { 252 'a': 253 np.ones([1, 2, 3, 1]), 254 'b': 255 np.ones([1, 2, 3, 1, 5]), 256 'c': 257 ragged_factory_ops.constant( 258 np.zeros([1, 2, 3, 1], dtype=np.uint8), dtype=dtypes.uint8), 259 'd': 260 ragged_factory_ops.constant( 261 np.zeros([1, 2, 3, 1, 3]).tolist(), ragged_rank=1), 262 'e': 263 ragged_factory_ops.constant( 264 np.zeros([1, 2, 3, 1, 2, 2]).tolist(), ragged_rank=2), 265 'f': 266 ragged_factory_ops.constant( 267 np.zeros([1, 2, 3, 1, 3]), dtype=dtypes.float32), 268 'g': 269 StructuredTensor.from_pyval([[ 270 [ # pylint: disable=g-complex-comprehension 271 [{ 272 'x': j, 273 'y': k 274 }] for k in range(3) 275 ] for j in range(2) 276 ]]), 277 'h': 278 StructuredTensor.from_pyval([[ 279 [ # pylint: disable=g-complex-comprehension 280 [[ 281 { 282 'x': j, 283 'y': k, 284 'z': z 285 } for z in range(j) 286 ]] for k in range(3) 287 ] for j in range(2) 288 ]]), 289 } 290 291 292if __name__ == '__main__': 293 googletest.main() 294