Skip to content

Commit

Permalink
Merge pull request #840 from multiversx/refactor-async-composability-…
Browse files Browse the repository at this point in the history
…todos

refactor output in case of error for async callback and empty function name check
  • Loading branch information
sasurobert authored Apr 12, 2024
2 parents 1cb1def + c4854af commit d5b3efd
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 45 deletions.
74 changes: 74 additions & 0 deletions vmhost/contexts/async_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package contexts

import (
"errors"
"github.com/multiversx/mx-chain-core-go/core"
"math/big"
"testing"

Expand Down Expand Up @@ -496,6 +497,79 @@ func TestAsyncContext_UpdateCurrentCallStatus(t *testing.T) {
require.Equal(t, vmhost.AsyncCallRejected, asyncCall.Status)
}

func TestAsyncContext_OutputInCaseOfErrorInCallback(t *testing.T) {
user := []byte("user")
contractA := []byte("contractA")
contractB := []byte("contractB")

host, _ := initializeVMAndWasmerAsyncContext(t)
host.EnableEpochsHandlerField = &worldmock.EnableEpochsHandlerStub{
IsFlagEnabledCalled: func(flag core.EnableEpochFlag) bool {
return flag == vmhost.AsyncV3Flag
},
}

async := makeAsyncContext(t, host, contractA)
host.Storage().SetAddress(contractA)
host.AsyncContext = async

vmInput := &vmcommon.ContractCallInput{
VMInput: vmcommon.VMInput{
CallerAddr: user,
Arguments: [][]byte{{0}},
CallType: vm.DirectCall,
},
RecipientAddr: contractA,
}
host.Runtime().InitStateFromContractCallInput(vmInput)

err := async.RegisterAsyncCall("", &vmhost.AsyncCall{
Destination: contractB,
Data: []byte("function"),
})
require.Nil(t, err)

err = async.Save()
require.Nil(t, err)

asyncCallId := async.GetCallID()
asyncStoragePrefix := host.Storage().GetVmProtectedPrefix(vmhost.AsyncDataPrefix)
asyncCallKey := vmhost.CustomStorageKey(string(asyncStoragePrefix), asyncCallId)

data, _, _, _ := host.Storage().GetStorageUnmetered(asyncCallKey)
require.NotEqual(t, len(data), 0)

vmInput = &vmcommon.ContractCallInput{
VMInput: vmcommon.VMInput{
CallerAddr: contractB,
Arguments: [][]byte{{0}},
CallType: vm.AsynchronousCallBack,
},
RecipientAddr: contractA,
}
host.Runtime().InitStateFromContractCallInput(vmInput)

async.callbackAsyncInitiatorCallID = asyncCallId
async.callType = vmInput.CallType
err = async.LoadParentContext()
require.Nil(t, err)

vmOutput := host.Output().CreateVMOutputInCaseOfError(vmhost.ErrNotEnoughGas)
outputAccount := vmOutput.OutputAccounts[string(contractA)]

require.NotNil(t, outputAccount)

storageUpdates := outputAccount.StorageUpdates
require.Equal(t, len(storageUpdates), 1)

asyncContextDeletionUpdate := storageUpdates[string(asyncCallKey)]
require.NotNil(t, asyncContextDeletionUpdate)
require.Equal(t, len(asyncContextDeletionUpdate.Data), 0)

data, _, _, _ = host.Storage().GetStorageUnmetered(asyncCallKey)
require.Equal(t, len(data), 0)
}

func TestAsyncContext_SendAsyncCallCrossShard(t *testing.T) {
host, world := initializeVMAndWasmerAsyncContext(t)
world.AcctMap.PutAccount(&worldmock.Account{
Expand Down
35 changes: 34 additions & 1 deletion vmhost/contexts/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -562,22 +562,55 @@ func (context *outputContext) DeployCode(input vmhost.CodeDeployInput) {
context.codeUpdates[string(input.ContractAddress)] = empty
}

// createVMOutputInCaseOfErrorOfAsyncCallback appends the deletion of the async context to the output
func (context *outputContext) createVMOutputInCaseOfErrorOfAsyncCallback(returnCode vmcommon.ReturnCode, returnMessage string) *vmcommon.VMOutput {
async := context.host.Async()
metering := context.host.Metering()

callId := async.GetCallbackAsyncInitiatorCallID()

context.outputState = &vmcommon.VMOutput{
GasRemaining: 0,
GasRefund: big.NewInt(0),
ReturnCode: returnCode,
ReturnMessage: returnMessage,
OutputAccounts: make(map[string]*vmcommon.OutputAccount),
}

err := async.DeleteFromCallID(callId)
if err != nil {
logOutput.Trace("failed to delete Async Context", "callId", callId, "err", err)
}

metering.UpdateGasStateOnFailure(context.outputState)

return context.outputState
}

// CreateVMOutputInCaseOfError creates a new vmOutput with the given error set as return message.
func (context *outputContext) CreateVMOutputInCaseOfError(err error) *vmcommon.VMOutput {
runtime := context.host.Runtime()
metering := context.host.Metering()

callType := runtime.GetVMInput().CallType

runtime.AddError(err, runtime.FunctionName())

returnCode := context.resolveReturnCodeFromError(err)
returnMessage := context.resolveReturnMessageFromError(err)

if context.host.EnableEpochsHandler().IsFlagEnabled(vmhost.AsyncV3Flag) && callType == vm.AsynchronousCallBack {
return context.createVMOutputInCaseOfErrorOfAsyncCallback(returnCode, returnMessage)
}

vmOutput := &vmcommon.VMOutput{
GasRemaining: 0,
GasRefund: big.NewInt(0),
ReturnCode: returnCode,
ReturnMessage: returnMessage,
}

context.host.Metering().UpdateGasStateOnFailure(vmOutput)
metering.UpdateGasStateOnFailure(vmOutput)

return vmOutput
}
Expand Down
2 changes: 2 additions & 0 deletions vmhost/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ const (
FixOOGReturnCodeFlag core.EnableEpochFlag = "FixOOGReturnCodeFlag"
// DynamicGasCostForDataTrieStorageLoadFlag defines the flag that activates the dynamic gas cost for data trie storage load
DynamicGasCostForDataTrieStorageLoadFlag core.EnableEpochFlag = "DynamicGasCostForDataTrieStorageLoadFlag"
// AsyncV3Flag defines the flag that activates async v3
AsyncV3Flag core.EnableEpochFlag = "AsyncV3Flag"
)
96 changes: 52 additions & 44 deletions vmhost/hostCore/execution.go
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,14 @@ func (host *vmHost) checkFinalGasAfterExit() error {
return nil
}

func (host *vmHost) checkValidFunctionName(name string) error {
if name == "" {
return executor.ErrInvalidFunction
}

return nil
}

func (host *vmHost) callInitFunction() error {
return host.callSCFunction(vmhost.InitFunctionName)
}
Expand All @@ -1154,12 +1162,18 @@ func (host *vmHost) callUpgradeFunction() error {
}

func (host *vmHost) callSCFunction(functionName string) error {
err := host.checkValidFunctionName(functionName)
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "checkValidFunctionName")
return err
}

runtime := host.Runtime()
if !runtime.HasFunction(functionName) {
return executor.ErrFuncNotFound
}

err := runtime.CallSCFunction(functionName)
err = runtime.CallSCFunction(functionName)
if err != nil {
err = host.handleBreakpointIfAny(err)
}
Expand Down Expand Up @@ -1236,12 +1250,6 @@ func (host *vmHost) callSCMethodAsynchronousCallBack() error {
metering.UseGas(metering.GasLeft())
}

// TODO matei-p R2 Returning an error here will cause the VMOutput to be
// empty (due to CreateVMOutputInCaseOfError()). But in release 2 of
// Promises, CreateVMOutputInCaseOfError() should still contain storage
// deletions caused by AsyncContext cleanup, even if callbackErr != nil and
// was returned here. The storage deletions MUST be persisted in the data
// trie once R2 goes live.
if !isCallComplete {
return callbackErr
}
Expand All @@ -1263,47 +1271,47 @@ func (host *vmHost) callFunctionAndExecuteAsync() (bool, error) {
runtime := host.Runtime()
async := host.Async()

// TODO refactor this, and apply this condition in other places where a
// function is called
if runtime.FunctionName() != "" {
err := host.verifyAllowedFunctionCall()
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "verifyAllowedFunctionCall")
return false, err
}
err := host.checkValidFunctionName(runtime.FunctionName())
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "checkValidFunctionName")
return false, err
}

functionName, err := runtime.FunctionNameChecked()
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "FunctionNameChecked")
return false, err
}
err = host.verifyAllowedFunctionCall()
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "verifyAllowedFunctionCall")
return false, err
}

