1import os.path
2import token
3from typing import IO, Any, Dict, Optional, Sequence, Set, Text, Tuple
4
5from pegen import grammar
6from pegen.grammar import (
7    Alt,
8    Cut,
9    Forced,
10    Gather,
11    GrammarVisitor,
12    Group,
13    Lookahead,
14    NamedItem,
15    NameLeaf,
16    NegativeLookahead,
17    Opt,
18    PositiveLookahead,
19    Repeat0,
20    Repeat1,
21    Rhs,
22    Rule,
23    StringLeaf,
24)
25from pegen.parser_generator import ParserGenerator
26
27MODULE_PREFIX = """\
28#!/usr/bin/env python3.8
29# @generated by pegen from {filename}
30
31import ast
32import sys
33import tokenize
34
35from typing import Any, Optional
36
37from pegen.parser import memoize, memoize_left_rec, logger, Parser
38
39"""
40MODULE_SUFFIX = """
41
42if __name__ == '__main__':
43    from pegen.parser import simple_parser_main
44    simple_parser_main({class_name})
45"""
46
47
48class InvalidNodeVisitor(GrammarVisitor):
49    def visit_NameLeaf(self, node: NameLeaf) -> bool:
50        name = node.value
51        return name.startswith("invalid")
52
53    def visit_StringLeaf(self, node: StringLeaf) -> bool:
54        return False
55
56    def visit_NamedItem(self, node: NamedItem) -> bool:
57        return self.visit(node.item)
58
59    def visit_Rhs(self, node: Rhs) -> bool:
60        return any(self.visit(alt) for alt in node.alts)
61
62    def visit_Alt(self, node: Alt) -> bool:
63        return any(self.visit(item) for item in node.items)
64
65    def lookahead_call_helper(self, node: Lookahead) -> bool:
66        return self.visit(node.node)
67
68    def visit_PositiveLookahead(self, node: PositiveLookahead) -> bool:
69        return self.lookahead_call_helper(node)
70
71    def visit_NegativeLookahead(self, node: NegativeLookahead) -> bool:
72        return self.lookahead_call_helper(node)
73
74    def visit_Opt(self, node: Opt) -> bool:
75        return self.visit(node.node)
76
77    def visit_Repeat(self, node: Repeat0) -> Tuple[str, str]:
78        return self.visit(node.node)
79
80    def visit_Gather(self, node: Gather) -> Tuple[str, str]:
81        return self.visit(node.node)
82
83    def visit_Group(self, node: Group) -> bool:
84        return self.visit(node.rhs)
85
86    def visit_Cut(self, node: Cut) -> bool:
87        return False
88
89    def visit_Forced(self, node: Forced) -> bool:
90        return self.visit(node.node)
91
92
93class PythonCallMakerVisitor(GrammarVisitor):
94    def __init__(self, parser_generator: ParserGenerator):
95        self.gen = parser_generator
96        self.cache: Dict[Any, Any] = {}
97
98    def visit_NameLeaf(self, node: NameLeaf) -> Tuple[Optional[str], str]:
99        name = node.value
100        if name == "SOFT_KEYWORD":
101            return "soft_keyword", "self.soft_keyword()"
102        if name in ("NAME", "NUMBER", "STRING", "OP", "TYPE_COMMENT"):
103            name = name.lower()
104            return name, f"self.{name}()"
105        if name in ("NEWLINE", "DEDENT", "INDENT", "ENDMARKER", "ASYNC", "AWAIT"):
106            # Avoid using names that can be Python keywords
107            return "_" + name.lower(), f"self.expect({name!r})"
108        return name, f"self.{name}()"
109
110    def visit_StringLeaf(self, node: StringLeaf) -> Tuple[str, str]:
111        return "literal", f"self.expect({node.value})"
112
113    def visit_Rhs(self, node: Rhs) -> Tuple[Optional[str], str]:
114        if node in self.cache:
115            return self.cache[node]
116        if len(node.alts) == 1 and len(node.alts[0].items) == 1:
117            self.cache[node] = self.visit(node.alts[0].items[0])
118        else:
119            name = self.gen.artifical_rule_from_rhs(node)
120            self.cache[node] = name, f"self.{name}()"
121        return self.cache[node]
122
123    def visit_NamedItem(self, node: NamedItem) -> Tuple[Optional[str], str]:
124        name, call = self.visit(node.item)
125        if node.name:
126            name = node.name
127        return name, call
128
129    def lookahead_call_helper(self, node: Lookahead) -> Tuple[str, str]:
130        name, call = self.visit(node.node)
131        head, tail = call.split("(", 1)
132        assert tail[-1] == ")"
133        tail = tail[:-1]
134        return head, tail
135
136    def visit_PositiveLookahead(self, node: PositiveLookahead) -> Tuple[None, str]:
137        head, tail = self.lookahead_call_helper(node)
138        return None, f"self.positive_lookahead({head}, {tail})"
139
140    def visit_NegativeLookahead(self, node: NegativeLookahead) -> Tuple[None, str]:
141        head, tail = self.lookahead_call_helper(node)
142        return None, f"self.negative_lookahead({head}, {tail})"
143
144    def visit_Opt(self, node: Opt) -> Tuple[str, str]:
145        name, call = self.visit(node.node)
146        # Note trailing comma (the call may already have one comma
147        # at the end, for example when rules have both repeat0 and optional
148        # markers, e.g: [rule*])
149        if call.endswith(","):
150            return "opt", call
151        else:
152            return "opt", f"{call},"
153
154    def visit_Repeat0(self, node: Repeat0) -> Tuple[str, str]:
155        if node in self.cache:
156            return self.cache[node]
157        name = self.gen.artificial_rule_from_repeat(node.node, False)
158        self.cache[node] = name, f"self.{name}(),"  # Also a trailing comma!
159        return self.cache[node]
160
161    def visit_Repeat1(self, node: Repeat1) -> Tuple[str, str]:
162        if node in self.cache:
163            return self.cache[node]
164        name = self.gen.artificial_rule_from_repeat(node.node, True)
165        self.cache[node] = name, f"self.{name}()"  # But no trailing comma here!
166        return self.cache[node]
167
168    def visit_Gather(self, node: Gather) -> Tuple[str, str]:
169        if node in self.cache:
170            return self.cache[node]
171        name = self.gen.artifical_rule_from_gather(node)
172        self.cache[node] = name, f"self.{name}()"  # No trailing comma here either!
173        return self.cache[node]
174
175    def visit_Group(self, node: Group) -> Tuple[Optional[str], str]:
176        return self.visit(node.rhs)
177
178    def visit_Cut(self, node: Cut) -> Tuple[str, str]:
179        return "cut", "True"
180
181    def visit_Forced(self, node: Forced) -> Tuple[str, str]:
182        if isinstance(node.node, Group):
183            _, val = self.visit(node.node.rhs)
184            return "forced", f"self.expect_forced({val}, '''({node.node.rhs!s})''')"
185        else:
186            return (
187                "forced",
188                f"self.expect_forced(self.expect({node.node.value}), {node.node.value!r})",
189            )
190
191
192class PythonParserGenerator(ParserGenerator, GrammarVisitor):
193    def __init__(
194        self,
195        grammar: grammar.Grammar,
196        file: Optional[IO[Text]],
197        tokens: Set[str] = set(token.tok_name.values()),
198        location_formatting: Optional[str] = None,
199        unreachable_formatting: Optional[str] = None,
200    ):
201        tokens.add("SOFT_KEYWORD")
202        super().__init__(grammar, tokens, file)
203        self.callmakervisitor: PythonCallMakerVisitor = PythonCallMakerVisitor(self)
204        self.invalidvisitor: InvalidNodeVisitor = InvalidNodeVisitor()
205        self.unreachable_formatting = unreachable_formatting or "None  # pragma: no cover"
206        self.location_formatting = (
207            location_formatting
208            or "lineno=start_lineno, col_offset=start_col_offset, "
209            "end_lineno=end_lineno, end_col_offset=end_col_offset"
210        )
211
212    def generate(self, filename: str) -> None:
213        self.collect_rules()
214        header = self.grammar.metas.get("header", MODULE_PREFIX)
215        if header is not None:
216            basename = os.path.basename(filename)
217            self.print(header.rstrip("\n").format(filename=basename))
218        subheader = self.grammar.metas.get("subheader", "")
219        if subheader:
220            self.print(subheader)
221        cls_name = self.grammar.metas.get("class", "GeneratedParser")
222        self.print("# Keywords and soft keywords are listed at the end of the parser definition.")
223        self.print(f"class {cls_name}(Parser):")
224        for rule in self.all_rules.values():
225            self.print()
226            with self.indent():
227                self.visit(rule)
228
229        self.print()
230        with self.indent():
231            self.print(f"KEYWORDS = {tuple(self.keywords)}")
232            self.print(f"SOFT_KEYWORDS = {tuple(self.soft_keywords)}")
233
234        trailer = self.grammar.metas.get("trailer", MODULE_SUFFIX.format(class_name=cls_name))
235        if trailer is not None:
236            self.print(trailer.rstrip("\n"))
237
238    def alts_uses_locations(self, alts: Sequence[Alt]) -> bool:
239        for alt in alts:
240            if alt.action and "LOCATIONS" in alt.action:
241                return True
242            for n in alt.items:
243                if isinstance(n.item, Group) and self.alts_uses_locations(n.item.rhs.alts):
244                    return True
245        return False
246
247    def visit_Rule(self, node: Rule) -> None:
248        is_loop = node.is_loop()
249        is_gather = node.is_gather()
250        rhs = node.flatten()
251        if node.left_recursive:
252            if node.leader:
253                self.print("@memoize_left_rec")
254            else:
255                # Non-leader rules in a cycle are not memoized,
256                # but they must still be logged.
257                self.print("@logger")
258        else:
259            self.print("@memoize")
260        node_type = node.type or "Any"
261        self.print(f"def {node.name}(self) -> Optional[{node_type}]:")
262        with self.indent():
263            self.print(f"# {node.name}: {rhs}")
264            self.print("mark = self._mark()")
265            if self.alts_uses_locations(node.rhs.alts):
266                self.print("tok = self._tokenizer.peek()")
267                self.print("start_lineno, start_col_offset = tok.start")
268            if is_loop:
269                self.print("children = []")
270            self.visit(rhs, is_loop=is_loop, is_gather=is_gather)
271            if is_loop:
272                self.print("return children")
273            else:
274                self.print("return None")
275
276    def visit_NamedItem(self, node: NamedItem) -> None:
277        name, call = self.callmakervisitor.visit(node.item)
278        if node.name:
279            name = node.name
280        if not name:
281            self.print(call)
282        else:
283            if name != "cut":
284                name = self.dedupe(name)
285            self.print(f"({name} := {call})")
286
287    def visit_Rhs(self, node: Rhs, is_loop: bool = False, is_gather: bool = False) -> None:
288        if is_loop:
289            assert len(node.alts) == 1
290        for alt in node.alts:
291            self.visit(alt, is_loop=is_loop, is_gather=is_gather)
292
293    def visit_Alt(self, node: Alt, is_loop: bool, is_gather: bool) -> None:
294        has_cut = any(isinstance(item.item, Cut) for item in node.items)
295        with self.local_variable_context():
296            if has_cut:
297                self.print("cut = False")
298            if is_loop:
299                self.print("while (")
300            else:
301                self.print("if (")
302            with self.indent():
303                first = True
304                for item in node.items:
305                    if first:
306                        first = False
307                    else:
308                        self.print("and")
309                    self.visit(item)
310                    if is_gather:
311                        self.print("is not None")
312
313            self.print("):")
314            with self.indent():
315                action = node.action
316                if not action:
317                    if is_gather:
318                        assert len(self.local_variable_names) == 2
319                        action = (
320                            f"[{self.local_variable_names[0]}] + {self.local_variable_names[1]}"
321                        )
322                    else:
323                        if self.invalidvisitor.visit(node):
324                            action = "UNREACHABLE"
325                        elif len(self.local_variable_names) == 1:
326                            action = f"{self.local_variable_names[0]}"
327                        else:
328                            action = f"[{', '.join(self.local_variable_names)}]"
329                elif "LOCATIONS" in action:
330                    self.print("tok = self._tokenizer.get_last_non_whitespace_token()")
331                    self.print("end_lineno, end_col_offset = tok.end")
332                    action = action.replace("LOCATIONS", self.location_formatting)
333
334                if is_loop:
335                    self.print(f"children.append({action})")
336                    self.print(f"mark = self._mark()")
337                else:
338                    if "UNREACHABLE" in action:
339                        action = action.replace("UNREACHABLE", self.unreachable_formatting)
340                    self.print(f"return {action}")
341
342            self.print("self._reset(mark)")
343            # Skip remaining alternatives if a cut was reached.
344            if has_cut:
345                self.print("if cut: return None")
346