1# Owner(s): ["module: nn"] 2 3import itertools 4import random 5from typing import List 6 7import torch 8import torch.nn.utils.rnn as rnn_utils 9from torch.testing._internal.common_utils import run_tests, TestCase 10 11 12class PackedSequenceTest(TestCase): 13 _type_by_name = { 14 "torch.DoubleTensor": (torch.DoubleTensor, "double"), 15 "torch.FloatTensor": (torch.FloatTensor, "float"), 16 # We leave out `'torch.HalfTensor': (torch.HalfTensor, 'half'),` 17 # because of an error in `pad_packed_sequence` 18 # > AttributeError: 'torch.HalfTensor' object has no attribute 'fill_' 19 "torch.LongTensor": (torch.LongTensor, "long"), 20 "torch.IntTensor": (torch.IntTensor, "int"), 21 "torch.ShortTensor": (torch.ShortTensor, "short"), 22 "torch.CharTensor": (torch.CharTensor, "char"), 23 "torch.ByteTensor": (torch.ByteTensor, "byte"), 24 } 25 26 def __init__(self, *args, **kwargs): 27 super().__init__(*args, **kwargs) 28 self.batch_size = 5 29 self.max_length = 6 30 31 def _ordered_sequence(self, tensor_type): 32 """Create ordered list of random sequences""" 33 seqs = [ 34 tensor_type(random.randint(1, self.max_length)) 35 for _ in range(self.batch_size) 36 ] 37 if tensor_type == torch.ByteTensor: 38 seqs = [s.random_(0, 256) for s in seqs] 39 else: 40 seqs = [s.random_(-128, 128) for s in seqs] 41 ordered = sorted(seqs, key=len, reverse=True) 42 return ordered 43 44 def _padded_sequence(self, tensor_type): 45 """Create Tensor of random padded sequences""" 46 ordered = self._ordered_sequence(tensor_type) 47 lengths = [len(i) for i in ordered] 48 padded_tensor = rnn_utils.pad_sequence(ordered) 49 return padded_tensor, lengths 50 51 def test_type_casts(self): 52 """Test type casting of `PackedSequence` against type casting of tensor""" 53 for input_type, _ in self._type_by_name.values(): 54 for expected_type_str, (_, cast_str) in self._type_by_name.items(): 55 for enforce_sorted in [True, False]: 56 padded, lengths = self._padded_sequence(input_type) 57 packed = rnn_utils.pack_padded_sequence( 58 padded, lengths, enforce_sorted=enforce_sorted 59 ) 60 # Apply cast to `PackedSequence` instance and unpack 61 masked = getattr(packed, cast_str)() 62 unpacked, lengths_out = rnn_utils.pad_packed_sequence(masked) 63 self.assertEqual(unpacked.type(), expected_type_str) 64 65 def test_wrong_order(self): 66 a = torch.ones(25, 300) 67 b = torch.ones(22, 300) 68 b_a = rnn_utils.pad_sequence([b, a]) 69 self.assertRaises( 70 RuntimeError, 71 lambda: rnn_utils.pack_padded_sequence(b_a, [22, 25], enforce_sorted=True), 72 ) 73 74 def test_pad_sequence_with_tensor_sequences(self): 75 seq_tuple_input = torch.nn.utils.rnn.pad_sequence( 76 (torch.tensor([[7, 6]]), torch.tensor([[-7, -1]])) 77 ) 78 seq_tensor_input = torch.nn.utils.rnn.pad_sequence( 79 torch.tensor([[[7, 6]], [[-7, -1]]]) 80 ) 81 self.assertEqual(seq_tuple_input, seq_tensor_input) 82 self.assertEqual(seq_tuple_input.shape, torch.Size([1, 2, 2])) 83 84 def test_pad_sequence_with_non_iterable_sequences(self): 85 msg = r"Expected iterable for input sequences, but got arg of type" 86 with self.assertRaisesRegex(RuntimeError, msg): 87 torch.nn.utils.rnn.pad_sequence(5) 88 89 def test_total_length(self): 90 padded, lengths = self._padded_sequence(torch.FloatTensor) 91 max_length = max(lengths) 92 packed = rnn_utils.pack_padded_sequence(padded, lengths) 93 # test ValueError if total_length < max_length 94 for total_length in (-1, 0, max_length - 1): 95 for batch_first in (True, False): 96 97 def err_fn(): 98 rnn_utils.pad_packed_sequence( 99 packed, batch_first=batch_first, total_length=total_length 100 ) 101 102 self.assertRaisesRegex( 103 ValueError, 104 r"Expected total_length to be at least the " 105 r"length of the longest sequence in input", 106 err_fn, 107 ) 108 # test that pad_packed_sequence returns results of correct length 109 for batch_first in (True, False): 110 no_extra_pad, _ = rnn_utils.pad_packed_sequence( 111 packed, batch_first=batch_first 112 ) 113 for total_length_delta in (0, 1, 8): 114 total_length = max_length + total_length_delta 115 unpacked, lengths_out = rnn_utils.pad_packed_sequence( 116 packed, batch_first=batch_first, total_length=total_length 117 ) 118 self.assertEqual(lengths, lengths_out) 119 self.assertEqual(unpacked.size(1 if batch_first else 0), total_length) 120 if total_length_delta == 0: 121 ref_output = no_extra_pad 122 elif batch_first: 123 extra_pad = no_extra_pad.new_zeros( 124 self.batch_size, total_length_delta 125 ) 126 ref_output = torch.cat([no_extra_pad, extra_pad], 1) 127 else: 128 extra_pad = no_extra_pad.new_zeros( 129 total_length_delta, self.batch_size 130 ) 131 ref_output = torch.cat([no_extra_pad, extra_pad], 0) 132 self.assertEqual(unpacked, ref_output) 133 134 def test_to(self): 135 for enforce_sorted in (True, False): 136 padded, lengths = self._padded_sequence(torch.IntTensor) 137 a = rnn_utils.pack_padded_sequence( 138 padded, lengths, enforce_sorted=enforce_sorted 139 ).cpu() 140 141 self.assertIs(a, a.to("cpu")) 142 self.assertIs(a, a.cpu()) 143 self.assertIs(a, a.to("cpu", dtype=torch.int32)) 144 self.assertEqual(a.long(), a.to(torch.int64)) 145 146 if torch.cuda.is_available(): 147 for cuda in [ 148 "cuda", 149 "cuda:0" if torch.cuda.device_count() == 1 else "cuda:1", 150 ]: 151 b = a.cuda(device=cuda) 152 self.assertIs(b, b.to(cuda)) 153 self.assertIs(b, b.cuda()) 154 self.assertEqual(a, b.to("cpu")) 155 self.assertEqual(b, a.to(cuda)) 156 self.assertEqual(a, b.to("cpu", dtype=torch.int32)) 157 self.assertIs(b, b.to(dtype=torch.int32)) 158 self.assertEqual(b.long(), b.to(dtype=torch.int64)) 159 160 def test_to_memory_format(self): 161 m = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=2, bias=True) 162 m = m.to(memory_format=torch.channels_last) 163 for param in m.parameters(): 164 if param.dim() == 4: 165 self.assertTrue(param.is_contiguous(memory_format=torch.channels_last)) 166 167 def test_pad_sequence(self): 168 def pad(tensor, length): 169 return torch.cat( 170 [ 171 tensor.data, 172 tensor.data.new( 173 length - tensor.size(0), *tensor.size()[1:] 174 ).zero_(), 175 ] 176 ) 177 178 # single dimensional 179 a = torch.tensor([1, 2, 3]) 180 b = torch.tensor([4, 5]) 181 c = torch.tensor([6]) 182 183 # batch_first = true 184 expected = torch.tensor([[4, 5, 0], [1, 2, 3], [6, 0, 0]]) 185 padded = rnn_utils.pad_sequence([b, a, c], True) 186 self.assertEqual(padded, expected) 187 188 # batch_first = false 189 padded = rnn_utils.pad_sequence([b, a, c]) 190 self.assertEqual(padded, expected.transpose(0, 1)) 191 192 # padding_side = "left", batch_first=True 193 expected = torch.tensor([[0, 4, 5], [1, 2, 3], [0, 0, 6]]) 194 padded = rnn_utils.pad_sequence( 195 [b, a, c], 196 batch_first=True, 197 padding_side="left", 198 ) 199 self.assertEqual(padded, expected) 200 201 # padding_side = "left", batch_first=False 202 padded = rnn_utils.pad_sequence( 203 [b, a, c], 204 batch_first=False, 205 padding_side="left", 206 ) 207 self.assertEqual(padded, expected.transpose(0, 1)) 208 209 # pad with non-zero value 210 expected = torch.tensor([[4, 5, 1], [1, 2, 3], [6, 1, 1]]) 211 padded = rnn_utils.pad_sequence([b, a, c], True, 1) 212 self.assertEqual(padded, expected) 213 214 # Test pad sorted sequence 215 expected = torch.tensor([[1, 2, 3], [4, 5, 0], [6, 0, 0]]) 216 padded = rnn_utils.pad_sequence([a, b, c], True) 217 self.assertEqual(padded, expected) 218 219 # more dimensions 220 maxlen = 9 221 for num_dim in (0, 1, 2, 3): 222 sequences: List[torch.Tensor] = [] 223 trailing_dims = [4] * num_dim 224 for i in range(1, maxlen + 1): 225 seq_len = i * i 226 sequences.append(torch.rand(seq_len, 5, *trailing_dims)) 227 random.shuffle(sequences) 228 # batch first = true 229 expected = torch.stack([pad(seq, maxlen * maxlen) for seq in sequences]) 230 padded = rnn_utils.pad_sequence(sequences, True) 231 self.assertEqual(padded, expected) 232 233 # batch first = false 234 padded = rnn_utils.pad_sequence(sequences) 235 self.assertEqual(padded, expected.transpose(0, 1)) 236 237 # padding_side = "left", batch_first=True 238 expected = torch.stack( 239 [pad(seq.flip(0), maxlen * maxlen).flip(0) for seq in sequences] 240 ) 241 padded = rnn_utils.pad_sequence( 242 sequences, 243 batch_first=True, 244 padding_side="left", 245 ) 246 self.assertEqual(padded, expected) 247 248 # padding_side = "left", batch_first=False 249 padded = rnn_utils.pad_sequence( 250 sequences, 251 batch_first=False, 252 padding_side="left", 253 ) 254 self.assertEqual(padded, expected.transpose(0, 1)) 255 256 def test_unpad_sequence(self): 257 # single dimensional 258 a = torch.tensor([1, 2, 3]) 259 b = torch.tensor([4, 5]) 260 c = torch.tensor([6]) 261 sequences = [a, b, c] 262 263 lengths = torch.as_tensor([v.size(0) for v in sequences]) 264 for batch_first in [True, False]: 265 padded_sequences = rnn_utils.pad_sequence( 266 sequences, batch_first=batch_first 267 ) 268 unpadded_sequences = rnn_utils.unpad_sequence( 269 padded_sequences, lengths, batch_first=batch_first 270 ) 271 self.assertEqual(sequences, unpadded_sequences) 272 273 # more dimensions 274 maxlen = 9 275 for num_dim in (0, 1, 2, 3): 276 sequences = [] 277 trailing_dims = [4] * num_dim 278 for i in range(1, maxlen + 1): 279 seq_len = i * i 280 sequences.append(torch.rand(seq_len, 5, *trailing_dims)) 281 random.shuffle(sequences) 282 283 lengths = torch.as_tensor([v.size(0) for v in sequences]) 284 padded_sequences = rnn_utils.pad_sequence( 285 sequences, batch_first=batch_first 286 ) 287 unpadded_sequences = rnn_utils.unpad_sequence( 288 padded_sequences, lengths, batch_first=batch_first 289 ) 290 self.assertEqual(sequences, unpadded_sequences) 291 292 def test_pack_sequence(self): 293 def _compatibility_test(sequences, lengths, batch_first, enforce_sorted=False): 294 padded = rnn_utils.pad_sequence(sequences, batch_first) 295 packed = rnn_utils.pack_sequence(sequences, enforce_sorted) 296 unpacked = rnn_utils.pad_packed_sequence(packed, batch_first) 297 self.assertEqual(padded, unpacked[0]) 298 pack_padded = rnn_utils.pack_padded_sequence( 299 padded, lengths, batch_first, enforce_sorted 300 ) 301 self.assertEqual(packed, pack_padded) 302 303 # single dimensional 304 a = torch.tensor([1, 2, 3]) 305 b = torch.tensor([4, 5]) 306 c = torch.tensor([6]) 307 packed = rnn_utils.pack_sequence([a, b, c], enforce_sorted=False) 308 expected = torch.tensor([1, 4, 6, 2, 5, 3]) 309 self.assertEqual(packed.batch_sizes, [3, 2, 1]) 310 self.assertEqual(packed.data.data, expected) 311 self.assertEqual(packed.sorted_indices, [0, 1, 2]) 312 self.assertEqual(packed.unsorted_indices, [0, 1, 2]) 313 314 packed_unsorted = rnn_utils.pack_sequence([b, c, a], enforce_sorted=False) 315 self.assertEqual(packed_unsorted.batch_sizes, [3, 2, 1]) 316 self.assertEqual(packed_unsorted.data.data, expected) 317 self.assertEqual(packed_unsorted.sorted_indices, [2, 0, 1]) 318 self.assertEqual(packed_unsorted.unsorted_indices, [1, 2, 0]) 319 320 # single dimensional, enforce_sorted = True 321 packed_enforce_sorted = rnn_utils.pack_sequence([a, b, c], enforce_sorted=True) 322 self.assertEqual(packed_enforce_sorted.batch_sizes, [3, 2, 1]) 323 self.assertEqual(packed_enforce_sorted.data.data, expected) 324 self.assertTrue(packed_enforce_sorted.sorted_indices is None) 325 self.assertTrue(packed_enforce_sorted.unsorted_indices is None) 326 327 with self.assertRaisesRegex(RuntimeError, "must be sorted in decreasing order"): 328 rnn_utils.pack_sequence([b, c, a], enforce_sorted=True) 329 330 with self.assertRaisesRegex( 331 RuntimeError, "You can pass `enforce_sorted=False`" 332 ): 333 rnn_utils.pack_sequence([b, c, a], enforce_sorted=True) 334 335 # more dimensions 336 maxlen = 9 337 for num_dim in (0, 1, 2, 3): 338 sequences = [] 339 lengths = [] 340 trailing_dims = [4] * num_dim 341 for i in range(maxlen, 0, -1): 342 seq_len = i * i 343 lengths.append(seq_len) 344 sequences.append(torch.rand(seq_len, 5, *trailing_dims)) 345 unsorted_sequences = [s.clone() for s in sequences] 346 random.shuffle(unsorted_sequences) 347 unsorted_sequences_lengths = [t.size(0) for t in unsorted_sequences] 348 349 # compatibility with other utilities 350 for batch_first in (True, False): 351 for enforce_sorted in (True, False): 352 _compatibility_test(sequences, lengths, batch_first, enforce_sorted) 353 _compatibility_test( 354 unsorted_sequences, unsorted_sequences_lengths, batch_first 355 ) 356 357 def test_unpack_sequence(self): 358 # single dimensional 359 a = torch.tensor([1, 2, 3]) 360 b = torch.tensor([4, 5]) 361 c = torch.tensor([6]) 362 sequences = [a, b, c] 363 364 packed_sequences = rnn_utils.pack_sequence(sequences, enforce_sorted=False) 365 unpacked_sequences = rnn_utils.unpack_sequence(packed_sequences) 366 self.assertEqual(sequences, unpacked_sequences) 367 368 # more dimensions 369 maxlen = 9 370 for num_dim in (0, 1, 2, 3): 371 sequences = [] 372 trailing_dims = [4] * num_dim 373 for i in range(1, maxlen + 1): 374 seq_len = i * i 375 sequences.append(torch.rand(seq_len, 5, *trailing_dims)) 376 random.shuffle(sequences) 377 378 packed_sequences = rnn_utils.pack_sequence(sequences, enforce_sorted=False) 379 unpacked_sequences = rnn_utils.unpack_sequence(packed_sequences) 380 self.assertEqual(sequences, unpacked_sequences) 381 382 def test_pack_padded_sequence(self): 383 def generate_test_case(sorted_lengths, should_shuffle): 384 def pad(tensor, length): 385 return torch.cat( 386 [ 387 tensor, 388 tensor.new(length - tensor.size(0), *tensor.size()[1:]).zero_(), 389 ] 390 ) 391 392 max_length = sorted_lengths[0] 393 batch_sizes = [ 394 sum(map(bool, filter(lambda x: x >= i, sorted_lengths))) 395 for i in range(1, max_length + 1) 396 ] 397 offset = 0 398 padded = torch.cat( 399 [ 400 pad( 401 i * 100 + torch.arange(1.0, 5 * l + 1).view(l, 1, 5), max_length 402 ) 403 for i, l in enumerate(sorted_lengths, 1) 404 ], 405 1, 406 ) 407 expected_data = [ 408 [ 409 torch.arange(1.0, 6) + (i + 1) * 100 + 5 * n 410 for i in range(batch_size) 411 ] 412 for n, batch_size in enumerate(batch_sizes) 413 ] 414 expected_data = list(itertools.chain.from_iterable(expected_data)) 415 expected_data = torch.stack(expected_data, dim=0) 416 417 if should_shuffle: 418 # Shuffle the padded sequence to create an unsorted sequence 419 permutation = list(range(len(sorted_lengths))) 420 random.shuffle(permutation) 421 422 unsorted_indices = torch.tensor(permutation) 423 padded = padded.index_select(1, unsorted_indices) 424 lengths = torch.tensor(sorted_lengths).index_select(0, unsorted_indices) 425 else: 426 unsorted_indices = None 427 lengths = sorted_lengths 428 429 return ( 430 padded.requires_grad_(), 431 lengths, 432 expected_data, 433 batch_sizes, 434 unsorted_indices, 435 ) 436 437 test_cases = [ 438 # sorted_lengths, should_shuffle 439 [[10, 8, 4, 2, 2, 2, 1], False], 440 [[11, 10, 8, 6, 4, 3, 1], False], 441 [[11, 10, 8, 6, 4, 3, 1], True], 442 ] 443 444 for test_case, batch_first in itertools.product(test_cases, (True, False)): 445 sorted_lengths, should_shuffle = test_case 446 ( 447 padded, 448 lengths, 449 expected_data, 450 batch_sizes, 451 unsorted_indices, 452 ) = generate_test_case(sorted_lengths, should_shuffle) 453 454 src = padded 455 if batch_first: 456 src = src.transpose(0, 1) 457 458 # check output 459 packed = rnn_utils.pack_padded_sequence( 460 src, lengths, batch_first=batch_first, enforce_sorted=not should_shuffle 461 ) 462 self.assertEqual(packed.data.data, expected_data) 463 self.assertEqual(packed.batch_sizes, batch_sizes) 464 self.assertEqual(packed.unsorted_indices, unsorted_indices) 465 466 # test inverse 467 unpacked, unpacked_len = rnn_utils.pad_packed_sequence( 468 packed, batch_first=batch_first 469 ) 470 self.assertEqual(unpacked, src) 471 self.assertEqual(unpacked_len, lengths) 472 473 # check grad 474 if padded.grad is not None: 475 padded.grad.data.zero_() 476 grad_output = unpacked.data.clone().normal_() 477 unpacked.backward(grad_output) 478 if batch_first: 479 grad_output.transpose_(0, 1) 480 for i, l in enumerate(lengths): 481 self.assertEqual(padded.grad.data[:l, i], grad_output[:l, i]) 482 if l < 10: 483 self.assertEqual(padded.grad.data[l:, i].abs().sum(), 0) 484 485 # test error messages 486 with self.assertRaisesRegex( 487 RuntimeError, "You can pass `enforce_sorted=False`" 488 ): 489 packed = rnn_utils.pack_padded_sequence(torch.randn(3, 3), [1, 3, 2]) 490 with self.assertRaisesRegex(RuntimeError, "empty tensor"): 491 packed = rnn_utils.pack_padded_sequence(torch.randn(0, 0), []) 492 with self.assertRaisesRegex(RuntimeError, "empty tensor"): 493 packed = rnn_utils.pack_padded_sequence( 494 torch.randn([0, 1, 10]), torch.randn([11, 14, 14, 2]), True 495 ) 496 497 498if __name__ == "__main__": 499 run_tests() 500