xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/structured/structured_tensor_spec_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for StructuredTensor.Spec."""
16
17from absl.testing import parameterized
18import numpy as np
19
20
21from tensorflow.python.data.ops import dataset_ops
22from tensorflow.python.eager import context
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import tensor_spec
26from tensorflow.python.framework import test_util
27from tensorflow.python.framework import type_spec
28from tensorflow.python.ops.ragged import ragged_factory_ops
29from tensorflow.python.ops.ragged import ragged_tensor
30from tensorflow.python.ops.ragged import row_partition
31from tensorflow.python.ops.ragged.dynamic_ragged_shape import DynamicRaggedShape
32from tensorflow.python.ops.structured import structured_tensor
33from tensorflow.python.ops.structured.structured_tensor import StructuredTensor
34from tensorflow.python.platform import googletest
35
36
37# TypeSpecs consts for fields types.
38T_3 = tensor_spec.TensorSpec([3])
39T_1_2 = tensor_spec.TensorSpec([1, 2])
40T_1_2_8 = tensor_spec.TensorSpec([1, 2, 8])
41T_1_2_3_4 = tensor_spec.TensorSpec([1, 2, 3, 4])
42T_2_3 = tensor_spec.TensorSpec([2, 3])
43R_1_N = ragged_tensor.RaggedTensorSpec([1, None])
44R_2_N = ragged_tensor.RaggedTensorSpec([2, None])
45R_1_N_N = ragged_tensor.RaggedTensorSpec([1, None, None])
46R_2_1_N = ragged_tensor.RaggedTensorSpec([2, 1, None])
47
48# TensorSpecs for nrows & row_splits in the _to_components encoding.
49NROWS_SPEC = tensor_spec.TensorSpec([], dtypes.int64)
50PARTITION_SPEC = row_partition.RowPartitionSpec()
51
52
53# pylint: disable=g-long-lambda
54@test_util.run_all_in_graph_and_eager_modes
55class StructuredTensorSpecTest(test_util.TensorFlowTestCase,
56                               parameterized.TestCase):
57
58  # TODO(edloper): Add a subclass of TensorFlowTestCase that overrides
59  # assertAllEqual etc to work with StructuredTensors.
60  def assertAllEqual(self, a, b, msg=None):
61    if not (isinstance(a, structured_tensor.StructuredTensor) or
62            isinstance(b, structured_tensor.StructuredTensor)):
63      return super(StructuredTensorSpecTest, self).assertAllEqual(a, b, msg)
64    if not (isinstance(a, structured_tensor.StructuredTensor) and
65            isinstance(b, structured_tensor.StructuredTensor)):
66      # TODO(edloper) Add support for this once structured_factory_ops is added.
67      raise ValueError('Not supported yet')
68
69    self.assertEqual(repr(a.shape), repr(b.shape))
70    self.assertEqual(set(a.field_names()), set(b.field_names()))
71    for field in a.field_names():
72      self.assertAllEqual(a.field_value(field), b.field_value(field))
73
74  def assertAllTensorsEqual(self, x, y):
75    assert isinstance(x, dict) and isinstance(y, dict)
76    self.assertEqual(set(x), set(y))
77    for key in x:
78      self.assertAllEqual(x[key], y[key])
79
80  def testConstruction(self):
81    spec1_fields = dict(a=T_1_2_3_4)
82    spec1 = StructuredTensor.Spec(
83        _ragged_shape=DynamicRaggedShape.Spec(
84            row_partitions=[],
85            static_inner_shape=tensor_shape.TensorShape([1, 2, 3]),
86            dtype=dtypes.int64),
87        _fields=spec1_fields)
88    self.assertEqual(spec1._shape, (1, 2, 3))
89    self.assertEqual(spec1._field_specs, spec1_fields)
90
91    spec2_fields = dict(a=T_1_2, b=T_1_2_8, c=R_1_N, d=R_1_N_N, s=spec1)
92    spec2 = StructuredTensor.Spec(
93        _ragged_shape=DynamicRaggedShape.Spec(
94            row_partitions=[],
95            static_inner_shape=tensor_shape.TensorShape([1, 2]),
96            dtype=dtypes.int64),
97        _fields=spec2_fields)
98    self.assertEqual(spec2._shape, (1, 2))
99    self.assertEqual(spec2._field_specs, spec2_fields)
100
101  # Note that there is no error for creating a spec without known rank.
102  @parameterized.parameters([
103      (None, r'fields: expected mapping, got None'),
104      ({1: tensor_spec.TensorSpec(None)},
105       r'expected str, got 1'),
106      ({'x': 0},
107       r'got 0'),
108  ])
109  def testConstructionErrors(self, field_specs, error):
110    with self.assertRaisesRegex(TypeError, error):
111      structured_tensor.StructuredTensor.Spec(
112          _ragged_shape=DynamicRaggedShape.Spec(
113              row_partitions=[],
114              static_inner_shape=[],
115              dtype=dtypes.int64),
116          _fields=field_specs)
117
118  def testValueType(self):
119    spec1 = StructuredTensor.Spec(
120        _ragged_shape=DynamicRaggedShape.Spec(
121            row_partitions=[],
122            static_inner_shape=[1, 2],
123            dtype=dtypes.int64),
124        _fields=dict(a=T_1_2))
125    self.assertEqual(spec1.value_type, StructuredTensor)
126
127  @parameterized.parameters([
128      {
129          'shape': [],
130          'fields': dict(x=[[1.0, 2.0]]),
131          'field_specs': dict(x=T_1_2),
132      },
133      {
134          'shape': [2],
135          'fields': dict(
136              a=ragged_factory_ops.constant_value([[1.0], [2.0, 3.0]]),
137              b=[[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]),
138          'field_specs': dict(a=R_2_N, b=T_2_3),
139      },
140  ])  # pyformat: disable
141  def testToFromComponents(self, shape, fields, field_specs):
142    struct = StructuredTensor.from_fields(fields, shape)
143    spec = StructuredTensor.Spec(_ragged_shape=DynamicRaggedShape.Spec(
144        row_partitions=[],
145        static_inner_shape=shape,
146        dtype=dtypes.int64), _fields=field_specs)
147    actual_components = spec._to_components(struct)
148    rt_reconstructed = spec._from_components(actual_components)
149    self.assertAllEqual(struct, rt_reconstructed)
150
151  def testToFromComponentsEmptyScalar(self):
152    struct = StructuredTensor.from_fields(fields={}, shape=[])
153    spec = struct._type_spec
154    components = spec._to_components(struct)
155    rt_reconstructed = spec._from_components(components)
156    self.assertAllEqual(struct, rt_reconstructed)
157
158  def testToFromComponentsEmptyTensor(self):
159    struct = StructuredTensor.from_fields(fields={}, shape=[1, 2, 3])
160    spec = struct._type_spec
161    components = spec._to_components(struct)
162    rt_reconstructed = spec._from_components(components)
163    self.assertAllEqual(struct, rt_reconstructed)
164
165  @parameterized.parameters([
166      {
167          'unbatched': lambda: [
168              StructuredTensor.from_fields({'a': 1, 'b': [5, 6]}),
169              StructuredTensor.from_fields({'a': 2, 'b': [7, 8]})],
170          'batch_size': 2,
171          'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={
172              'a': [1, 2],
173              'b': [[5, 6], [7, 8]]}),
174      },
175      {
176          'unbatched': lambda: [
177              StructuredTensor.from_fields(shape=[3], fields={
178                  'a': [1, 2, 3],
179                  'b': [[5, 6], [6, 7], [7, 8]]}),
180              StructuredTensor.from_fields(shape=[3], fields={
181                  'a': [2, 3, 4],
182                  'b': [[2, 2], [3, 3], [4, 4]]})],
183          'batch_size': 2,
184          'batched': lambda: StructuredTensor.from_fields(shape=[2, 3], fields={
185              'a': [[1, 2, 3], [2, 3, 4]],
186              'b': [[[5, 6], [6, 7], [7, 8]],
187                    [[2, 2], [3, 3], [4, 4]]]}),
188      },
189      {
190          'unbatched': lambda: [
191              StructuredTensor.from_fields(shape=[], fields={
192                  'a': 1,
193                  'b': StructuredTensor.from_fields({'x': [5]})}),
194              StructuredTensor.from_fields(shape=[], fields={
195                  'a': 2,
196                  'b': StructuredTensor.from_fields({'x': [6]})})],
197          'batch_size': 2,
198          'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={
199              'a': [1, 2],
200              'b': StructuredTensor.from_fields(shape=[2], fields={
201                  'x': [[5], [6]]})}),
202      },
203      {
204          'unbatched': lambda: [
205              StructuredTensor.from_fields(shape=[], fields={
206                  'Ragged3d': ragged_factory_ops.constant_value([[1, 2], [3]]),
207                  'Ragged2d': ragged_factory_ops.constant_value([1]),
208              }),
209              StructuredTensor.from_fields(shape=[], fields={
210                  'Ragged3d': ragged_factory_ops.constant_value([[1]]),
211                  'Ragged2d': ragged_factory_ops.constant_value([2, 3]),
212              })],
213          'batch_size': 2,
214          'batched': lambda: StructuredTensor.from_fields(shape=[2], fields={
215              'Ragged3d': ragged_factory_ops.constant_value(
216                  [[[1, 2], [3]], [[1]]]),
217              'Ragged2d': ragged_factory_ops.constant_value([[1], [2, 3]]),
218          }),
219          'use_only_batched_spec': True,
220      },
221  ])  # pyformat: disable
222  def testBatchUnbatchValues(self,
223                             unbatched,
224                             batch_size,
225                             batched,
226                             use_only_batched_spec=False):
227    batched = batched()  # Deferred init because it creates tensors.
228    unbatched = unbatched()  # Deferred init because it creates tensors.
229
230    def unbatch_gen():
231      for i in unbatched:
232        yield i
233
234    ds = dataset_ops.Dataset.from_tensors(batched)
235    ds2 = ds.unbatch()
236    if context.executing_eagerly():
237      v = list(ds2.batch(2))
238      self.assertAllEqual(v[0], batched)
239
240    if not use_only_batched_spec:
241      unbatched_spec = type_spec.type_spec_from_value(unbatched[0])
242
243      dsu = dataset_ops.Dataset.from_generator(
244          unbatch_gen, output_signature=unbatched_spec)
245      dsu2 = dsu.batch(2)
246      if context.executing_eagerly():
247        v = list(dsu2)
248        self.assertAllEqual(v[0], batched)
249
250  def _lambda_for_fields(self):
251    return lambda: {
252        'a':
253            np.ones([1, 2, 3, 1]),
254        'b':
255            np.ones([1, 2, 3, 1, 5]),
256        'c':
257            ragged_factory_ops.constant(
258                np.zeros([1, 2, 3, 1], dtype=np.uint8), dtype=dtypes.uint8),
259        'd':
260            ragged_factory_ops.constant(
261                np.zeros([1, 2, 3, 1, 3]).tolist(), ragged_rank=1),
262        'e':
263            ragged_factory_ops.constant(
264                np.zeros([1, 2, 3, 1, 2, 2]).tolist(), ragged_rank=2),
265        'f':
266            ragged_factory_ops.constant(
267                np.zeros([1, 2, 3, 1, 3]), dtype=dtypes.float32),
268        'g':
269            StructuredTensor.from_pyval([[
270                [  # pylint: disable=g-complex-comprehension
271                    [{
272                        'x': j,
273                        'y': k
274                    }] for k in range(3)
275                ] for j in range(2)
276            ]]),
277        'h':
278            StructuredTensor.from_pyval([[
279                [  # pylint: disable=g-complex-comprehension
280                    [[
281                        {
282                            'x': j,
283                            'y': k,
284                            'z': z
285                        } for z in range(j)
286                    ]] for k in range(3)
287                ] for j in range(2)
288            ]]),
289    }
290
291
292if __name__ == '__main__':
293  googletest.main()
294