Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] RPC support #512

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions chain/fork/remote_state_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package fork

import (
"github.com/ethereum/go-ethereum/common"
"github.com/holiman/uint256"
)

type RemoteStateCache interface {
GetStorageAt(common.Address, common.Hash) (common.Hash, error)
GetStateObject(common.Address) (*uint256.Int, uint64, []byte, error)
}

var _ RemoteStateCache = (*EmptyRemoteStateCache)(nil)

type EmptyRemoteStateCache struct{}

func (d EmptyRemoteStateCache) GetStorageAt(address common.Address, hash common.Hash) (common.Hash, error) {
return common.Hash{}, nil
}

func (d EmptyRemoteStateCache) GetStateObject(address common.Address) (*uint256.Int, uint64, []byte, error) {
return uint256.NewInt(0), 0, nil, nil
}
168 changes: 168 additions & 0 deletions chain/fork/remote_state_provider.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package fork

import (
"fmt"
"github.com/crytic/medusa/chain/types"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state"
"github.com/holiman/uint256"
)

var _ state.RemoteStateProvider = (*RemoteStateProvider)(nil)
var _ state.RemoteStateProviderFactory = (*RemoteStateProviderFactory)(nil)

type RemoteStateProvider struct {
cache RemoteStateCache

stateObjBySnapshot map[int][]common.Address
stateSlotBySnapshot map[int]map[common.Address][]common.Hash

stateObjsImported map[common.Address]struct{}
stateSlotsImported map[common.Address]map[common.Hash]struct{}
}

func newRemoteStateProvider(cache RemoteStateCache) *RemoteStateProvider {
return &RemoteStateProvider{
cache: cache,
stateObjBySnapshot: make(map[int][]common.Address),
stateSlotBySnapshot: make(map[int]map[common.Address][]common.Hash),
stateObjsImported: make(map[common.Address]struct{}),
stateSlotsImported: make(map[common.Address]map[common.Hash]struct{}),
}
}

func (s *RemoteStateProvider) ImportStateObject(addr common.Address, snapId int) (bal *uint256.Int, nonce uint64, code []byte, e *state.RemoteStateError) {
if existingSnap, ok := s.stateObjsImported[addr]; ok {
return nil, 0, nil, &state.RemoteStateError{
CannotQueryDirtyAccount: true,
Error: fmt.Errorf("state object %s was already imported in snapshot %d", addr.Hex(), existingSnap),
}
}

bal, nonce, code, err := s.cache.GetStateObject(addr)
if err == nil {
s.recordImportedStateObject(addr, snapId)
return bal, nonce, code, nil
} else {
return uint256.NewInt(0), 0, nil, &state.RemoteStateError{
CannotQueryDirtyAccount: false,
Error: err,
}
}
}

func (s *RemoteStateProvider) ImportStorageAt(addr common.Address, slot common.Hash, snapId int) (common.Hash, *state.RemoteStorageError) {
imported := s.isStateSlotImported(addr, slot)
if imported {
return common.Hash{}, &state.RemoteStorageError{
CannotQueryDirtySlot: true,
Error: fmt.Errorf("state slot %s of address %s was already imported in snapshot %d", slot.Hex(), addr.Hex(), snapId),
}
}
data, err := s.cache.GetStorageAt(addr, slot)
if err == nil {
s.recordImportedStateSlot(addr, slot, snapId)
return data, nil
} else {
return common.Hash{}, &state.RemoteStorageError{
CannotQueryDirtySlot: false,
Error: err,
}
}
}

func (s *RemoteStateProvider) MarkSlotWritten(addr common.Address, slot common.Hash, snapId int) {
s.recordImportedStateSlot(addr, slot, snapId)
}

func (s *RemoteStateProvider) NotifyRevertedToSnapshot(snapId int) {
// purge all records down to and not including the provided snapId

accountsToClear := make([]common.Address, 0)
for sId, accounts := range s.stateObjBySnapshot {
if sId > snapId {
accountsToClear = append(accountsToClear, accounts...)
delete(s.stateObjBySnapshot, sId)
}
}
for _, addr := range accountsToClear {
delete(s.stateObjsImported, addr)
}

accountSlotsToClear := make(map[common.Address][]common.Hash)
for sId, accounts := range s.stateSlotBySnapshot {
if sId > snapId {
for addr, slots := range accounts {
if _, ok := accountSlotsToClear[addr]; !ok {
accountSlotsToClear[addr] = make([]common.Hash, 0, len(slots))
}
accountSlotsToClear[addr] = append(accountSlotsToClear[addr], slots...)
}
delete(s.stateSlotBySnapshot, sId)
}
}

for addr, slots := range accountSlotsToClear {
for _, slot := range slots {
delete(s.stateSlotsImported[addr], slot)
}
}
}

