Skip to content

Commit

Permalink
Implement EIP-4200
Browse files Browse the repository at this point in the history
  • Loading branch information
gurukamath committed Jun 26, 2024
1 parent 343146c commit c5318cc
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 14 deletions.
106 changes: 95 additions & 11 deletions src/ethereum/prague/vm/eof.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""

from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Set

from . import EOF, EOF_MAGIC, EOF_MAGIC_LENGTH
from .exceptions import InvalidEOF
Expand Down Expand Up @@ -324,39 +324,123 @@ def validate_body(code: bytes, eof_header: EOFHeader) -> None:
raise InvalidEOF("Stray bytes found after data section")


def validate_code_section(code: bytes) -> None:
def get_valid_jump_destinations(code: bytes) -> Set[int]:
"""
Validate a code section of the EOF container.
Get the valid jump destinations for the code. The immediate bytes
of the PUSH, RJUMP, RJUMPI, RJUMPV opcodes are invalid as jump
destinations.
Parameters
----------
code : bytes
The code section to validate.
The code section of the EOF container.
Raises
------
InvalidEOF
If the code section is invalid.
Returns
-------
valid_jump_destinations : Set[int]
The valid jump destinations in the code.
"""
counter = 0
valid_jump_destinations = set()
while counter < len(code):
try:
opcode = get_opcode(code[counter], EOF.EOF1)
except ValueError:
raise InvalidEOF("Invalid opcode in code section")
valid_jump_destinations.add(counter)

counter += 1

if (
opcode.value >= Ops.PUSH1.value
and opcode.value <= Ops.PUSH32.value
):
push_data_size = opcode.value - Ops.PUSH1.value + 1
if len(code) < counter + push_data_size + 1:
if len(code) < counter + push_data_size:
raise InvalidEOF("Push data missing")
counter += push_data_size
continue

counter += push_data_size + 1
if opcode in (Ops.RJUMP, Ops.RJUMPI):
if len(code) < counter + 2:
raise InvalidEOF("Relative jump offset missing")
counter += 2
continue

counter += 1
if opcode == Ops.RJUMPV:
if len(code) < counter + 1:
raise InvalidEOF("max_index missing for RJUMPV")
max_index = code[counter]
num_relative_indices = max_index + 1
counter += 1

for _ in range(num_relative_indices):
if len(code) < counter + 2:
raise InvalidEOF("Relative jump indices missing")
counter += 2
continue

return valid_jump_destinations


def validate_code_section(code: bytes) -> None:
"""
Validate a code section of the EOF container.
Parameters
----------
code : bytes
The code section to validate.
Raises
------
InvalidEOF
If the code section is invalid.
"""
counter = 0
valid_jump_destinations = get_valid_jump_destinations(code)

for counter in valid_jump_destinations:
opcode = get_opcode(code[counter], EOF.EOF1)

# Make sure the bytes encoding relative offset
# are available
if opcode in (Ops.RJUMP, Ops.RJUMPI):
relative_offset = int.from_bytes(
code[counter + 1 : counter + 3], "big", signed=True
)
pc_post_instruction = counter + 3
jump_destination = pc_post_instruction + relative_offset
if (
jump_destination < 0
or len(code) < jump_destination + 1
or jump_destination not in valid_jump_destinations
):
raise InvalidEOF("Invalid jump destination")

elif opcode == Ops.RJUMPV:
num_relative_indices = code[counter + 1] + 1
# pc_post_instruction will be
# counter + 1 <- for normal pc increment to next opcode
# + 1 <- for the 1 byte max_index
# + 2 * num_relative_indices <- for the 2 bytes of each offset
pc_post_instruction = counter + 2 + 2 * num_relative_indices

index_position = counter + 2
for _ in range(num_relative_indices):
relative_offset = int.from_bytes(
code[index_position : index_position + 2],
"big",
signed=True,
)
index_position += 2
jump_destination = pc_post_instruction + relative_offset
if (
jump_destination < 0
or len(code) < jump_destination + 1
or jump_destination not in valid_jump_destinations
):
raise InvalidEOF("Invalid jump destination")


def validate_eof_code(code: bytes, eof_header: EOFHeader) -> None:
Expand Down
3 changes: 3 additions & 0 deletions src/ethereum/prague/vm/gas.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@
GAS_INIT_CODE_WORD_COST = 2
GAS_BLOBHASH_OPCODE = Uint(3)
GAS_POINT_EVALUATION = Uint(50000)
GAS_RJUMP = Uint(2)
GAS_RJUMPI = Uint(4)
GAS_RJUMPV = Uint(4)

