xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/iter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 # mypy: ignore-errors
2 
3 import itertools
4 import operator
5 import sys
6 from typing import Dict, List, Optional, TYPE_CHECKING, Union
7 
8 from .. import polyfills, variables
9 from ..bytecode_transformation import create_call_function, create_instruction
10 from ..exc import (
11     handle_observed_exception,
12     ObservedUserStopIteration,
13     raise_observed_exception,
14     unimplemented,
15     UserError,
16 )
17 from .base import MutableLocal, VariableTracker
18 from .constant import ConstantVariable
19 
20 
21 if TYPE_CHECKING:
22     from torch._dynamo.symbolic_convert import InstructionTranslator
23 
24 
25 MAX_ITERATOR_LIMIT = 100 * 1024  # 100k
26 
27 
28 class ItertoolsVariable(VariableTracker):
29     def __init__(self, value, **kwargs) -> None:
30         super().__init__(**kwargs)
31         self.value = value
32 
33     def __repr__(self) -> str:
34         return f"ItertoolsVariable({self.value})"
35 
36     def as_python_constant(self):
37         return self.value
38 
39     def call_function(
40         self,
41         tx: "InstructionTranslator",
42         args: "List[VariableTracker]",
43         kwargs: "Dict[str, VariableTracker]",
44     ) -> "VariableTracker":
45         if (
46             self.value is itertools.product
47             and not kwargs
48             and all(arg.has_unpack_var_sequence(tx) for arg in args)
49         ):
50             seqs = [arg.unpack_var_sequence(tx) for arg in args]
51             items = []
52             for item in itertools.product(*seqs):
53                 items.append(variables.TupleVariable(list(item)))
54             return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
55         elif self.value is itertools.accumulate:
56             from .builtin import BuiltinVariable
57 
58             if any(key not in ["initial", "func"] for key in kwargs.keys()):
59                 unimplemented(
60                     "Unsupported kwargs for itertools.accumulate: "
61                     f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}"
62                 )
63 
64             acc = kwargs.get("initial")
65 
66             if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx):
67                 seq = args[0].unpack_var_sequence(tx)
68 
69                 if "func" in kwargs and len(args) == 1:
70                     func = kwargs["func"].call_function
71                 elif len(args) == 2:
72                     func = args[1].call_function
73                 elif len(args) == 1:
74                     # Default to operator.add
75                     func = BuiltinVariable(operator.add).call_function
76                 else:
77                     unimplemented(
78                         "itertools.accumulate can only accept one of: `func` kwarg, pos 2 arg"
79                     )
80             else:
81                 unimplemented("Unsupported arguments for itertools.accumulate")
82 
83             items = []
84             if acc is not None:
85                 items.append(acc)
86             for item in seq:
87                 if acc is None:
88                     acc = item
89                 else:
90                     try:
91                         acc = func(tx, [acc, item], {})
92                     except Exception as e:
93                         unimplemented(
94                             f"Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})",
95                             from_exc=e,
96                         )
97                 items.append(acc)
98 
99             return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
100         elif (
101             self.value is itertools.combinations
102             and not kwargs
103             and len(args) == 2
104             and args[0].has_unpack_var_sequence(tx)
105             and args[1].is_python_constant()
106         ):
107             iterable = args[0].unpack_var_sequence(tx)
108             r = args[1].as_python_constant()
109 
110             items = []
111             for item in itertools.combinations(iterable, r):
112                 items.append(variables.TupleVariable(list(item)))
113             return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
114         elif self.value is itertools.groupby:
115             if any(kw != "key" for kw in kwargs.keys()):
116                 unimplemented(
117                     "Unsupported kwargs for itertools.groupby: "
118                     f"{','.join(set(kwargs.keys()) - {'key'})}"
119                 )
120 
121             def retrieve_const_key(key):
122                 if isinstance(key, variables.SymNodeVariable):
123                     return key.evaluate_expr()
124                 elif isinstance(key, variables.ConstantVariable):
125                     return key.as_python_constant()
126                 else:
127                     unimplemented(
128                         "Unsupported key type for itertools.groupby: " + str(type(key))
129                     )
130 
131             if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
132                 seq = args[0].unpack_var_sequence(tx)
133                 keyfunc = (
134                     (
135                         lambda x: (
136                             retrieve_const_key(
137                                 kwargs.get("key").call_function(tx, [x], {})
138                             )
139                         )
140                     )
141                     if "key" in kwargs
142                     else None
143                 )
144             else:
145                 unimplemented("Unsupported arguments for itertools.groupby")
146 
147             result = []
148             try:
149                 for k, v in itertools.groupby(seq, key=keyfunc):
150                     result.append(
151                         variables.TupleVariable(
152                             [
153                                 variables.ConstantVariable.create(k)
154                                 if variables.ConstantVariable.is_literal(k)
155                                 else k,
156                                 variables.ListIteratorVariable(
157                                     list(v), mutable_local=MutableLocal()
158                                 ),
159                             ],
160                             mutable_local=MutableLocal(),
161                         )
162                     )
163             except Exception as e:
164                 unimplemented(
165                     "Unexpected failure when calling itertools.groupby",
166                     from_exc=e,
167                 )
168             return variables.ListIteratorVariable(result, mutable_local=MutableLocal())
169         elif self.value is itertools.repeat:
170             if len(args) < 2:
171                 return variables.RepeatIteratorVariable(
172                     *args, mutable_local=MutableLocal()
173                 )
174 
175             from .builder import SourcelessBuilder
176 
177             return tx.inline_user_function_return(
178                 SourcelessBuilder.create(tx, polyfills.repeat), args, kwargs
179             )
180         elif self.value is itertools.count:
181             return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())
182         elif self.value is itertools.cycle:
183             return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal())
184         elif self.value is itertools.dropwhile:
185             return variables.UserFunctionVariable(polyfills.dropwhile).call_function(
186                 tx, args, kwargs
187             )
188         elif self.value is itertools.zip_longest:
189             return variables.UserFunctionVariable(polyfills.zip_longest).call_function(
190                 tx, args, kwargs
191             )
192         else:
193             return super().call_function(tx, args, kwargs)
194 
195 
196 class IteratorVariable(VariableTracker):
197     def __init__(self, **kwargs) -> None:
198         super().__init__(**kwargs)
199 
200     def next_variable(self, tx):
201         unimplemented("abstract method, must implement")
202 
203     # NOTE: only call when unpacking this iterator safely done eagerly!
204     # Normally, iterators are accessed lazily.
205     # Example of safe eager unpacking: list(map(f, seq))
206     # Example of unsafe eager unpacking: list(islice(map(f, seq), 5))
207     def force_unpack_var_sequence(self, tx) -> List[VariableTracker]:
208         result = []
209         while True:
210             try:
211                 result.append(self.next_variable(tx))
212             except ObservedUserStopIteration:
213                 handle_observed_exception(tx)
214                 break
215         return result
216 
217     # don't call force_unpack_var_sequence since it can mutate
218     # IteratorVariable state!
219     def has_force_unpack_var_sequence(self, tx) -> bool:
220         return True
221 
222 
223 class RepeatIteratorVariable(IteratorVariable):
224     def __init__(self, item: VariableTracker, **kwargs) -> None:
225         super().__init__(**kwargs)
226         self.item = item
227 
228     # Repeat needs no mutation, clone self
229     def next_variable(self, tx):
230         return self.item
231 
232     def reconstruct(self, codegen):
233         codegen.add_push_null(
234             lambda: codegen.extend_output(
235                 [
236                     codegen.create_load_python_module(itertools),
237                     codegen.create_load_attr("repeat"),
238                 ]
239             )
240         )
241         codegen(self.item)
242         codegen.extend_output(create_call_function(1, False))
243 
244 
245 class CountIteratorVariable(IteratorVariable):
246     def __init__(self, item: int = 0, step: int = 1, **kwargs) -> None:
247         super().__init__(**kwargs)
248         if not isinstance(item, VariableTracker):
249             item = ConstantVariable.create(item)
250         if not isinstance(step, VariableTracker):
251             step = ConstantVariable.create(step)
252         self.item = item
253         self.step = step
254 
255     def next_variable(self, tx):
256         assert self.mutable_local
257         old_item = self.item
258         tx.output.side_effects.mutation(self)
259         self.item = self.item.call_method(tx, "__add__", [self.step], {})
260         return old_item
261 
262     def reconstruct(self, codegen):
263         codegen.add_push_null(
264             lambda: codegen.extend_output(
265                 [
266                     codegen.create_load_python_module(itertools),
267                     codegen.create_load_attr("count"),
268                 ]
269             )
270         )
271         codegen(self.item)
272         codegen(self.step)
273         codegen.extend_output(create_call_function(2, False))
274 
275 
276 class CycleIteratorVariable(IteratorVariable):
277     def __init__(
278         self,
279         iterator: IteratorVariable,
280         saved: List[VariableTracker] = None,
281         saved_index: int = 0,
282         item: Optional[VariableTracker] = None,
283         **kwargs,
284     ) -> None:
285         if saved is None:
286             saved = []
287         super().__init__(**kwargs)
288         self.iterator = iterator
289         self.saved = saved
290         self.saved_index = saved_index
291         self.item = item
292 
293     def next_variable(self, tx):
294         assert self.mutable_local
295 
296         if self.iterator is not None:
297             try:
298                 new_item = self.iterator.next_variable(tx)
299                 if len(self.saved) > MAX_ITERATOR_LIMIT:
300                     unimplemented(
301                         "input iterator to itertools.cycle has too many items"
302                     )
303                 tx.output.side_effects.mutation(self)
304                 self.saved.append(new_item)
305                 self.item = new_item
306                 if self.item is None:
307                     return self.next_variable(tx)
308                 return self.item
309             except ObservedUserStopIteration:
310                 handle_observed_exception(tx)
311                 self.iterator = None
312                 return self.next_variable(tx)
313         elif len(self.saved) > 0:
314             tx.output.side_effects.mutation(self)
315             self.saved_index = (self.saved_index + 1) % len(self.saved)
316             return self.item
317         else:
318             raise_observed_exception(StopIteration, tx, self)
319 
320 
321 class ZipVariable(IteratorVariable):
322     """
323     Represents zip(*iterables)
324     """
325 
326     _nonvar_fields = {
327         "index",
328         "strict",
329         *IteratorVariable._nonvar_fields,
330     }
331 
332     def __init__(
333         self,
334         iterables: List[Union[List[VariableTracker], VariableTracker]],
335         strict: bool = False,
336         **kwargs,
337     ) -> None:
338         super().__init__(**kwargs)
339         assert isinstance(iterables, list)
340         # can be list[Variable] or VariableTracker (with next_variable implemented)
341         self.iterables = iterables
342         self.index = 0
343         self.strict = strict
344 
345     def python_type(self):
346         return zip
347 
348     def has_unpack_var_sequence(self, tx) -> bool:
349         return all(
350             isinstance(it, list) or it.has_unpack_var_sequence(tx)
351             for it in self.iterables
352         )
353 
354     def unpack_var_sequence(self, tx) -> List["VariableTracker"]:
355         assert self.has_unpack_var_sequence(tx)
356         iterables = []
357         for it in self.iterables:
358             if isinstance(it, list):
359                 iterables.append(it[self.index :])
360             else:
361                 iterables.append(it.unpack_var_sequence(tx))
362         kwargs = {"strict": self.strict} if self.strict else {}
363         zipped = zip(*iterables, **kwargs)
364         return [variables.TupleVariable(list(var)) for var in zipped]
365 
366     def next_variable(self, tx):
367         assert self.mutable_local
368         old_index = self.index
369         args = []
370 
371         def get_item(it):
372             if isinstance(it, list):
373                 if old_index >= len(it):
374                     raise_observed_exception(StopIteration, tx, self)
375                 return it[old_index]
376             else:
377                 return it.next_variable(tx)
378 
379         try:
380             for idx, it in enumerate(self.iterables):
381                 args.append(get_item(it))
382         except ObservedUserStopIteration:
383             if self.strict:
384                 if idx == 0:
385                     # all other iterables should be exhausted
386                     for it in self.iterables:
387                         try:
388                             get_item(it)
389                         except ObservedUserStopIteration:
390                             handle_observed_exception(tx)
391                             continue
392                         # no ObservedUserStopIteration - fall through to UserError
393                         break
394                     else:
395                         # all iterables exhausted, raise original error
396                         raise
397                 handle_observed_exception(tx)
398                 raise UserError(
399                     ValueError,
400                     "zip() has one argument of len differing from others",
401                 ) from None
402             raise
403 
404         tx.output.side_effects.mutation(self)
405         self.index += 1
406         return variables.TupleVariable(args)
407 
408     def reconstruct_items(self, codegen):
409         for it in self.iterables:
410             if isinstance(it, list):
411                 remaining_items = it[self.index :]
412                 codegen.foreach(remaining_items)
413                 codegen.append_output(
414                     create_instruction("BUILD_TUPLE", arg=len(remaining_items))
415                 )
416             else:
417                 codegen(it)
418 
419     def reconstruct(self, codegen):
420         codegen.add_push_null(
421             lambda: codegen.load_import_from("builtins", "zip"), call_function_ex=True
422         )
423         self.reconstruct_items(codegen)
424         codegen.append_output(
425             create_instruction("BUILD_TUPLE", arg=len(self.iterables))
426         )
427         if sys.version_info >= (3, 10):
428             codegen.extend_output(
429                 [
430                     codegen.create_load_const("strict"),
431                     codegen.create_load_const(self.strict),
432                     create_instruction("BUILD_MAP", arg=1),
433                     create_instruction("CALL_FUNCTION_EX", arg=1),
434                 ]
435             )
436         else:
437             codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=0))
438 
439 
440 class MapVariable(ZipVariable):
441     """
442     Represents map(fn, *iterables)
443     """
444 
445     def __init__(
446         self,
447         fn: VariableTracker,
448         iterables: List[Union[List[VariableTracker], VariableTracker]],
449         **kwargs,
450     ) -> None:
451         super().__init__(iterables, **kwargs)
452         self.fn = fn
453 
454     def python_type(self):
455         return map
456 
457     def has_unpack_var_sequence(self, tx) -> bool:
458         return False
459 
460     def next_variable(self, tx):
461         args = super().next_variable(tx)
462         return self.fn.call_function(tx, args.items, {})
463 
464     def reconstruct(self, codegen):
465         codegen.add_push_null(
466             lambda: codegen.load_import_from("builtins", "map"), call_function_ex=True
467         )
468         codegen(self.fn)
469         self.reconstruct_items(codegen)
470         codegen.extend_output(
471             [
472                 create_instruction("BUILD_TUPLE", arg=len(self.iterables) + 1),
473                 create_instruction("CALL_FUNCTION_EX", arg=0),
474             ]
475         )
476