Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[codegen]: fix some hardcoded references to STORAGE location #4015

Merged
merged 20 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ jobs:
- evm-version: paris
- evm-version: shanghai

# test pre-cancun with opt-codesize and opt-none
- evm-version: shanghai
opt-mode: none
- evm-version: shanghai
opt-mode: codesize

# test py-evm
- evm-backend: py-evm
evm-version: shanghai
Expand Down
1 change: 1 addition & 0 deletions vyper/ast/nodes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ class VariableDecl(VyperNode):
is_constant: bool = ...
is_public: bool = ...
is_immutable: bool = ...
is_transient: bool = ...
_expanded_getter: FunctionDef = ...

class AugAssign(VyperNode):
Expand Down
33 changes: 25 additions & 8 deletions vyper/codegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
STORAGE,
TRANSIENT,
AddrSpace,
legal_in_staticcall,
)
from vyper.evm.opcodes import version_check
from vyper.exceptions import CompilerPanic, TypeCheckFailure, TypeMismatch
Expand Down Expand Up @@ -136,6 +137,14 @@ def address_space_to_data_location(s: AddrSpace) -> DataLocation:
raise CompilerPanic("unreachable!") # pragma: nocover


def writeable(context, ir_node):
assert ir_node.is_pointer # sanity check

if context.is_constant() and not legal_in_staticcall(ir_node.location):
return False
return ir_node.mutable


# Copy byte array word-for-word (including layout)
# TODO make this a private function
def make_byte_array_copier(dst, src):
Expand All @@ -150,12 +159,9 @@ def make_byte_array_copier(dst, src):
return STORE(dst, 0)

with src.cache_when_complex("src") as (b1, src):
has_storage = STORAGE in (src.location, dst.location)
is_memory_copy = dst.location == src.location == MEMORY
batch_uses_identity = is_memory_copy and not version_check(begin="cancun")
if src.typ.maxlen <= 32 and (has_storage or batch_uses_identity):
if src.typ.maxlen <= 32 and not copy_opcode_available(dst, src):
# if there is no batch copy opcode available,
# it's cheaper to run two load/stores instead of copy_bytes

ret = ["seq"]
# store length word
len_ = get_bytearray_length(src)
Expand Down Expand Up @@ -914,6 +920,15 @@ def make_setter(left, right):
return _complex_make_setter(left, right)


# locations with no dedicated copy opcode
# (i.e. storage and transient storage)
def copy_opcode_available(left, right):
if left.location == MEMORY and right.location == MEMORY:
return version_check(begin="cancun")

return left.location == MEMORY and right.location.has_copy_opcode


def _complex_make_setter(left, right):
if right.value == "~empty" and left.location == MEMORY:
# optimized memzero
Expand All @@ -935,8 +950,10 @@ def _complex_make_setter(left, right):
assert left.encoding == Encoding.VYPER
len_ = left.typ.memory_bytes_required

has_storage = STORAGE in (left.location, right.location)
if has_storage:
# special logic for identity precompile (pre-cancun) in the else branch
mem2mem = left.location == right.location == MEMORY

if not copy_opcode_available(left, right) and not mem2mem:
if _opt_codesize():
# assuming PUSH2, a single sstore(dst (sload src)) is 8 bytes,
# sstore(add (dst ofst), (sload (add (src ofst)))) is 16 bytes,
Expand Down Expand Up @@ -983,7 +1000,7 @@ def _complex_make_setter(left, right):
base_unroll_cost + (nth_word_cost * (n_words - 1)) >= identity_base_cost
)

# calldata to memory, code to memory, cancun, or codesize -
# calldata to memory, code to memory, cancun, or opt-codesize -
# batch copy is always better.
else:
should_batch_copy = True
Expand Down
13 changes: 7 additions & 6 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
get_type_for_exact_size,
make_setter,
wrap_value_for_external_return,
writeable,
)
from vyper.codegen.expr import Expr
from vyper.codegen.return_ import make_return_stmt
from vyper.evm.address_space import MEMORY, STORAGE
from vyper.evm.address_space import MEMORY
from vyper.exceptions import CodegenPanic, StructureException, TypeCheckFailure, tag_exceptions
from vyper.semantics.types import DArrayT
from vyper.semantics.types.shortcuts import UINT256_T
Expand Down Expand Up @@ -312,18 +313,18 @@ def parse_Return(self):
def _get_target(self, target):
_dbg_expr = target

if isinstance(target, vy_ast.Name) and target.id in self.context.forvars:
if isinstance(target, vy_ast.Name) and target.id in self.context.forvars: # pragma: nocover
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")

if isinstance(target, vy_ast.Tuple):
target = Expr(target, self.context).ir_node
for node in target.args:
if (node.location == STORAGE and self.context.is_constant()) or not node.mutable:
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")
items = target.args
if any(not writeable(self.context, item) for item in items): # pragma: nocover
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")
return target

target = Expr.parse_pointer_expr(target, self.context)
if (target.location == STORAGE and self.context.is_constant()) or not target.mutable:
if not writeable(self.context, target): # pragma: nocover
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")
return target

