xref: /aosp_15_r20/external/pytorch/test/nn/test_packed_sequence.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: nn"]
2
3import itertools
4import random
5from typing import List
6
7import torch
8import torch.nn.utils.rnn as rnn_utils
9from torch.testing._internal.common_utils import run_tests, TestCase
10
11
12class PackedSequenceTest(TestCase):
13    _type_by_name = {
14        "torch.DoubleTensor": (torch.DoubleTensor, "double"),
15        "torch.FloatTensor": (torch.FloatTensor, "float"),
16        # We leave out `'torch.HalfTensor': (torch.HalfTensor, 'half'),`
17        # because of an error in `pad_packed_sequence`
18        # > AttributeError: 'torch.HalfTensor' object has no attribute 'fill_'
19        "torch.LongTensor": (torch.LongTensor, "long"),
20        "torch.IntTensor": (torch.IntTensor, "int"),
21        "torch.ShortTensor": (torch.ShortTensor, "short"),
22        "torch.CharTensor": (torch.CharTensor, "char"),
23        "torch.ByteTensor": (torch.ByteTensor, "byte"),
24    }
25
26    def __init__(self, *args, **kwargs):
27        super().__init__(*args, **kwargs)
28        self.batch_size = 5
29        self.max_length = 6
30
31    def _ordered_sequence(self, tensor_type):
32        """Create ordered list of random sequences"""
33        seqs = [
34            tensor_type(random.randint(1, self.max_length))
35            for _ in range(self.batch_size)
36        ]
37        if tensor_type == torch.ByteTensor:
38            seqs = [s.random_(0, 256) for s in seqs]
39        else:
40            seqs = [s.random_(-128, 128) for s in seqs]
41        ordered = sorted(seqs, key=len, reverse=True)
42        return ordered
43
44    def _padded_sequence(self, tensor_type):
45        """Create Tensor of random padded sequences"""
46        ordered = self._ordered_sequence(tensor_type)
47        lengths = [len(i) for i in ordered]
48        padded_tensor = rnn_utils.pad_sequence(ordered)
49        return padded_tensor, lengths
50
51    def test_type_casts(self):
52        """Test type casting of `PackedSequence` against type casting of tensor"""
53        for input_type, _ in self._type_by_name.values():
54            for expected_type_str, (_, cast_str) in self._type_by_name.items():
55                for enforce_sorted in [True, False]:
56                    padded, lengths = self._padded_sequence(input_type)
57                    packed = rnn_utils.pack_padded_sequence(
58                        padded, lengths, enforce_sorted=enforce_sorted
59                    )
60                    # Apply cast to `PackedSequence` instance and unpack
61                    masked = getattr(packed, cast_str)()
62                    unpacked, lengths_out = rnn_utils.pad_packed_sequence(masked)
63                    self.assertEqual(unpacked.type(), expected_type_str)
64
65    def test_wrong_order(self):
66        a = torch.ones(25, 300)
67        b = torch.ones(22, 300)
68        b_a = rnn_utils.pad_sequence([b, a])
69        self.assertRaises(
70            RuntimeError,
71            lambda: rnn_utils.pack_padded_sequence(b_a, [22, 25], enforce_sorted=True),
72        )
73
74    def test_pad_sequence_with_tensor_sequences(self):
75        seq_tuple_input = torch.nn.utils.rnn.pad_sequence(
76            (torch.tensor([[7, 6]]), torch.tensor([[-7, -1]]))
77        )
78        seq_tensor_input = torch.nn.utils.rnn.pad_sequence(
79            torch.tensor([[[7, 6]], [[-7, -1]]])
80        )
81        self.assertEqual(seq_tuple_input, seq_tensor_input)
82        self.assertEqual(seq_tuple_input.shape, torch.Size([1, 2, 2]))
83
84    def test_pad_sequence_with_non_iterable_sequences(self):
85        msg = r"Expected iterable for input sequences, but got arg of type"
86        with self.assertRaisesRegex(RuntimeError, msg):
87            torch.nn.utils.rnn.pad_sequence(5)
88
89    def test_total_length(self):
90        padded, lengths = self._padded_sequence(torch.FloatTensor)
91        max_length = max(lengths)
92        packed = rnn_utils.pack_padded_sequence(padded, lengths)
93        # test ValueError if total_length < max_length
94        for total_length in (-1, 0, max_length - 1):
95            for batch_first in (True, False):
96
97                def err_fn():
98                    rnn_utils.pad_packed_sequence(
99                        packed, batch_first=batch_first, total_length=total_length
100                    )
101
102            self.assertRaisesRegex(
103                ValueError,
104                r"Expected total_length to be at least the "
105                r"length of the longest sequence in input",
106                err_fn,
107            )
108        # test that pad_packed_sequence returns results of correct length
109        for batch_first in (True, False):
110            no_extra_pad, _ = rnn_utils.pad_packed_sequence(
111                packed, batch_first=batch_first
112            )
113            for total_length_delta in (0, 1, 8):
114                total_length = max_length + total_length_delta
115                unpacked, lengths_out = rnn_utils.pad_packed_sequence(
116                    packed, batch_first=batch_first, total_length=total_length
117                )
118                self.assertEqual(lengths, lengths_out)
119                self.assertEqual(unpacked.size(1 if batch_first else 0), total_length)
120                if total_length_delta == 0:
121                    ref_output = no_extra_pad
122                elif batch_first:
123                    extra_pad = no_extra_pad.new_zeros(
124                        self.batch_size, total_length_delta
125                    )
126                    ref_output = torch.cat([no_extra_pad, extra_pad], 1)
127                else:
128                    extra_pad = no_extra_pad.new_zeros(
129                        total_length_delta, self.batch_size
130                    )
131                    ref_output = torch.cat([no_extra_pad, extra_pad], 0)
132                self.assertEqual(unpacked, ref_output)
133
134    def test_to(self):
135        for enforce_sorted in (True, False):
136            padded, lengths = self._padded_sequence(torch.IntTensor)
137            a = rnn_utils.pack_padded_sequence(
138                padded, lengths, enforce_sorted=enforce_sorted
139            ).cpu()
140
141            self.assertIs(a, a.to("cpu"))
142            self.assertIs(a, a.cpu())
143            self.assertIs(a, a.to("cpu", dtype=torch.int32))
144            self.assertEqual(a.long(), a.to(torch.int64))
145
146            if torch.cuda.is_available():
147                for cuda in [
148                    "cuda",
149                    "cuda:0" if torch.cuda.device_count() == 1 else "cuda:1",
150                ]:
151                    b = a.cuda(device=cuda)
152                    self.assertIs(b, b.to(cuda))
153                    self.assertIs(b, b.cuda())
154                    self.assertEqual(a, b.to("cpu"))
155                    self.assertEqual(b, a.to(cuda))
156                    self.assertEqual(a, b.to("cpu", dtype=torch.int32))
157                    self.assertIs(b, b.to(dtype=torch.int32))
158                    self.assertEqual(b.long(), b.to(dtype=torch.int64))
159
160    def test_to_memory_format(self):
161        m = torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=2, bias=True)
162        m = m.to(memory_format=torch.channels_last)
163        for param in m.parameters():
164            if param.dim() == 4:
165                self.assertTrue(param.is_contiguous(memory_format=torch.channels_last))
166
167    def test_pad_sequence(self):
168        def pad(tensor, length):
169            return torch.cat(
170                [
171                    tensor.data,
172                    tensor.data.new(
173                        length - tensor.size(0), *tensor.size()[1:]
174                    ).zero_(),
175                ]
176            )
177
178        # single dimensional
179        a = torch.tensor([1, 2, 3])
180        b = torch.tensor([4, 5])
181        c = torch.tensor([6])
182
183        # batch_first = true
184        expected = torch.tensor([[4, 5, 0], [1, 2, 3], [6, 0, 0]])
185        padded = rnn_utils.pad_sequence([b, a, c], True)
186        self.assertEqual(padded, expected)
187
188        # batch_first = false
189        padded = rnn_utils.pad_sequence([b, a, c])
190        self.assertEqual(padded, expected.transpose(0, 1))
191
192        # padding_side = "left", batch_first=True
193        expected = torch.tensor([[0, 4, 5], [1, 2, 3], [0, 0, 6]])
194        padded = rnn_utils.pad_sequence(
195            [b, a, c],
196            batch_first=True,
197            padding_side="left",
198        )
199        self.assertEqual(padded, expected)
200
201        # padding_side = "left", batch_first=False
202        padded = rnn_utils.pad_sequence(
203            [b, a, c],
204            batch_first=False,
205            padding_side="left",
206        )
207        self.assertEqual(padded, expected.transpose(0, 1))
208
209        # pad with non-zero value
210        expected = torch.tensor([[4, 5, 1], [1, 2, 3], [6, 1, 1]])
211        padded = rnn_utils.pad_sequence([b, a, c], True, 1)
212        self.assertEqual(padded, expected)
213
214        # Test pad sorted sequence
215        expected = torch.tensor([[1, 2, 3], [4, 5, 0], [6, 0, 0]])
216        padded = rnn_utils.pad_sequence([a, b, c], True)
217        self.assertEqual(padded, expected)
218
219        # more dimensions
220        maxlen = 9
221        for num_dim in (0, 1, 2, 3):
222            sequences: List[torch.Tensor] = []
223            trailing_dims = [4] * num_dim
224            for i in range(1, maxlen + 1):
225                seq_len = i * i
226                sequences.append(torch.rand(seq_len, 5, *trailing_dims))
227            random.shuffle(sequences)
228            # batch first = true
229            expected = torch.stack([pad(seq, maxlen * maxlen) for seq in sequences])
230            padded = rnn_utils.pad_sequence(sequences, True)
231            self.assertEqual(padded, expected)
232
233            # batch first = false
234            padded = rnn_utils.pad_sequence(sequences)
235            self.assertEqual(padded, expected.transpose(0, 1))
236
237            # padding_side = "left", batch_first=True
238            expected = torch.stack(
239                [pad(seq.flip(0), maxlen * maxlen).flip(0) for seq in sequences]
240            )
241            padded = rnn_utils.pad_sequence(
242                sequences,
243                batch_first=True,
244                padding_side="left",
245            )
246            self.assertEqual(padded, expected)
247
248            # padding_side = "left", batch_first=False
249            padded = rnn_utils.pad_sequence(
250                sequences,
251                batch_first=False,
252                padding_side="left",
253            )
254            self.assertEqual(padded, expected.transpose(0, 1))
255
256    def test_unpad_sequence(self):
257        # single dimensional
258        a = torch.tensor([1, 2, 3])
259        b = torch.tensor([4, 5])
260        c = torch.tensor([6])
261        sequences = [a, b, c]
262
263        lengths = torch.as_tensor([v.size(0) for v in sequences])
264        for batch_first in [True, False]:
265            padded_sequences = rnn_utils.pad_sequence(
266                sequences, batch_first=batch_first
267            )
268            unpadded_sequences = rnn_utils.unpad_sequence(
269                padded_sequences, lengths, batch_first=batch_first
270            )
271            self.assertEqual(sequences, unpadded_sequences)
272
273        # more dimensions
274        maxlen = 9
275        for num_dim in (0, 1, 2, 3):
276            sequences = []
277            trailing_dims = [4] * num_dim
278            for i in range(1, maxlen + 1):
279                seq_len = i * i
280                sequences.append(torch.rand(seq_len, 5, *trailing_dims))
281            random.shuffle(sequences)
282
283            lengths = torch.as_tensor([v.size(0) for v in sequences])
284            padded_sequences = rnn_utils.pad_sequence(
285                sequences, batch_first=batch_first
286            )
287            unpadded_sequences = rnn_utils.unpad_sequence(
288                padded_sequences, lengths, batch_first=batch_first
289            )
290            self.assertEqual(sequences, unpadded_sequences)
291
292    def test_pack_sequence(self):
293        def _compatibility_test(sequences, lengths, batch_first, enforce_sorted=False):
294            padded = rnn_utils.pad_sequence(sequences, batch_first)
295            packed = rnn_utils.pack_sequence(sequences, enforce_sorted)
296            unpacked = rnn_utils.pad_packed_sequence(packed, batch_first)
297            self.assertEqual(padded, unpacked[0])
298            pack_padded = rnn_utils.pack_padded_sequence(
299                padded, lengths, batch_first, enforce_sorted
300            )
301            self.assertEqual(packed, pack_padded)
302
303        # single dimensional
304        a = torch.tensor([1, 2, 3])
305        b = torch.tensor([4, 5])
306        c = torch.tensor([6])
307        packed = rnn_utils.pack_sequence([a, b, c], enforce_sorted=False)
308        expected = torch.tensor([1, 4, 6, 2, 5, 3])
309        self.assertEqual(packed.batch_sizes, [3, 2, 1])
310        self.assertEqual(packed.data.data, expected)
311        self.assertEqual(packed.sorted_indices, [0, 1, 2])
312        self.assertEqual(packed.unsorted_indices, [0, 1, 2])
313
314        packed_unsorted = rnn_utils.pack_sequence([b, c, a], enforce_sorted=False)
315        self.assertEqual(packed_unsorted.batch_sizes, [3, 2, 1])
316        self.assertEqual(packed_unsorted.data.data, expected)
317        self.assertEqual(packed_unsorted.sorted_indices, [2, 0, 1])
318        self.assertEqual(packed_unsorted.unsorted_indices, [1, 2, 0])
319
320        # single dimensional, enforce_sorted = True
321        packed_enforce_sorted = rnn_utils.pack_sequence([a, b, c], enforce_sorted=True)
322        self.assertEqual(packed_enforce_sorted.batch_sizes, [3, 2, 1])
323        self.assertEqual(packed_enforce_sorted.data.data, expected)
324        self.assertTrue(packed_enforce_sorted.sorted_indices is None)
325        self.assertTrue(packed_enforce_sorted.unsorted_indices is None)
326
327        with self.assertRaisesRegex(RuntimeError, "must be sorted in decreasing order"):
328            rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
329
330        with self.assertRaisesRegex(
331            RuntimeError, "You can pass `enforce_sorted=False`"
332        ):
333            rnn_utils.pack_sequence([b, c, a], enforce_sorted=True)
334
335        # more dimensions
336        maxlen = 9
337        for num_dim in (0, 1, 2, 3):
338            sequences = []
339            lengths = []
340            trailing_dims = [4] * num_dim
341            for i in range(maxlen, 0, -1):
342                seq_len = i * i
343                lengths.append(seq_len)
344                sequences.append(torch.rand(seq_len, 5, *trailing_dims))
345            unsorted_sequences = [s.clone() for s in sequences]
346            random.shuffle(unsorted_sequences)
347            unsorted_sequences_lengths = [t.size(0) for t in unsorted_sequences]
348
349            # compatibility with other utilities
350            for batch_first in (True, False):
351                for enforce_sorted in (True, False):
352                    _compatibility_test(sequences, lengths, batch_first, enforce_sorted)
353                _compatibility_test(
354                    unsorted_sequences, unsorted_sequences_lengths, batch_first
355                )
356
357    def test_unpack_sequence(self):
358        # single dimensional
359        a = torch.tensor([1, 2, 3])
360        b = torch.tensor([4, 5])
361        c = torch.tensor([6])
362        sequences = [a, b, c]
363
364        packed_sequences = rnn_utils.pack_sequence(sequences, enforce_sorted=False)
365        unpacked_sequences = rnn_utils.unpack_sequence(packed_sequences)
366        self.assertEqual(sequences, unpacked_sequences)
367
368        # more dimensions
369        maxlen = 9
370        for num_dim in (0, 1, 2, 3):
371            sequences = []
372            trailing_dims = [4] * num_dim
373            for i in range(1, maxlen + 1):
374                seq_len = i * i
375                sequences.append(torch.rand(seq_len, 5, *trailing_dims))
376            random.shuffle(sequences)
377
378            packed_sequences = rnn_utils.pack_sequence(sequences, enforce_sorted=False)
379            unpacked_sequences = rnn_utils.unpack_sequence(packed_sequences)
380            self.assertEqual(sequences, unpacked_sequences)
381
382    def test_pack_padded_sequence(self):
383        def generate_test_case(sorted_lengths, should_shuffle):
384            def pad(tensor, length):
385                return torch.cat(
386                    [
387                        tensor,
388                        tensor.new(length - tensor.size(0), *tensor.size()[1:]).zero_(),
389                    ]
390                )
391
392            max_length = sorted_lengths[0]
393            batch_sizes = [
394                sum(map(bool, filter(lambda x: x >= i, sorted_lengths)))
395                for i in range(1, max_length + 1)
396            ]
397            offset = 0
398            padded = torch.cat(
399                [
400                    pad(
401                        i * 100 + torch.arange(1.0, 5 * l + 1).view(l, 1, 5), max_length
402                    )
403                    for i, l in enumerate(sorted_lengths, 1)
404                ],
405                1,
406            )
407            expected_data = [
408                [
409                    torch.arange(1.0, 6) + (i + 1) * 100 + 5 * n
410                    for i in range(batch_size)
411                ]
412                for n, batch_size in enumerate(batch_sizes)
413            ]
414            expected_data = list(itertools.chain.from_iterable(expected_data))
415            expected_data = torch.stack(expected_data, dim=0)
416
417            if should_shuffle:
418                # Shuffle the padded sequence to create an unsorted sequence
419                permutation = list(range(len(sorted_lengths)))
420                random.shuffle(permutation)
421
422                unsorted_indices = torch.tensor(permutation)
423                padded = padded.index_select(1, unsorted_indices)
424                lengths = torch.tensor(sorted_lengths).index_select(0, unsorted_indices)
425            else:
426                unsorted_indices = None
427                lengths = sorted_lengths
428
429            return (
430                padded.requires_grad_(),
431                lengths,
432                expected_data,
433                batch_sizes,
434                unsorted_indices,
435            )
436
437        test_cases = [
438            # sorted_lengths, should_shuffle
439            [[10, 8, 4, 2, 2, 2, 1], False],
440            [[11, 10, 8, 6, 4, 3, 1], False],
441            [[11, 10, 8, 6, 4, 3, 1], True],
442        ]
443
444        for test_case, batch_first in itertools.product(test_cases, (True, False)):
445            sorted_lengths, should_shuffle = test_case
446            (
447                padded,
448                lengths,
449                expected_data,
450                batch_sizes,
451                unsorted_indices,
452            ) = generate_test_case(sorted_lengths, should_shuffle)
453
454            src = padded
455            if batch_first:
456                src = src.transpose(0, 1)
457
458            # check output
459            packed = rnn_utils.pack_padded_sequence(
460                src, lengths, batch_first=batch_first, enforce_sorted=not should_shuffle
461            )
462            self.assertEqual(packed.data.data, expected_data)
463            self.assertEqual(packed.batch_sizes, batch_sizes)
464            self.assertEqual(packed.unsorted_indices, unsorted_indices)
465
466            # test inverse
467            unpacked, unpacked_len = rnn_utils.pad_packed_sequence(
468                packed, batch_first=batch_first
469            )
470            self.assertEqual(unpacked, src)
471            self.assertEqual(unpacked_len, lengths)
472
473            # check grad
474            if padded.grad is not None:
475                padded.grad.data.zero_()
476            grad_output = unpacked.data.clone().normal_()
477            unpacked.backward(grad_output)
478            if batch_first:
479                grad_output.transpose_(0, 1)
480            for i, l in enumerate(lengths):
481                self.assertEqual(padded.grad.data[:l, i], grad_output[:l, i])
482                if l < 10:
483                    self.assertEqual(padded.grad.data[l:, i].abs().sum(), 0)
484
485        # test error messages
486        with self.assertRaisesRegex(
487            RuntimeError, "You can pass `enforce_sorted=False`"
488        ):
489            packed = rnn_utils.pack_padded_sequence(torch.randn(3, 3), [1, 3, 2])
490        with self.assertRaisesRegex(RuntimeError, "empty tensor"):
491            packed = rnn_utils.pack_padded_sequence(torch.randn(0, 0), [])
492        with self.assertRaisesRegex(RuntimeError, "empty tensor"):
493            packed = rnn_utils.pack_padded_sequence(
494                torch.randn([0, 1, 10]), torch.randn([11, 14, 14, 2]), True
495            )
496
497
498if __name__ == "__main__":
499    run_tests()
500