-
Notifications
You must be signed in to change notification settings - Fork 125
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[x/programs] Add CallContext to programs to allow setting of default …
…values (#1058) * Add CallContext
- Loading branch information
1 parent
5c2cb96
commit 5b7a43e
Showing
3 changed files
with
188 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
// Copyright (C) 2023, Ava Labs, Inc. All rights reserved. | ||
// See the file LICENSE for licensing terms. | ||
|
||
package runtime | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"reflect" | ||
|
||
"github.com/ava-labs/avalanchego/ids" | ||
|
||
"github.com/ava-labs/hypersdk/codec" | ||
) | ||
|
||
var ( | ||
callInfoTypeInfo = reflect.TypeOf(CallInfo{}) | ||
|
||
errCannotOverwrite = errors.New("trying to overwrite set field") | ||
) | ||
|
||
type CallContext struct { | ||
r *WasmRuntime | ||
defaultCallInfo CallInfo | ||
} | ||
|
||
func (c CallContext) createCallInfo(callInfo *CallInfo) (*CallInfo, error) { | ||
newCallInfo := *callInfo | ||
resultInfo := reflect.ValueOf(&newCallInfo) | ||
defaults := reflect.ValueOf(c.defaultCallInfo) | ||
for i := 0; i < defaults.NumField(); i++ { | ||
defaultField := defaults.Field(i) | ||
if !defaultField.IsZero() { | ||
resultField := resultInfo.Elem().Field(i) | ||
if !resultField.IsZero() { | ||
return nil, fmt.Errorf("%w %s", errCannotOverwrite, callInfoTypeInfo.Field(i).Name) | ||
} | ||
resultField.Set(defaultField) | ||
} | ||
} | ||
return &newCallInfo, nil | ||
} | ||
|
||
func (c CallContext) CallProgram(ctx context.Context, info *CallInfo) ([]byte, error) { | ||
newInfo, err := c.createCallInfo(info) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return c.r.CallProgram(ctx, newInfo) | ||
} | ||
|
||
func (c CallContext) WithStateManager(manager StateManager) CallContext { | ||
c.defaultCallInfo.State = manager | ||
return c | ||
} | ||
|
||
func (c CallContext) WithActor(address codec.Address) CallContext { | ||
c.defaultCallInfo.Actor = address | ||
return c | ||
} | ||
|
||
func (c CallContext) WithFunction(s string) CallContext { | ||
c.defaultCallInfo.FunctionName = s | ||
return c | ||
} | ||
|
||
func (c CallContext) WithProgram(address codec.Address) CallContext { | ||
c.defaultCallInfo.Program = address | ||
return c | ||
} | ||
|
||
func (c CallContext) WithFuel(u uint64) CallContext { | ||
c.defaultCallInfo.Fuel = u | ||
return c | ||
} | ||
|
||
func (c CallContext) WithParams(bytes []byte) CallContext { | ||
c.defaultCallInfo.Params = bytes | ||
return c | ||
} | ||
|
||
func (c CallContext) WithHeight(height uint64) CallContext { | ||
c.defaultCallInfo.Height = height | ||
return c | ||
} | ||
|
||
func (c CallContext) WithActionID(actionID ids.ID) CallContext { | ||
c.defaultCallInfo.ActionID = actionID | ||
return c | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
// Copyright (C) 2023, Ava Labs, Inc. All rights reserved. | ||
// See the file LICENSE for licensing terms. | ||
|
||
package runtime | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/ava-labs/avalanchego/ids" | ||
"github.com/ava-labs/avalanchego/utils/logging" | ||
"github.com/bytecodealliance/wasmtime-go/v14" | ||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/ava-labs/hypersdk/codec" | ||
"github.com/ava-labs/hypersdk/x/programs/test" | ||
) | ||
|
||
func TestCallContext(t *testing.T) { | ||
require := require.New(t) | ||
|
||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
programID := ids.GenerateTestID() | ||
programAccount := codec.CreateAddress(0, programID) | ||
r := NewRuntime( | ||
NewConfig(), | ||
logging.NoLog{}, | ||
).WithDefaults( | ||
&CallInfo{ | ||
State: &test.StateManager{ProgramsMap: map[ids.ID]string{programID: "call_program"}, AccountMap: map[codec.Address]ids.ID{programAccount: programID}}, | ||
Program: programAccount, | ||
Fuel: 1000000, | ||
}) | ||
actor := codec.CreateAddress(1, ids.GenerateTestID()) | ||
|
||
result, err := r.WithActor(actor).CallProgram( | ||
ctx, | ||
&CallInfo{ | ||
FunctionName: "actor_check", | ||
}) | ||
require.NoError(err) | ||
require.Equal(actor, into[codec.Address](result)) | ||
|
||
result, err = r.WithActor(codec.CreateAddress(2, ids.GenerateTestID())).CallProgram( | ||
ctx, | ||
&CallInfo{ | ||
FunctionName: "actor_check", | ||
}) | ||
require.NoError(err) | ||
require.NotEqual(actor, into[codec.Address](result)) | ||
|
||
result, err = r.WithFuel(0).CallProgram( | ||
ctx, | ||
&CallInfo{ | ||
FunctionName: "actor_check", | ||
}) | ||
require.Equal(wasmtime.OutOfFuel, *err.(*wasmtime.Trap).Code()) | ||
require.Nil(result) | ||
} | ||
|
||
func TestCallContextPreventOverwrite(t *testing.T) { | ||
require := require.New(t) | ||
|
||
ctx, cancel := context.WithCancel(context.Background()) | ||
defer cancel() | ||
|
||
program0ID := ids.GenerateTestID() | ||
program0Address := codec.CreateAddress(0, program0ID) | ||
program1ID := ids.GenerateTestID() | ||
program1Address := codec.CreateAddress(1, program1ID) | ||
|
||
r := NewRuntime( | ||
NewConfig(), | ||
logging.NoLog{}, | ||
).WithDefaults( | ||
&CallInfo{ | ||
Program: program0Address, | ||
State: &test.StateManager{ProgramsMap: map[ids.ID]string{program0ID: "call_program"}, AccountMap: map[codec.Address]ids.ID{program0Address: program0ID}}, | ||
Fuel: 1000000, | ||
}) | ||
|
||
// try to use a context that has a default program with a different program | ||
result, err := r.CallProgram( | ||
ctx, | ||
&CallInfo{ | ||
Program: program1Address, | ||
State: &test.StateManager{ProgramsMap: map[ids.ID]string{program1ID: "call_program"}, AccountMap: map[codec.Address]ids.ID{program1Address: program1ID}}, | ||
FunctionName: "actor_check", | ||
}) | ||
require.ErrorIs(err, errCannotOverwrite) | ||
require.Nil(result) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters