Skip to content

Commit

Permalink
[refactor] Add for and while to the new frontend AST builder (#3353)
Browse files Browse the repository at this point in the history
* add static group for

* format

* add range for

* fix import

* add ndrange for

* add grouped ndrange for

* add grouped struct for

* format

* add break, continue and pass

* format

* add while

* add test
  • Loading branch information
lin-hitonami authored Nov 8, 2021
1 parent def95bd commit 10a8f7a
Show file tree
Hide file tree
Showing 3 changed files with 566 additions and 169 deletions.
46 changes: 43 additions & 3 deletions python/taichi/lang/ast_builder_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
from enum import Enum

from taichi.lang.exception import TaichiSyntaxError

Expand Down Expand Up @@ -91,7 +92,7 @@ def check_loop_var(self, loop_var):
.format(loop_var))


class IRScopeGuard:
class VariableScopeGuard:
def __init__(self, scopes, stmt_block=None):
self.scopes = scopes
self.stmt_block = stmt_block
Expand All @@ -104,6 +105,29 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.scopes.pop()


class LoopStatus(Enum):
Normal = 0
Break = 1
Continue = 2


class ControlScopeAttribute:
def __init__(self):
self.is_static = False
self.status = LoopStatus.Normal


class ControlScopeGuard:
def __init__(self, scopes):
self.scopes = scopes

def __enter__(self):
self.scopes.append(ControlScopeAttribute())

def __exit__(self, exc_type, exc_val, exc_tb):
self.scopes.pop()


class IRBuilderContext:
def __init__(self,
excluded_parameters=(),
Expand All @@ -125,18 +149,34 @@ def __init__(self,

# e.g.: FunctionDef, Module, Global
def variable_scope_guard(self, *args):
return IRScopeGuard(self.local_scopes, *args)
return VariableScopeGuard(self.local_scopes, *args)

# e.g.: For, While
def control_scope_guard(self):
return ScopeGuard(self.control_scopes)
return ControlScopeGuard(self.control_scopes)

def current_scope(self):
return self.local_scopes[-1]

def current_control_scope(self):
return self.control_scopes[-1]

def loop_status(self):
if len(self.control_scopes):
return self.control_scopes[-1].status
return LoopStatus.Normal

def set_loop_status(self, status):
self.control_scopes[-1].status = status

def set_static_loop(self):
self.control_scopes[-1].is_static = True

def is_in_static(self):
if len(self.control_scopes):
return self.control_scopes[-1].is_static
return False

def is_var_declared(self, name):
for s in self.local_scopes:
if name in s:
Expand Down
Loading

0 comments on commit 10a8f7a

Please sign in to comment.