func (s *RemoteStateProvider) isStateSlotImported(addr common.Address, slot common.Hash) bool {
if _, ok := s.stateSlotsImported[addr]; !ok {
return false
} else {
if _, ok := s.stateSlotsImported[addr][slot]; !ok {
return false
} else {
return true
}
}
}

func (s *RemoteStateProvider) recordImportedStateObject(addr common.Address, snapId int) {
s.stateObjsImported[addr] = struct{}{}
if _, ok := s.stateObjBySnapshot[snapId]; !ok {
s.stateObjBySnapshot[snapId] = make([]common.Address, 0)
}
s.stateObjBySnapshot[snapId] = append(s.stateObjBySnapshot[snapId], addr)
}

func (s *RemoteStateProvider) recordImportedStateSlot(addr common.Address, slot common.Hash, snapId int) {
if _, ok := s.stateSlotsImported[addr]; !ok {
s.stateSlotsImported[addr] = make(map[common.Hash]struct{})
}
s.stateSlotsImported[addr][slot] = struct{}{}
if _, ok := s.stateSlotBySnapshot[snapId]; !ok {
s.stateSlotBySnapshot[snapId] = make(map[common.Address][]common.Hash)
}
if _, ok := s.stateSlotBySnapshot[snapId][addr]; !ok {
s.stateSlotBySnapshot[snapId][addr] = make([]common.Hash, 0)
}
s.stateSlotBySnapshot[snapId][addr] = append(s.stateSlotBySnapshot[snapId][addr], slot)
}

type RemoteStateProviderFactory struct {
RemoteStateCache
}

func NewRemoteStateProviderFactory(cache RemoteStateCache) *RemoteStateProviderFactory {
return &RemoteStateProviderFactory{cache}
}

func (r RemoteStateProviderFactory) New() state.RemoteStateProvider {
return newRemoteStateProvider(r.RemoteStateCache)
}

type MedusaStateFactory struct {
*RemoteStateProviderFactory
}

func NewMedusaStateFactory(remoteStateFactory *RemoteStateProviderFactory) *MedusaStateFactory {
return &MedusaStateFactory{remoteStateFactory}
}