Expand Down
17 changes: 14 additions & 3 deletions vyper/evm/address_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,25 @@ class AddrSpace:
load_op: the opcode for loading a word from this address space
store_op: the opcode for storing a word to this address space
(an address space is read-only if store_op is None)
copy_op: the opcode for batch-copying from this address space
to memory
"""

name: str
word_scale: int
load_op: str
# TODO maybe make positional instead of defaulting to None
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd prefer positional here

store_op: Optional[str] = None
copy_op: Optional[str] = None

@property
def word_addressable(self) -> bool:
return self.word_scale == 1

@property
def has_copy_opcode(self):
return self.copy_op is not None


# alternative:
# class Memory(AddrSpace):
Expand All @@ -42,13 +49,17 @@ def word_addressable(self) -> bool:
#
# MEMORY = Memory()

MEMORY = AddrSpace("memory", 32, "mload", "mstore")
MEMORY = AddrSpace("memory", 32, "mload", "mstore", "mcopy")
STORAGE = AddrSpace("storage", 1, "sload", "sstore")
TRANSIENT = AddrSpace("transient", 1, "tload", "tstore")
CALLDATA = AddrSpace("calldata", 32, "calldataload")
CALLDATA = AddrSpace("calldata", 32, "calldataload", None, "calldatacopy")
# immutables address space: "immutables" section of memory
# which is read-write in deploy code but then gets turned into
# the "data" section of the runtime code
IMMUTABLES = AddrSpace("immutables", 32, "iload", "istore")
# data addrspace: "data" section of runtime code, read-only.
DATA = AddrSpace("data", 32, "dload")
DATA = AddrSpace("data", 32, "dload", None, "dloadbytes")


def legal_in_staticcall(location: AddrSpace):
return location not in (STORAGE, TRANSIENT)
20 changes: 10 additions & 10 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,13 +621,6 @@ def visit_VariableDecl(self, node):
assert isinstance(node.target, vy_ast.Name)
name = node.target.id

if node.is_public:
# generate function type and add to metadata
# we need this when building the public getter
func_t = ContractFunctionT.getter_from_VariableDecl(node)
node._metadata["getter_type"] = func_t
self._add_exposed_function(func_t, node)

# TODO: move this check to local analysis
if node.is_immutable:
# mutability is checked automatically preventing assignment
Expand All @@ -648,7 +641,7 @@ def visit_VariableDecl(self, node):
)
raise ImmutableViolation(message, node)

data_loc = (
location = (
DataLocation.CODE
if node.is_immutable
else DataLocation.UNSET
Expand All @@ -666,21 +659,28 @@ def visit_VariableDecl(self, node):
else Modifiability.MODIFIABLE
)

type_ = type_from_annotation(node.annotation, data_loc)
type_ = type_from_annotation(node.annotation, location)

if node.is_transient and not version_check(begin="cancun"):
raise EvmVersionException("`transient` is not available pre-cancun", node.annotation)

var_info = VarInfo(
type_,
decl_node=node,
location=data_loc,
location=location,
modifiability=modifiability,
is_public=node.is_public,
)
node.target._metadata["varinfo"] = var_info # TODO maybe put this in the global namespace
node._metadata["type"] = type_

if node.is_public:
# generate function type and add to metadata
# we need this when building the public getter
func_t = ContractFunctionT.getter_from_VariableDecl(node)
node._metadata["getter_type"] = func_t
self._add_exposed_function(func_t, node)

def _finalize():
# add the variable name to `self` namespace if the variable is either
# 1. a public constant or immutable; or
Expand Down
5 changes: 4 additions & 1 deletion vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,10 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio
"""
if not node.is_public:
raise CompilerPanic("getter generated for non-public function")
type_ = type_from_annotation(node.annotation, DataLocation.STORAGE)

# calculated by caller (ModuleAnalyzer.visit_VariableDecl)
type_ = node.target._metadata["varinfo"].typ

arguments, return_type = type_.getter_signature
args = []
for i, item in enumerate(arguments):
Expand Down
5 changes: 3 additions & 2 deletions vyper/semantics/types/subscriptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class HashMapT(_SubscriptableT):

_equality_attrs = ("key_type", "value_type")

# disallow everything but storage
# disallow everything but storage or transient
_invalid_locations = (
DataLocation.UNSET,
DataLocation.CALLDATA,
Expand Down Expand Up @@ -84,10 +84,11 @@ def from_annotation(cls, node: vy_ast.Subscript) -> "HashMapT":
)

k_ast, v_ast = node.slice.elements
key_type = type_from_annotation(k_ast, DataLocation.STORAGE)
key_type = type_from_annotation(k_ast)
if not key_type._as_hashmap_key:
raise InvalidType("can only use primitive types as HashMap key!", k_ast)

# TODO: thread through actual location - might also be TRANSIENT
value_type = type_from_annotation(v_ast, DataLocation.STORAGE)

return cls(key_type, value_type)
Expand Down
Loading