Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[codegen]: fix some hardcoded references to STORAGE location #4015

Merged
merged 20 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions vyper/codegen/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def make_byte_array_copier(dst, src):
return STORE(dst, 0)

with src.cache_when_complex("src") as (b1, src):
has_storage = STORAGE in (src.location, dst.location)
has_storage = any(loc.word_addressable for loc in (src.location, dst.location))
is_memory_copy = dst.location == src.location == MEMORY
batch_uses_identity = is_memory_copy and not version_check(begin="cancun")
if src.typ.maxlen <= 32 and (has_storage or batch_uses_identity):
Expand Down Expand Up @@ -934,7 +934,7 @@ def _complex_make_setter(left, right):
assert left.encoding == Encoding.VYPER
len_ = left.typ.memory_bytes_required

has_storage = STORAGE in (left.location, right.location)
has_storage = any(loc.word_addressable for loc in (left.location, right.location))
if has_storage:
if _opt_codesize():
# assuming PUSH2, a single sstore(dst (sload src)) is 8 bytes,
Expand Down
8 changes: 5 additions & 3 deletions vyper/codegen/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from vyper.codegen.expr import Expr
from vyper.codegen.return_ import make_return_stmt
from vyper.evm.address_space import MEMORY, STORAGE
from vyper.evm.address_space import MEMORY
from vyper.exceptions import CodegenPanic, StructureException, TypeCheckFailure, tag_exceptions
from vyper.semantics.types import DArrayT
from vyper.semantics.types.shortcuts import UINT256_T
Expand Down Expand Up @@ -315,12 +315,14 @@ def _get_target(self, target):
if isinstance(target, vy_ast.Tuple):
target = Expr(target, self.context).ir_node
for node in target.args:
if (node.location == STORAGE and self.context.is_constant()) or not node.mutable:
if (
node.location.word_addressable and self.context.is_constant()
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved
) or not node.mutable:
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")
return target

target = Expr.parse_pointer_expr(target, self.context)
if (target.location == STORAGE and self.context.is_constant()) or not target.mutable:
if (target.location.word_addressable and self.context.is_constant()) or not target.mutable:
raise TypeCheckFailure(f"Failed constancy check\n{_dbg_expr}")
return target

Expand Down
38 changes: 19 additions & 19 deletions vyper/semantics/analysis/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,10 +621,28 @@ def visit_VariableDecl(self, node):
assert isinstance(node.target, vy_ast.Name)
name = node.target.id

data_loc = (
DataLocation.CODE
if node.is_immutable
else DataLocation.UNSET
if node.is_constant
else DataLocation.TRANSIENT
if node.is_transient
else DataLocation.STORAGE
)

modifiability = (
Modifiability.RUNTIME_CONSTANT
if node.is_immutable
else Modifiability.CONSTANT
if node.is_constant
else Modifiability.MODIFIABLE
)

if node.is_public:
# generate function type and add to metadata
# we need this when building the public getter
func_t = ContractFunctionT.getter_from_VariableDecl(node)
func_t = ContractFunctionT.getter_from_VariableDecl(node, data_loc)
node._metadata["getter_type"] = func_t
self._add_exposed_function(func_t, node)

Expand All @@ -648,24 +666,6 @@ def visit_VariableDecl(self, node):
)
raise ImmutableViolation(message, node)

data_loc = (
DataLocation.CODE
if node.is_immutable
else DataLocation.UNSET
if node.is_constant
else DataLocation.TRANSIENT
if node.is_transient
else DataLocation.STORAGE
)

modifiability = (
Modifiability.RUNTIME_CONSTANT
if node.is_immutable
else Modifiability.CONSTANT
if node.is_constant
else Modifiability.MODIFIABLE
)

type_ = type_from_annotation(node.annotation, data_loc)

if node.is_transient and not version_check(begin="cancun"):
Expand Down
11 changes: 9 additions & 2 deletions vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,9 @@ def set_reentrancy_key_position(self, position: VarOffset) -> None:
self.reentrancy_key_position = position

@classmethod
def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctionT":
def getter_from_VariableDecl(
cls, node: vy_ast.VariableDecl, data_loc: DataLocation
) -> "ContractFunctionT":
"""
Generate a `ContractFunctionT` object from an `VariableDecl` node.

Expand All @@ -453,14 +455,19 @@ def getter_from_VariableDecl(cls, node: vy_ast.VariableDecl) -> "ContractFunctio
---------
node : VariableDecl
Vyper ast node to generate the function definition from.
data_loc : DataLocation

Returns
-------
ContractFunctionT
"""
if not node.is_public:
raise CompilerPanic("getter generated for non-public function")
type_ = type_from_annotation(node.annotation, DataLocation.STORAGE)

assert data_loc not in (DataLocation.MEMORY, DataLocation.CALLDATA)
charles-cooper marked this conversation as resolved.
Show resolved Hide resolved

type_ = type_from_annotation(node.annotation, data_loc)

arguments, return_type = type_.getter_signature
args = []
for i, item in enumerate(arguments):
Expand Down
2 changes: 1 addition & 1 deletion vyper/semantics/types/subscriptable.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class HashMapT(_SubscriptableT):

_equality_attrs = ("key_type", "value_type")

# disallow everything but storage
# disallow everything but storage or transient
_invalid_locations = (
DataLocation.UNSET,
DataLocation.CALLDATA,
Expand Down
Loading