func (f *MedusaStateFactory) New(root common.Hash, db state.Database) (types.MedusaStateDB, error) {
return state.NewProxyDB(root, db, f.RemoteStateProviderFactory)
}
29 changes: 20 additions & 9 deletions chain/test_chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package chain
import (
"errors"
"fmt"
"github.com/crytic/medusa/chain/fork"
"math/big"
"sort"

Expand Down Expand Up @@ -63,10 +64,10 @@ type TestChain struct {
// genesisDefinition represents the Genesis information used to generate the chain's initial state.
genesisDefinition *core.Genesis

// state represents the current Ethereum world state.StateDB. It tracks all state across the chain and dummyChain
// and is the subject of state changes when executing new transactions. This does not track the current block
// head or anything of that nature and simply tracks accounts, balances, code, storage, etc.
state *state.StateDB
// state represents the current Ethereum world (interface implementing state.StateDB). It tracks all state across
// the chain and dummyChain and is the subject of state changes when executing new transactions. This does not
// track the current block head or anything of that nature and simply tracks accounts, balances, code, storage, etc.
state chainTypes.MedusaStateDB

// stateDatabase refers to the database object which state uses to store data. It is constructed over db.
stateDatabase state.Database
Expand All @@ -85,6 +86,10 @@ type TestChain struct {

// Events defines the event system for the TestChain.
Events TestChainEvents

// stateDbFactory used to construct state databases from db/root. Abstracts away the backing RPC when running in
// fork mode.
stateDbFactory *fork.MedusaStateFactory
}

// NewTestChain creates a simulated Ethereum backend used for testing, or returns an error if one occurred.
Expand Down Expand Up @@ -179,6 +184,11 @@ func NewTestChain(genesisAlloc types.GenesisAlloc, testChainConfig *config.TestC
transactionTracerRouter := NewTestChainTracerRouter()
callTracerRouter := NewTestChainTracerRouter()

// Set up the state factory
remoteCache := fork.EmptyRemoteStateCache{}
rspf := fork.NewRemoteStateProviderFactory(remoteCache)
sf := fork.NewMedusaStateFactory(rspf)

// Create our instance
chain := &TestChain{
genesisDefinition: genesisDefinition,
Expand All @@ -193,6 +203,7 @@ func NewTestChain(genesisAlloc types.GenesisAlloc, testChainConfig *config.TestC
testChainConfig: testChainConfig,
chainConfig: genesisDefinition.Config,
vmConfigExtensions: vmConfigExtensions,
stateDbFactory: sf,
}

// Add our internal tracers to this chain.
Expand Down Expand Up @@ -297,7 +308,7 @@ func (t *TestChain) GenesisDefinition() *core.Genesis {
}

// State returns the current state.StateDB of the chain.
func (t *TestChain) State() *state.StateDB {
func (t *TestChain) State() chainTypes.MedusaStateDB {
return t.state
}

Expand Down Expand Up @@ -460,9 +471,9 @@ func (t *TestChain) BlockHashFromNumber(blockNumber uint64) (common.Hash, error)

// StateFromRoot obtains a state from a given state root hash.
// Returns the state, or an error if one occurred.
func (t *TestChain) StateFromRoot(root common.Hash) (*state.StateDB, error) {
func (t *TestChain) StateFromRoot(root common.Hash) (chainTypes.MedusaStateDB, error) {
// Load our state from the database
stateDB, err := state.New(root, t.stateDatabase, nil)
stateDB, err := t.stateDbFactory.New(root, t.stateDatabase)
if err != nil {
return nil, err
}
Expand All @@ -486,7 +497,7 @@ func (t *TestChain) StateRootAfterBlockNumber(blockNumber uint64) (common.Hash,

// StateAfterBlockNumber obtains the Ethereum world state after processing all transactions in the provided block
// number. Returns the state, or an error if one occurs.
func (t *TestChain) StateAfterBlockNumber(blockNumber uint64) (*state.StateDB, error) {
func (t *TestChain) StateAfterBlockNumber(blockNumber uint64) (chainTypes.MedusaStateDB, error) {
// Obtain our block's post-execution state root hash
root, err := t.StateRootAfterBlockNumber(blockNumber)
if err != nil {
Expand Down Expand Up @@ -558,7 +569,7 @@ func (t *TestChain) RevertToBlockNumber(blockNumber uint64) error {
// It takes an optional state argument, which is the state to execute the message over. If not provided, the
// current pending state (or committed state if none is pending) will be used instead.
// The state executed over may be a pending block state.
func (t *TestChain) CallContract(msg *core.Message, state *state.StateDB, additionalTracers ...*TestChainTracer) (*core.ExecutionResult, error) {
func (t *TestChain) CallContract(msg *core.Message, state chainTypes.MedusaStateDB, additionalTracers ...*TestChainTracer) (*core.ExecutionResult, error) {
// If our provided state is nil, use our current chain state.
if state == nil {
state = t.state
Expand Down
28 changes: 28 additions & 0 deletions chain/types/medusa_statedb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package types

import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/tracing"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/holiman/uint256"
)

var _ MedusaStateDB = (*state.StateDB)(nil)
var _ MedusaStateDB = (*state.ForkStateDb)(nil)

type MedusaStateDB interface {
vm.StateDB
// geth's built-in statedb interface is not complete.
// We need to add the extra methods that Medusa uses.
IntermediateRoot(bool) common.Hash
Finalise(bool)
Logs() []*types.Log
GetLogs(common.Hash, uint64, common.Hash) []*types.Log
TxIndex() int
SetBalance(common.Address, *uint256.Int, tracing.BalanceChangeReason)
SetTxContext(common.Hash, int)
Commit(uint64, bool) (common.Hash, error)
SetLogger(*tracing.Hooks)
}
4 changes: 2 additions & 2 deletions chain/vendored/apply_transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ package vendored

import (
"github.com/crytic/medusa/chain/config"
types2 "github.com/crytic/medusa/chain/types"
"github.com/ethereum/go-ethereum/common"
. "github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
Expand All @@ -36,7 +36,7 @@ import (
// This executes on an underlying EVM and returns a transaction receipt, or an error if one occurs.
// Additional changes:
// - Exposed core.ExecutionResult as a return value.
func EVMApplyTransaction(msg *Message, config *params.ChainConfig, testChainConfig *config.TestChainConfig, author *common.Address, gp *GasPool, statedb *state.StateDB, blockNumber *big.Int, blockHash common.Hash, tx *types.Transaction, usedGas *uint64, evm *vm.EVM) (receipt *types.Receipt, result *ExecutionResult, err error) {
func EVMApplyTransaction(msg *Message, config *params.ChainConfig, testChainConfig *config.TestChainConfig, author *common.Address, gp *GasPool, statedb types2.MedusaStateDB, blockNumber *big.Int, blockHash common.Hash, tx *types.Transaction, usedGas *uint64, evm *vm.EVM) (receipt *types.Receipt, result *ExecutionResult, err error) {
// Apply the OnTxStart and OnTxEnd hooks
if evm.Config.Tracer != nil && evm.Config.Tracer.OnTxStart != nil {
evm.Config.Tracer.OnTxStart(evm.GetVMContext(), tx, msg.From)
Expand Down
6 changes: 3 additions & 3 deletions fuzzing/executiontracer/execution_tracer.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package executiontracer

import (
"github.com/crytic/medusa/chain/types"
"math/big"

"github.com/crytic/medusa/chain"
"github.com/crytic/medusa/fuzzing/contracts"
"github.com/crytic/medusa/utils"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/tracing"
coretypes "github.com/ethereum/go-ethereum/core/types"

Expand All @@ -20,7 +20,7 @@ import (
// CallWithExecutionTrace obtains an execution trace for a given call, on the provided chain, using the state
// provided. If a nil state is provided, the current chain state will be used.
// Returns the ExecutionTrace for the call or an error if one occurs.
func CallWithExecutionTrace(testChain *chain.TestChain, contractDefinitions contracts.Contracts, msg *core.Message, state *state.StateDB) (*core.ExecutionResult, *ExecutionTrace, error) {
func CallWithExecutionTrace(testChain *chain.TestChain, contractDefinitions contracts.Contracts, msg *core.Message, state types.MedusaStateDB) (*core.ExecutionResult, *ExecutionTrace, error) {
// Create an execution tracer
executionTracer := NewExecutionTracer(contractDefinitions, testChain.CheatCodeContracts())
defer executionTracer.Close()
Expand Down Expand Up @@ -302,7 +302,7 @@ func (t *ExecutionTracer) OnOpcode(pc uint64, op byte, gas, cost uint64, scope t
// TODO: Move this to OnLog
if op == byte(vm.LOG0) || op == byte(vm.LOG1) || op == byte(vm.LOG2) || op == byte(vm.LOG3) || op == byte(vm.LOG4) {
t.onNextCaptureState = append(t.onNextCaptureState, func() {
logs := t.evmContext.StateDB.(*state.StateDB).Logs()
logs := t.evmContext.StateDB.(types.MedusaStateDB).Logs()
if len(logs) > 0 {
t.currentCallFrame.Operations = append(t.currentCallFrame.Operations, logs[len(logs)-1])
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,4 @@ require (
rsc.io/tmplfunc v0.0.3 // indirect
)

replace github.com/ethereum/go-ethereum => github.com/crytic/medusa-geth v0.0.0-20240919134035-0fd368c28419
replace github.com/ethereum/go-ethereum => github.com/crytic/medusa-geth v0.0.0-20241130192903-8ab947767bf4
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ github.com/crate-crypto/go-kzg-4844 v1.0.0/go.mod h1:1kMhvPgI0Ky3yIa+9lFySEBUBXk
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/crytic/medusa-geth v0.0.0-20240919134035-0fd368c28419 h1:MJXzWPObZtF0EMRqX64JkzJDj+GMLPxg3XK5xb12FFU=
github.com/crytic/medusa-geth v0.0.0-20240919134035-0fd368c28419/go.mod h1:ajGCVsk6ctffGwe9TSDQqj4HIUUQ1WdUit5tWFNl8Tw=
github.com/crytic/medusa-geth v0.0.0-20241130173605-b90d9e750c68 h1:omfbSnk8EEIr/B+Sv1iHqTK5sZPw4JfNojbAgFMR9g4=
github.com/crytic/medusa-geth v0.0.0-20241130173605-b90d9e750c68/go.mod h1:ajGCVsk6ctffGwe9TSDQqj4HIUUQ1WdUit5tWFNl8Tw=
github.com/crytic/medusa-geth v0.0.0-20241130192903-8ab947767bf4 h1:hUCM94+Pa65FkkRynJUktVinTHvAdxxcIeuLEkC1/bc=
github.com/crytic/medusa-geth v0.0.0-20241130192903-8ab947767bf4/go.mod h1:ajGCVsk6ctffGwe9TSDQqj4HIUUQ1WdUit5tWFNl8Tw=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
Expand Down