err = runtime.CallSCFunction(functionName)
if err != nil {
err = host.handleBreakpointIfAny(err)
log.Trace("breakpoint detected and handled", "err", err)
}
if err == nil {
err = host.checkFinalGasAfterExit()
}
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "sc function")
return true, err
}
functionName, err := runtime.FunctionNameChecked()
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "FunctionNameChecked")
return false, err
}

err = async.Execute()
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "async execution")
return false, err
}
err = runtime.CallSCFunction(functionName)
if err != nil {
err = host.handleBreakpointIfAny(err)
log.Trace("breakpoint detected and handled", "err", err)
}
if err == nil {
err = host.checkFinalGasAfterExit()
}
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "sc function")
return true, err
}

if !async.IsComplete() || async.HasLegacyGroup() {
async.SetResults(host.Output().GetVMOutput())
err = async.Save()
return false, err
}
} else {
return false, executor.ErrInvalidFunction
err = async.Execute()
if err != nil {
log.Trace("call SC method failed", "error", err, "src", "async execution")
return false, err
}

if !async.IsComplete() || async.HasLegacyGroup() {
async.SetResults(host.Output().GetVMOutput())
err = async.Save()
return false, err
}

return true, nil
Expand Down

0 comments on commit d5b3efd

Please sign in to comment.