Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Add for and while to the new frontend AST builder #3353

Merged
merged 12 commits into from
Nov 8, 2021
46 changes: 43 additions & 3 deletions python/taichi/lang/ast_builder_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
from enum import Enum

from taichi.lang.exception import TaichiSyntaxError

Expand Down Expand Up @@ -91,7 +92,7 @@ def check_loop_var(self, loop_var):
.format(loop_var))


class IRScopeGuard:
class VariableScopeGuard:
def __init__(self, scopes, stmt_block=None):
self.scopes = scopes
self.stmt_block = stmt_block
Expand All @@ -104,6 +105,29 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.scopes.pop()


class LoopStatus(Enum):
Normal = 0
Break = 1
Continue = 2


class ControlScopeAttribute:
def __init__(self):
self.is_static = False
self.status = LoopStatus.Normal


class ControlScopeGuard:
def __init__(self, scopes):
self.scopes = scopes

def __enter__(self):
self.scopes.append(ControlScopeAttribute())

def __exit__(self, exc_type, exc_val, exc_tb):
self.scopes.pop()


class IRBuilderContext:
def __init__(self,
excluded_parameters=(),
Expand All @@ -125,18 +149,34 @@ def __init__(self,

# e.g.: FunctionDef, Module, Global
def variable_scope_guard(self, *args):
return IRScopeGuard(self.local_scopes, *args)
return VariableScopeGuard(self.local_scopes, *args)

# e.g.: For, While
def control_scope_guard(self):
return ScopeGuard(self.control_scopes)
return ControlScopeGuard(self.control_scopes)

def current_scope(self):
return self.local_scopes[-1]

def current_control_scope(self):
return self.control_scopes[-1]

def loop_status(self):
if len(self.control_scopes):
return self.control_scopes[-1].status
return LoopStatus.Normal

def set_loop_status(self, status):
self.control_scopes[-1].status = status

def set_static_loop(self):
self.control_scopes[-1].is_static = True

def is_in_static(self):
if len(self.control_scopes):
return self.control_scopes[-1].is_static
return False

def is_var_declared(self, name):
for s in self.local_scopes:
if name in s:
Expand Down
Loading