Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[lang]: allow downcasting of bytestrings #3832

Merged
merged 4 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions tests/functional/builtins/codegen/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import eth.codecs.abi.exceptions
import pytest

from vyper.compiler import compile_code
from vyper.exceptions import InvalidLiteral, InvalidType, TypeMismatch
from vyper.semantics.types import AddressT, BoolT, BytesM_T, BytesT, DecimalT, IntegerT, StringT
from vyper.semantics.types.shortcuts import BYTES20_T, BYTES32_T, UINT, UINT160_T, UINT256_T
Expand Down Expand Up @@ -560,23 +561,75 @@ def foo(x: {i_typ}) -> {o_typ}:
assert_compile_failed(lambda: get_contract(code), TypeMismatch)


@pytest.mark.parametrize("typ", sorted(TEST_TYPES))
def test_bytes_too_large_cases(get_contract, assert_compile_failed, typ):
@pytest.mark.parametrize("typ", sorted(BASE_TYPES))
def test_bytes_too_large_cases(typ):
code_1 = f"""
@external
def foo(x: Bytes[33]) -> {typ}:
return convert(x, {typ})
"""
assert_compile_failed(lambda: get_contract(code_1), TypeMismatch)
with pytest.raises(TypeMismatch):
compile_code(code_1)

bytes_33 = b"1" * 33
code_2 = f"""
@external
def foo() -> {typ}:
return convert({bytes_33}, {typ})
"""
with pytest.raises(TypeMismatch):
compile_code(code_2)

assert_compile_failed(lambda: get_contract(code_2, TypeMismatch))

@pytest.mark.parametrize("cls1,cls2", itertools.product((StringT, BytesT), (StringT, BytesT)))
def test_bytestring_conversions(cls1, cls2, get_contract, tx_failed):
typ1 = cls1(33)
typ2 = cls2(32)

def bytestring(cls, string):
if cls == BytesT:
return string.encode("utf-8")
return string

code_1 = f"""
@external
def foo(x: {typ1}) -> {typ2}:
return convert(x, {typ2})
"""
c = get_contract(code_1)

for i in range(33): # inclusive 32
s = "1" * i
arg = bytestring(cls1, s)
out = bytestring(cls2, s)
assert c.foo(arg) == out

with tx_failed():
# TODO: sanity check it is convert which is reverting, not arg clamping
c.foo(bytestring(cls1, "1" * 33))

code_2_template = """
@external
def foo() -> {typ}:
return convert({arg}, {typ})
"""

# test literals
for i in range(33): # inclusive 32
s = "1" * i
arg = bytestring(cls1, s)
out = bytestring(cls2, s)
code = code_2_template.format(typ=typ2, arg=repr(arg))
if cls1 == cls2: # ex.: can't convert "" to String[32]
with pytest.raises(InvalidType):
compile_code(code)
else:
c = get_contract(code)
assert c.foo() == out

failing_code = code_2_template.format(typ=typ2, arg=bytestring(cls1, "1" * 33))
with pytest.raises(TypeMismatch):
compile_code(failing_code)


@pytest.mark.parametrize("n", range(1, 33))
Expand Down
30 changes: 19 additions & 11 deletions vyper/builtins/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,23 +422,31 @@ def to_address(expr, arg, out_typ):
return IRnode.from_list(ret, out_typ)


# question: should we allow bytesM -> String?
@_input_types(BytesT)
def to_string(expr, arg, out_typ):
_check_bytes(expr, arg, out_typ, out_typ.maxlen)
def _cast_bytestring(expr, arg, out_typ):
# ban converting Bytes[20] to Bytes[21]
if isinstance(arg.typ, out_typ.__class__) and arg.typ.maxlen <= out_typ.maxlen:
_FAIL(arg.typ, out_typ, expr)
# can't downcast literals with known length (e.g. b"abc" to Bytes[2])
if isinstance(expr, vy_ast.Constant) and arg.typ.maxlen > out_typ.maxlen:
_FAIL(arg.typ, out_typ, expr)

ret = ["seq"]
if out_typ.maxlen < arg.typ.maxlen:
ret.append(["assert", ["le", get_bytearray_length(arg), out_typ.maxlen]])
ret.append(arg)
# NOTE: this is a pointer cast
return IRnode.from_list(arg, typ=out_typ)
return IRnode.from_list(ret, typ=out_typ, location=arg.location, encoding=arg.encoding)


@_input_types(StringT)
def to_bytes(expr, arg, out_typ):
_check_bytes(expr, arg, out_typ, out_typ.maxlen)
# question: should we allow bytesM -> String?
@_input_types(BytesT, StringT)
def to_string(expr, arg, out_typ):
return _cast_bytestring(expr, arg, out_typ)

# TODO: more casts

# NOTE: this is a pointer cast
return IRnode.from_list(arg, typ=out_typ)
@_input_types(StringT, BytesT)
def to_bytes(expr, arg, out_typ):
return _cast_bytestring(expr, arg, out_typ)


@_input_types(IntegerT)
Expand Down
Loading