xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/structured/structured_tensor_slice_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for StructuredTensor."""
16
17from absl.testing import parameterized
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import tensor_shape
22from tensorflow.python.framework import tensor_spec
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops.structured import structured_tensor
26from tensorflow.python.platform import googletest
27
28
29# TODO(edloper): Move this to a common util package (forked from ragged).
30class _SliceBuilder:
31  """Helper to construct arguments for __getitem__.
32
33  Usage: _SliceBuilder()[<expr>] slice_spec Python generates for <expr>.
34  """
35
36  def __getitem__(self, slice_spec):
37    return slice_spec
38
39
40# TODO(edloper): Move this to a common util package (forked from ragged).
41SLICE_BUILDER = _SliceBuilder()
42
43
44# TODO(edloper): Move this to a common util package (forked from ragged).
45def _make_tensor_slice_spec(slice_spec, use_constant=True):
46  """Wraps all integers in an extended slice spec w/ a tensor.
47
48  This function is used to help test slicing when the slice spec contains
49  tensors, rather than integers.
50
51  Args:
52    slice_spec: The extended slice spec.
53    use_constant: If true, then wrap each integer with a tf.constant.  If false,
54      then wrap each integer with a tf.placeholder.
55
56  Returns:
57    A copy of slice_spec, but with each integer i replaced with tf.constant(i).
58  """
59
60  def make_piece_scalar(piece):
61    if isinstance(piece, int):
62      scalar = constant_op.constant(piece)
63      if use_constant:
64        return scalar
65      else:
66        return array_ops.placeholder_with_default(scalar, [])
67    elif isinstance(piece, slice):
68      return slice(
69          make_piece_scalar(piece.start), make_piece_scalar(piece.stop),
70          make_piece_scalar(piece.step))
71    else:
72      return piece
73
74  if isinstance(slice_spec, tuple):
75    return tuple(make_piece_scalar(piece) for piece in slice_spec)
76  else:
77    return make_piece_scalar(slice_spec)
78
79
80EXAMPLE_STRUCT = {
81    # f1: scalar value field
82    "f1": 1,
83    # f2: matrix field
84    "f2": [[1, 2], [3, 4]],
85    # f3: scalar structure field
86    "f3": {"f3_1": 1},
87    # f4: vector structure field
88    "f4": [{"f4_1": 1, "f4_2": b"a"}, {"f4_1": 2, "f4_2": b"b"}],
89    # f5: matrix structure field
90    "f5": [[{"f5_1": 1}, {"f5_1": 2}], [{"f5_1": 3}, {"f5_1": 4}]],
91}
92
93EXAMPLE_STRUCT_2 = {
94    # f1: scalar value field
95    "f1": 5,
96    # f2: matrix field
97    "f2": [[6, 7], [8, 9]],
98    # f3: scalar structure field
99    "f3": {"f3_1": 9},
100    # f4: vector structure field
101    "f4": [{"f4_1": 5, "f4_2": b"A"}, {"f4_1": 6, "f4_2": b"B"}],
102    # f5: matrix structure field
103    "f5": [[{"f5_1": 6}, {"f5_1": 7}], [{"f5_1": 8}, {"f5_1": 9}]],
104}
105
106EXAMPLE_STRUCT_VECTOR = [EXAMPLE_STRUCT] * 5 + [EXAMPLE_STRUCT_2]
107
108EXAMPLE_STRUCT_SPEC1 = structured_tensor.StructuredTensorSpec([], {
109    "f1": tensor_spec.TensorSpec([], dtypes.int32),
110    "f2": tensor_spec.TensorSpec([2, 2], dtypes.int32),
111    "f3": structured_tensor.StructuredTensorSpec(
112        [], {"f3_1": tensor_spec.TensorSpec([], dtypes.int32)}),
113    "f4": structured_tensor.StructuredTensorSpec(
114        [2], {"f4_1": tensor_spec.TensorSpec([2], dtypes.int32),
115              "f4_2": tensor_spec.TensorSpec([2], dtypes.string)}),
116    "f5": structured_tensor.StructuredTensorSpec(
117        [2, 2], {"f5_1": tensor_spec.TensorSpec([2, 2], dtypes.int32)}),
118})
119
120
121@test_util.run_all_in_graph_and_eager_modes
122class StructuredTensorSliceTest(test_util.TensorFlowTestCase,
123                                parameterized.TestCase):
124
125  def assertAllEqual(self, a, b, msg=None):
126    if not (isinstance(a, structured_tensor.StructuredTensor) or
127            isinstance(b, structured_tensor.StructuredTensor)):
128      super(StructuredTensorSliceTest, self).assertAllEqual(a, b, msg)
129    elif (isinstance(a, structured_tensor.StructuredTensor) and
130          isinstance(b, structured_tensor.StructuredTensor)):
131      a_shape = tensor_shape.as_shape(a.shape)
132      b_shape = tensor_shape.as_shape(b.shape)
133      a_shape.assert_is_compatible_with(b_shape)
134      self.assertEqual(set(a.field_names()), set(b.field_names()))
135      for field in a.field_names():
136        self.assertAllEqual(a.field_value(field), b.field_value(field))
137    elif isinstance(b, structured_tensor.StructuredTensor):
138      self.assertAllEqual(b, a, msg)
139    else:
140      if a.rank == 0:
141        self.assertIsInstance(b, dict)
142        self.assertEqual(set(a.field_names()), set(b))
143        for (key, b_val) in b.items():
144          a_val = a.field_value(key)
145          self.assertAllEqual(a_val, b_val)
146      else:
147        self.assertIsInstance(b, (list, tuple))
148        a.shape[:1].assert_is_compatible_with([len(b)])
149        for i in range(len(b)):
150          self.assertAllEqual(a[i], b[i])
151
152  def _TestGetItem(self, struct, slice_spec, expected):
153    """Helper function for testing StructuredTensor.__getitem__.
154
155    Checks that calling `struct.__getitem__(slice_spec) returns the expected
156    value.  Checks three different configurations for each slice spec:
157
158      * Call __getitem__ with the slice spec as-is (with int values)
159      * Call __getitem__ with int values in the slice spec wrapped in
160        `tf.constant()`.
161      * Call __getitem__ with int values in the slice spec wrapped in
162        `tf.compat.v1.placeholder()` (so value is not known at graph
163        construction time).
164
165    Args:
166      struct: The StructuredTensor to test.
167      slice_spec: The slice spec.
168      expected: The expected value of struct.__getitem__(slice_spec), as a
169        python list.
170    """
171    tensor_slice_spec1 = _make_tensor_slice_spec(slice_spec, True)
172    tensor_slice_spec2 = _make_tensor_slice_spec(slice_spec, False)
173    value1 = struct.__getitem__(slice_spec)
174    value2 = struct.__getitem__(tensor_slice_spec1)
175    value3 = struct.__getitem__(tensor_slice_spec2)
176    self.assertAllEqual(value1, expected, "slice_spec=%s" % (slice_spec,))
177    self.assertAllEqual(value2, expected, "slice_spec=%s" % (slice_spec,))
178    self.assertAllEqual(value3, expected, "slice_spec=%s" % (slice_spec,))
179
180  @parameterized.parameters([
181      # Simple indexing
182      (SLICE_BUILDER["f1"], EXAMPLE_STRUCT["f1"]),
183      (SLICE_BUILDER["f2"], EXAMPLE_STRUCT["f2"]),
184      (SLICE_BUILDER["f3"], EXAMPLE_STRUCT["f3"]),
185      (SLICE_BUILDER["f4"], EXAMPLE_STRUCT["f4"]),
186      (SLICE_BUILDER["f5"], EXAMPLE_STRUCT["f5"]),
187      # Multidimensional indexing
188      (SLICE_BUILDER["f2", 1], EXAMPLE_STRUCT["f2"][1]),
189      (SLICE_BUILDER["f3", "f3_1"], EXAMPLE_STRUCT["f3"]["f3_1"]),
190      (SLICE_BUILDER["f4", 1], EXAMPLE_STRUCT["f4"][1]),
191      (SLICE_BUILDER["f4", 1, "f4_2"], EXAMPLE_STRUCT["f4"][1]["f4_2"]),
192      (SLICE_BUILDER["f5", 0, 1], EXAMPLE_STRUCT["f5"][0][1]),
193      (SLICE_BUILDER["f5", 0, 1, "f5_1"], EXAMPLE_STRUCT["f5"][0][1]["f5_1"]),
194      # Multidimensional slicing
195      (SLICE_BUILDER["f2", 1:], EXAMPLE_STRUCT["f2"][1:]),
196      (SLICE_BUILDER["f4", :1], EXAMPLE_STRUCT["f4"][:1]),
197      (SLICE_BUILDER["f4", 1:, "f4_2"], [b"b"]),
198      (SLICE_BUILDER["f4", :, "f4_2"], [b"a", b"b"]),
199      (SLICE_BUILDER["f5", :, :, "f5_1"], [[1, 2], [3, 4]]),
200      # Slicing over multiple keys
201      (SLICE_BUILDER[:], EXAMPLE_STRUCT),
202      # List-valued key.
203      (["f2", 1], EXAMPLE_STRUCT["f2"][1]),
204  ])
205  def testGetitemFromScalarStruct(self, slice_spec, expected):
206    # By default, lists are converted to RaggedTensors.
207    struct = structured_tensor.StructuredTensor.from_pyval(EXAMPLE_STRUCT)
208    self._TestGetItem(struct, slice_spec, expected)
209
210    # Using an explicit TypeSpec, we can convert them to Tensors instead.
211    struct2 = structured_tensor.StructuredTensor.from_pyval(
212        EXAMPLE_STRUCT, EXAMPLE_STRUCT_SPEC1)
213    self._TestGetItem(struct2, slice_spec, expected)
214
215  @parameterized.parameters([
216      (SLICE_BUILDER[2], EXAMPLE_STRUCT_VECTOR[2]),
217      (SLICE_BUILDER[5], EXAMPLE_STRUCT_VECTOR[5]),
218      (SLICE_BUILDER[-2], EXAMPLE_STRUCT_VECTOR[-2]),
219      (SLICE_BUILDER[-1], EXAMPLE_STRUCT_VECTOR[-1]),
220      (SLICE_BUILDER[2, "f1"], EXAMPLE_STRUCT_VECTOR[2]["f1"]),
221      (SLICE_BUILDER[-1, "f1"], EXAMPLE_STRUCT_VECTOR[-1]["f1"]),
222      (SLICE_BUILDER[5:], EXAMPLE_STRUCT_VECTOR[5:]),
223      (SLICE_BUILDER[3:, "f1"], [1, 1, 5]),
224      (SLICE_BUILDER[::2, "f1"], [1, 1, 1]),
225      (SLICE_BUILDER[1::2, "f1"], [1, 1, 5]),
226      (SLICE_BUILDER[4:, "f5", 0, 1, "f5_1"], [2, 7], True),
227      (SLICE_BUILDER[4:, "f5", :, :, "f5_1"],
228       [[[1, 2], [3, 4]], [[6, 7], [8, 9]]]),
229  ])  # pyformat: disable
230  def testGetitemFromVectorStruct(self, slice_spec, expected,
231                                  test_requires_typespec=False):
232    # By default, lists are converted to RaggedTensors.
233    if not test_requires_typespec:
234      struct_vector = structured_tensor.StructuredTensor.from_pyval(
235          EXAMPLE_STRUCT_VECTOR)
236      self._TestGetItem(struct_vector, slice_spec, expected)
237
238    # Using an explicit TypeSpec, we can convert them to Tensors instead.
239    struct_vector2 = structured_tensor.StructuredTensor.from_pyval(
240        EXAMPLE_STRUCT_VECTOR, EXAMPLE_STRUCT_SPEC1._batch(6))
241    self._TestGetItem(struct_vector2, slice_spec, expected)
242
243  # TODO(edloper): Add tests for slicing from matrix StructuredTensors.
244
245  @parameterized.parameters([
246      (SLICE_BUILDER[:2], r"Key for indexing a StructuredTensor must be "
247       r"a string or a full slice \(':'\)"),
248      (SLICE_BUILDER["f4", ...], r"Slicing not supported for Ellipsis"),
249      (SLICE_BUILDER["f4", None], r"Slicing not supported for tf.newaxis"),
250      (SLICE_BUILDER["f4", :, 0],
251       r"Key for indexing a StructuredTensor must be a string"),
252  ])
253  def testGetItemError(self, slice_spec, error, exception=ValueError):
254    struct = structured_tensor.StructuredTensor.from_pyval(EXAMPLE_STRUCT)
255    with self.assertRaisesRegex(exception, error):
256      struct.__getitem__(slice_spec)
257
258  @parameterized.parameters([
259      (SLICE_BUILDER[:, 1],
260       r"Key for indexing a StructuredTensor must be a string"),
261  ])
262  def testGetItemFromVectorError(self, slice_spec, error, exception=ValueError):
263    struct = structured_tensor.StructuredTensor.from_pyval(
264        EXAMPLE_STRUCT_VECTOR)
265    with self.assertRaisesRegex(exception, error):
266      struct.__getitem__(slice_spec)
267
268
269if __name__ == "__main__":
270  googletest.main()
271