Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: more frontend optimizations #3785

Merged
merged 17 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 44 additions & 39 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"node_id",
"_metadata",
"_original_node",
"_cache_descendants",
)
NODE_SRC_ATTRIBUTES = (
"col_offset",
Expand Down Expand Up @@ -211,15 +212,17 @@ def _node_filter(node, filters):
return True


def _sort_nodes(node_iterable):
# sorting function for VyperNode.get_children
def _apply_filters(node_iter, node_type, filters, reverse):
ret = node_iter
if node_type is not None:
ret = (i for i in ret if isinstance(i, node_type))
if filters is not None:
ret = (i for i in ret if _node_filter(i, filters))

def sortkey(key):
return float("inf") if key is None else key

return sorted(
node_iterable, key=lambda k: (sortkey(k.lineno), sortkey(k.col_offset), k.node_id)
)
ret = list(ret)
if reverse:
ret.reverse()
return ret


def _raise_syntax_exc(error_msg: str, ast_struct: dict) -> None:
Expand Down Expand Up @@ -257,10 +260,13 @@ class VyperNode:
"""

__slots__ = NODE_BASE_ATTRIBUTES + NODE_SRC_ATTRIBUTES

_public_slots = [i for i in __slots__ if not i.startswith("_")]
_only_empty_fields: tuple = ()
_translated_fields: dict = {}

def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict):
# this function is performance-sensitive
"""
AST node initializer method.

Expand All @@ -275,21 +281,19 @@ def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict):
Dictionary of fields to be included within the node.
"""
self.set_parent(parent)
self._children: set = set()
self._children: list = []
self._metadata: NodeMetadata = NodeMetadata()
self._original_node = None
self._cache_descendants = None

for field_name in NODE_SRC_ATTRIBUTES:
# when a source offset is not available, use the parent's source offset
value = kwargs.get(field_name)
if kwargs.get(field_name) is None:
value = kwargs.pop(field_name, None)
if value is None:
value = getattr(parent, field_name, None)
setattr(self, field_name, value)

for field_name, value in kwargs.items():
if field_name in NODE_SRC_ATTRIBUTES:
continue

if field_name in self._translated_fields:
field_name = self._translated_fields[field_name]

Expand All @@ -309,7 +313,7 @@ def __init__(self, parent: Optional["VyperNode"] = None, **kwargs: dict):

# add to children of parent last to ensure an accurate hash is generated
if parent is not None:
parent._children.add(self)
parent._children.append(self)

# set parent, can be useful when inserting copied nodes into the AST
def set_parent(self, parent: "VyperNode"):
Expand Down Expand Up @@ -338,7 +342,7 @@ def from_node(cls, node: "VyperNode", **kwargs) -> "VyperNode":
-------
Vyper node instance
"""
ast_struct = {i: getattr(node, i) for i in VyperNode.__slots__ if not i.startswith("_")}
ast_struct = {i: getattr(node, i) for i in VyperNode._public_slots}
ast_struct.update(ast_type=cls.__name__, **kwargs)
return cls(**ast_struct)

Expand All @@ -355,10 +359,11 @@ def get_fields(cls) -> set:
return set(i for i in slot_fields if not i.startswith("_"))

def __hash__(self):
values = [getattr(self, i, None) for i in VyperNode.__slots__ if not i.startswith("_")]
values = [getattr(self, i, None) for i in VyperNode._public_slots]
return hash(tuple(values))

def __deepcopy__(self, memo):
# default implementation of deepcopy is a hotspot
return pickle.loads(pickle.dumps(self))

def __eq__(self, other):
Expand Down Expand Up @@ -537,14 +542,7 @@ def get_children(
list
Child nodes matching the filter conditions.
"""
children = _sort_nodes(self._children)
if node_type is not None:
children = [i for i in children if isinstance(i, node_type)]
if reverse:
children.reverse()
if filters is None:
return children
return [i for i in children if _node_filter(i, filters)]
return _apply_filters(iter(self._children), node_type, filters, reverse)

def get_descendants(
self,
Expand All @@ -553,6 +551,7 @@ def get_descendants(
include_self: bool = False,
reverse: bool = False,
) -> list:
# this function is performance-sensitive
"""
Return a list of descendant nodes of this node which match the given filter(s).

Expand Down Expand Up @@ -589,19 +588,25 @@ def get_descendants(
list
Descendant nodes matching the filter conditions.
"""
children = self.get_children(node_type, filters)
for node in self.get_children():
children.extend(node.get_descendants(node_type, filters))
if (
include_self
and (not node_type or isinstance(self, node_type))
and _node_filter(self, filters)
):
children.append(self)
result = _sort_nodes(children)
if reverse:
result.reverse()
return result
ret = self._get_descendants(include_self)
return _apply_filters(ret, node_type, filters, reverse)

def _get_descendants(self, include_self=True):
# get descendants in topsort order
if self._cache_descendants is None:
ret = [self]
for node in self._children:
ret.extend(node._get_descendants())

self._cache_descendants = ret

ret = iter(self._cache_descendants)

if not include_self:
s = next(ret) # pop
assert s is self

return ret

def get(self, field_str: str) -> Any:
"""
Expand Down Expand Up @@ -669,7 +674,7 @@ def add_to_body(self, node: VyperNode) -> None:
self.body.append(node)
node._depth = self._depth + 1
node._parent = self
self._children.add(node)
self._children.append(node)

def remove_from_body(self, node: VyperNode) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions vyper/semantics/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def get_expr_info(node: vy_ast.ExprNode, is_callable: bool = False) -> ExprInfo:


def get_common_types(*nodes: vy_ast.VyperNode, filter_fn: Callable = None) -> List:
# this function is a performance hotspot
"""
Return a list of common possible types between one or more nodes.

Expand Down
16 changes: 11 additions & 5 deletions vyper/semantics/types/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,17 @@ def abi_type(self) -> ABIType:
return ABI_GIntM(self.bits, self.is_signed)

def compare_type(self, other: VyperType) -> bool:
if not super().compare_type(other):
return False
assert isinstance(other, IntegerT) # mypy

return self.is_signed == other.is_signed and self.bits == other.bits
# this function is performance sensitive
# originally:
# if not super().compare_type(other):
# return False
# return self.is_signed == other.is_signed and self.bits == other.bits

return ( # noqa: E721
self.__class__ == other.__class__
and self.is_signed == other.is_signed # type: ignore
and self.bits == other.bits # type: ignore
)


# helper function for readability.
Expand Down
Loading