TARGET_BLOB_GAS_PER_BLOCK = U64(393216)
GAS_PER_BLOB = Uint(2**17)
Expand Down
16 changes: 15 additions & 1 deletion src/ethereum/prague/vm/instructions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ class Ops(enum.Enum):
LOG3 = 0xA3
LOG4 = 0xA4

# Static Relative Jumps
RJUMP = 0xE0
RJUMPI = 0xE1
RJUMPV = 0xE2

# System Operations
CREATE = 0xF0
CALL = 0xF1
Expand Down Expand Up @@ -355,6 +360,9 @@ class Ops(enum.Enum):
Ops.LOG2: log_instructions.log2,
Ops.LOG3: log_instructions.log3,
Ops.LOG4: log_instructions.log4,
Ops.RJUMP: control_flow_instructions.rjump,
Ops.RJUMPI: control_flow_instructions.rjumpi,
Ops.RJUMPV: control_flow_instructions.rjumpv,
Ops.CREATE: system_instructions.create,
Ops.RETURN: system_instructions.return_,
Ops.CALL: system_instructions.call,
Expand All @@ -367,7 +375,13 @@ class Ops(enum.Enum):
}


OPCODES_INVALID_IN_LEGACY = (Ops.INVALID,)
OPCODES_INVALID_IN_LEGACY = (
Ops.INVALID,
# Relative Jump instructions
Ops.RJUMP,
Ops.RJUMPI,
Ops.RJUMPV,
)

OPCODES_INVALID_IN_EOF1 = (
# Control Flow Ops
Expand Down
110 changes: 109 additions & 1 deletion src/ethereum/prague/vm/instructions/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,16 @@

from ethereum.base_types import U256, Uint

from ...vm.gas import GAS_BASE, GAS_HIGH, GAS_JUMPDEST, GAS_MID, charge_gas
from ...vm.gas import (
GAS_BASE,
GAS_HIGH,
GAS_JUMPDEST,
GAS_MID,
GAS_RJUMP,
GAS_RJUMPI,
GAS_RJUMPV,
charge_gas,
)
from .. import Evm
from ..exceptions import InvalidJumpDestError
from ..stack import pop, push
Expand Down Expand Up @@ -169,3 +178,102 @@ def jumpdest(evm: Evm) -> None:

# PROGRAM COUNTER
evm.pc += 1


def rjump(evm: Evm) -> None:
"""
Jump to a relative offset.
Parameters
----------
evm :
The current EVM frame.
"""
# STACK
pass

# GAS
charge_gas(evm, GAS_RJUMP)

# OPERATION
pass

# PROGRAM COUNTER
relative_offset = int.from_bytes(
evm.code[evm.pc + 1 : evm.pc + 3], "big", signed=True
)
# pc + 1 + 2 bytes of relative offset
pc_post_instruction = int(evm.pc) + 3
evm.pc = Uint(pc_post_instruction + relative_offset)


def rjumpi(evm: Evm) -> None:
"""
Jump to a relative offset given a condition.
Parameters
----------
evm :
The current EVM frame.
"""
# STACK
condition = pop(evm.stack)

# GAS
charge_gas(evm, GAS_RJUMPI)

# OPERATION
pass

# PROGRAM COUNTER
relative_offset = int.from_bytes(
evm.code[evm.pc + 1 : evm.pc + 3], "big", signed=True
)
# pc + 1 + 2 bytes of relative offset
pc_post_instruction = int(evm.pc) + 3
if condition == 0:
evm.pc = Uint(pc_post_instruction)
else:
evm.pc = Uint(pc_post_instruction + relative_offset)


def rjumpv(evm: Evm) -> None:
"""
Jump to a relative offset via jump table.
Parameters
----------
evm :
The current EVM frame.
"""
# STACK
case = pop(evm.stack)

# GAS
charge_gas(evm, GAS_RJUMPV)

# OPERATION
pass

# PROGRAM COUNTER
max_index = evm.code[evm.pc + 1]
num_relative_indices = max_index + 1
# pc_post_instruction will be
# counter + 1 <- for normal pc increment to next opcode
# + 1 <- for the 1 byte max_index
# + 2 * num_relative_indices <- for the 2 bytes of each offset
pc_post_instruction = int(evm.pc) + 2 + 2 * num_relative_indices

if case > max_index:
evm.pc = Uint(pc_post_instruction)
else:
relative_offset_position = evm.pc + 2 + 2 * case
relative_offset = int.from_bytes(
evm.code[relative_offset_position : relative_offset_position + 2],
"big",
signed=True,
)
evm.pc = Uint(pc_post_instruction + relative_offset)
6 changes: 5 additions & 1 deletion whitelist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -438,4 +438,8 @@ req
predeploy

eof
eof1
eof1

RJUMP
RJUMPI
RJUMPV

0 comments on commit c5318cc

Please sign in to comment.