Skip to content

Commit

Permalink
[x/programs] Add balance handling (#1086)
Browse files Browse the repository at this point in the history
* add balance host functions
  • Loading branch information
dboehm-avalabs authored Jul 9, 2024
1 parent 82fa13c commit 07b3ad2
Show file tree
Hide file tree
Showing 23 changed files with 368 additions and 31 deletions.
60 changes: 54 additions & 6 deletions x/programs/cmd/simulator/cmd/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package cmd
import (
"context"
"crypto/sha256"
"encoding/binary"
"errors"

"github.com/ava-labs/avalanchego/database"
Expand All @@ -30,6 +31,25 @@ type programStateManager struct {
state.Mutable
}

func (s *programStateManager) GetBalance(ctx context.Context, address codec.Address) (uint64, error) {
return getAccountBalance(ctx, s, address)
}

func (s *programStateManager) TransferBalance(ctx context.Context, from codec.Address, to codec.Address, amount uint64) error {
fromBalance, err := getAccountBalance(ctx, s, from)
if err != nil {
return err
}
if fromBalance < amount {
return errors.New("insufficient balance")
}
toBalance, err := getAccountBalance(ctx, s, to)
if err != nil {
return err
}
return setAccountBalance(ctx, s, to, toBalance+amount)
}

func (s *programStateManager) GetAccountProgram(ctx context.Context, account codec.Address) (ids.ID, error) {
programID, exists, err := getAccountProgram(ctx, s, account)
if err != nil {
Expand Down Expand Up @@ -107,11 +127,12 @@ func accountStateKey(key []byte) (k []byte) {
return
}

func accountDataKey(key []byte) (k []byte) {
k = make([]byte, 2+len(key))
func accountDataKey(account []byte, key []byte) (k []byte) {
k = make([]byte, 2+len(account)+len(key))
k[0] = accountPrefix
copy(k[1:], key)
k[len(k)-1] = accountDataPrefix
copy(k[1:], account)
k[1+len(account)] = accountDataPrefix
copy(k[2+len(account):], key)
return
}

Expand All @@ -122,6 +143,33 @@ func programKey(key []byte) (k []byte) {
return
}

func getAccountBalance(
ctx context.Context,
db state.Immutable,
account codec.Address,
) (
uint64,
error,
) {
v, err := db.GetValue(ctx, accountDataKey(account[:], []byte("balance")))
if errors.Is(err, database.ErrNotFound) {
return 0, nil
}
if err != nil {
return 0, err
}
return binary.BigEndian.Uint64(v), nil
}

func setAccountBalance(
ctx context.Context,
mu state.Mutable,
account codec.Address,
amount uint64,
) error {
return mu.Insert(ctx, accountDataKey(account[:], []byte("balance")), binary.BigEndian.AppendUint64(nil, amount))
}

// [programID] -> [programBytes]
func getAccountProgram(
ctx context.Context,
Expand All @@ -132,7 +180,7 @@ func getAccountProgram(
bool, // exists
error,
) {
v, err := db.GetValue(ctx, accountDataKey(account[:]))
v, err := db.GetValue(ctx, accountDataKey(account[:], []byte("program")))
if errors.Is(err, database.ErrNotFound) {
return ids.Empty, false, nil
}
Expand All @@ -148,7 +196,7 @@ func setAccountProgram(
account codec.Address,
programID ids.ID,
) error {
return mu.Insert(ctx, accountDataKey(account[:]), programID[:])
return mu.Insert(ctx, accountDataKey(account[:], []byte("program")), programID[:])
}

// [programID] -> [programBytes]
Expand Down
16 changes: 16 additions & 0 deletions x/programs/cmd/simulator/cmd/storage_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (C) 2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package cmd

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestPrefix(t *testing.T) {
require := require.New(t)
stateKey := accountDataKey([]byte{0}, []byte{1, 2, 3})
require.Equal([]byte{accountPrefix, 0, accountDataPrefix, 1, 2, 3}, stateKey)
}
5 changes: 5 additions & 0 deletions x/programs/runtime/call_context.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,8 @@ func (c CallContext) WithTimestamp(timestamp uint64) CallContext {
c.defaultCallInfo.Timestamp = timestamp
return c
}

func (c CallContext) WithValue(value uint64) CallContext {
c.defaultCallInfo.Value = value
return c
}
4 changes: 2 additions & 2 deletions x/programs/runtime/call_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func TestCallContext(t *testing.T) {
NewConfig(),
logging.NoLog{},
).WithDefaults(
&CallInfo{
CallInfo{
State: &test.StateManager{ProgramsMap: map[ids.ID]string{programID: "call_program"}, AccountMap: map[codec.Address]ids.ID{programAccount: programID}},
Program: programAccount,
Fuel: 1000000,
Expand Down Expand Up @@ -74,7 +74,7 @@ func TestCallContextPreventOverwrite(t *testing.T) {
NewConfig(),
logging.NoLog{},
).WithDefaults(
&CallInfo{
CallInfo{
Program: program0Address,
State: &test.StateManager{ProgramsMap: map[ids.ID]string{program0ID: "call_program"}, AccountMap: map[codec.Address]ids.ID{program0Address: program0ID}},
Fuel: 1000000,
Expand Down
45 changes: 45 additions & 0 deletions x/programs/runtime/import_balance.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (C) 2024, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package runtime

import (
"context"

"github.com/ava-labs/hypersdk/codec"
)

const (
sendBalanceCost = 10000
getBalanceCost = 10000
)

type transferBalanceInput struct {
To codec.Address
Amount uint64
}

func NewBalanceModule() *ImportModule {
return &ImportModule{
Name: "balance",
HostFunctions: map[string]HostFunction{
"get": {FuelCost: getBalanceCost, Function: Function[codec.Address, uint64](func(callInfo *CallInfo, address codec.Address) (uint64, error) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
return callInfo.State.GetBalance(ctx, address)
})},
"send": {FuelCost: sendBalanceCost, Function: Function[transferBalanceInput, Result[Unit, ProgramCallErrorCode]](func(callInfo *CallInfo, input transferBalanceInput) (Result[Unit, ProgramCallErrorCode], error) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
err := callInfo.State.TransferBalance(ctx, callInfo.Program, input.To, input.Amount)
if err != nil {
if extractedError, ok := ExtractProgramCallErrorCode(err); ok {
return Err[Unit, ProgramCallErrorCode](extractedError), nil
}
return Err[Unit, ProgramCallErrorCode](ExecutionFailure), err
}
return Ok[Unit, ProgramCallErrorCode](Unit{}), nil
})},
},
}
}
49 changes: 49 additions & 0 deletions x/programs/runtime/import_balance_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (C) 2023, Ava Labs, Inc. All rights reserved.
// See the file LICENSE for licensing terms.

package runtime

import (
"context"
"testing"

"github.com/ava-labs/avalanchego/ids"
"github.com/stretchr/testify/require"

"github.com/ava-labs/hypersdk/codec"
"github.com/ava-labs/hypersdk/x/programs/test"
)

func TestImportBalanceGetBalance(t *testing.T) {
require := require.New(t)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
actor := codec.CreateAddress(0, ids.GenerateTestID())
program := newTestProgram(ctx, "balance")
program.Runtime.StateManager.(test.StateManager).Balances[actor] = 3
result, err := program.WithActor(actor).Call("balance")
require.NoError(err)
require.Equal(uint64(3), into[uint64](result))
}

func TestImportBalanceSend(t *testing.T) {
require := require.New(t)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
actor := codec.CreateAddress(0, ids.GenerateTestID())
program := newTestProgram(ctx, "balance")
program.Runtime.StateManager.(test.StateManager).Balances[program.Address] = 3
result, err := program.Call("send_balance", actor)
require.NoError(err)
require.True(into[bool](result))

result, err = program.WithActor(actor).Call("balance")
require.NoError(err)
require.Equal(uint64(1), into[uint64](result))

result, err = program.WithActor(program.Address).Call("balance")
require.NoError(err)
require.Equal(uint64(2), into[uint64](result))
}
4 changes: 4 additions & 0 deletions x/programs/runtime/import_program.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ const (
ExecutionFailure ProgramCallErrorCode = iota
CallPanicked
OutOfFuel
InsufficientBalance
)

type callProgramInput struct {
Program codec.Address
FunctionName string
Params []byte
Fuel uint64
Value uint64
}

type deployProgramInput struct {
Expand Down Expand Up @@ -72,6 +74,7 @@ func NewProgramModule(r *WasmRuntime) *ImportModule {
newInfo.FunctionName = input.FunctionName
newInfo.Params = input.Params
newInfo.Fuel = input.Fuel
newInfo.Value = input.Value

result, err := r.CallProgram(
context.Background(),
Expand All @@ -85,6 +88,7 @@ func NewProgramModule(r *WasmRuntime) *ImportModule {

// return any remaining fuel to the calling program
callInfo.AddFuel(newInfo.RemainingFuel())
callInfo.Value += newInfo.Value

return Ok[RawBytes, ProgramCallErrorCode](result), nil
})},
Expand Down
12 changes: 10 additions & 2 deletions x/programs/runtime/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type CallInfo struct {
// the serialized parameters that will be passed to the called function
Params []byte

// the amount of fuel allowed to be consumed by wasm for this call
// the maximum amount of fuel allowed to be consumed by wasm for this call
Fuel uint64

// the height of the chain that this call was made from
Expand All @@ -55,6 +55,8 @@ type CallInfo struct {
// the action id that triggered this call
ActionID ids.ID

Value uint64

inst *ProgramInstance
}

Expand Down Expand Up @@ -83,11 +85,17 @@ type ProgramInstance struct {
result []byte
}

func (p *ProgramInstance) call(_ context.Context, callInfo *CallInfo) ([]byte, error) {
func (p *ProgramInstance) call(ctx context.Context, callInfo *CallInfo) ([]byte, error) {
if err := p.store.AddFuel(callInfo.Fuel); err != nil {
return nil, err
}

if callInfo.Value > 0 {
if err := callInfo.State.TransferBalance(ctx, callInfo.Actor, callInfo.Program, callInfo.Value); err != nil {
return nil, errors.New("insufficient balance")
}
}

// create the program context
programCtx := Context{
Program: callInfo.Program,
Expand Down
17 changes: 12 additions & 5 deletions x/programs/runtime/runtime.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,17 @@ type WasmRuntime struct {
}

type StateManager interface {
GetProgramState(address codec.Address) state.Mutable
ProgramStore
BalanceManager
ProgramManager
}

type BalanceManager interface {
GetBalance(ctx context.Context, address codec.Address) (uint64, error)
TransferBalance(ctx context.Context, from codec.Address, to codec.Address, amount uint64) error
}

type ProgramStore interface {
type ProgramManager interface {
GetProgramState(address codec.Address) state.Mutable
GetAccountProgram(ctx context.Context, account codec.Address) (ids.ID, error)
GetProgramBytes(ctx context.Context, programID ids.ID) ([]byte, error)
NewAccountWithProgram(ctx context.Context, programID ids.ID, accountCreationData []byte) (codec.Address, error)
Expand Down Expand Up @@ -62,14 +68,15 @@ func NewRuntime(
}

runtime.AddImportModule(NewLogModule())
runtime.AddImportModule(NewBalanceModule())
runtime.AddImportModule(NewStateAccessModule())
runtime.AddImportModule(NewProgramModule(runtime))

return runtime
}

func (r *WasmRuntime) WithDefaults(callInfo *CallInfo) CallContext {
return CallContext{r: r, defaultCallInfo: *callInfo}
func (r *WasmRuntime) WithDefaults(callInfo CallInfo) CallContext {
return CallContext{r: r, defaultCallInfo: callInfo}
}

func (r *WasmRuntime) AddImportModule(mod *ImportModule) {
Expand Down
29 changes: 29 additions & 0 deletions x/programs/runtime/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/ava-labs/hypersdk/codec"
"github.com/ava-labs/hypersdk/x/programs/test"
)

func BenchmarkRuntimeCallProgramBasic(b *testing.B) {
Expand All @@ -28,6 +29,34 @@ func BenchmarkRuntimeCallProgramBasic(b *testing.B) {
}
}

func TestRuntimeCallProgramBasicAttachValue(t *testing.T) {
require := require.New(t)

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

program := newTestProgram(ctx, "simple")
actor := codec.CreateAddress(0, ids.GenerateTestID())
program.Runtime.StateManager.(test.StateManager).Balances[actor] = 10

actorBalance, err := program.Runtime.StateManager.GetBalance(context.Background(), actor)
require.NoError(err)
require.Equal(uint64(10), actorBalance)

// calling a program with a value transfers that amount from the caller to the program
result, err := program.WithActor(actor).WithValue(4).Call("get_value")
require.NoError(err)
require.Equal(uint64(0), into[uint64](result))

actorBalance, err = program.Runtime.StateManager.GetBalance(context.Background(), actor)
require.NoError(err)
require.Equal(uint64(6), actorBalance)

programBalance, err := program.Runtime.StateManager.GetBalance(context.Background(), program.Address)
require.NoError(err)
require.Equal(uint64(4), programBalance)
}

func TestRuntimeCallProgramBasic(t *testing.T) {
require := require.New(t)

Expand Down
2 changes: 2 additions & 0 deletions x/programs/runtime/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,5 @@ func (o Option[T]) Some() (T, bool) {
func (o Option[T]) None() bool {
return o.isNone
}

type Unit struct{}
Loading

0 comments on commit 07b3ad2

Please sign in to comment.