1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for structured_array_ops.""" 15 16 17from absl.testing import parameterized 18 19from tensorflow.python.eager import def_function 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import random_seed 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.framework import test_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops import random_ops 28from tensorflow.python.ops.ragged import ragged_factory_ops 29from tensorflow.python.ops.ragged import ragged_tensor 30from tensorflow.python.ops.ragged import row_partition 31from tensorflow.python.ops.ragged.dynamic_ragged_shape import DynamicRaggedShape 32from tensorflow.python.ops.structured import structured_array_ops 33from tensorflow.python.ops.structured import structured_tensor 34from tensorflow.python.ops.structured.structured_tensor import StructuredTensor 35from tensorflow.python.platform import googletest 36from tensorflow.python.util import nest 37 38 39# TODO(martinz):create StructuredTensorTestCase. 40# pylint: disable=g-long-lambda 41@test_util.run_all_in_graph_and_eager_modes 42class StructuredArrayOpsTest(test_util.TensorFlowTestCase, 43 parameterized.TestCase): 44 45 def assertAllEqual(self, a, b, msg=None): 46 if not (isinstance(a, structured_tensor.StructuredTensor) or 47 isinstance(b, structured_tensor.StructuredTensor)): 48 return super(StructuredArrayOpsTest, self).assertAllEqual(a, b, msg) 49 50 if not isinstance(a, structured_tensor.StructuredTensor): 51 a = structured_tensor.StructuredTensor.from_pyval(a) 52 elif not isinstance(b, structured_tensor.StructuredTensor): 53 b = structured_tensor.StructuredTensor.from_pyval(b) 54 55 try: 56 nest.assert_same_structure(a, b, expand_composites=True) 57 except (TypeError, ValueError) as e: 58 self.assertIsNone(e, (msg + ": " if msg else "") + str(e)) 59 a_tensors = [x for x in nest.flatten(a, expand_composites=True) 60 if isinstance(x, ops.Tensor)] 61 b_tensors = [x for x in nest.flatten(b, expand_composites=True) 62 if isinstance(x, ops.Tensor)] 63 self.assertLen(a_tensors, len(b_tensors)) 64 a_arrays, b_arrays = self.evaluate((a_tensors, b_tensors)) 65 for a_array, b_array in zip(a_arrays, b_arrays): 66 self.assertAllEqual(a_array, b_array, msg) 67 68 def _assertStructuredEqual(self, a, b, msg, check_shape): 69 if check_shape: 70 self.assertEqual(repr(a.shape), repr(b.shape)) 71 self.assertEqual(set(a.field_names()), set(b.field_names())) 72 for field in a.field_names(): 73 a_value = a.field_value(field) 74 b_value = b.field_value(field) 75 self.assertIs(type(a_value), type(b_value)) 76 if isinstance(a_value, structured_tensor.StructuredTensor): 77 self._assertStructuredEqual(a_value, b_value, msg, check_shape) 78 else: 79 self.assertAllEqual(a_value, b_value, msg) 80 81 @parameterized.named_parameters([ 82 dict( 83 testcase_name="0D_0", 84 st={"x": 1}, 85 axis=0, 86 expected=[{"x": 1}]), 87 dict( 88 testcase_name="0D_minus_1", 89 st={"x": 1}, 90 axis=-1, 91 expected=[{"x": 1}]), 92 dict( 93 testcase_name="1D_0", 94 st=[{"x": [1, 3]}, {"x": [2, 7, 9]}], 95 axis=0, 96 expected=[[{"x": [1, 3]}, {"x": [2, 7, 9]}]]), 97 dict( 98 testcase_name="1D_1", 99 st=[{"x": [1]}, {"x": [2, 10]}], 100 axis=1, 101 expected=[[{"x": [1]}], [{"x": [2, 10]}]]), 102 dict( 103 testcase_name="2D_0", 104 st=[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]], 105 axis=0, 106 expected=[[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]]]), 107 dict( 108 testcase_name="2D_1", 109 st=[[{"x": 1}, {"x": 2}], [{"x": 3}]], 110 axis=1, 111 expected=[[[{"x": 1}, {"x": 2}]], [[{"x": 3}]]]), 112 dict( 113 testcase_name="2D_2", 114 st=[[{"x": [1]}, {"x": [2]}], [{"x": [3, 4]}]], 115 axis=2, 116 expected=[[[{"x": [1]}], [{"x": [2]}]], [[{"x": [3, 4]}]]]), 117 dict( 118 testcase_name="3D_0", 119 st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], 120 axis=0, 121 expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], 122 [[{"x": [4, 5]}]]]]), 123 dict( 124 testcase_name="3D_minus_4", 125 st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], 126 axis=-4, # same as zero 127 expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], 128 [[{"x": [4, 5]}]]]]), 129 dict( 130 testcase_name="3D_1", 131 st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], 132 axis=1, 133 expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]]], 134 [[[{"x": [4, 5]}]]]]), 135 dict( 136 testcase_name="3D_minus_3", 137 st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], 138 axis=-3, # same as 1 139 expected=[[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]]], 140 [[[{"x": [4, 5]}]]]]), 141 dict( 142 testcase_name="3D_2", 143 st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], 144 axis=2, 145 expected=[[[[{"x": [1]}, {"x": [2]}]], [[{"x": [3]}]]], 146 [[[{"x": [4, 5]}]]]]), 147 dict( 148 testcase_name="3D_minus_2", 149 st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], 150 axis=-2, # same as 2 151 expected=[[[[{"x": [1]}, {"x": [2]}]], [[{"x": [3]}]]], 152 [[[{"x": [4, 5]}]]]]), 153 dict( 154 testcase_name="3D_3", 155 st=[[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]], 156 axis=3, 157 expected=[[[[{"x": [1]}], [{"x": [2]}]], [[{"x": [3]}]]], 158 [[[{"x": [4, 5]}]]]]), 159 ]) # pyformat: disable 160 def testExpandDims(self, st, axis, expected): 161 st = StructuredTensor.from_pyval(st) 162 result = array_ops.expand_dims(st, axis) 163 self.assertAllEqual(result, expected) 164 165 def testExpandDimsAxisTooBig(self): 166 st = [[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]] 167 st = StructuredTensor.from_pyval(st) 168 with self.assertRaisesRegex(ValueError, 169 "axis=4 out of bounds: expected -4<=axis<4"): 170 array_ops.expand_dims(st, 4) 171 172 def testExpandDimsAxisTooSmall(self): 173 st = [[[{"x": [1]}, {"x": [2]}], [{"x": [3]}]], [[{"x": [4, 5]}]]] 174 st = StructuredTensor.from_pyval(st) 175 with self.assertRaisesRegex(ValueError, 176 "axis=-5 out of bounds: expected -4<=axis<4"): 177 array_ops.expand_dims(st, -5) 178 179 def testExpandDimsScalar(self): 180 # Note that if we expand_dims for the final dimension and there are scalar 181 # fields, then the shape is (2, None, None, 1), whereas if it is constructed 182 # from pyval it is (2, None, None, None). 183 st = [[[{"x": 1}, {"x": 2}], [{"x": 3}]], [[{"x": 4}]]] 184 st = StructuredTensor.from_pyval(st) 185 result = array_ops.expand_dims(st, 3) 186 expected_shape = tensor_shape.TensorShape([2, None, None, 1]) 187 self.assertEqual(repr(expected_shape), repr(result.shape)) 188 189 @parameterized.named_parameters([ 190 dict( 191 testcase_name="scalar_int32", 192 row_partitions=None, 193 shape=(), 194 dtype=dtypes.int32, 195 expected=1), 196 dict( 197 testcase_name="scalar_int64", 198 row_partitions=None, 199 shape=(), 200 dtype=dtypes.int64, 201 expected=1), 202 dict( 203 testcase_name="list_0_int32", 204 row_partitions=None, 205 shape=(0), 206 dtype=dtypes.int32, 207 expected=0), 208 dict( 209 testcase_name="list_0_0_int32", 210 row_partitions=None, 211 shape=(0, 0), 212 dtype=dtypes.int32, 213 expected=0), 214 dict( 215 testcase_name="list_int32", 216 row_partitions=None, 217 shape=(7), 218 dtype=dtypes.int32, 219 expected=7), 220 dict( 221 testcase_name="list_int64", 222 row_partitions=None, 223 shape=(7), 224 dtype=dtypes.int64, 225 expected=7), 226 dict( 227 testcase_name="matrix_int32", 228 row_partitions=[[0, 3, 6]], 229 shape=(2, 3), 230 dtype=dtypes.int32, 231 expected=6), 232 dict( 233 testcase_name="tensor_int32", 234 row_partitions=[[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]], 235 shape=(2, 3, 1), 236 dtype=dtypes.int32, 237 expected=6), 238 dict( 239 testcase_name="ragged_1_int32", 240 row_partitions=[[0, 3, 4]], 241 shape=(2, None), 242 dtype=dtypes.int32, 243 expected=4), 244 dict( 245 testcase_name="ragged_2_float32", 246 row_partitions=[[0, 3, 4], [0, 2, 3, 5, 7]], 247 shape=(2, None, None), 248 dtype=dtypes.float32, 249 expected=7), 250 ]) # pyformat: disable 251 def testSizeObject(self, row_partitions, shape, dtype, expected): 252 if row_partitions is not None: 253 row_partitions = [ 254 row_partition.RowPartition.from_row_splits(r) for r in row_partitions 255 ] 256 st = StructuredTensor.from_fields({}, 257 shape=shape, 258 row_partitions=row_partitions) 259 # NOTE: size is very robust. There aren't arguments that 260 # should cause this operation to fail. 261 actual = array_ops.size(st, out_type=dtype) 262 self.assertAllEqual(actual, expected) 263 264 actual2 = array_ops.size_v2(st, out_type=dtype) 265 self.assertAllEqual(actual2, expected) 266 267 def test_shape_v2(self): 268 rt = ragged_tensor.RaggedTensor.from_row_lengths(["a", "b", "c"], [1, 2]) 269 st = StructuredTensor.from_fields_and_rank({"r": rt}, rank=2) 270 actual = array_ops.shape_v2(st, out_type=dtypes.int64) 271 actual_static_lengths = actual.static_lengths() 272 self.assertAllEqual([2, (1, 2)], actual_static_lengths) 273 274 def test_shape(self): 275 rt = ragged_tensor.RaggedTensor.from_row_lengths(["a", "b", "c"], [1, 2]) 276 st = StructuredTensor.from_fields_and_rank({"r": rt}, rank=2) 277 actual = array_ops.shape(st, out_type=dtypes.int64).static_lengths() 278 actual_v2 = array_ops.shape_v2(st, out_type=dtypes.int64).static_lengths() 279 expected = [2, (1, 2)] 280 self.assertAllEqual(expected, actual) 281 self.assertAllEqual(expected, actual_v2) 282 283 @parameterized.named_parameters([ 284 dict( 285 testcase_name="list_empty_2_1", 286 values=[[{}, {}], [{}]], 287 dtype=dtypes.int32, 288 expected=3), 289 dict( 290 testcase_name="list_empty_2", 291 values=[{}, {}], 292 dtype=dtypes.int32, 293 expected=2), 294 dict( 295 testcase_name="list_empty_1", 296 values=[{}], 297 dtype=dtypes.int32, 298 expected=1), 299 dict( 300 testcase_name="list_example_1", 301 values=[{"x": [3]}, {"x": [4, 5]}], 302 dtype=dtypes.int32, 303 expected=2), 304 dict( 305 testcase_name="list_example_2", 306 values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]], 307 dtype=dtypes.float32, 308 expected=3), 309 dict( 310 testcase_name="list_example_2_None", 311 values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]], 312 dtype=None, 313 expected=3), 314 ]) # pyformat: disable 315 def testSizeAlt(self, values, dtype, expected): 316 st = StructuredTensor.from_pyval(values) 317 # NOTE: size is very robust. There aren't arguments that 318 # should cause this operation to fail. 319 actual = array_ops.size(st, out_type=dtype) 320 self.assertAllEqual(actual, expected) 321 322 actual2 = array_ops.size_v2(st, out_type=dtype) 323 self.assertAllEqual(actual2, expected) 324 325 @parameterized.named_parameters([ 326 dict( 327 testcase_name="scalar_int32", 328 row_partitions=None, 329 shape=(), 330 dtype=dtypes.int32, 331 expected=0), 332 dict( 333 testcase_name="scalar_bool", 334 row_partitions=None, 335 shape=(), 336 dtype=dtypes.bool, 337 expected=False), 338 dict( 339 testcase_name="scalar_int64", 340 row_partitions=None, 341 shape=(), 342 dtype=dtypes.int64, 343 expected=0), 344 dict( 345 testcase_name="scalar_float32", 346 row_partitions=None, 347 shape=(), 348 dtype=dtypes.float32, 349 expected=0.0), 350 dict( 351 testcase_name="list_0_int32", 352 row_partitions=None, 353 shape=(0), 354 dtype=dtypes.int32, 355 expected=[]), 356 dict( 357 testcase_name="list_0_0_int32", 358 row_partitions=None, 359 shape=(0, 0), 360 dtype=dtypes.int32, 361 expected=[]), 362 dict( 363 testcase_name="list_int32", 364 row_partitions=None, 365 shape=(7), 366 dtype=dtypes.int32, 367 expected=[0, 0, 0, 0, 0, 0, 0]), 368 dict( 369 testcase_name="list_int64", 370 row_partitions=None, 371 shape=(7), 372 dtype=dtypes.int64, 373 expected=[0, 0, 0, 0, 0, 0, 0]), 374 dict( 375 testcase_name="list_float32", 376 row_partitions=None, 377 shape=(7), 378 dtype=dtypes.float32, 379 expected=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), 380 dict( 381 testcase_name="matrix_int32", 382 row_partitions=[[0, 3, 6]], 383 shape=(2, 3), 384 dtype=dtypes.int32, 385 expected=[[0, 0, 0], [0, 0, 0]]), 386 dict( 387 testcase_name="matrix_float64", 388 row_partitions=[[0, 3, 6]], 389 shape=(2, 3), 390 dtype=dtypes.float64, 391 expected=[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), 392 dict( 393 testcase_name="tensor_int32", 394 row_partitions=[[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]], 395 shape=(2, 3, 1), 396 dtype=dtypes.int32, 397 expected=[[[0], [0], [0]], [[0], [0], [0]]]), 398 dict( 399 testcase_name="tensor_float32", 400 row_partitions=[[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]], 401 shape=(2, 3, 1), 402 dtype=dtypes.float32, 403 expected=[[[0.0], [0.0], [0.0]], [[0.0], [0.0], [0.0]]]), 404 dict( 405 testcase_name="ragged_1_float32", 406 row_partitions=[[0, 3, 4]], 407 shape=(2, None), 408 dtype=dtypes.float32, 409 expected=[[0.0, 0.0, 0.0], [0.0]]), 410 dict( 411 testcase_name="ragged_2_float32", 412 row_partitions=[[0, 3, 4], [0, 2, 3, 5, 7]], 413 shape=(2, None, None), 414 dtype=dtypes.float32, 415 expected=[[[0.0, 0.0], [0.0], [0.0, 0.0]], [[0.0, 0.0]]]), 416 ]) # pyformat: disable 417 def testZerosLikeObject(self, row_partitions, shape, dtype, expected): 418 if row_partitions is not None: 419 row_partitions = [ 420 row_partition.RowPartition.from_row_splits(r) for r in row_partitions 421 ] 422 st = StructuredTensor.from_fields({}, 423 shape=shape, 424 row_partitions=row_partitions) 425 # NOTE: zeros_like is very robust. There aren't arguments that 426 # should cause this operation to fail. 427 actual = array_ops.zeros_like(st, dtype) 428 self.assertAllEqual(actual, expected) 429 430 actual2 = array_ops.zeros_like_v2(st, dtype) 431 self.assertAllEqual(actual2, expected) 432 433 @parameterized.named_parameters([ 434 dict( 435 testcase_name="list_empty_2_1", 436 values=[[{}, {}], [{}]], 437 dtype=dtypes.int32, 438 expected=[[0, 0], [0]]), 439 dict( 440 testcase_name="list_empty_2", 441 values=[{}, {}], 442 dtype=dtypes.int32, 443 expected=[0, 0]), 444 dict( 445 testcase_name="list_empty_1", 446 values=[{}], 447 dtype=dtypes.int32, 448 expected=[0]), 449 dict( 450 testcase_name="list_example_1", 451 values=[{"x": [3]}, {"x": [4, 5]}], 452 dtype=dtypes.int32, 453 expected=[0, 0]), 454 dict( 455 testcase_name="list_example_2", 456 values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]], 457 dtype=dtypes.float32, 458 expected=[[0.0], [0.0, 0.0]]), 459 dict( 460 testcase_name="list_example_2_None", 461 values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]], 462 dtype=None, 463 expected=[[0.0], [0.0, 0.0]]), 464 ]) # pyformat: disable 465 def testZerosLikeObjectAlt(self, values, dtype, expected): 466 st = StructuredTensor.from_pyval(values) 467 # NOTE: zeros_like is very robust. There aren't arguments that 468 # should cause this operation to fail. 469 actual = array_ops.zeros_like(st, dtype) 470 self.assertAllEqual(actual, expected) 471 472 actual2 = array_ops.zeros_like_v2(st, dtype) 473 self.assertAllEqual(actual2, expected) 474 475 @parameterized.named_parameters([ 476 dict( 477 testcase_name="scalar_int32", 478 row_partitions=None, 479 shape=(), 480 dtype=dtypes.int32, 481 expected=1), 482 dict( 483 testcase_name="scalar_bool", 484 row_partitions=None, 485 shape=(), 486 dtype=dtypes.bool, 487 expected=True), 488 dict( 489 testcase_name="scalar_int64", 490 row_partitions=None, 491 shape=(), 492 dtype=dtypes.int64, 493 expected=1), 494 dict( 495 testcase_name="scalar_float32", 496 row_partitions=None, 497 shape=(), 498 dtype=dtypes.float32, 499 expected=1.0), 500 dict( 501 testcase_name="list_0_int32", 502 row_partitions=None, 503 shape=(0), 504 dtype=dtypes.int32, 505 expected=[]), 506 dict( 507 testcase_name="list_0_0_int32", 508 row_partitions=None, 509 shape=(0, 0), 510 dtype=dtypes.int32, 511 expected=[]), 512 dict( 513 testcase_name="list_int32", 514 row_partitions=None, 515 shape=(7), 516 dtype=dtypes.int32, 517 expected=[1, 1, 1, 1, 1, 1, 1]), 518 dict( 519 testcase_name="list_int64", 520 row_partitions=None, 521 shape=(7), 522 dtype=dtypes.int64, 523 expected=[1, 1, 1, 1, 1, 1, 1]), 524 dict( 525 testcase_name="list_float32", 526 row_partitions=None, 527 shape=(7), 528 dtype=dtypes.float32, 529 expected=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]), 530 dict( 531 testcase_name="matrix_int32", 532 row_partitions=[[0, 3, 6]], 533 shape=(2, 3), 534 dtype=dtypes.int32, 535 expected=[[1, 1, 1], [1, 1, 1]]), 536 dict( 537 testcase_name="matrix_float64", 538 row_partitions=[[0, 3, 6]], 539 shape=(2, 3), 540 dtype=dtypes.float64, 541 expected=[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]), 542 dict( 543 testcase_name="tensor_int32", 544 row_partitions=[[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]], 545 shape=(2, 3, 1), 546 dtype=dtypes.int32, 547 expected=[[[1], [1], [1]], [[1], [1], [1]]]), 548 dict( 549 testcase_name="tensor_float32", 550 row_partitions=[[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]], 551 shape=(2, 3, 1), 552 dtype=dtypes.float32, 553 expected=[[[1.0], [1.0], [1.0]], [[1.0], [1.0], [1.0]]]), 554 dict( 555 testcase_name="ragged_1_float32", 556 row_partitions=[[0, 3, 4]], 557 shape=(2, None), 558 dtype=dtypes.float32, 559 expected=[[1.0, 1.0, 1.0], [1.0]]), 560 dict( 561 testcase_name="ragged_2_float32", 562 row_partitions=[[0, 3, 4], [0, 2, 3, 5, 7]], 563 shape=(2, None, None), 564 dtype=dtypes.float32, 565 expected=[[[1.0, 1.0], [1.0], [1.0, 1.0]], [[1.0, 1.0]]]), 566 ]) # pyformat: disable 567 def testOnesLikeObject(self, row_partitions, shape, dtype, expected): 568 if row_partitions is not None: 569 row_partitions = [ 570 row_partition.RowPartition.from_row_splits(r) for r in row_partitions 571 ] 572 st = StructuredTensor.from_fields({}, 573 shape=shape, 574 row_partitions=row_partitions) 575 # NOTE: ones_like is very robust. There aren't arguments that 576 # should cause this operation to fail. 577 actual = array_ops.ones_like(st, dtype) 578 self.assertAllEqual(actual, expected) 579 580 actual2 = array_ops.ones_like_v2(st, dtype) 581 self.assertAllEqual(actual2, expected) 582 583 @parameterized.named_parameters([ 584 dict( 585 testcase_name="list_empty_2_1", 586 values=[[{}, {}], [{}]], 587 dtype=dtypes.int32, 588 expected=[[1, 1], [1]]), 589 dict( 590 testcase_name="list_empty_2", 591 values=[{}, {}], 592 dtype=dtypes.int32, 593 expected=[1, 1]), 594 dict( 595 testcase_name="list_empty_1", 596 values=[{}], 597 dtype=dtypes.int32, 598 expected=[1]), 599 dict( 600 testcase_name="list_example_1", 601 values=[{"x": [3]}, {"x": [4, 5]}], 602 dtype=dtypes.int32, 603 expected=[1, 1]), 604 dict( 605 testcase_name="list_example_2", 606 values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]], 607 dtype=dtypes.float32, 608 expected=[[1.0], [1.0, 1.0]]), 609 dict( 610 testcase_name="list_example_2_None", 611 values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]], 612 dtype=None, 613 expected=[[1.0], [1.0, 1.0]]), 614 ]) # pyformat: disable 615 def testOnesLikeObjectAlt(self, values, dtype, expected): 616 st = StructuredTensor.from_pyval(values) 617 # NOTE: ones_like is very robust. There aren't arguments that 618 # should cause this operation to fail. 619 actual = array_ops.ones_like(st, dtype) 620 self.assertAllEqual(actual, expected) 621 622 actual2 = array_ops.ones_like_v2(st, dtype) 623 self.assertAllEqual(actual2, expected) 624 625 @parameterized.named_parameters([ 626 dict( 627 testcase_name="scalar", 628 row_partitions=None, 629 shape=(), 630 expected=0), 631 dict( 632 testcase_name="list_0", 633 row_partitions=None, 634 shape=(0,), 635 expected=1), 636 dict( 637 testcase_name="list_0_0", 638 row_partitions=None, 639 shape=(0, 0), 640 expected=2), 641 dict( 642 testcase_name="list_7", 643 row_partitions=None, 644 shape=(7,), 645 expected=1), 646 dict( 647 testcase_name="matrix", 648 row_partitions=[[0, 3, 6]], 649 shape=(2, 3), 650 expected=2), 651 dict( 652 testcase_name="tensor", 653 row_partitions=[[0, 3, 6], [0, 1, 2, 3, 4, 5, 6]], 654 shape=(2, 3, 1), 655 expected=3), 656 dict( 657 testcase_name="ragged_1", 658 row_partitions=[[0, 3, 4]], 659 shape=(2, None), 660 expected=2), 661 dict( 662 testcase_name="ragged_2", 663 row_partitions=[[0, 3, 4], [0, 2, 3, 5, 7]], 664 shape=(2, None, None), 665 expected=3), 666 ]) # pyformat: disable 667 def testRank(self, row_partitions, shape, expected): 668 if row_partitions is not None: 669 row_partitions = [ 670 row_partition.RowPartition.from_row_splits(r) for r in row_partitions 671 ] 672 st = StructuredTensor.from_fields({}, 673 shape=shape, 674 row_partitions=row_partitions) 675 676 # NOTE: rank is very robust. There aren't arguments that 677 # should cause this operation to fail. 678 actual = structured_array_ops.rank(st) 679 self.assertAllEqual(expected, actual) 680 681 @parameterized.named_parameters([ 682 dict( 683 testcase_name="list_empty_2_1", 684 values=[[{}, {}], [{}]], 685 expected=2), 686 dict( 687 testcase_name="list_empty_2", 688 values=[{}, {}], 689 expected=1), 690 dict( 691 testcase_name="list_empty_1", 692 values=[{}], 693 expected=1), 694 dict( 695 testcase_name="list_example_1", 696 values=[{"x": [3]}, {"x": [4, 5]}], 697 expected=1), 698 dict( 699 testcase_name="list_example_2", 700 values=[[{"x": [3]}], [{"x": [4, 5]}, {"x": []}]], 701 expected=2), 702 ]) # pyformat: disable 703 def testRankAlt(self, values, expected): 704 st = StructuredTensor.from_pyval(values) 705 # NOTE: rank is very robust. There aren't arguments that 706 # should cause this operation to fail. 707 actual = array_ops.rank(st) 708 self.assertAllEqual(expected, actual) 709 710 @parameterized.named_parameters([ 711 dict( 712 testcase_name="list_empty", 713 values=[[{}], [{}]], 714 axis=0, 715 expected=[{}, {}]), 716 dict( 717 testcase_name="list_empty_2_1", 718 values=[[{}, {}], [{}]], 719 axis=0, 720 expected=[{}, {}, {}]), 721 dict( 722 testcase_name="list_with_fields", 723 values=[[{"a": 4, "b": [3, 4]}], [{"a": 5, "b": [5, 6]}]], 724 axis=0, 725 expected=[{"a": 4, "b": [3, 4]}, {"a": 5, "b": [5, 6]}]), 726 dict( 727 testcase_name="list_with_submessages", 728 values=[[{"a": {"foo": 3}, "b": [3, 4]}], 729 [{"a": {"foo": 4}, "b": [5, 6]}]], 730 axis=0, 731 expected=[{"a": {"foo": 3}, "b": [3, 4]}, 732 {"a": {"foo": 4}, "b": [5, 6]}]), 733 dict( 734 testcase_name="list_with_empty_submessages", 735 values=[[{"a": {}, "b": [3, 4]}], 736 [{"a": {}, "b": [5, 6]}]], 737 axis=0, 738 expected=[{"a": {}, "b": [3, 4]}, 739 {"a": {}, "b": [5, 6]}]), 740 dict( 741 testcase_name="lists_of_lists", 742 values=[[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}], 743 [{"a": {}, "b": [7, 8, 9]}]], 744 [[{"a": {}, "b": [10]}]]], 745 axis=0, 746 expected=[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}], 747 [{"a": {}, "b": [7, 8, 9]}], 748 [{"a": {}, "b": [10]}]]), 749 dict( 750 testcase_name="lists_of_lists_axis_1", 751 values=[[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}], 752 [{"a": {}, "b": [7, 8, 9]}]], 753 [[{"a": {}, "b": []}], [{"a": {}, "b": [3]}]]], 754 axis=1, 755 expected=[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}, 756 {"a": {}, "b": []}], 757 [{"a": {}, "b": [7, 8, 9]}, {"a": {}, "b": [3]}]]), 758 dict( 759 testcase_name="lists_of_lists_axis_minus_2", 760 values=[[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}], 761 [{"a": {}, "b": [7, 8, 9]}]], 762 [[{"a": {}, "b": [10]}]]], 763 axis=-2, # Same as axis=0. 764 expected=[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}], 765 [{"a": {}, "b": [7, 8, 9]}], 766 [{"a": {}, "b": [10]}]]), 767 dict( 768 testcase_name="from_structured_tensor_util_test", 769 values=[[{"x0": 0, "y": {"z": [[3, 13]]}}, 770 {"x0": 1, "y": {"z": [[3], [4, 13]]}}, 771 {"x0": 2, "y": {"z": [[3, 5], [4]]}}], 772 [{"x0": 3, "y": {"z": [[3, 7, 1], [4]]}}, 773 {"x0": 4, "y": {"z": [[3], [4]]}}]], 774 axis=0, 775 expected=[{"x0": 0, "y": {"z": [[3, 13]]}}, 776 {"x0": 1, "y": {"z": [[3], [4, 13]]}}, 777 {"x0": 2, "y": {"z": [[3, 5], [4]]}}, 778 {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}}, 779 {"x0": 4, "y": {"z": [[3], [4]]}}]), 780 ]) # pyformat: disable 781 def testConcat(self, values, axis, expected): 782 values = [StructuredTensor.from_pyval(v) for v in values] 783 actual = array_ops.concat(values, axis) 784 self.assertAllEqual(actual, expected) 785 786 def testConcatTuple(self): 787 values = (StructuredTensor.from_pyval([{"a": 3}]), 788 StructuredTensor.from_pyval([{"a": 4}])) 789 actual = array_ops.concat(values, axis=0) 790 self.assertAllEqual(actual, [{"a": 3}, {"a": 4}]) 791 792 @parameterized.named_parameters([ 793 dict( 794 testcase_name="field_dropped", 795 values=[[{"a": [2]}], [{}]], 796 axis=0, 797 error_type=ValueError, 798 error_regex="a"), 799 dict( 800 testcase_name="field_added", 801 values=[[{"b": [3]}], [{"b": [3], "a": [7]}]], 802 axis=0, 803 error_type=ValueError, 804 error_regex="b"), 805 dict(testcase_name="rank_submessage_change", 806 values=[[{"a": [{"b": [[3]]}]}], 807 [{"a": [[{"b": [3]}]]}]], 808 axis=0, 809 error_type=ValueError, 810 error_regex="Ranks of sub-message do not match", 811 ), 812 dict(testcase_name="rank_message_change", 813 values=[[{"a": [3]}], 814 [[{"a": 3}]]], 815 axis=0, 816 error_type=ValueError, 817 error_regex="Ranks of sub-message do not match", 818 ), 819 dict(testcase_name="concat_scalar", 820 values=[{"a": [3]}, {"a": [4]}], 821 axis=0, 822 error_type=ValueError, 823 error_regex="axis=0 out of bounds", 824 ), 825 dict(testcase_name="concat_axis_large", 826 values=[[{"a": [3]}], [{"a": [4]}]], 827 axis=1, 828 error_type=ValueError, 829 error_regex="axis=1 out of bounds", 830 ), 831 dict(testcase_name="concat_axis_large_neg", 832 values=[[{"a": [3]}], [{"a": [4]}]], 833 axis=-2, 834 error_type=ValueError, 835 error_regex="axis=-2 out of bounds", 836 ), 837 dict(testcase_name="concat_deep_rank_wrong", 838 values=[[{"a": [3]}], [{"a": [[4]]}]], 839 axis=0, 840 error_type=ValueError, 841 error_regex="must have rank", 842 ), 843 ]) # pyformat: disable 844 def testConcatError(self, values, axis, error_type, error_regex): 845 values = [StructuredTensor.from_pyval(v) for v in values] 846 with self.assertRaisesRegex(error_type, error_regex): 847 array_ops.concat(values, axis) 848 849 def testConcatWithRagged(self): 850 values = [StructuredTensor.from_pyval({}), array_ops.constant(3)] 851 with self.assertRaisesRegex(ValueError, 852 "values must be a list of StructuredTensors"): 853 array_ops.concat(values, 0) 854 855 def testConcatNotAList(self): 856 values = StructuredTensor.from_pyval({}) 857 with self.assertRaisesRegex( 858 ValueError, "values must be a list of StructuredTensors"): 859 structured_array_ops.concat(values, 0) 860 861 def testConcatEmptyList(self): 862 with self.assertRaisesRegex(ValueError, 863 "values must not be an empty list"): 864 structured_array_ops.concat([], 0) 865 866 def testExtendOpErrorNotList(self): 867 # Should be a list. 868 values = StructuredTensor.from_pyval({}) 869 def leaf_op(values): 870 return values[0] 871 with self.assertRaisesRegex(ValueError, "Expected a list"): 872 structured_array_ops._extend_op(values, leaf_op) 873 874 def testExtendOpErrorEmptyList(self): 875 def leaf_op(values): 876 return values[0] 877 with self.assertRaisesRegex(ValueError, "List cannot be empty"): 878 structured_array_ops._extend_op([], leaf_op) 879 880 def testRandomShuffle2021(self): 881 original = StructuredTensor.from_pyval([ 882 {"x0": 0, "y": {"z": [[3, 13]]}}, 883 {"x0": 1, "y": {"z": [[3], [4, 13]]}}, 884 {"x0": 2, "y": {"z": [[3, 5], [4]]}}, 885 {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}}, 886 {"x0": 4, "y": {"z": [[3], [4]]}}]) # pyformat: disable 887 random_seed.set_seed(1066) 888 result = random_ops.random_shuffle(original, seed=2021) 889 expected = StructuredTensor.from_pyval([ 890 {"x0": 0, "y": {"z": [[3, 13]]}}, 891 {"x0": 1, "y": {"z": [[3], [4, 13]]}}, 892 {"x0": 4, "y": {"z": [[3], [4]]}}, 893 {"x0": 2, "y": {"z": [[3, 5], [4]]}}, 894 {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}},]) # pyformat: disable 895 self.assertAllEqual(result, expected) 896 897 def testRandomShuffle2022Eager(self): 898 original = StructuredTensor.from_pyval([ 899 {"x0": 0, "y": {"z": [[3, 13]]}}, 900 {"x0": 1, "y": {"z": [[3], [4, 13]]}}, 901 {"x0": 2, "y": {"z": [[3, 5], [4]]}}, 902 {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}}, 903 {"x0": 4, "y": {"z": [[3], [4]]}}]) # pyformat: disable 904 expected = StructuredTensor.from_pyval([ 905 {"x0": 1, "y": {"z": [[3], [4, 13]]}}, 906 {"x0": 0, "y": {"z": [[3, 13]]}}, 907 {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}}, 908 {"x0": 4, "y": {"z": [[3], [4]]}}, 909 {"x0": 2, "y": {"z": [[3, 5], [4]]}}]) # pyformat: disable 910 random_seed.set_seed(1066) 911 result = structured_array_ops.random_shuffle(original, seed=2022) 912 self.assertAllEqual(result, expected) 913 914 def testRandomShuffleScalarError(self): 915 original = StructuredTensor.from_pyval( 916 {"x0": 2, "y": {"z": [[3, 5], [4]]}}) # pyformat: disable 917 918 with self.assertRaisesRegex(ValueError, "scalar"): 919 random_ops.random_shuffle(original) 920 921 def testStructuredTensorArrayLikeNoRank(self): 922 """Test when the rank is unknown.""" 923 @def_function.function 924 def my_fun(foo): 925 bar_shape = math_ops.range(foo) 926 bar = array_ops.zeros(shape=bar_shape) 927 structured_array_ops._structured_tensor_like(bar) 928 929 with self.assertRaisesRegex(ValueError, 930 "Can't build StructuredTensor w/ unknown rank"): 931 my_fun(array_ops.constant(3)) 932 933 def testStructuredTensorArrayRankOneKnownShape(self): 934 """Fully test structured_tensor_array_like.""" 935 foo = array_ops.zeros(shape=[4]) 936 result = structured_array_ops._structured_tensor_like(foo) 937 self.assertAllEqual([{}, {}, {}, {}], result) 938 939 # Note that we have to be careful about whether the indices are int32 940 # or int64. 941 def testStructuredTensorArrayRankOneUnknownShape(self): 942 """Fully test structured_tensor_array_like.""" 943 @def_function.function 944 def my_fun(my_shape): 945 my_zeros = array_ops.zeros(my_shape) 946 return structured_array_ops._structured_tensor_like(my_zeros) 947 948 result = my_fun(array_ops.constant(4)) 949 shape = DynamicRaggedShape._from_inner_shape([4], dtype=dtypes.int32) 950 expected = StructuredTensor.from_shape(shape) 951 self.assertAllEqual(expected, result) 952 953 def testStructuredTensorArrayRankTwoUnknownShape(self): 954 """Fully test structured_tensor_array_like.""" 955 @def_function.function 956 def my_fun(my_shape): 957 my_zeros = array_ops.zeros(my_shape) 958 return structured_array_ops._structured_tensor_like(my_zeros) 959 960 result = my_fun(array_ops.constant([2, 2])) 961 self.assertAllEqual([[{}, {}], [{}, {}]], result) 962 963 def testStructuredTensorArrayRankZero(self): 964 """Fully test structured_tensor_array_like.""" 965 foo = array_ops.zeros(shape=[]) 966 result = structured_array_ops._structured_tensor_like(foo) 967 self.assertAllEqual({}, result) 968 969 def testStructuredTensorLikeStructuredTensor(self): 970 """Fully test structured_tensor_array_like.""" 971 foo = structured_tensor.StructuredTensor.from_pyval([{"a": 3}, {"a": 7}]) 972 result = structured_array_ops._structured_tensor_like(foo) 973 self.assertAllEqual([{}, {}], result) 974 975 def testStructuredTensorArrayLike(self): 976 """There was a bug in a case in a private function. 977 978 This was difficult to reach externally, so I wrote a test 979 to check it directly. 980 """ 981 rt = ragged_tensor.RaggedTensor.from_row_splits( 982 array_ops.zeros(shape=[5, 3]), [0, 3, 5]) 983 result = structured_array_ops._structured_tensor_like(rt) 984 self.assertEqual(3, result.rank) 985 986 @parameterized.named_parameters([ 987 dict( 988 testcase_name="list_empty", 989 params=[{}, {}, {}], 990 indices=[0, 2], 991 axis=0, 992 batch_dims=0, 993 expected=[{}, {}]), 994 dict( 995 testcase_name="list_of_lists_empty", 996 params=[[{}, {}], [{}], [{}, {}, {}]], 997 indices=[2, 0], 998 axis=0, 999 batch_dims=0, 1000 expected=[[{}, {}, {}], [{}, {}]]), 1001 dict( 1002 testcase_name="list_with_fields", 1003 params=[{"a": 4, "b": [3, 4]}, {"a": 5, "b": [5, 6]}, 1004 {"a": 7, "b": [9, 10]}], 1005 indices=[2, 0, 0], 1006 axis=0, 1007 batch_dims=0, 1008 expected=[{"a": 7, "b": [9, 10]}, {"a": 4, "b": [3, 4]}, 1009 {"a": 4, "b": [3, 4]}]), 1010 dict( 1011 testcase_name="list_with_submessages", 1012 params=[{"a": {"foo": 3}, "b": [3, 4]}, 1013 {"a": {"foo": 4}, "b": [5, 6]}, 1014 {"a": {"foo": 7}, "b": [9, 10]}], 1015 indices=[2, 0], 1016 axis=0, 1017 batch_dims=0, 1018 expected=[{"a": {"foo": 7}, "b": [9, 10]}, 1019 {"a": {"foo": 3}, "b": [3, 4]}]), 1020 dict( 1021 testcase_name="list_with_empty_submessages", 1022 params=[{"a": {}, "b": [3, 4]}, 1023 {"a": {}, "b": [5, 6]}, 1024 {"a": {}, "b": [9, 10]}], 1025 indices=[2, 0], 1026 axis=0, 1027 batch_dims=0, 1028 expected=[{"a": {}, "b": [9, 10]}, 1029 {"a": {}, "b": [3, 4]}]), 1030 dict( 1031 testcase_name="lists_of_lists", 1032 params=[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}], 1033 [{"a": {}, "b": [7, 8, 9]}], 1034 [{"a": {}, "b": []}]], 1035 indices=[2, 0, 0], 1036 axis=0, 1037 batch_dims=0, 1038 expected=[[{"a": {}, "b": []}], 1039 [{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}], 1040 [{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}]]), 1041 dict( 1042 testcase_name="lists_of_lists_axis_1", 1043 params=[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}], 1044 [{"a": {}, "b": [7, 8, 9]}, {"a": {}, "b": [2, 8, 2]}], 1045 [{"a": {}, "b": []}, {"a": {}, "b": [4]}]], 1046 indices=[1, 0], 1047 axis=1, 1048 batch_dims=0, 1049 expected=[[{"a": {}, "b": [5]}, {"a": {}, "b": [3, 4]}], 1050 [{"a": {}, "b": [2, 8, 2]}, {"a": {}, "b": [7, 8, 9]}], 1051 [{"a": {}, "b": [4]}, {"a": {}, "b": []}]]), 1052 dict( 1053 testcase_name="lists_of_lists_axis_minus_2", 1054 params=[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}], 1055 [{"a": {}, "b": [7, 8, 9]}], 1056 [{"a": {}, "b": []}]], 1057 indices=[2, 0, 0], 1058 axis=-2, # same as 0 1059 batch_dims=0, 1060 expected=[[{"a": {}, "b": []}], 1061 [{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}], 1062 [{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}]]), 1063 dict( 1064 testcase_name="lists_of_lists_axis_minus_1", 1065 params=[[{"a": {}, "b": [3, 4]}, {"a": {}, "b": [5]}], 1066 [{"a": {}, "b": [7, 8, 9]}, {"a": {}, "b": [2, 8, 2]}], 1067 [{"a": {}, "b": []}, {"a": {}, "b": [4]}]], 1068 indices=[1, 0], 1069 axis=-1, # same as 1 1070 batch_dims=0, 1071 expected=[[{"a": {}, "b": [5]}, {"a": {}, "b": [3, 4]}], 1072 [{"a": {}, "b": [2, 8, 2]}, {"a": {}, "b": [7, 8, 9]}], 1073 [{"a": {}, "b": [4]}, {"a": {}, "b": []}]]), 1074 dict( 1075 testcase_name="from_structured_tensor_util_test", 1076 params=[{"x0": 0, "y": {"z": [[3, 13]]}}, 1077 {"x0": 1, "y": {"z": [[3], [4, 13]]}}, 1078 {"x0": 2, "y": {"z": [[3, 5], [4]]}}, 1079 {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}}, 1080 {"x0": 4, "y": {"z": [[3], [4]]}}], 1081 indices=[1, 0, 4, 3, 2], 1082 axis=0, 1083 batch_dims=0, 1084 expected=[{"x0": 1, "y": {"z": [[3], [4, 13]]}}, 1085 {"x0": 0, "y": {"z": [[3, 13]]}}, 1086 {"x0": 4, "y": {"z": [[3], [4]]}}, 1087 {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}}, 1088 {"x0": 2, "y": {"z": [[3, 5], [4]]}}]), 1089 dict( 1090 testcase_name="scalar_index_axis_0", 1091 params=[{"x0": 0, "y": {"z": [[3, 13]]}}, 1092 {"x0": 1, "y": {"z": [[3], [4, 13]]}}, 1093 {"x0": 2, "y": {"z": [[3, 5], [4]]}}, 1094 {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}}, 1095 {"x0": 4, "y": {"z": [[3], [4]]}}], 1096 indices=3, 1097 axis=0, 1098 batch_dims=0, 1099 expected={"x0": 3, "y": {"z": [[3, 7, 1], [4]]}}), 1100 dict( 1101 testcase_name="params_2D_vector_index_axis_1_batch_dims_1", 1102 params=[[{"x0": 0, "y": {"z": [[3, 13]]}}, 1103 {"x0": 1, "y": {"z": [[3], [4, 13]]}}], 1104 [{"x0": 2, "y": {"z": [[3, 5], [4]]}}, 1105 {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}}, 1106 {"x0": 4, "y": {"z": [[3], [4]]}}]], 1107 indices=[1, 0], 1108 axis=1, 1109 batch_dims=1, 1110 expected=[{"x0": 1, "y": {"z": [[3], [4, 13]]}}, 1111 {"x0": 2, "y": {"z": [[3, 5], [4]]}}]), 1112 ]) # pyformat: disable 1113 def testGather(self, params, indices, axis, batch_dims, expected): 1114 params = StructuredTensor.from_pyval(params) 1115 # validate_indices isn't actually used, and we aren't testing names 1116 actual = array_ops.gather( 1117 params, 1118 indices, 1119 validate_indices=True, 1120 axis=axis, 1121 name=None, 1122 batch_dims=batch_dims) 1123 self.assertAllEqual(actual, expected) 1124 1125 @parameterized.named_parameters([ 1126 dict( 1127 testcase_name="params_2D_index_2D_axis_1_batch_dims_1", 1128 params=[[{"x0": 0, "y": {"z": [[3, 13]]}}, 1129 {"x0": 1, "y": {"z": [[3], [4, 13]]}}], 1130 [{"x0": 2, "y": {"z": [[3, 5], [4]]}}, 1131 {"x0": 3, "y": {"z": [[3, 7, 1], [4]]}}, 1132 {"x0": 4, "y": {"z": [[3], [4]]}}]], 1133 indices=[[1, 0], [0, 2]], 1134 axis=1, 1135 batch_dims=1, 1136 expected=[[{"x0": 1, "y": {"z": [[3], [4, 13]]}}, 1137 {"x0": 0, "y": {"z": [[3, 13]]}}], 1138 [{"x0": 2, "y": {"z": [[3, 5], [4]]}}, 1139 {"x0": 4, "y": {"z": [[3], [4]]}}]]), 1140 dict( 1141 testcase_name="params_1D_index_2D_axis_0_batch_dims_0", 1142 params=[{"x0": 0, "y": {"z": [[3, 13]]}}], 1143 indices=[[0], [0, 0]], 1144 axis=0, 1145 batch_dims=0, 1146 expected=[[{"x0": 0, "y": {"z": [[3, 13]]}}], 1147 [{"x0": 0, "y": {"z": [[3, 13]]}}, 1148 {"x0": 0, "y": {"z": [[3, 13]]}}]]), 1149 ]) # pyformat: disable 1150 def testGatherRagged(self, params, indices, axis, batch_dims, expected): 1151 params = StructuredTensor.from_pyval(params) 1152 # Shouldn't need to do this, but see cl/366396997 1153 indices = ragged_factory_ops.constant(indices) 1154 # validate_indices isn't actually used, and we aren't testing names 1155 actual = array_ops.gather( 1156 params, 1157 indices, 1158 validate_indices=True, 1159 axis=axis, 1160 name=None, 1161 batch_dims=batch_dims) 1162 self.assertAllEqual(actual, expected) 1163 1164 @parameterized.named_parameters([ 1165 dict(testcase_name="params_scalar", 1166 params={"a": [3]}, 1167 indices=0, 1168 axis=0, 1169 batch_dims=0, 1170 error_type=ValueError, 1171 error_regex="axis=0 out of bounds", 1172 ), 1173 dict(testcase_name="axis_large", 1174 params=[{"a": [3]}], 1175 indices=0, 1176 axis=1, 1177 batch_dims=0, 1178 error_type=ValueError, 1179 error_regex="axis=1 out of bounds", 1180 ), 1181 dict(testcase_name="axis_large_neg", 1182 params=[{"a": [3]}], 1183 indices=0, 1184 axis=-2, 1185 batch_dims=0, 1186 error_type=ValueError, 1187 error_regex="axis=-2 out of bounds", 1188 ), 1189 dict(testcase_name="batch_large", 1190 params=[[{"a": [3]}]], 1191 indices=0, 1192 axis=0, 1193 batch_dims=1, 1194 error_type=ValueError, 1195 error_regex="batch_dims=1 out of bounds", 1196 ), 1197 ]) # pyformat: disable 1198 def testGatherError(self, 1199 params, 1200 indices, axis, batch_dims, 1201 error_type, 1202 error_regex): 1203 params = StructuredTensor.from_pyval(params) 1204 with self.assertRaisesRegex(error_type, error_regex): 1205 structured_array_ops.gather( 1206 params, 1207 indices, 1208 validate_indices=True, 1209 axis=axis, 1210 name=None, 1211 batch_dims=batch_dims) 1212 1213 1214if __name__ == "__main__": 1215 googletest.main() 1216