From b35908b2fe2e05b804228198f42c7bf0752db9ec Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Thu, 9 Sep 2021 13:51:14 -0700 Subject: [PATCH] Support parsing nonlocal and global (#6) * support parsing nonlocal and global --- synr/ast.py | 36 ++++++++++++++++++++++++++++++++++++ synr/compiler.py | 14 ++++++++++++++ tests/test_synr.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 93 insertions(+) diff --git a/synr/ast.py b/synr/ast.py index c62545a..89cf2db 100644 --- a/synr/ast.py +++ b/synr/ast.py @@ -536,6 +536,42 @@ def my_function(): call: Call +@attr.s(auto_attribs=True, frozen=True) +class Nonlocal(Stmt): + """A nonlocal statement. + + Example + ------- + .. code-block:: python + x, y = 1, 2 + def foo(): + nonlocal x, y + return x + + In :code:`nonlocal x, y`, :code:`vars` is :code`[x, y]`. + """ + + vars: List[Var] + + +@attr.s(auto_attribs=True, frozen=True) +class Global(Stmt): + """A global statement. + + Example + ------- + .. code-block:: python + x, y = 1, 2 + def foo(): + global x, y + return x + + In :code:`global x, y`, :code:`vars` is :code`[x, y]`. + """ + + vars: List[Var] + + @attr.s(auto_attribs=True, frozen=True) class Block(Node): """A sequence of statements. diff --git a/synr/compiler.py b/synr/compiler.py index afddb65..9e39a29 100644 --- a/synr/compiler.py +++ b/synr/compiler.py @@ -352,6 +352,20 @@ def compile_stmt(self, stmt: py_ast.stmt) -> Stmt: None if stmt.msg is None else self.compile_expr(stmt.msg), ) + elif isinstance(stmt, py_ast.Nonlocal): + # TODO: the variable spans here are incorrect as the Python AST stores each identifier + # as a raw string (with no span information), so we just use the statement's span + return Nonlocal( + stmt_span, [Var(stmt_span, Id(stmt_span, name)) for name in stmt.names] + ) + + elif isinstance(stmt, py_ast.Global): + # TODO: the variable spans here are incorrect as the Python AST stores each identifier + # as a raw string (with no span information), so we just use the statement's span + return Global( + stmt_span, [Var(stmt_span, Id(stmt_span, name)) for name in stmt.names] + ) + else: self.error(f"Found unexpected {type(stmt)} when compiling stmt", stmt_span) return Stmt(Span.invalid()) diff --git a/tests/test_synr.py b/tests/test_synr.py index 273a2e2..5b14e5d 100644 --- a/tests/test_synr.py +++ b/tests/test_synr.py @@ -572,6 +572,45 @@ def bar(): assert bar.decorators[1].span.start_line == start_line + 3 +def test_nonlocal(): + x, y = 1, 2 + + def foo(): + nonlocal x, y + return x + y + + module = to_ast(foo) + fn = assert_one_fn(module, "foo") + nl = fn.body.stmts[0] + assert isinstance(nl, synr.ast.Nonlocal) + assert len(nl.vars) == 2 + x, y = nl.vars + assert isinstance(x, synr.ast.Var) and x.id.name == "x" + assert isinstance(y, synr.ast.Var) and y.id.name == "y" + + _, start_line = inspect.getsourcelines(foo) + assert nl.span.start_line == start_line + 1 + # NOTE: variable spans are a bit hacky so we don't check them here + + +def test_global(): + def foo(): + global x, y + return x + y + + module = to_ast(foo) + fn = assert_one_fn(module, "foo") + gl = fn.body.stmts[0] + assert isinstance(gl, synr.ast.Global) + assert len(gl.vars) == 2 + x, y = gl.vars + assert isinstance(x, synr.ast.Var) and x.id.name == "x" + assert isinstance(y, synr.ast.Var) and y.id.name == "y" + + _, start_line = inspect.getsourcelines(foo) + assert gl.span.start_line == start_line + 1 + + if __name__ == "__main__": test_id_function() test_class() @@ -589,3 +628,7 @@ def bar(): test_constants() test_err_msg() test_scoped_func() + test_local_func() + test_decorators() + test_nonlocal() + test_global()