Skip to content

Commit

Permalink
[x/programs] Add CallContext to programs to allow setting of default …
Browse files Browse the repository at this point in the history
…values (#1058)

* Add CallContext
  • Loading branch information
dboehm-avalabs authored Jul 2, 2024
1 parent 5c2cb96 commit 5b7a43e
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 0 deletions.
91 changes: 91 additions & 0 deletions x/programs/runtime/call_context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (C) 2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package runtime

import (
"context"
"errors"
"fmt"
"reflect"

"github.com/ava-labs/avalanchego/ids"

"github.com/ava-labs/hypersdk/codec"
)

var (
callInfoTypeInfo = reflect.TypeOf(CallInfo{})

errCannotOverwrite = errors.New("trying to overwrite set field")
)

type CallContext struct {
r *WasmRuntime
defaultCallInfo CallInfo
}

func (c CallContext) createCallInfo(callInfo *CallInfo) (*CallInfo, error) {
newCallInfo := *callInfo
resultInfo := reflect.ValueOf(&newCallInfo)
defaults := reflect.ValueOf(c.defaultCallInfo)
for i := 0; i < defaults.NumField(); i++ {
defaultField := defaults.Field(i)
if !defaultField.IsZero() {
resultField := resultInfo.Elem().Field(i)
if !resultField.IsZero() {
return nil, fmt.Errorf("%w %s", errCannotOverwrite, callInfoTypeInfo.Field(i).Name)
}
resultField.Set(defaultField)
}
}
return &newCallInfo, nil
}

func (c CallContext) CallProgram(ctx context.Context, info *CallInfo) ([]byte, error) {
newInfo, err := c.createCallInfo(info)
if err != nil {
return nil, err
}
return c.r.CallProgram(ctx, newInfo)
}

func (c CallContext) WithStateManager(manager StateManager) CallContext {
c.defaultCallInfo.State = manager
return c
}

func (c CallContext) WithActor(address codec.Address) CallContext {
c.defaultCallInfo.Actor = address
return c
}

func (c CallContext) WithFunction(s string) CallContext {
c.defaultCallInfo.FunctionName = s
return c
}

func (c CallContext) WithProgram(address codec.Address) CallContext {
c.defaultCallInfo.Program = address
return c
}

func (c CallContext) WithFuel(u uint64) CallContext {
c.defaultCallInfo.Fuel = u
return c
}

func (c CallContext) WithParams(bytes []byte) CallContext {
c.defaultCallInfo.Params = bytes
return c
}

func (c CallContext) WithHeight(height uint64) CallContext {
c.defaultCallInfo.Height = height
return c
}

func (c CallContext) WithActionID(actionID ids.ID) CallContext {
c.defaultCallInfo.ActionID = actionID
return c
}
93 changes: 93 additions & 0 deletions x/programs/runtime/call_context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (C) 2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package runtime

import (
"context"
"testing"

"github.com/ava-labs/avalanchego/ids"
"github.com/ava-labs/avalanchego/utils/logging"
"github.com/bytecodealliance/wasmtime-go/v14"
"github.com/stretchr/testify/require"

"github.com/ava-labs/hypersdk/codec"
"github.com/ava-labs/hypersdk/x/programs/test"
)

func TestCallContext(t *testing.T) {
require := require.New(t)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
programID := ids.GenerateTestID()
programAccount := codec.CreateAddress(0, programID)
r := NewRuntime(
NewConfig(),
logging.NoLog{},
).WithDefaults(
&CallInfo{
State: &test.StateManager{ProgramsMap: map[ids.ID]string{programID: "call_program"}, AccountMap: map[codec.Address]ids.ID{programAccount: programID}},
Program: programAccount,
Fuel: 1000000,
})
actor := codec.CreateAddress(1, ids.GenerateTestID())

result, err := r.WithActor(actor).CallProgram(
ctx,
&CallInfo{
FunctionName: "actor_check",
})
require.NoError(err)
require.Equal(actor, into[codec.Address](result))

result, err = r.WithActor(codec.CreateAddress(2, ids.GenerateTestID())).CallProgram(
ctx,
&CallInfo{
FunctionName: "actor_check",
})
require.NoError(err)
require.NotEqual(actor, into[codec.Address](result))

result, err = r.WithFuel(0).CallProgram(
ctx,
&CallInfo{
FunctionName: "actor_check",
})
require.Equal(wasmtime.OutOfFuel, *err.(*wasmtime.Trap).Code())
require.Nil(result)
}

func TestCallContextPreventOverwrite(t *testing.T) {
require := require.New(t)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

program0ID := ids.GenerateTestID()
program0Address := codec.CreateAddress(0, program0ID)
program1ID := ids.GenerateTestID()
program1Address := codec.CreateAddress(1, program1ID)

r := NewRuntime(
NewConfig(),
logging.NoLog{},
).WithDefaults(
&CallInfo{
Program: program0Address,
State: &test.StateManager{ProgramsMap: map[ids.ID]string{program0ID: "call_program"}, AccountMap: map[codec.Address]ids.ID{program0Address: program0ID}},
Fuel: 1000000,
})

// try to use a context that has a default program with a different program
result, err := r.CallProgram(
ctx,
&CallInfo{
Program: program1Address,
State: &test.StateManager{ProgramsMap: map[ids.ID]string{program1ID: "call_program"}, AccountMap: map[codec.Address]ids.ID{program1Address: program1ID}},
FunctionName: "actor_check",
})
require.ErrorIs(err, errCannotOverwrite)
require.Nil(result)
}
4 changes: 4 additions & 0 deletions x/programs/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ func NewRuntime(
return runtime
}

func (r *WasmRuntime) WithDefaults(callInfo *CallInfo) CallContext {
return CallContext{r: r, defaultCallInfo: *callInfo}
}

func (r *WasmRuntime) AddImportModule(mod *ImportModule) {
r.hostImports.AddModule(mod)
r.linkerNeedsInitialization = true
Expand Down

0 comments on commit 5b7a43e

Please sign in to comment.