1import ast
2import difflib
3import io
4import textwrap
5import unittest
6
7from test import test_tools
8from typing import Dict, Any
9from tokenize import TokenInfo, NAME, NEWLINE, NUMBER, OP
10
11test_tools.skip_if_missing("peg_generator")
12with test_tools.imports_under_tool("peg_generator"):
13    from pegen.grammar_parser import GeneratedParser as GrammarParser
14    from pegen.testutil import parse_string, generate_parser, make_parser
15    from pegen.grammar import GrammarVisitor, GrammarError, Grammar
16    from pegen.grammar_visualizer import ASTGrammarPrinter
17    from pegen.parser import Parser
18    from pegen.parser_generator import compute_nullables, compute_left_recursives
19    from pegen.python_generator import PythonParserGenerator
20
21
22class TestPegen(unittest.TestCase):
23    def test_parse_grammar(self) -> None:
24        grammar_source = """
25        start: sum NEWLINE
26        sum: t1=term '+' t2=term { action } | term
27        term: NUMBER
28        """
29        expected = """
30        start: sum NEWLINE
31        sum: term '+' term | term
32        term: NUMBER
33        """
34        grammar: Grammar = parse_string(grammar_source, GrammarParser)
35        rules = grammar.rules
36        self.assertEqual(str(grammar), textwrap.dedent(expected).strip())
37        # Check the str() and repr() of a few rules; AST nodes don't support ==.
38        self.assertEqual(str(rules["start"]), "start: sum NEWLINE")
39        self.assertEqual(str(rules["sum"]), "sum: term '+' term | term")
40        expected_repr = (
41            "Rule('term', None, Rhs([Alt([NamedItem(None, NameLeaf('NUMBER'))])]))"
42        )
43        self.assertEqual(repr(rules["term"]), expected_repr)
44
45    def test_long_rule_str(self) -> None:
46        grammar_source = """
47        start: zero | one | one zero | one one | one zero zero | one zero one | one one zero | one one one
48        """
49        expected = """
50        start:
51            | zero
52            | one
53            | one zero
54            | one one
55            | one zero zero
56            | one zero one
57            | one one zero
58            | one one one
59        """
60        grammar: Grammar = parse_string(grammar_source, GrammarParser)
61        self.assertEqual(str(grammar.rules["start"]), textwrap.dedent(expected).strip())
62
63    def test_typed_rules(self) -> None:
64        grammar = """
65        start[int]: sum NEWLINE
66        sum[int]: t1=term '+' t2=term { action } | term
67        term[int]: NUMBER
68        """
69        rules = parse_string(grammar, GrammarParser).rules
70        # Check the str() and repr() of a few rules; AST nodes don't support ==.
71        self.assertEqual(str(rules["start"]), "start: sum NEWLINE")
72        self.assertEqual(str(rules["sum"]), "sum: term '+' term | term")
73        self.assertEqual(
74            repr(rules["term"]),
75            "Rule('term', 'int', Rhs([Alt([NamedItem(None, NameLeaf('NUMBER'))])]))",
76        )
77
78    def test_gather(self) -> None:
79        grammar = """
80        start: ','.thing+ NEWLINE
81        thing: NUMBER
82        """
83        rules = parse_string(grammar, GrammarParser).rules
84        self.assertEqual(str(rules["start"]), "start: ','.thing+ NEWLINE")
85        self.assertTrue(
86            repr(rules["start"]).startswith(
87                "Rule('start', None, Rhs([Alt([NamedItem(None, Gather(StringLeaf(\"','\"), NameLeaf('thing'"
88            )
89        )
90        self.assertEqual(str(rules["thing"]), "thing: NUMBER")
91        parser_class = make_parser(grammar)
92        node = parse_string("42\n", parser_class)
93        node = parse_string("1, 2\n", parser_class)
94        self.assertEqual(
95            node,
96            [
97                [
98                    TokenInfo(
99                        NUMBER, string="1", start=(1, 0), end=(1, 1), line="1, 2\n"
100                    ),
101                    TokenInfo(
102                        NUMBER, string="2", start=(1, 3), end=(1, 4), line="1, 2\n"
103                    ),
104                ],
105                TokenInfo(
106                    NEWLINE, string="\n", start=(1, 4), end=(1, 5), line="1, 2\n"
107                ),
108            ],
109        )
110
111    def test_expr_grammar(self) -> None:
112        grammar = """
113        start: sum NEWLINE
114        sum: term '+' term | term
115        term: NUMBER
116        """
117        parser_class = make_parser(grammar)
118        node = parse_string("42\n", parser_class)
119        self.assertEqual(
120            node,
121            [
122                TokenInfo(NUMBER, string="42", start=(1, 0), end=(1, 2), line="42\n"),
123                TokenInfo(NEWLINE, string="\n", start=(1, 2), end=(1, 3), line="42\n"),
124            ],
125        )
126
127    def test_optional_operator(self) -> None:
128        grammar = """
129        start: sum NEWLINE
130        sum: term ('+' term)?
131        term: NUMBER
132        """
133        parser_class = make_parser(grammar)
134        node = parse_string("1 + 2\n", parser_class)
135        self.assertEqual(
136            node,
137            [
138                [
139                    TokenInfo(
140                        NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 + 2\n"
141                    ),
142                    [
143                        TokenInfo(
144                            OP, string="+", start=(1, 2), end=(1, 3), line="1 + 2\n"
145                        ),
146                        TokenInfo(
147                            NUMBER, string="2", start=(1, 4), end=(1, 5), line="1 + 2\n"
148                        ),
149                    ],
150                ],
151                TokenInfo(
152                    NEWLINE, string="\n", start=(1, 5), end=(1, 6), line="1 + 2\n"
153                ),
154            ],
155        )
156        node = parse_string("1\n", parser_class)
157        self.assertEqual(
158            node,
159            [
160                [
161                    TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1\n"),
162                    None,
163                ],
164                TokenInfo(NEWLINE, string="\n", start=(1, 1), end=(1, 2), line="1\n"),
165            ],
166        )
167
168    def test_optional_literal(self) -> None:
169        grammar = """
170        start: sum NEWLINE
171        sum: term '+' ?
172        term: NUMBER
173        """
174        parser_class = make_parser(grammar)
175        node = parse_string("1+\n", parser_class)
176        self.assertEqual(
177            node,
178            [
179                [
180                    TokenInfo(
181                        NUMBER, string="1", start=(1, 0), end=(1, 1), line="1+\n"
182                    ),
183                    TokenInfo(OP, string="+", start=(1, 1), end=(1, 2), line="1+\n"),
184                ],
185                TokenInfo(NEWLINE, string="\n", start=(1, 2), end=(1, 3), line="1+\n"),
186            ],
187        )
188        node = parse_string("1\n", parser_class)
189        self.assertEqual(
190            node,
191            [
192                [
193                    TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1\n"),
194                    None,
195                ],
196                TokenInfo(NEWLINE, string="\n", start=(1, 1), end=(1, 2), line="1\n"),
197            ],
198        )
199
200    def test_alt_optional_operator(self) -> None:
201        grammar = """
202        start: sum NEWLINE
203        sum: term ['+' term]
204        term: NUMBER
205        """
206        parser_class = make_parser(grammar)
207        node = parse_string("1 + 2\n", parser_class)
208        self.assertEqual(
209            node,
210            [
211                [
212                    TokenInfo(
213                        NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 + 2\n"
214                    ),
215                    [
216                        TokenInfo(
217                            OP, string="+", start=(1, 2), end=(1, 3), line="1 + 2\n"
218                        ),
219                        TokenInfo(
220                            NUMBER, string="2", start=(1, 4), end=(1, 5), line="1 + 2\n"
221                        ),
222                    ],
223                ],
224                TokenInfo(
225                    NEWLINE, string="\n", start=(1, 5), end=(1, 6), line="1 + 2\n"
226                ),
227            ],
228        )
229        node = parse_string("1\n", parser_class)
230        self.assertEqual(
231            node,
232            [
233                [
234                    TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1\n"),
235                    None,
236                ],
237                TokenInfo(NEWLINE, string="\n", start=(1, 1), end=(1, 2), line="1\n"),
238            ],
239        )
240
241    def test_repeat_0_simple(self) -> None:
242        grammar = """
243        start: thing thing* NEWLINE
244        thing: NUMBER
245        """
246        parser_class = make_parser(grammar)
247        node = parse_string("1 2 3\n", parser_class)
248        self.assertEqual(
249            node,
250            [
251                TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 2 3\n"),
252                [
253                    TokenInfo(
254                        NUMBER, string="2", start=(1, 2), end=(1, 3), line="1 2 3\n"
255                    ),
256                    TokenInfo(
257                        NUMBER, string="3", start=(1, 4), end=(1, 5), line="1 2 3\n"
258                    ),
259                ],
260                TokenInfo(
261                    NEWLINE, string="\n", start=(1, 5), end=(1, 6), line="1 2 3\n"
262                ),
263            ],
264        )
265        node = parse_string("1\n", parser_class)
266        self.assertEqual(
267            node,
268            [
269                TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1\n"),
270                [],
271                TokenInfo(NEWLINE, string="\n", start=(1, 1), end=(1, 2), line="1\n"),
272            ],
273        )
274
275    def test_repeat_0_complex(self) -> None:
276        grammar = """
277        start: term ('+' term)* NEWLINE
278        term: NUMBER
279        """
280        parser_class = make_parser(grammar)
281        node = parse_string("1 + 2 + 3\n", parser_class)
282        self.assertEqual(
283            node,
284            [
285                TokenInfo(
286                    NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 + 2 + 3\n"
287                ),
288                [
289                    [
290                        TokenInfo(
291                            OP, string="+", start=(1, 2), end=(1, 3), line="1 + 2 + 3\n"
292                        ),
293                        TokenInfo(
294                            NUMBER,
295                            string="2",
296                            start=(1, 4),
297                            end=(1, 5),
298                            line="1 + 2 + 3\n",
299                        ),
300                    ],
301                    [
302                        TokenInfo(
303                            OP, string="+", start=(1, 6), end=(1, 7), line="1 + 2 + 3\n"
304                        ),
305                        TokenInfo(
306                            NUMBER,
307                            string="3",
308                            start=(1, 8),
309                            end=(1, 9),
310                            line="1 + 2 + 3\n",
311                        ),
312                    ],
313                ],
314                TokenInfo(
315                    NEWLINE, string="\n", start=(1, 9), end=(1, 10), line="1 + 2 + 3\n"
316                ),
317            ],
318        )
319
320    def test_repeat_1_simple(self) -> None:
321        grammar = """
322        start: thing thing+ NEWLINE
323        thing: NUMBER
324        """
325        parser_class = make_parser(grammar)
326        node = parse_string("1 2 3\n", parser_class)
327        self.assertEqual(
328            node,
329            [
330                TokenInfo(NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 2 3\n"),
331                [
332                    TokenInfo(
333                        NUMBER, string="2", start=(1, 2), end=(1, 3), line="1 2 3\n"
334                    ),
335                    TokenInfo(
336                        NUMBER, string="3", start=(1, 4), end=(1, 5), line="1 2 3\n"
337                    ),
338                ],
339                TokenInfo(
340                    NEWLINE, string="\n", start=(1, 5), end=(1, 6), line="1 2 3\n"
341                ),
342            ],
343        )
344        with self.assertRaises(SyntaxError):
345            parse_string("1\n", parser_class)
346
347    def test_repeat_1_complex(self) -> None:
348        grammar = """
349        start: term ('+' term)+ NEWLINE
350        term: NUMBER
351        """
352        parser_class = make_parser(grammar)
353        node = parse_string("1 + 2 + 3\n", parser_class)
354        self.assertEqual(
355            node,
356            [
357                TokenInfo(
358                    NUMBER, string="1", start=(1, 0), end=(1, 1), line="1 + 2 + 3\n"
359                ),
360                [
361                    [
362                        TokenInfo(
363                            OP, string="+", start=(1, 2), end=(1, 3), line="1 + 2 + 3\n"
364                        ),
365                        TokenInfo(
366                            NUMBER,
367                            string="2",
368                            start=(1, 4),
369                            end=(1, 5),
370                            line="1 + 2 + 3\n",
371                        ),
372                    ],
373                    [
374                        TokenInfo(
375                            OP, string="+", start=(1, 6), end=(1, 7), line="1 + 2 + 3\n"
376                        ),
377                        TokenInfo(
378                            NUMBER,
379                            string="3",
380                            start=(1, 8),
381                            end=(1, 9),
382                            line="1 + 2 + 3\n",
383                        ),
384                    ],
385                ],
386                TokenInfo(
387                    NEWLINE, string="\n", start=(1, 9), end=(1, 10), line="1 + 2 + 3\n"
388                ),
389            ],
390        )
391        with self.assertRaises(SyntaxError):
392            parse_string("1\n", parser_class)
393
394    def test_repeat_with_sep_simple(self) -> None:
395        grammar = """
396        start: ','.thing+ NEWLINE
397        thing: NUMBER
398        """
399        parser_class = make_parser(grammar)
400        node = parse_string("1, 2, 3\n", parser_class)
401        self.assertEqual(
402            node,
403            [
404                [
405                    TokenInfo(
406                        NUMBER, string="1", start=(1, 0), end=(1, 1), line="1, 2, 3\n"
407                    ),
408                    TokenInfo(
409                        NUMBER, string="2", start=(1, 3), end=(1, 4), line="1, 2, 3\n"
410                    ),
411                    TokenInfo(
412                        NUMBER, string="3", start=(1, 6), end=(1, 7), line="1, 2, 3\n"
413                    ),
414                ],
415                TokenInfo(
416                    NEWLINE, string="\n", start=(1, 7), end=(1, 8), line="1, 2, 3\n"
417                ),
418            ],
419        )
420
421    def test_left_recursive(self) -> None:
422        grammar_source = """
423        start: expr NEWLINE
424        expr: ('-' term | expr '+' term | term)
425        term: NUMBER
426        foo: NAME+
427        bar: NAME*
428        baz: NAME?
429        """
430        grammar: Grammar = parse_string(grammar_source, GrammarParser)
431        parser_class = generate_parser(grammar)
432        rules = grammar.rules
433        self.assertFalse(rules["start"].left_recursive)
434        self.assertTrue(rules["expr"].left_recursive)
435        self.assertFalse(rules["term"].left_recursive)
436        self.assertFalse(rules["foo"].left_recursive)
437        self.assertFalse(rules["bar"].left_recursive)
438        self.assertFalse(rules["baz"].left_recursive)
439        node = parse_string("1 + 2 + 3\n", parser_class)
440        self.assertEqual(
441            node,
442            [
443                [
444                    [
445                        TokenInfo(
446                            NUMBER,
447                            string="1",
448                            start=(1, 0),
449                            end=(1, 1),
450                            line="1 + 2 + 3\n",
451                        ),
452                        TokenInfo(
453                            OP, string="+", start=(1, 2), end=(1, 3), line="1 + 2 + 3\n"
454                        ),
455                        TokenInfo(
456                            NUMBER,
457                            string="2",
458                            start=(1, 4),
459                            end=(1, 5),
460                            line="1 + 2 + 3\n",
461                        ),
462                    ],
463                    TokenInfo(
464                        OP, string="+", start=(1, 6), end=(1, 7), line="1 + 2 + 3\n"
465                    ),
466                    TokenInfo(
467                        NUMBER, string="3", start=(1, 8), end=(1, 9), line="1 + 2 + 3\n"
468                    ),
469                ],
470                TokenInfo(
471                    NEWLINE, string="\n", start=(1, 9), end=(1, 10), line="1 + 2 + 3\n"
472                ),
473            ],
474        )
475
476    def test_python_expr(self) -> None:
477        grammar = """
478        start: expr NEWLINE? $ { ast.Expression(expr, lineno=1, col_offset=0) }
479        expr: ( expr '+' term { ast.BinOp(expr, ast.Add(), term, lineno=expr.lineno, col_offset=expr.col_offset, end_lineno=term.end_lineno, end_col_offset=term.end_col_offset) }
480            | expr '-' term { ast.BinOp(expr, ast.Sub(), term, lineno=expr.lineno, col_offset=expr.col_offset, end_lineno=term.end_lineno, end_col_offset=term.end_col_offset) }
481            | term { term }
482            )
483        term: ( l=term '*' r=factor { ast.BinOp(l, ast.Mult(), r, lineno=l.lineno, col_offset=l.col_offset, end_lineno=r.end_lineno, end_col_offset=r.end_col_offset) }
484            | l=term '/' r=factor { ast.BinOp(l, ast.Div(), r, lineno=l.lineno, col_offset=l.col_offset, end_lineno=r.end_lineno, end_col_offset=r.end_col_offset) }
485            | factor { factor }
486            )
487        factor: ( '(' expr ')' { expr }
488                | atom { atom }
489                )
490        atom: ( n=NAME { ast.Name(id=n.string, ctx=ast.Load(), lineno=n.start[0], col_offset=n.start[1], end_lineno=n.end[0], end_col_offset=n.end[1]) }
491            | n=NUMBER { ast.Constant(value=ast.literal_eval(n.string), lineno=n.start[0], col_offset=n.start[1], end_lineno=n.end[0], end_col_offset=n.end[1]) }
492            )
493        """
494        parser_class = make_parser(grammar)
495        node = parse_string("(1 + 2*3 + 5)/(6 - 2)\n", parser_class)
496        code = compile(node, "", "eval")
497        val = eval(code)
498        self.assertEqual(val, 3.0)
499
500    def test_nullable(self) -> None:
501        grammar_source = """
502        start: sign NUMBER
503        sign: ['-' | '+']
504        """
505        grammar: Grammar = parse_string(grammar_source, GrammarParser)
506        rules = grammar.rules
507        nullables = compute_nullables(rules)
508        self.assertNotIn(rules["start"], nullables)  # Not None!
509        self.assertIn(rules["sign"], nullables)
510
511    def test_advanced_left_recursive(self) -> None:
512        grammar_source = """
513        start: NUMBER | sign start
514        sign: ['-']
515        """
516        grammar: Grammar = parse_string(grammar_source, GrammarParser)
517        rules = grammar.rules
518        nullables = compute_nullables(rules)
519        compute_left_recursives(rules)
520        self.assertNotIn(rules["start"], nullables)  # Not None!
521        self.assertIn(rules["sign"], nullables)
522        self.assertTrue(rules["start"].left_recursive)
523        self.assertFalse(rules["sign"].left_recursive)
524
525    def test_mutually_left_recursive(self) -> None:
526        grammar_source = """
527        start: foo 'E'
528        foo: bar 'A' | 'B'
529        bar: foo 'C' | 'D'
530        """
531        grammar: Grammar = parse_string(grammar_source, GrammarParser)
532        out = io.StringIO()
533        genr = PythonParserGenerator(grammar, out)
534        rules = grammar.rules
535        self.assertFalse(rules["start"].left_recursive)
536        self.assertTrue(rules["foo"].left_recursive)
537        self.assertTrue(rules["bar"].left_recursive)
538        genr.generate("<string>")
539        ns: Dict[str, Any] = {}
540        exec(out.getvalue(), ns)
541        parser_class: Type[Parser] = ns["GeneratedParser"]
542        node = parse_string("D A C A E", parser_class)
543
544        self.assertEqual(
545            node,
546            [
547                [
548                    [
549                        [
550                            TokenInfo(
551                                type=NAME,
552                                string="D",
553                                start=(1, 0),
554                                end=(1, 1),
555                                line="D A C A E",
556                            ),
557                            TokenInfo(
558                                type=NAME,
559                                string="A",
560                                start=(1, 2),
561                                end=(1, 3),
562                                line="D A C A E",
563                            ),
564                        ],
565                        TokenInfo(
566                            type=NAME,
567                            string="C",
568                            start=(1, 4),
569                            end=(1, 5),
570                            line="D A C A E",
571                        ),
572                    ],
573                    TokenInfo(
574                        type=NAME,
575                        string="A",
576                        start=(1, 6),
577                        end=(1, 7),
578                        line="D A C A E",
579                    ),
580                ],
581                TokenInfo(
582                    type=NAME, string="E", start=(1, 8), end=(1, 9), line="D A C A E"
583                ),
584            ],
585        )
586        node = parse_string("B C A E", parser_class)
587        self.assertEqual(
588            node,
589            [
590                [
591                    [
592                        TokenInfo(
593                            type=NAME,
594                            string="B",
595                            start=(1, 0),
596                            end=(1, 1),
597                            line="B C A E",
598                        ),
599                        TokenInfo(
600                            type=NAME,
601                            string="C",
602                            start=(1, 2),
603                            end=(1, 3),
604                            line="B C A E",
605                        ),
606                    ],
607                    TokenInfo(
608                        type=NAME, string="A", start=(1, 4), end=(1, 5), line="B C A E"
609                    ),
610                ],
611                TokenInfo(
612                    type=NAME, string="E", start=(1, 6), end=(1, 7), line="B C A E"
613                ),
614            ],
615        )
616
617    def test_nasty_mutually_left_recursive(self) -> None:
618        # This grammar does not recognize 'x - + =', much to my chagrin.
619        # But that's the way PEG works.
620        # [Breathlessly]
621        # The problem is that the toplevel target call
622        # recurses into maybe, which recognizes 'x - +',
623        # and then the toplevel target looks for another '+',
624        # which fails, so it retreats to NAME,
625        # which succeeds, so we end up just recognizing 'x',
626        # and then start fails because there's no '=' after that.
627        grammar_source = """
628        start: target '='
629        target: maybe '+' | NAME
630        maybe: maybe '-' | target
631        """
632        grammar: Grammar = parse_string(grammar_source, GrammarParser)
633        out = io.StringIO()
634        genr = PythonParserGenerator(grammar, out)
635        genr.generate("<string>")
636        ns: Dict[str, Any] = {}
637        exec(out.getvalue(), ns)
638        parser_class = ns["GeneratedParser"]
639        with self.assertRaises(SyntaxError):
640            parse_string("x - + =", parser_class)
641
642    def test_lookahead(self) -> None:
643        grammar = """
644        start: (expr_stmt | assign_stmt) &'.'
645        expr_stmt: !(target '=') expr
646        assign_stmt: target '=' expr
647        expr: term ('+' term)*
648        target: NAME
649        term: NUMBER
650        """
651        parser_class = make_parser(grammar)
652        node = parse_string("foo = 12 + 12 .", parser_class)
653        self.assertEqual(
654            node,
655            [
656                TokenInfo(
657                    NAME, string="foo", start=(1, 0), end=(1, 3), line="foo = 12 + 12 ."
658                ),
659                TokenInfo(
660                    OP, string="=", start=(1, 4), end=(1, 5), line="foo = 12 + 12 ."
661                ),
662                [
663                    TokenInfo(
664                        NUMBER,
665                        string="12",
666                        start=(1, 6),
667                        end=(1, 8),
668                        line="foo = 12 + 12 .",
669                    ),
670                    [
671                        [
672                            TokenInfo(
673                                OP,
674                                string="+",
675                                start=(1, 9),
676                                end=(1, 10),
677                                line="foo = 12 + 12 .",
678                            ),
679                            TokenInfo(
680                                NUMBER,
681                                string="12",
682                                start=(1, 11),
683                                end=(1, 13),
684                                line="foo = 12 + 12 .",
685                            ),
686                        ]
687                    ],
688                ],
689            ],
690        )
691
692    def test_named_lookahead_error(self) -> None:
693        grammar = """
694        start: foo=!'x' NAME
695        """
696        with self.assertRaises(SyntaxError):
697            make_parser(grammar)
698
699    def test_start_leader(self) -> None:
700        grammar = """
701        start: attr | NAME
702        attr: start '.' NAME
703        """
704        # Would assert False without a special case in compute_left_recursives().
705        make_parser(grammar)
706
707    def test_opt_sequence(self) -> None:
708        grammar = """
709        start: [NAME*]
710        """
711        # This case was failing because of a double trailing comma at the end
712        # of a line in the generated source. See bpo-41044
713        make_parser(grammar)
714
715    def test_left_recursion_too_complex(self) -> None:
716        grammar = """
717        start: foo
718        foo: bar '+' | baz '+' | '+'
719        bar: baz '-' | foo '-' | '-'
720        baz: foo '*' | bar '*' | '*'
721        """
722        with self.assertRaises(ValueError) as errinfo:
723            make_parser(grammar)
724            self.assertTrue("no leader" in str(errinfo.exception.value))
725
726    def test_cut(self) -> None:
727        grammar = """
728        start: '(' ~ expr ')'
729        expr: NUMBER
730        """
731        parser_class = make_parser(grammar)
732        node = parse_string("(1)", parser_class)
733        self.assertEqual(
734            node,
735            [
736                TokenInfo(OP, string="(", start=(1, 0), end=(1, 1), line="(1)"),
737                TokenInfo(NUMBER, string="1", start=(1, 1), end=(1, 2), line="(1)"),
738                TokenInfo(OP, string=")", start=(1, 2), end=(1, 3), line="(1)"),
739            ],
740        )
741
742    def test_dangling_reference(self) -> None:
743        grammar = """
744        start: foo ENDMARKER
745        foo: bar NAME
746        """
747        with self.assertRaises(GrammarError):
748            parser_class = make_parser(grammar)
749
750    def test_bad_token_reference(self) -> None:
751        grammar = """
752        start: foo
753        foo: NAMEE
754        """
755        with self.assertRaises(GrammarError):
756            parser_class = make_parser(grammar)
757
758    def test_missing_start(self) -> None:
759        grammar = """
760        foo: NAME
761        """
762        with self.assertRaises(GrammarError):
763            parser_class = make_parser(grammar)
764
765    def test_invalid_rule_name(self) -> None:
766        grammar = """
767        start: _a b
768        _a: 'a'
769        b: 'b'
770        """
771        with self.assertRaisesRegex(GrammarError, "cannot start with underscore: '_a'"):
772            parser_class = make_parser(grammar)
773
774    def test_invalid_variable_name(self) -> None:
775        grammar = """
776        start: a b
777        a: _x='a'
778        b: 'b'
779        """
780        with self.assertRaisesRegex(GrammarError, "cannot start with underscore: '_x'"):
781            parser_class = make_parser(grammar)
782
783    def test_invalid_variable_name_in_temporal_rule(self) -> None:
784        grammar = """
785        start: a b
786        a: (_x='a' | 'b') | 'c'
787        b: 'b'
788        """
789        with self.assertRaisesRegex(GrammarError, "cannot start with underscore: '_x'"):
790            parser_class = make_parser(grammar)
791
792    def test_soft_keyword(self) -> None:
793        grammar = """
794        start:
795            | "number" n=NUMBER { eval(n.string) }
796            | "string" n=STRING { n.string }
797            | SOFT_KEYWORD l=NAME n=(NUMBER | NAME | STRING) { f"{l.string} = {n.string}"}
798        """
799        parser_class = make_parser(grammar)
800        self.assertEqual(parse_string("number 1", parser_class), 1)
801        self.assertEqual(parse_string("string 'b'", parser_class), "'b'")
802        self.assertEqual(
803            parse_string("number test 1", parser_class), "test = 1"
804        )
805        assert (
806            parse_string("string test 'b'", parser_class) == "test = 'b'"
807        )
808        with self.assertRaises(SyntaxError):
809            parse_string("test 1", parser_class)
810
811    def test_forced(self) -> None:
812        grammar = """
813        start: NAME &&':' | NAME
814        """
815        parser_class = make_parser(grammar)
816        self.assertTrue(parse_string("number :", parser_class))
817        with self.assertRaises(SyntaxError) as e:
818            parse_string("a", parser_class)
819
820        self.assertIn("expected ':'", str(e.exception))
821
822    def test_forced_with_group(self) -> None:
823        grammar = """
824        start: NAME &&(':' | ';') | NAME
825        """
826        parser_class = make_parser(grammar)
827        self.assertTrue(parse_string("number :", parser_class))
828        self.assertTrue(parse_string("number ;", parser_class))
829        with self.assertRaises(SyntaxError) as e:
830            parse_string("a", parser_class)
831        self.assertIn("expected (':' | ';')", e.exception.args[0])
832
833    def test_unreachable_explicit(self) -> None:
834        source = """
835        start: NAME { UNREACHABLE }
836        """
837        grammar = parse_string(source, GrammarParser)
838        out = io.StringIO()
839        genr = PythonParserGenerator(
840            grammar, out, unreachable_formatting="This is a test"
841        )
842        genr.generate("<string>")
843        self.assertIn("This is a test", out.getvalue())
844
845    def test_unreachable_implicit1(self) -> None:
846        source = """
847        start: NAME | invalid_input
848        invalid_input: NUMBER { None }
849        """
850        grammar = parse_string(source, GrammarParser)
851        out = io.StringIO()
852        genr = PythonParserGenerator(
853            grammar, out, unreachable_formatting="This is a test"
854        )
855        genr.generate("<string>")
856        self.assertIn("This is a test", out.getvalue())
857
858    def test_unreachable_implicit2(self) -> None:
859        source = """
860        start: NAME | '(' invalid_input ')'
861        invalid_input: NUMBER { None }
862        """
863        grammar = parse_string(source, GrammarParser)
864        out = io.StringIO()
865        genr = PythonParserGenerator(
866            grammar, out, unreachable_formatting="This is a test"
867        )
868        genr.generate("<string>")
869        self.assertIn("This is a test", out.getvalue())
870
871    def test_unreachable_implicit3(self) -> None:
872        source = """
873        start: NAME | invalid_input { None }
874        invalid_input: NUMBER
875        """
876        grammar = parse_string(source, GrammarParser)
877        out = io.StringIO()
878        genr = PythonParserGenerator(
879            grammar, out, unreachable_formatting="This is a test"
880        )
881        genr.generate("<string>")
882        self.assertNotIn("This is a test", out.getvalue())
883
884    def test_locations_in_alt_action_and_group(self) -> None:
885        grammar = """
886        start: t=term NEWLINE? $ { ast.Expression(t, LOCATIONS) }
887        term:
888            | l=term '*' r=factor { ast.BinOp(l, ast.Mult(), r, LOCATIONS) }
889            | l=term '/' r=factor { ast.BinOp(l, ast.Div(), r, LOCATIONS) }
890            | factor
891        factor:
892            | (
893                n=NAME { ast.Name(id=n.string, ctx=ast.Load(), LOCATIONS) } |
894                n=NUMBER { ast.Constant(value=ast.literal_eval(n.string), LOCATIONS) }
895            )
896        """
897        parser_class = make_parser(grammar)
898        source = "2*3\n"
899        o = ast.dump(parse_string(source, parser_class).body, include_attributes=True)
900        p = ast.dump(ast.parse(source).body[0].value, include_attributes=True).replace(
901            " kind=None,", ""
902        )
903        diff = "\n".join(
904            difflib.unified_diff(
905                o.split("\n"), p.split("\n"), "cpython", "python-pegen"
906            )
907        )
908        self.assertFalse(diff)
909
910
911class TestGrammarVisitor:
912    class Visitor(GrammarVisitor):
913        def __init__(self) -> None:
914            self.n_nodes = 0
915
916        def visit(self, node: Any, *args: Any, **kwargs: Any) -> None:
917            self.n_nodes += 1
918            super().visit(node, *args, **kwargs)
919
920    def test_parse_trivial_grammar(self) -> None:
921        grammar = """
922        start: 'a'
923        """
924        rules = parse_string(grammar, GrammarParser)
925        visitor = self.Visitor()
926
927        visitor.visit(rules)
928
929        self.assertEqual(visitor.n_nodes, 6)
930
931    def test_parse_or_grammar(self) -> None:
932        grammar = """
933        start: rule
934        rule: 'a' | 'b'
935        """
936        rules = parse_string(grammar, GrammarParser)
937        visitor = self.Visitor()
938
939        visitor.visit(rules)
940
941        # Grammar/Rule/Rhs/Alt/NamedItem/NameLeaf   -> 6
942        #         Rule/Rhs/                         -> 2
943        #                  Alt/NamedItem/StringLeaf -> 3
944        #                  Alt/NamedItem/StringLeaf -> 3
945
946        self.assertEqual(visitor.n_nodes, 14)
947
948    def test_parse_repeat1_grammar(self) -> None:
949        grammar = """
950        start: 'a'+
951        """
952        rules = parse_string(grammar, GrammarParser)
953        visitor = self.Visitor()
954
955        visitor.visit(rules)
956
957        # Grammar/Rule/Rhs/Alt/NamedItem/Repeat1/StringLeaf -> 6
958        self.assertEqual(visitor.n_nodes, 7)
959
960    def test_parse_repeat0_grammar(self) -> None:
961        grammar = """
962        start: 'a'*
963        """
964        rules = parse_string(grammar, GrammarParser)
965        visitor = self.Visitor()
966
967        visitor.visit(rules)
968
969        # Grammar/Rule/Rhs/Alt/NamedItem/Repeat0/StringLeaf -> 6
970
971        self.assertEqual(visitor.n_nodes, 7)
972
973    def test_parse_optional_grammar(self) -> None:
974        grammar = """
975        start: 'a' ['b']
976        """
977        rules = parse_string(grammar, GrammarParser)
978        visitor = self.Visitor()
979
980        visitor.visit(rules)
981
982        # Grammar/Rule/Rhs/Alt/NamedItem/StringLeaf                       -> 6
983        #                      NamedItem/Opt/Rhs/Alt/NamedItem/Stringleaf -> 6
984
985        self.assertEqual(visitor.n_nodes, 12)
986
987
988class TestGrammarVisualizer(unittest.TestCase):
989    def test_simple_rule(self) -> None:
990        grammar = """
991        start: 'a' 'b'
992        """
993        rules = parse_string(grammar, GrammarParser)
994
995        printer = ASTGrammarPrinter()
996        lines: List[str] = []
997        printer.print_grammar_ast(rules, printer=lines.append)
998
999        output = "\n".join(lines)
1000        expected_output = textwrap.dedent(
1001            """\
1002        └──Rule
1003           └──Rhs
1004              └──Alt
1005                 ├──NamedItem
1006                 │  └──StringLeaf("'a'")
1007                 └──NamedItem
1008                    └──StringLeaf("'b'")
1009        """
1010        )
1011
1012        self.assertEqual(output, expected_output)
1013
1014    def test_multiple_rules(self) -> None:
1015        grammar = """
1016        start: a b
1017        a: 'a'
1018        b: 'b'
1019        """
1020        rules = parse_string(grammar, GrammarParser)
1021
1022        printer = ASTGrammarPrinter()
1023        lines: List[str] = []
1024        printer.print_grammar_ast(rules, printer=lines.append)
1025
1026        output = "\n".join(lines)
1027        expected_output = textwrap.dedent(
1028            """\
1029        └──Rule
1030           └──Rhs
1031              └──Alt
1032                 ├──NamedItem
1033                 │  └──NameLeaf('a')
1034                 └──NamedItem
1035                    └──NameLeaf('b')
1036
1037        └──Rule
1038           └──Rhs
1039              └──Alt
1040                 └──NamedItem
1041                    └──StringLeaf("'a'")
1042
1043        └──Rule
1044           └──Rhs
1045              └──Alt
1046                 └──NamedItem
1047                    └──StringLeaf("'b'")
1048                        """
1049        )
1050
1051        self.assertEqual(output, expected_output)
1052
1053    def test_deep_nested_rule(self) -> None:
1054        grammar = """
1055        start: 'a' ['b'['c'['d']]]
1056        """
1057        rules = parse_string(grammar, GrammarParser)
1058
1059        printer = ASTGrammarPrinter()
1060        lines: List[str] = []
1061        printer.print_grammar_ast(rules, printer=lines.append)
1062
1063        output = "\n".join(lines)
1064        expected_output = textwrap.dedent(
1065            """\
1066        └──Rule
1067           └──Rhs
1068              └──Alt
1069                 ├──NamedItem
1070                 │  └──StringLeaf("'a'")
1071                 └──NamedItem
1072                    └──Opt
1073                       └──Rhs
1074                          └──Alt
1075                             ├──NamedItem
1076                             │  └──StringLeaf("'b'")
1077                             └──NamedItem
1078                                └──Opt
1079                                   └──Rhs
1080                                      └──Alt
1081                                         ├──NamedItem
1082                                         │  └──StringLeaf("'c'")
1083                                         └──NamedItem
1084                                            └──Opt
1085                                               └──Rhs
1086                                                  └──Alt
1087                                                     └──NamedItem
1088                                                        └──StringLeaf("'d'")
1089                                """
1090        )
1091
1092        self.assertEqual(output, expected_output)
1093