xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/structured/structured_array_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for structured_array_ops."""
15
16
17from absl.testing import parameterized
18
19from tensorflow.python.eager import def_function
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import random_seed
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import random_ops
28from tensorflow.python.ops.ragged import ragged_factory_ops
29from tensorflow.python.ops.ragged import ragged_tensor
30from tensorflow.python.ops.ragged import row_partition
31from tensorflow.python.ops.ragged.dynamic_ragged_shape import DynamicRaggedShape
32from tensorflow.python.ops.structured import structured_array_ops
33from tensorflow.python.ops.structured import structured_tensor
34from tensorflow.python.ops.structured.structured_tensor import StructuredTensor
35from tensorflow.python.platform import googletest
36from tensorflow.python.util import nest
37
38
39# TODO(martinz):create StructuredTensorTestCase.
40# pylint: disable=g-long-lambda
41@test_util.run_all_in_graph_and_eager_modes
42class StructuredArrayOpsTest(test_util.TensorFlowTestCase,
43                             parameterized.TestCase):
44
45  def assertAllEqual(self, a, b, msg=None):
46    if not (isinstance(a, structured_tensor.StructuredTensor) or
47            isinstance(b, structured_tensor.StructuredTensor)):
48      return super(StructuredArrayOpsTest, self).assertAllEqual(a, b, msg)
49
50    if not isinstance(a, structured_tensor.StructuredTensor):
51      a = structured_tensor.StructuredTensor.from_pyval(a)
52    elif not isinstance(b, structured_tensor.StructuredTensor):
53      b = structured_tensor.StructuredTensor.from_pyval(b)
54
55    try:
56      nest.assert_same_structure(a, b, expand_composites=True)
57    except (TypeError, ValueError) as e:
58      self.assertIsNone(e, (msg + ": " if msg else "") + str(e))
59    a_tensors = [x for x in nest.flatten(a, expand_composites=True)
60                 if isinstance(x, ops.Tensor)]
61    b_tensors = [x for x in nest.flatten(b, expand_composites=True)
62                 if isinstance(x, ops.Tensor)]
63    self.assertLen(a_tensors, len(b_tensors))
64    a_arrays, b_arrays = self.evaluate((a_tensors, b_tensors))
65    for a_array, b_array in zip(a_arrays, b_arrays):
66      self.assertAllEqual(a_array, b_array, msg)
67
68  def _assertStructuredEqual(self, a, b, msg, check_shape):
69    if check_shape:
70      self.assertEqual(repr(a.shape), repr(b.shape))
71    self.assertEqual(set(a.field_names()), set(b.field_names()))
72    for field in a.field_names():
73      a_value = a.field_value(field)
74      b_value = b.field_value(field)
75      self.assertIs(type(a_value), type(b_value))
76      if isinstance(a_value, structured_tensor.StructuredTensor):
77        self._assertStructuredEqual(a_value, b_value, msg, check_shape)
78      else:
79        self.assertAllEqual(a_value, b_value, msg)
80
81  @parameterized.named_parameters([
82      dict(
83          testcase_name="0D_0",
84          st={"x": 1},
85          axis=0,
86          expected=[{"x": 1}]),
87      dict(
88          testcase_name="0D_minus_1",
89          st={"x": 1},
90          axis=-1,
91          expected=[{"x": 1}]),
92      dict(
93          testcase_name="1D_0",
94          st=[{"x": [1, 3]}, {"x": [2, 7, 9]}],
95          axis=0,
96          expected=[[{"x": [1, 3]}, {"x": [2, 7, 9]}]]),
97      dict(
98          testcase_name="1D_1",
99          st=[{"x": [1]}, {"x": [2, 10]}],
100          axis=1,
101          expected=[[{"x": [1]}], [{"x": [2, 10]}]]),
102      dict(
103          testcase_name="2D_0",
104          st=[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]],
105          axis=0,
106          expected=[[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]]]),
107      dict(
108          testcase_name="2D_1",
109          st=[[{"x": 1}, {"x": 2}], [{"x": 3}]],
110          axis=1,
111          expected=[[[{"x": 1}, {"x": 2}]], [[{"x": 3}]]]),
112      dict(
113          testcase_name="2D_2",
114          st=[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]],
115          axis=2,
116          expected=[[[{"x": [1]}], [{"x": [2]}]], [[{"x": [3, 4]}]]]),
117      dict(
118          testcase_name="3D_0",
119          st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
120          axis=0,
121          expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]],
122                     [[{"x": [4, 5]}]]]]),
123      dict(
124          testcase_name="3D_minus_4",
125          st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
126          axis=-4,  # same as zero
127          expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]],
128                     [[{"x": [4, 5]}]]]]),
129      dict(
130          testcase_name="3D_1",
131          st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
132          axis=1,
133          expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]]],
134                    [[[{"x": [4, 5]}]]]]),
135      dict(
136          testcase_name="3D_minus_3",
137          st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
138          axis=-3,  # same as 1
139          expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]]],
140                    [[[{"x": [4, 5]}]]]]),
141      dict(
142          testcase_name="3D_2",
143          st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
144          axis=2,
145          expected=[[[[{"x": [1]}, {"x": [2]}]], [[{"x": [3]}]]],
146                    [[[{"x": [4, 5]}]]]]),
147      dict(
148          testcase_name="3D_minus_2",
149          st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
150          axis=-2,  # same as 2
151          expected=[[[[{"x": [1]}, {"x": [2]}]], [[{"x": [3]}]]],
152                    [[[{"x": [4, 5]}]]]]),
153      dict(
154          testcase_name="3D_3",
155          st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]],
156          axis=3,
157          expected=[[[[{"x": [1]}], [{"x": [2]}]], [[{"x": [3]}]]],
158                    [[[{"x": [4, 5]}]]]]),
159  ])  # pyformat: disable
160  def testExpandDims(self, st, axis, expected):
161    st = StructuredTensor.from_pyval(st)
162    result = array_ops.expand_dims(st, axis)
163    self.assertAllEqual(result, expected)
164
165  def testExpandDimsAxisTooBig(self):
166    st = [[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]]
167    st = StructuredTensor.from_pyval(st)
168    with self.assertRaisesRegex(ValueError,
169                                "axis=4 out of bounds: expected -4<=axis<4"):
170      array_ops.expand_dims(st, 4)
171
172  def testExpandDimsAxisTooSmall(self):
173    st = [[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]]
174    st = StructuredTensor.from_pyval(st)
175    with self.assertRaisesRegex(ValueError,
176                                "axis=-5 out of bounds: expected -4<=axis<4"):
177      array_ops.expand_dims(st, -5)
178
179  def testExpandDimsScalar(self):
180    # Note that if we expand_dims for the final dimension and there are scalar
181    # fields, then the shape is (2, None, None, 1), whereas if it is constructed
182    # from pyval it is (2, None, None, None).
183    st = [[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]]
184    st = StructuredTensor.from_pyval(st)
185    result = array_ops.expand_dims(st, 3)
186    expected_shape = tensor_shape.TensorShape([2, None, None, 1])
187    self.assertEqual(repr(expected_shape), repr(result.shape))
188
189  @parameterized.named_parameters([
190      dict(
191          testcase_name="scalar_int32",
192          row_partitions=None,
193          shape=(),
194          dtype=dtypes.int32,
195          expected=1),
196      dict(
197          testcase_name="scalar_int64",
198          row_partitions=None,
199          shape=(),
200          dtype=dtypes.int64,
201          expected=1),
202      dict(
203          testcase_name="list_0_int32",
204          row_partitions=None,
205          shape=(0),
206          dtype=dtypes.int32,
207          expected=0),
208      dict(
209          testcase_name="list_0_0_int32",
210          row_partitions=None,
211          shape=(0, 0),
212          dtype=dtypes.int32,
213          expected=0),
214      dict(
215          testcase_name="list_int32",
216          row_partitions=None,
217          shape=(7),
218          dtype=dtypes.int32,
219          expected=7),
220      dict(
221          testcase_name="list_int64",
222          row_partitions=None,
223          shape=(7),
224          dtype=dtypes.int64,
225          expected=7),
226      dict(
227          testcase_name="matrix_int32",
228          row_partitions=[[0, 3, 6]],
229          shape=(2, 3),
230          dtype=dtypes.int32,
231          expected=6),
232      dict(
233          testcase_name="tensor_int32",
234          row_partitions=[[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]],
235          shape=(2, 3, 1),
236          dtype=dtypes.int32,
237          expected=6),
238      dict(
239          testcase_name="ragged_1_int32",
240          row_partitions=[[0, 3, 4]],
241          shape=(2, None),
242          dtype=dtypes.int32,
243          expected=4),
244      dict(
245          testcase_name="ragged_2_float32",
246          row_partitions=[[0, 3, 4], [0, 2, 3, 5, 7]],
247          shape=(2, None, None),
248          dtype=dtypes.float32,
249          expected=7),
250  ])  # pyformat: disable
251  def testSizeObject(self, row_partitions, shape, dtype, expected):
252    if row_partitions is not None:
253      row_partitions = [
254          row_partition.RowPartition.from_row_splits(r) for r in row_partitions
255      ]
256    st = StructuredTensor.from_fields({},
257                                      shape=shape,
258                                      row_partitions=row_partitions)
259    # NOTE: size is very robust. There aren't arguments that
260    # should cause this operation to fail.
261    actual = array_ops.size(st, out_type=dtype)
262    self.assertAllEqual(actual, expected)
263
264    actual2 = array_ops.size_v2(st, out_type=dtype)
265    self.assertAllEqual(actual2, expected)
266
267  def test_shape_v2(self):
268    rt = ragged_tensor.RaggedTensor.from_row_lengths(["a", "b", "c"], [1, 2])
269    st = StructuredTensor.from_fields_and_rank({"r": rt}, rank=2)
270    actual = array_ops.shape_v2(st, out_type=dtypes.int64)
271    actual_static_lengths = actual.static_lengths()
272    self.assertAllEqual([2, (1, 2)], actual_static_lengths)
273
274  def test_shape(self):
275    rt = ragged_tensor.RaggedTensor.from_row_lengths(["a", "b", "c"], [1, 2])
276    st = StructuredTensor.from_fields_and_rank({"r": rt}, rank=2)
277    actual = array_ops.shape(st, out_type=dtypes.int64).static_lengths()
278    actual_v2 = array_ops.shape_v2(st, out_type=dtypes.int64).static_lengths()
279    expected = [2, (1, 2)]
280    self.assertAllEqual(expected, actual)
281    self.assertAllEqual(expected, actual_v2)
282
283  @parameterized.named_parameters([
284      dict(
285          testcase_name="list_empty_2_1",
286          values=[[{}, {}], [{}]],
287          dtype=dtypes.int32,
288          expected=3),
289      dict(
290          testcase_name="list_empty_2",
291          values=[{}, {}],
292          dtype=dtypes.int32,
293          expected=2),
294      dict(
295          testcase_name="list_empty_1",
296          values=[{}],
297          dtype=dtypes.int32,
298          expected=1),
299      dict(
300          testcase_name="list_example_1",
301          values=[{"x": [3]}, {"x": [4, 5]}],
302          dtype=dtypes.int32,
303          expected=2),
304      dict(
305          testcase_name="list_example_2",
306          values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]],
307          dtype=dtypes.float32,
308          expected=3),
309      dict(
310          testcase_name="list_example_2_None",
311          values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]],
312          dtype=None,
313          expected=3),
314  ])  # pyformat: disable
315  def testSizeAlt(self, values, dtype, expected):
316    st = StructuredTensor.from_pyval(values)
317    # NOTE: size is very robust. There aren't arguments that
318    # should cause this operation to fail.
319    actual = array_ops.size(st, out_type=dtype)
320    self.assertAllEqual(actual, expected)
321
322    actual2 = array_ops.size_v2(st, out_type=dtype)
323    self.assertAllEqual(actual2, expected)
324
325  @parameterized.named_parameters([
326      dict(
327          testcase_name="scalar_int32",
328          row_partitions=None,
329          shape=(),
330          dtype=dtypes.int32,
331          expected=0),
332      dict(
333          testcase_name="scalar_bool",
334          row_partitions=None,
335          shape=(),
336          dtype=dtypes.bool,
337          expected=False),
338      dict(
339          testcase_name="scalar_int64",
340          row_partitions=None,
341          shape=(),
342          dtype=dtypes.int64,
343          expected=0),
344      dict(
345          testcase_name="scalar_float32",
346          row_partitions=None,
347          shape=(),
348          dtype=dtypes.float32,
349          expected=0.0),
350      dict(
351          testcase_name="list_0_int32",
352          row_partitions=None,
353          shape=(0),
354          dtype=dtypes.int32,
355          expected=[]),
356      dict(
357          testcase_name="list_0_0_int32",
358          row_partitions=None,
359          shape=(0, 0),
360          dtype=dtypes.int32,
361          expected=[]),
362      dict(
363          testcase_name="list_int32",
364          row_partitions=None,
365          shape=(7),
366          dtype=dtypes.int32,
367          expected=[0, 0, 0, 0, 0, 0, 0]),
368      dict(
369          testcase_name="list_int64",
370          row_partitions=None,
371          shape=(7),
372          dtype=dtypes.int64,
373          expected=[0, 0, 0, 0, 0, 0, 0]),
374      dict(
375          testcase_name="list_float32",
376          row_partitions=None,
377          shape=(7),
378          dtype=dtypes.float32,
379          expected=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
380      dict(
381          testcase_name="matrix_int32",
382          row_partitions=[[0, 3, 6]],
383          shape=(2, 3),
384          dtype=dtypes.int32,
385          expected=[[0, 0, 0], [0, 0, 0]]),
386      dict(
387          testcase_name="matrix_float64",
388          row_partitions=[[0, 3, 6]],
389          shape=(2, 3),
390          dtype=dtypes.float64,
391          expected=[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]),
392      dict(
393          testcase_name="tensor_int32",
394          row_partitions=[[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]],
395          shape=(2, 3, 1),
396          dtype=dtypes.int32,
397          expected=[[[0], [0], [0]], [[0], [0], [0]]]),
398      dict(
399          testcase_name="tensor_float32",
400          row_partitions=[[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]],
401          shape=(2, 3, 1),
402          dtype=dtypes.float32,
403          expected=[[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]),
404      dict(
405          testcase_name="ragged_1_float32",
406          row_partitions=[[0, 3, 4]],
407          shape=(2, None),
408          dtype=dtypes.float32,
409          expected=[[0.0, 0.0, 0.0], [0.0]]),
410      dict(
411          testcase_name="ragged_2_float32",
412          row_partitions=[[0, 3, 4], [0, 2, 3, 5, 7]],
413          shape=(2, None, None),
414          dtype=dtypes.float32,
415          expected=[[[0.0, 0.0], [0.0], [0.0, 0.0]], [[0.0, 0.0]]]),
416  ])  # pyformat: disable
417  def testZerosLikeObject(self, row_partitions, shape, dtype, expected):
418    if row_partitions is not None:
419      row_partitions = [
420          row_partition.RowPartition.from_row_splits(r) for r in row_partitions
421      ]
422    st = StructuredTensor.from_fields({},
423                                      shape=shape,
424                                      row_partitions=row_partitions)
425    # NOTE: zeros_like is very robust. There aren't arguments that
426    # should cause this operation to fail.
427    actual = array_ops.zeros_like(st, dtype)
428    self.assertAllEqual(actual, expected)
429
430    actual2 = array_ops.zeros_like_v2(st, dtype)
431    self.assertAllEqual(actual2, expected)
432
433  @parameterized.named_parameters([
434      dict(
435          testcase_name="list_empty_2_1",
436          values=[[{}, {}], [{}]],
437          dtype=dtypes.int32,
438          expected=[[0, 0], [0]]),
439      dict(
440          testcase_name="list_empty_2",
441          values=[{}, {}],
442          dtype=dtypes.int32,
443          expected=[0, 0]),
444      dict(
445          testcase_name="list_empty_1",
446          values=[{}],
447          dtype=dtypes.int32,
448          expected=[0]),
449      dict(
450          testcase_name="list_example_1",
451          values=[{"x": [3]}, {"x": [4, 5]}],
452          dtype=dtypes.int32,
453          expected=[0, 0]),
454      dict(
455          testcase_name="list_example_2",
456          values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]],
457          dtype=dtypes.float32,
458          expected=[[0.0], [0.0, 0.0]]),
459      dict(
460          testcase_name="list_example_2_None",
461          values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]],
462          dtype=None,
463          expected=[[0.0], [0.0, 0.0]]),
464  ])  # pyformat: disable
465  def testZerosLikeObjectAlt(self, values, dtype, expected):
466    st = StructuredTensor.from_pyval(values)
467    # NOTE: zeros_like is very robust. There aren't arguments that
468    # should cause this operation to fail.
469    actual = array_ops.zeros_like(st, dtype)
470    self.assertAllEqual(actual, expected)
471
472    actual2 = array_ops.zeros_like_v2(st, dtype)
473    self.assertAllEqual(actual2, expected)
474
475  @parameterized.named_parameters([
476      dict(
477          testcase_name="scalar_int32",
478          row_partitions=None,
479          shape=(),
480          dtype=dtypes.int32,
481          expected=1),
482      dict(
483          testcase_name="scalar_bool",
484          row_partitions=None,
485          shape=(),
486          dtype=dtypes.bool,
487          expected=True),
488      dict(
489          testcase_name="scalar_int64",
490          row_partitions=None,
491          shape=(),
492          dtype=dtypes.int64,
493          expected=1),
494      dict(
495          testcase_name="scalar_float32",
496          row_partitions=None,
497          shape=(),
498          dtype=dtypes.float32,
499          expected=1.0),
500      dict(
501          testcase_name="list_0_int32",
502          row_partitions=None,
503          shape=(0),
504          dtype=dtypes.int32,
505          expected=[]),
506      dict(
507          testcase_name="list_0_0_int32",
508          row_partitions=None,
509          shape=(0, 0),
510          dtype=dtypes.int32,
511          expected=[]),
512      dict(
513          testcase_name="list_int32",
514          row_partitions=None,
515          shape=(7),
516          dtype=dtypes.int32,
517          expected=[1, 1, 1, 1, 1, 1, 1]),
518      dict(
519          testcase_name="list_int64",
520          row_partitions=None,
521          shape=(7),
522          dtype=dtypes.int64,
523          expected=[1, 1, 1, 1, 1, 1, 1]),
524      dict(
525          testcase_name="list_float32",
526          row_partitions=None,
527          shape=(7),
528          dtype=dtypes.float32,
529          expected=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]),
530      dict(
531          testcase_name="matrix_int32",
532          row_partitions=[[0, 3, 6]],
533          shape=(2, 3),
534          dtype=dtypes.int32,
535          expected=[[1, 1, 1], [1, 1, 1]]),
536      dict(
537          testcase_name="matrix_float64",
538          row_partitions=[[0, 3, 6]],
539          shape=(2, 3),
540          dtype=dtypes.float64,
541          expected=[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]),
542      dict(
543          testcase_name="tensor_int32",
544          row_partitions=[[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]],
545          shape=(2, 3, 1),
546          dtype=dtypes.int32,
547          expected=[[[1], [1], [1]], [[1], [1], [1]]]),
548      dict(
549          testcase_name="tensor_float32",
550          row_partitions=[[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]],
551          shape=(2, 3, 1),
552          dtype=dtypes.float32,
553          expected=[[[1.0], [1.0], [1.0]], [[1.0], [1.0], [1.0]]]),
554      dict(
555          testcase_name="ragged_1_float32",
556          row_partitions=[[0, 3, 4]],
557          shape=(2, None),
558          dtype=dtypes.float32,
559          expected=[[1.0, 1.0, 1.0], [1.0]]),
560      dict(
561          testcase_name="ragged_2_float32",
562          row_partitions=[[0, 3, 4], [0, 2, 3, 5, 7]],
563          shape=(2, None, None),
564          dtype=dtypes.float32,
565          expected=[[[1.0, 1.0], [1.0], [1.0, 1.0]], [[1.0, 1.0]]]),
566  ])  # pyformat: disable
567  def testOnesLikeObject(self, row_partitions, shape, dtype, expected):
568    if row_partitions is not None:
569      row_partitions = [
570          row_partition.RowPartition.from_row_splits(r) for r in row_partitions
571      ]
572    st = StructuredTensor.from_fields({},
573                                      shape=shape,
574                                      row_partitions=row_partitions)
575    # NOTE: ones_like is very robust. There aren't arguments that
576    # should cause this operation to fail.
577    actual = array_ops.ones_like(st, dtype)
578    self.assertAllEqual(actual, expected)
579
580    actual2 = array_ops.ones_like_v2(st, dtype)
581    self.assertAllEqual(actual2, expected)
582
583  @parameterized.named_parameters([
584      dict(
585          testcase_name="list_empty_2_1",
586          values=[[{}, {}], [{}]],
587          dtype=dtypes.int32,
588          expected=[[1, 1], [1]]),
589      dict(
590          testcase_name="list_empty_2",
591          values=[{}, {}],
592          dtype=dtypes.int32,
593          expected=[1, 1]),
594      dict(
595          testcase_name="list_empty_1",
596          values=[{}],
597          dtype=dtypes.int32,
598          expected=[1]),
599      dict(
600          testcase_name="list_example_1",
601          values=[{"x": [3]}, {"x": [4, 5]}],
602          dtype=dtypes.int32,
603          expected=[1, 1]),
604      dict(
605          testcase_name="list_example_2",
606          values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]],
607          dtype=dtypes.float32,
608          expected=[[1.0], [1.0, 1.0]]),
609      dict(
610          testcase_name="list_example_2_None",
611          values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]],
612          dtype=None,
613          expected=[[1.0], [1.0, 1.0]]),
614  ])  # pyformat: disable
615  def testOnesLikeObjectAlt(self, values, dtype, expected):
616    st = StructuredTensor.from_pyval(values)
617    # NOTE: ones_like is very robust. There aren't arguments that
618    # should cause this operation to fail.
619    actual = array_ops.ones_like(st, dtype)
620    self.assertAllEqual(actual, expected)
621
622    actual2 = array_ops.ones_like_v2(st, dtype)
623    self.assertAllEqual(actual2, expected)
624
625  @parameterized.named_parameters([
626      dict(
627          testcase_name="scalar",
628          row_partitions=None,
629          shape=(),
630          expected=0),
631      dict(
632          testcase_name="list_0",
633          row_partitions=None,
634          shape=(0,),
635          expected=1),
636      dict(
637          testcase_name="list_0_0",
638          row_partitions=None,
639          shape=(0, 0),
640          expected=2),
641      dict(
642          testcase_name="list_7",
643          row_partitions=None,
644          shape=(7,),
645          expected=1),
646      dict(
647          testcase_name="matrix",
648          row_partitions=[[0, 3, 6]],
649          shape=(2, 3),
650          expected=2),
651      dict(
652          testcase_name="tensor",
653          row_partitions=[[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]],
654          shape=(2, 3, 1),
655          expected=3),
656      dict(
657          testcase_name="ragged_1",
658          row_partitions=[[0, 3, 4]],
659          shape=(2, None),
660          expected=2),
661      dict(
662          testcase_name="ragged_2",
663          row_partitions=[[0, 3, 4], [0, 2, 3, 5, 7]],
664          shape=(2, None, None),
665          expected=3),
666  ])  # pyformat: disable
667  def testRank(self, row_partitions, shape, expected):
668    if row_partitions is not None:
669      row_partitions = [
670          row_partition.RowPartition.from_row_splits(r) for r in row_partitions
671      ]
672    st = StructuredTensor.from_fields({},
673                                      shape=shape,
674                                      row_partitions=row_partitions)
675
676    # NOTE: rank is very robust. There aren't arguments that
677    # should cause this operation to fail.
678    actual = structured_array_ops.rank(st)
679    self.assertAllEqual(expected, actual)
680
681  @parameterized.named_parameters([
682      dict(
683          testcase_name="list_empty_2_1",
684          values=[[{}, {}], [{}]],
685          expected=2),
686      dict(
687          testcase_name="list_empty_2",
688          values=[{}, {}],
689          expected=1),
690      dict(
691          testcase_name="list_empty_1",
692          values=[{}],
693          expected=1),
694      dict(
695          testcase_name="list_example_1",
696          values=[{"x": [3]}, {"x": [4, 5]}],
697          expected=1),
698      dict(
699          testcase_name="list_example_2",
700          values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]],
701          expected=2),
702  ])  # pyformat: disable
703  def testRankAlt(self, values, expected):
704    st = StructuredTensor.from_pyval(values)
705    # NOTE: rank is very robust. There aren't arguments that
706    # should cause this operation to fail.
707    actual = array_ops.rank(st)
708    self.assertAllEqual(expected, actual)
709
710  @parameterized.named_parameters([
711      dict(
712          testcase_name="list_empty",
713          values=[[{}], [{}]],
714          axis=0,
715          expected=[{}, {}]),
716      dict(
717          testcase_name="list_empty_2_1",
718          values=[[{}, {}], [{}]],
719          axis=0,
720          expected=[{}, {}, {}]),
721      dict(
722          testcase_name="list_with_fields",
723          values=[[{"a": 4, "b": [3, 4]}], [{"a": 5, "b": [5, 6]}]],
724          axis=0,
725          expected=[{"a": 4, "b": [3, 4]}, {"a": 5, "b": [5, 6]}]),
726      dict(
727          testcase_name="list_with_submessages",
728          values=[[{"a": {"foo": 3}, "b": [3, 4]}],
729                  [{"a": {"foo": 4}, "b": [5, 6]}]],
730          axis=0,
731          expected=[{"a": {"foo": 3}, "b": [3, 4]},
732                    {"a": {"foo": 4}, "b": [5, 6]}]),
733      dict(
734          testcase_name="list_with_empty_submessages",
735          values=[[{"a": {}, "b": [3, 4]}],
736                  [{"a": {}, "b": [5, 6]}]],
737          axis=0,
738          expected=[{"a": {}, "b": [3, 4]},
739                    {"a": {}, "b": [5, 6]}]),
740      dict(
741          testcase_name="lists_of_lists",
742          values=[[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}],
743                   [{"a": {}, "b": [7, 8, 9]}]],
744                  [[{"a": {}, "b": [10]}]]],
745          axis=0,
746          expected=[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}],
747                    [{"a": {}, "b": [7, 8, 9]}],
748                    [{"a": {}, "b": [10]}]]),
749      dict(
750          testcase_name="lists_of_lists_axis_1",
751          values=[[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}],
752                   [{"a": {}, "b": [7, 8, 9]}]],
753                  [[{"a": {}, "b": []}], [{"a": {}, "b": [3]}]]],
754          axis=1,
755          expected=[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]},
756                     {"a": {}, "b": []}],
757                    [{"a": {}, "b": [7, 8, 9]}, {"a": {}, "b": [3]}]]),
758      dict(
759          testcase_name="lists_of_lists_axis_minus_2",
760          values=[[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}],
761                   [{"a": {}, "b": [7, 8, 9]}]],
762                  [[{"a": {}, "b": [10]}]]],
763          axis=-2,  # Same as axis=0.
764          expected=[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}],
765                    [{"a": {}, "b": [7, 8, 9]}],
766                    [{"a": {}, "b": [10]}]]),
767      dict(
768          testcase_name="from_structured_tensor_util_test",
769          values=[[{"x0": 0, "y": {"z": [[3, 13]]}},
770                   {"x0": 1, "y": {"z": [[3], [4, 13]]}},
771                   {"x0": 2, "y": {"z": [[3, 5], [4]]}}],
772                  [{"x0": 3, "y": {"z": [[3, 7, 1], [4]]}},
773                   {"x0": 4, "y": {"z": [[3], [4]]}}]],
774          axis=0,
775          expected=[{"x0": 0, "y": {"z": [[3, 13]]}},
776                    {"x0": 1, "y": {"z": [[3], [4, 13]]}},
777                    {"x0": 2, "y": {"z": [[3, 5], [4]]}},
778                    {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}},
779                    {"x0": 4, "y": {"z": [[3], [4]]}}]),
780  ])  # pyformat: disable
781  def testConcat(self, values, axis, expected):
782    values = [StructuredTensor.from_pyval(v) for v in values]
783    actual = array_ops.concat(values, axis)
784    self.assertAllEqual(actual, expected)
785
786  def testConcatTuple(self):
787    values = (StructuredTensor.from_pyval([{"a": 3}]),
788              StructuredTensor.from_pyval([{"a": 4}]))
789    actual = array_ops.concat(values, axis=0)
790    self.assertAllEqual(actual, [{"a": 3}, {"a": 4}])
791
792  @parameterized.named_parameters([
793      dict(
794          testcase_name="field_dropped",
795          values=[[{"a": [2]}], [{}]],
796          axis=0,
797          error_type=ValueError,
798          error_regex="a"),
799      dict(
800          testcase_name="field_added",
801          values=[[{"b": [3]}], [{"b": [3], "a": [7]}]],
802          axis=0,
803          error_type=ValueError,
804          error_regex="b"),
805      dict(testcase_name="rank_submessage_change",
806           values=[[{"a": [{"b": [[3]]}]}],
807                   [{"a": [[{"b": [3]}]]}]],
808           axis=0,
809           error_type=ValueError,
810           error_regex="Ranks of sub-message do not match",
811          ),
812      dict(testcase_name="rank_message_change",
813           values=[[{"a": [3]}],
814                   [[{"a": 3}]]],
815           axis=0,
816           error_type=ValueError,
817           error_regex="Ranks of sub-message do not match",
818          ),
819      dict(testcase_name="concat_scalar",
820           values=[{"a": [3]}, {"a": [4]}],
821           axis=0,
822           error_type=ValueError,
823           error_regex="axis=0 out of bounds",
824          ),
825      dict(testcase_name="concat_axis_large",
826           values=[[{"a": [3]}], [{"a": [4]}]],
827           axis=1,
828           error_type=ValueError,
829           error_regex="axis=1 out of bounds",
830          ),
831      dict(testcase_name="concat_axis_large_neg",
832           values=[[{"a": [3]}], [{"a": [4]}]],
833           axis=-2,
834           error_type=ValueError,
835           error_regex="axis=-2 out of bounds",
836          ),
837      dict(testcase_name="concat_deep_rank_wrong",
838           values=[[{"a": [3]}], [{"a": [[4]]}]],
839           axis=0,
840           error_type=ValueError,
841           error_regex="must have rank",
842          ),
843  ])  # pyformat: disable
844  def testConcatError(self, values, axis, error_type, error_regex):
845    values = [StructuredTensor.from_pyval(v) for v in values]
846    with self.assertRaisesRegex(error_type, error_regex):
847      array_ops.concat(values, axis)
848
849  def testConcatWithRagged(self):
850    values = [StructuredTensor.from_pyval({}), array_ops.constant(3)]
851    with self.assertRaisesRegex(ValueError,
852                                "values must be a list of StructuredTensors"):
853      array_ops.concat(values, 0)
854
855  def testConcatNotAList(self):
856    values = StructuredTensor.from_pyval({})
857    with self.assertRaisesRegex(
858        ValueError, "values must be a list of StructuredTensors"):
859      structured_array_ops.concat(values, 0)
860
861  def testConcatEmptyList(self):
862    with self.assertRaisesRegex(ValueError,
863                                "values must not be an empty list"):
864      structured_array_ops.concat([], 0)
865
866  def testExtendOpErrorNotList(self):
867    # Should be a list.
868    values = StructuredTensor.from_pyval({})
869    def leaf_op(values):
870      return values[0]
871    with self.assertRaisesRegex(ValueError, "Expected a list"):
872      structured_array_ops._extend_op(values, leaf_op)
873
874  def testExtendOpErrorEmptyList(self):
875    def leaf_op(values):
876      return values[0]
877    with self.assertRaisesRegex(ValueError, "List cannot be empty"):
878      structured_array_ops._extend_op([], leaf_op)
879
880  def testRandomShuffle2021(self):
881    original = StructuredTensor.from_pyval([
882        {"x0": 0, "y": {"z": [[3, 13]]}},
883        {"x0": 1, "y": {"z": [[3], [4, 13]]}},
884        {"x0": 2, "y": {"z": [[3, 5], [4]]}},
885        {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}},
886        {"x0": 4, "y": {"z": [[3], [4]]}}])  # pyformat: disable
887    random_seed.set_seed(1066)
888    result = random_ops.random_shuffle(original, seed=2021)
889    expected = StructuredTensor.from_pyval([
890        {"x0": 0, "y": {"z": [[3, 13]]}},
891        {"x0": 1, "y": {"z": [[3], [4, 13]]}},
892        {"x0": 4, "y": {"z": [[3], [4]]}},
893        {"x0": 2, "y": {"z": [[3, 5], [4]]}},
894        {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}},])  # pyformat: disable
895    self.assertAllEqual(result, expected)
896
897  def testRandomShuffle2022Eager(self):
898    original = StructuredTensor.from_pyval([
899        {"x0": 0, "y": {"z": [[3, 13]]}},
900        {"x0": 1, "y": {"z": [[3], [4, 13]]}},
901        {"x0": 2, "y": {"z": [[3, 5], [4]]}},
902        {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}},
903        {"x0": 4, "y": {"z": [[3], [4]]}}])  # pyformat: disable
904    expected = StructuredTensor.from_pyval([
905        {"x0": 1, "y": {"z": [[3], [4, 13]]}},
906        {"x0": 0, "y": {"z": [[3, 13]]}},
907        {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}},
908        {"x0": 4, "y": {"z": [[3], [4]]}},
909        {"x0": 2, "y": {"z": [[3, 5], [4]]}}])  # pyformat: disable
910    random_seed.set_seed(1066)
911    result = structured_array_ops.random_shuffle(original, seed=2022)
912    self.assertAllEqual(result, expected)
913
914  def testRandomShuffleScalarError(self):
915    original = StructuredTensor.from_pyval(
916        {"x0": 2, "y": {"z": [[3, 5], [4]]}})  # pyformat: disable
917
918    with self.assertRaisesRegex(ValueError, "scalar"):
919      random_ops.random_shuffle(original)
920
921  def testStructuredTensorArrayLikeNoRank(self):
922    """Test when the rank is unknown."""
923    @def_function.function
924    def my_fun(foo):
925      bar_shape = math_ops.range(foo)
926      bar = array_ops.zeros(shape=bar_shape)
927      structured_array_ops._structured_tensor_like(bar)
928
929    with self.assertRaisesRegex(ValueError,
930                                "Can't build StructuredTensor w/ unknown rank"):
931      my_fun(array_ops.constant(3))
932
933  def testStructuredTensorArrayRankOneKnownShape(self):
934    """Fully test structured_tensor_array_like."""
935    foo = array_ops.zeros(shape=[4])
936    result = structured_array_ops._structured_tensor_like(foo)
937    self.assertAllEqual([{}, {}, {}, {}], result)
938
939  # Note that we have to be careful about whether the indices are int32
940  # or int64.
941  def testStructuredTensorArrayRankOneUnknownShape(self):
942    """Fully test structured_tensor_array_like."""
943    @def_function.function
944    def my_fun(my_shape):
945      my_zeros = array_ops.zeros(my_shape)
946      return structured_array_ops._structured_tensor_like(my_zeros)
947
948    result = my_fun(array_ops.constant(4))
949    shape = DynamicRaggedShape._from_inner_shape([4], dtype=dtypes.int32)
950    expected = StructuredTensor.from_shape(shape)
951    self.assertAllEqual(expected, result)
952
953  def testStructuredTensorArrayRankTwoUnknownShape(self):
954    """Fully test structured_tensor_array_like."""
955    @def_function.function
956    def my_fun(my_shape):
957      my_zeros = array_ops.zeros(my_shape)
958      return structured_array_ops._structured_tensor_like(my_zeros)
959
960    result = my_fun(array_ops.constant([2, 2]))
961    self.assertAllEqual([[{}, {}], [{}, {}]], result)
962
963  def testStructuredTensorArrayRankZero(self):
964    """Fully test structured_tensor_array_like."""
965    foo = array_ops.zeros(shape=[])
966    result = structured_array_ops._structured_tensor_like(foo)
967    self.assertAllEqual({}, result)
968
969  def testStructuredTensorLikeStructuredTensor(self):
970    """Fully test structured_tensor_array_like."""
971    foo = structured_tensor.StructuredTensor.from_pyval([{"a": 3}, {"a": 7}])
972    result = structured_array_ops._structured_tensor_like(foo)
973    self.assertAllEqual([{}, {}], result)
974
975  def testStructuredTensorArrayLike(self):
976    """There was a bug in a case in a private function.
977
978    This was difficult to reach externally, so I wrote a test
979    to check it directly.
980    """
981    rt = ragged_tensor.RaggedTensor.from_row_splits(
982        array_ops.zeros(shape=[5, 3]), [0, 3, 5])
983    result = structured_array_ops._structured_tensor_like(rt)
984    self.assertEqual(3, result.rank)
985
986  @parameterized.named_parameters([
987      dict(
988          testcase_name="list_empty",
989          params=[{}, {}, {}],
990          indices=[0, 2],
991          axis=0,
992          batch_dims=0,
993          expected=[{}, {}]),
994      dict(
995          testcase_name="list_of_lists_empty",
996          params=[[{}, {}], [{}], [{}, {}, {}]],
997          indices=[2, 0],
998          axis=0,
999          batch_dims=0,
1000          expected=[[{}, {}, {}], [{}, {}]]),
1001      dict(
1002          testcase_name="list_with_fields",
1003          params=[{"a": 4, "b": [3, 4]}, {"a": 5, "b": [5, 6]},
1004                  {"a": 7, "b": [9, 10]}],
1005          indices=[2, 0, 0],
1006          axis=0,
1007          batch_dims=0,
1008          expected=[{"a": 7, "b": [9, 10]}, {"a": 4, "b": [3, 4]},
1009                    {"a": 4, "b": [3, 4]}]),
1010      dict(
1011          testcase_name="list_with_submessages",
1012          params=[{"a": {"foo": 3}, "b": [3, 4]},
1013                  {"a": {"foo": 4}, "b": [5, 6]},
1014                  {"a": {"foo": 7}, "b": [9, 10]}],
1015          indices=[2, 0],
1016          axis=0,
1017          batch_dims=0,
1018          expected=[{"a": {"foo": 7}, "b": [9, 10]},
1019                    {"a": {"foo": 3}, "b": [3, 4]}]),
1020      dict(
1021          testcase_name="list_with_empty_submessages",
1022          params=[{"a": {}, "b": [3, 4]},
1023                  {"a": {}, "b": [5, 6]},
1024                  {"a": {}, "b": [9, 10]}],
1025          indices=[2, 0],
1026          axis=0,
1027          batch_dims=0,
1028          expected=[{"a": {}, "b": [9, 10]},
1029                    {"a": {}, "b": [3, 4]}]),
1030      dict(
1031          testcase_name="lists_of_lists",
1032          params=[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}],
1033                  [{"a": {}, "b": [7, 8, 9]}],
1034                  [{"a": {}, "b": []}]],
1035          indices=[2, 0, 0],
1036          axis=0,
1037          batch_dims=0,
1038          expected=[[{"a": {}, "b": []}],
1039                    [{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}],
1040                    [{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}]]),
1041      dict(
1042          testcase_name="lists_of_lists_axis_1",
1043          params=[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}],
1044                  [{"a": {}, "b": [7, 8, 9]}, {"a": {}, "b": [2, 8, 2]}],
1045                  [{"a": {}, "b": []}, {"a": {}, "b": [4]}]],
1046          indices=[1, 0],
1047          axis=1,
1048          batch_dims=0,
1049          expected=[[{"a": {}, "b": [5]}, {"a": {}, "b": [3, 4]}],
1050                    [{"a": {}, "b": [2, 8, 2]}, {"a": {}, "b": [7, 8, 9]}],
1051                    [{"a": {}, "b": [4]}, {"a": {}, "b": []}]]),
1052      dict(
1053          testcase_name="lists_of_lists_axis_minus_2",
1054          params=[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}],
1055                  [{"a": {}, "b": [7, 8, 9]}],
1056                  [{"a": {}, "b": []}]],
1057          indices=[2, 0, 0],
1058          axis=-2,  # same as 0
1059          batch_dims=0,
1060          expected=[[{"a": {}, "b": []}],
1061                    [{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}],
1062                    [{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}]]),
1063      dict(
1064          testcase_name="lists_of_lists_axis_minus_1",
1065          params=[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}],
1066                  [{"a": {}, "b": [7, 8, 9]}, {"a": {}, "b": [2, 8, 2]}],
1067                  [{"a": {}, "b": []}, {"a": {}, "b": [4]}]],
1068          indices=[1, 0],
1069          axis=-1,  # same as 1
1070          batch_dims=0,
1071          expected=[[{"a": {}, "b": [5]}, {"a": {}, "b": [3, 4]}],
1072                    [{"a": {}, "b": [2, 8, 2]}, {"a": {}, "b": [7, 8, 9]}],
1073                    [{"a": {}, "b": [4]}, {"a": {}, "b": []}]]),
1074      dict(
1075          testcase_name="from_structured_tensor_util_test",
1076          params=[{"x0": 0, "y": {"z": [[3, 13]]}},
1077                  {"x0": 1, "y": {"z": [[3], [4, 13]]}},
1078                  {"x0": 2, "y": {"z": [[3, 5], [4]]}},
1079                  {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}},
1080                  {"x0": 4, "y": {"z": [[3], [4]]}}],
1081          indices=[1, 0, 4, 3, 2],
1082          axis=0,
1083          batch_dims=0,
1084          expected=[{"x0": 1, "y": {"z": [[3], [4, 13]]}},
1085                    {"x0": 0, "y": {"z": [[3, 13]]}},
1086                    {"x0": 4, "y": {"z": [[3], [4]]}},
1087                    {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}},
1088                    {"x0": 2, "y": {"z": [[3, 5], [4]]}}]),
1089      dict(
1090          testcase_name="scalar_index_axis_0",
1091          params=[{"x0": 0, "y": {"z": [[3, 13]]}},
1092                  {"x0": 1, "y": {"z": [[3], [4, 13]]}},
1093                  {"x0": 2, "y": {"z": [[3, 5], [4]]}},
1094                  {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}},
1095                  {"x0": 4, "y": {"z": [[3], [4]]}}],
1096          indices=3,
1097          axis=0,
1098          batch_dims=0,
1099          expected={"x0": 3, "y": {"z": [[3, 7, 1], [4]]}}),
1100      dict(
1101          testcase_name="params_2D_vector_index_axis_1_batch_dims_1",
1102          params=[[{"x0": 0, "y": {"z": [[3, 13]]}},
1103                   {"x0": 1, "y": {"z": [[3], [4, 13]]}}],
1104                  [{"x0": 2, "y": {"z": [[3, 5], [4]]}},
1105                   {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}},
1106                   {"x0": 4, "y": {"z": [[3], [4]]}}]],
1107          indices=[1, 0],
1108          axis=1,
1109          batch_dims=1,
1110          expected=[{"x0": 1, "y": {"z": [[3], [4, 13]]}},
1111                    {"x0": 2, "y": {"z": [[3, 5], [4]]}}]),
1112  ])  # pyformat: disable
1113  def testGather(self, params, indices, axis, batch_dims, expected):
1114    params = StructuredTensor.from_pyval(params)
1115    # validate_indices isn't actually used, and we aren't testing names
1116    actual = array_ops.gather(
1117        params,
1118        indices,
1119        validate_indices=True,
1120        axis=axis,
1121        name=None,
1122        batch_dims=batch_dims)
1123    self.assertAllEqual(actual, expected)
1124
1125  @parameterized.named_parameters([
1126      dict(
1127          testcase_name="params_2D_index_2D_axis_1_batch_dims_1",
1128          params=[[{"x0": 0, "y": {"z": [[3, 13]]}},
1129                   {"x0": 1, "y": {"z": [[3], [4, 13]]}}],
1130                  [{"x0": 2, "y": {"z": [[3, 5], [4]]}},
1131                   {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}},
1132                   {"x0": 4, "y": {"z": [[3], [4]]}}]],
1133          indices=[[1, 0], [0, 2]],
1134          axis=1,
1135          batch_dims=1,
1136          expected=[[{"x0": 1, "y": {"z": [[3], [4, 13]]}},
1137                     {"x0": 0, "y": {"z": [[3, 13]]}}],
1138                    [{"x0": 2, "y": {"z": [[3, 5], [4]]}},
1139                     {"x0": 4, "y": {"z": [[3], [4]]}}]]),
1140      dict(
1141          testcase_name="params_1D_index_2D_axis_0_batch_dims_0",
1142          params=[{"x0": 0, "y": {"z": [[3, 13]]}}],
1143          indices=[[0], [0, 0]],
1144          axis=0,
1145          batch_dims=0,
1146          expected=[[{"x0": 0, "y": {"z": [[3, 13]]}}],
1147                    [{"x0": 0, "y": {"z": [[3, 13]]}},
1148                     {"x0": 0, "y": {"z": [[3, 13]]}}]]),
1149  ])  # pyformat: disable
1150  def testGatherRagged(self, params, indices, axis, batch_dims, expected):
1151    params = StructuredTensor.from_pyval(params)
1152    # Shouldn't need to do this, but see cl/366396997
1153    indices = ragged_factory_ops.constant(indices)
1154    # validate_indices isn't actually used, and we aren't testing names
1155    actual = array_ops.gather(
1156        params,
1157        indices,
1158        validate_indices=True,
1159        axis=axis,
1160        name=None,
1161        batch_dims=batch_dims)
1162    self.assertAllEqual(actual, expected)
1163
1164  @parameterized.named_parameters([
1165      dict(testcase_name="params_scalar",
1166           params={"a": [3]},
1167           indices=0,
1168           axis=0,
1169           batch_dims=0,
1170           error_type=ValueError,
1171           error_regex="axis=0 out of bounds",
1172          ),
1173      dict(testcase_name="axis_large",
1174           params=[{"a": [3]}],
1175           indices=0,
1176           axis=1,
1177           batch_dims=0,
1178           error_type=ValueError,
1179           error_regex="axis=1 out of bounds",
1180          ),
1181      dict(testcase_name="axis_large_neg",
1182           params=[{"a": [3]}],
1183           indices=0,
1184           axis=-2,
1185           batch_dims=0,
1186           error_type=ValueError,
1187           error_regex="axis=-2 out of bounds",
1188          ),
1189      dict(testcase_name="batch_large",
1190           params=[[{"a": [3]}]],
1191           indices=0,
1192           axis=0,
1193           batch_dims=1,
1194           error_type=ValueError,
1195           error_regex="batch_dims=1 out of bounds",
1196          ),
1197  ])  # pyformat: disable
1198  def testGatherError(self,
1199                      params,
1200                      indices, axis, batch_dims,
1201                      error_type,
1202                      error_regex):
1203    params = StructuredTensor.from_pyval(params)
1204    with self.assertRaisesRegex(error_type, error_regex):
1205      structured_array_ops.gather(
1206          params,
1207          indices,
1208          validate_indices=True,
1209          axis=axis,
1210          name=None,
1211          batch_dims=batch_dims)
1212
1213
1214if __name__ == "__main__":
1215  googletest.main()
1216