xref: /aosp_15_r20/external/pytorch/test/jit/test_scriptmod_ann.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 # Owner(s): ["oncall: jit"]
2 
3 import os
4 import sys
5 import unittest
6 import warnings
7 from typing import Dict, List, Optional
8 
9 import torch
10 
11 
12 # Make the helper files in test/ importable
13 pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
14 sys.path.append(pytorch_test_dir)
15 from torch.testing._internal.jit_utils import JitTestCase
16 
17 
18 if __name__ == "__main__":
19     raise RuntimeError(
20         "This test file is not meant to be run directly, use:\n\n"
21         "\tpython test/test_jit.py TESTNAME\n\n"
22         "instead."
23     )
24 
25 
26 class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
27     # NB: There are no tests for `Tuple` or `NamedTuple` here. In fact,
28     # reassigning a non-empty Tuple to an attribute previously typed
29     # as containing an empty Tuple SHOULD fail. See note in `_check.py`
30 
31     def test_annotated_falsy_base_type(self):
32         class M(torch.nn.Module):
33             def __init__(self) -> None:
34                 super().__init__()
35                 self.x: int = 0
36 
37             def forward(self, x: int):
38                 self.x = x
39                 return 1
40 
41         with warnings.catch_warnings(record=True) as w:
42             self.checkModule(M(), (1,))
43         assert len(w) == 0
44 
45     def test_annotated_nonempty_container(self):
46         class M(torch.nn.Module):
47             def __init__(self) -> None:
48                 super().__init__()
49                 self.x: List[int] = [1, 2, 3]
50 
51             def forward(self, x: List[int]):
52                 self.x = x
53                 return 1
54 
55         with warnings.catch_warnings(record=True) as w:
56             self.checkModule(M(), ([1, 2, 3],))
57         assert len(w) == 0
58 
59     def test_annotated_empty_tensor(self):
60         class M(torch.nn.Module):
61             def __init__(self) -> None:
62                 super().__init__()
63                 self.x: torch.Tensor = torch.empty(0)
64 
65             def forward(self, x: torch.Tensor):
66                 self.x = x
67                 return self.x
68 
69         with warnings.catch_warnings(record=True) as w:
70             self.checkModule(M(), (torch.rand(2, 3),))
71         assert len(w) == 0
72 
73     def test_annotated_with_jit_attribute(self):
74         class M(torch.nn.Module):
75             def __init__(self) -> None:
76                 super().__init__()
77                 self.x = torch.jit.Attribute([], List[int])
78 
79             def forward(self, x: List[int]):
80                 self.x = x
81                 return self.x
82 
83         with warnings.catch_warnings(record=True) as w:
84             self.checkModule(M(), ([1, 2, 3],))
85         assert len(w) == 0
86 
87     def test_annotated_class_level_annotation_only(self):
88         class M(torch.nn.Module):
89             x: List[int]
90 
91             def __init__(self) -> None:
92                 super().__init__()
93                 self.x = []
94 
95             def forward(self, y: List[int]):
96                 self.x = y
97                 return self.x
98 
99         with warnings.catch_warnings(record=True) as w:
100             self.checkModule(M(), ([1, 2, 3],))
101         assert len(w) == 0
102 
103     def test_annotated_class_level_annotation_and_init_annotation(self):
104         class M(torch.nn.Module):
105             x: List[int]
106 
107             def __init__(self) -> None:
108                 super().__init__()
109                 self.x: List[int] = []
110 
111             def forward(self, y: List[int]):
112                 self.x = y
113                 return self.x
114 
115         with warnings.catch_warnings(record=True) as w:
116             self.checkModule(M(), ([1, 2, 3],))
117         assert len(w) == 0
118 
119     def test_annotated_class_level_jit_annotation(self):
120         class M(torch.nn.Module):
121             x: List[int]
122 
123             def __init__(self) -> None:
124                 super().__init__()
125                 self.x: List[int] = torch.jit.annotate(List[int], [])
126 
127             def forward(self, y: List[int]):
128                 self.x = y
129                 return self.x
130 
131         with warnings.catch_warnings(record=True) as w:
132             self.checkModule(M(), ([1, 2, 3],))
133         assert len(w) == 0
134 
135     def test_annotated_empty_list(self):
136         class M(torch.nn.Module):
137             def __init__(self) -> None:
138                 super().__init__()
139                 self.x: List[int] = []
140 
141             def forward(self, x: List[int]):
142                 self.x = x
143                 return 1
144 
145         with self.assertRaisesRegexWithHighlight(
146             RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
147         ):
148             with self.assertWarnsRegex(
149                 UserWarning,
150                 "doesn't support "
151                 "instance-level annotations on "
152                 "empty non-base types",
153             ):
154                 torch.jit.script(M())
155 
156     @unittest.skipIf(
157         sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
158     )
159     def test_annotated_empty_list_lowercase(self):
160         class M(torch.nn.Module):
161             def __init__(self) -> None:
162                 super().__init__()
163                 self.x: list[int] = []
164 
165             def forward(self, x: list[int]):
166                 self.x = x
167                 return 1
168 
169         with self.assertRaisesRegexWithHighlight(
170             RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
171         ):
172             with self.assertWarnsRegex(
173                 UserWarning,
174                 "doesn't support "
175                 "instance-level annotations on "
176                 "empty non-base types",
177             ):
178                 torch.jit.script(M())
179 
180     def test_annotated_empty_dict(self):
181         class M(torch.nn.Module):
182             def __init__(self) -> None:
183                 super().__init__()
184                 self.x: Dict[str, int] = {}
185 
186             def forward(self, x: Dict[str, int]):
187                 self.x = x
188                 return 1
189 
190         with self.assertRaisesRegexWithHighlight(
191             RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
192         ):
193             with self.assertWarnsRegex(
194                 UserWarning,
195                 "doesn't support "
196                 "instance-level annotations on "
197                 "empty non-base types",
198             ):
199                 torch.jit.script(M())
200 
201     @unittest.skipIf(
202         sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
203     )
204     def test_annotated_empty_dict_lowercase(self):
205         class M(torch.nn.Module):
206             def __init__(self) -> None:
207                 super().__init__()
208                 self.x: dict[str, int] = {}
209 
210             def forward(self, x: dict[str, int]):
211                 self.x = x
212                 return 1
213 
214         with self.assertRaisesRegexWithHighlight(
215             RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
216         ):
217             with self.assertWarnsRegex(
218                 UserWarning,
219                 "doesn't support "
220                 "instance-level annotations on "
221                 "empty non-base types",
222             ):
223                 torch.jit.script(M())
224 
225     def test_annotated_empty_optional(self):
226         class M(torch.nn.Module):
227             def __init__(self) -> None:
228                 super().__init__()
229                 self.x: Optional[str] = None
230 
231             def forward(self, x: Optional[str]):
232                 self.x = x
233                 return 1
234 
235         with self.assertRaisesRegexWithHighlight(
236             RuntimeError, "Wrong type for attribute assignment", "self.x = x"
237         ):
238             with self.assertWarnsRegex(
239                 UserWarning,
240                 "doesn't support "
241                 "instance-level annotations on "
242                 "empty non-base types",
243             ):
244                 torch.jit.script(M())
245 
246     def test_annotated_with_jit_empty_list(self):
247         class M(torch.nn.Module):
248             def __init__(self) -> None:
249                 super().__init__()
250                 self.x = torch.jit.annotate(List[int], [])
251 
252             def forward(self, x: List[int]):
253                 self.x = x
254                 return 1
255 
256         with self.assertRaisesRegexWithHighlight(
257             RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
258         ):
259             with self.assertWarnsRegex(
260                 UserWarning,
261                 "doesn't support "
262                 "instance-level annotations on "
263                 "empty non-base types",
264             ):
265                 torch.jit.script(M())
266 
267     @unittest.skipIf(
268         sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
269     )
270     def test_annotated_with_jit_empty_list_lowercase(self):
271         class M(torch.nn.Module):
272             def __init__(self) -> None:
273                 super().__init__()
274                 self.x = torch.jit.annotate(list[int], [])
275 
276             def forward(self, x: list[int]):
277                 self.x = x
278                 return 1
279 
280         with self.assertRaisesRegexWithHighlight(
281             RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
282         ):
283             with self.assertWarnsRegex(
284                 UserWarning,
285                 "doesn't support "
286                 "instance-level annotations on "
287                 "empty non-base types",
288             ):
289                 torch.jit.script(M())
290 
291     def test_annotated_with_jit_empty_dict(self):
292         class M(torch.nn.Module):
293             def __init__(self) -> None:
294                 super().__init__()
295                 self.x = torch.jit.annotate(Dict[str, int], {})
296 
297             def forward(self, x: Dict[str, int]):
298                 self.x = x
299                 return 1
300 
301         with self.assertRaisesRegexWithHighlight(
302             RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
303         ):
304             with self.assertWarnsRegex(
305                 UserWarning,
306                 "doesn't support "
307                 "instance-level annotations on "
308                 "empty non-base types",
309             ):
310                 torch.jit.script(M())
311 
312     @unittest.skipIf(
313         sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
314     )
315     def test_annotated_with_jit_empty_dict_lowercase(self):
316         class M(torch.nn.Module):
317             def __init__(self) -> None:
318                 super().__init__()
319                 self.x = torch.jit.annotate(dict[str, int], {})
320 
321             def forward(self, x: dict[str, int]):
322                 self.x = x
323                 return 1
324 
325         with self.assertRaisesRegexWithHighlight(
326             RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
327         ):
328             with self.assertWarnsRegex(
329                 UserWarning,
330                 "doesn't support "
331                 "instance-level annotations on "
332                 "empty non-base types",
333             ):
334                 torch.jit.script(M())
335 
336     def test_annotated_with_jit_empty_optional(self):
337         class M(torch.nn.Module):
338             def __init__(self) -> None:
339                 super().__init__()
340                 self.x = torch.jit.annotate(Optional[str], None)
341 
342             def forward(self, x: Optional[str]):
343                 self.x = x
344                 return 1
345 
346         with self.assertRaisesRegexWithHighlight(
347             RuntimeError, "Wrong type for attribute assignment", "self.x = x"
348         ):
349             with self.assertWarnsRegex(
350                 UserWarning,
351                 "doesn't support "
352                 "instance-level annotations on "
353                 "empty non-base types",
354             ):
355                 torch.jit.script(M())
356 
357     def test_annotated_with_torch_jit_import(self):
358         from torch import jit
359 
360         class M(torch.nn.Module):
361             def __init__(self) -> None:
362                 super().__init__()
363                 self.x = jit.annotate(Optional[str], None)
364 
365             def forward(self, x: Optional[str]):
366                 self.x = x
367                 return 1
368 
369         with self.assertRaisesRegexWithHighlight(
370             RuntimeError, "Wrong type for attribute assignment", "self.x = x"
371         ):
372             with self.assertWarnsRegex(
373                 UserWarning,
374                 "doesn't support "
375                 "instance-level annotations on "
376                 "empty non-base types",
377             ):
378                 torch.jit.script(M())
379