Skip to content

Commit

Permalink
refactor to pass smart wallet config to executor and update graphql e…
Browse files Browse the repository at this point in the history
…ndpoint for test (#127)
  • Loading branch information
v9n authored Feb 9, 2025
1 parent ef4472a commit a4d6835
Show file tree
Hide file tree
Showing 13 changed files with 382 additions and 104 deletions.
2 changes: 1 addition & 1 deletion aggregator/task_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func (agg *Aggregator) startTaskEngine(ctx context.Context) {
Prefix: "default",
})
agg.worker = apqueue.NewWorker(agg.queue, agg.db)
taskExecutor := taskengine.NewExecutor(agg.db, agg.logger)
taskExecutor := taskengine.NewExecutor(agg.config.SmartWallet, agg.db, agg.logger)
taskengine.SetMacroVars(agg.config.MacroVars)
taskengine.SetMacroSecrets(agg.config.MacroSecrets)
taskengine.SetCache(agg.cache)
Expand Down
2 changes: 1 addition & 1 deletion core/taskengine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ func (n *Engine) TriggerTask(user *model.User, payload *avsproto.UserTriggerTask

if payload.IsBlocking {
// Run the task inline, by pass the queue system
executor := NewExecutor(n.db, n.logger)
executor := NewExecutor(n.smartWalletConfig, n.db, n.logger)
execution, err := executor.RunTask(task, &queueTaskData)
if err == nil {
return &avsproto.UserTriggerTaskResp{
Expand Down
2 changes: 1 addition & 1 deletion core/taskengine/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ func TestTriggerAsync(t *testing.T) {
Prefix: "default",
})
worker := apqueue.NewWorker(n.queue, n.db)
taskExecutor := NewExecutor(n.db, testutil.GetLogger())
taskExecutor := NewExecutor(testutil.GetTestSmartWalletConfig(), db, testutil.GetLogger())
worker.RegisterProcessor(
JobTypeExecuteTask,
taskExecutor,
Expand Down
17 changes: 10 additions & 7 deletions core/taskengine/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,22 @@ import (
sdklogging "github.com/Layr-Labs/eigensdk-go/logging"

"github.com/AvaProtocol/ap-avs/core/apqueue"
"github.com/AvaProtocol/ap-avs/core/config"
"github.com/AvaProtocol/ap-avs/storage"
)

func NewExecutor(db storage.Storage, logger sdklogging.Logger) *TaskExecutor {
func NewExecutor(config *config.SmartWalletConfig, db storage.Storage, logger sdklogging.Logger) *TaskExecutor {
return &TaskExecutor{
db: db,
logger: logger,
db: db,
logger: logger,
smartWalletConfig: config,
}
}

type TaskExecutor struct {
db storage.Storage
logger sdklogging.Logger
db storage.Storage
logger sdklogging.Logger
smartWalletConfig *config.SmartWalletConfig
}

type QueueExecutionData struct {
Expand Down Expand Up @@ -79,8 +82,8 @@ func (x *TaskExecutor) RunTask(task *model.Task, queueData *QueueExecutionData)
}
triggerMetadata := queueData.TriggerMetadata

vm, err := NewVMWithData(task.Id, task.Trigger, triggerMetadata, task.Nodes, task.Edges)
vm.secrets, _ = LoadSecretForTask(x.db, task)
secrets, _ := LoadSecretForTask(x.db, task)
vm, err := NewVMWithData(task, triggerMetadata, x.smartWalletConfig, secrets)

if err != nil {
return nil, err
Expand Down
73 changes: 62 additions & 11 deletions core/taskengine/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@ import (

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/ethclient"
"github.com/expr-lang/expr"

"github.com/AvaProtocol/ap-avs/core/config"
"github.com/AvaProtocol/ap-avs/core/taskengine/macros"
"github.com/AvaProtocol/ap-avs/model"
"github.com/AvaProtocol/ap-avs/pkg/erc20"
avsproto "github.com/AvaProtocol/ap-avs/protobuf"
)
Expand Down Expand Up @@ -57,6 +60,7 @@ type VM struct {
TaskNodes map[string]*avsproto.TaskNode
TaskEdges []*avsproto.TaskEdge
TaskTrigger *avsproto.TaskTrigger
TaskOwner common.Address

// executin logs and result per plans
ExecutionLogs []*avsproto.Execution_Step
Expand All @@ -73,7 +77,8 @@ type VM struct {
entrypoint string
instructionCount int64

logger sdklogging.Logger
smartWalletConfig *config.SmartWalletConfig
logger sdklogging.Logger
}

func NewVM() *VM {
Expand Down Expand Up @@ -132,19 +137,21 @@ func (v *VM) GetNodeNameAsVar(nodeID string) string {
return standardized
}

func NewVMWithData(taskID string, trigger *avsproto.TaskTrigger, triggerMetadata *avsproto.TriggerMetadata, nodes []*avsproto.TaskNode, edges []*avsproto.TaskEdge) (*VM, error) {
func NewVMWithData(task *model.Task, triggerMetadata *avsproto.TriggerMetadata, smartWalletConfig *config.SmartWalletConfig, secrets map[string]string) (*VM, error) {
v := &VM{
Status: VMStateInitialize,
TaskEdges: edges,
TaskNodes: make(map[string]*avsproto.TaskNode),
TaskTrigger: trigger,
plans: make(map[string]*Step),
mu: &sync.Mutex{},
instructionCount: 0,
secrets: map[string]string{},
Status: VMStateInitialize,
TaskEdges: task.Edges,
TaskNodes: make(map[string]*avsproto.TaskNode),
TaskTrigger: task.Trigger,
TaskOwner: common.HexToAddress(task.Owner),
plans: make(map[string]*Step),
mu: &sync.Mutex{},
instructionCount: 0,
secrets: secrets,
smartWalletConfig: smartWalletConfig,
}

for _, node := range nodes {
for _, node := range task.Nodes {
v.TaskNodes[node.Id] = node
}

Expand Down Expand Up @@ -340,6 +347,10 @@ func (v *VM) executeNode(node *avsproto.TaskNode) (*avsproto.Execution_Step, err
executionLog, err = v.runGraphQL(node.Id, nodeValue)
} else if nodeValue := node.GetCustomCode(); nodeValue != nil {
executionLog, err = v.runCustomCode(node.Id, nodeValue)
} else if nodeValue := node.GetContractRead(); nodeValue != nil {
executionLog, err = v.runContractRead(node.Id, nodeValue)
} else if nodeValue := node.GetContractWrite(); nodeValue != nil {
executionLog, err = v.runContractWrite(node.Id, nodeValue)
}

return executionLog, err
Expand Down Expand Up @@ -411,6 +422,46 @@ func (v *VM) runGraphQL(stepID string, node *avsproto.GraphQLQueryNode) (*avspro
return executionLog, nil
}

func (v *VM) runContractRead(stepID string, node *avsproto.ContractReadNode) (*avsproto.Execution_Step, error) {
rpcClient, err := ethclient.Dial(v.smartWalletConfig.EthRpcUrl)
defer func() {
rpcClient.Close()
}()

if err != nil {
v.logger.Error("error execute contract read node", "task_id", v.TaskID, "step", stepID, "calldata", node.CallData, "error", err)
return nil, err

}

processor := NewContractReadProcessor(v, rpcClient)
executionLog, err := processor.Execute(stepID, node)

if err != nil {
v.logger.Error("error execute contract read node", "task_id", v.TaskID, "step", stepID, "calldata", node.CallData, "error", err)
return nil, err
}
v.ExecutionLogs = append(v.ExecutionLogs, executionLog)

return executionLog, nil
}

func (v *VM) runContractWrite(stepID string, node *avsproto.ContractWriteNode) (*avsproto.Execution_Step, error) {
rpcClient, err := ethclient.Dial(v.smartWalletConfig.EthRpcUrl)
defer func() {
rpcClient.Close()
}()

processor := NewContractWriteProcessor(v, rpcClient, v.smartWalletConfig, v.TaskOwner)
executionLog, err := processor.Execute(stepID, node)
if err != nil {
v.logger.Error("error execute contract write node", "task_id", v.TaskID, "step", stepID, "calldata", node.CallData, "error", err)
}
v.ExecutionLogs = append(v.ExecutionLogs, executionLog)

return executionLog, nil
}

func (v *VM) runCustomCode(stepID string, node *avsproto.CustomCodeNode) (*avsproto.Execution_Step, error) {
r := NewJSProcessor(v)
executionLog, err := r.Execute(stepID, node)
Expand Down
20 changes: 18 additions & 2 deletions core/taskengine/vm_runner_contract_read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"

"github.com/AvaProtocol/ap-avs/core/testutil"
"github.com/AvaProtocol/ap-avs/model"
avsproto "github.com/AvaProtocol/ap-avs/protobuf"
)

Expand Down Expand Up @@ -38,7 +39,14 @@ func TestContractReadSimpleReturn(t *testing.T) {
},
}

vm, err := NewVMWithData("123", trigger, nil, nodes, edges)
vm, err := NewVMWithData(&model.Task{
&avsproto.Task{
Id: "123",
Nodes: nodes,
Edges: edges,
Trigger: trigger,
},
}, nil, testutil.GetTestSmartWalletConfig(), nil)

n := NewContractReadProcessor(vm, testutil.GetRpcClient())

Expand Down Expand Up @@ -93,7 +101,15 @@ func TestContractReadComplexReturn(t *testing.T) {
},
}

vm, err := NewVMWithData("123abc", trigger, nil, nodes, edges)
vm, err := NewVMWithData(&model.Task{
&avsproto.Task{
Id: "123abc",
Nodes: nodes,
Edges: edges,
Trigger: trigger,
},
}, nil, testutil.GetTestSmartWalletConfig(), nil)

n := NewContractReadProcessor(vm, testutil.GetRpcClient())
step, err := n.Execute("123abc", node)

Expand Down
10 changes: 9 additions & 1 deletion core/taskengine/vm_runner_contract_write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/AvaProtocol/ap-avs/core/chainio/aa"
"github.com/AvaProtocol/ap-avs/core/testutil"
"github.com/AvaProtocol/ap-avs/model"
avsproto "github.com/AvaProtocol/ap-avs/protobuf"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/ethclient"
Expand Down Expand Up @@ -43,7 +44,14 @@ func TestContractWriteSimpleReturn(t *testing.T) {
},
}

vm, err := NewVMWithData("123", trigger, nil, nodes, edges)
vm, err := NewVMWithData(&model.Task{
&avsproto.Task{
Id: "123",
Nodes: nodes,
Edges: edges,
Trigger: trigger,
},
}, nil, testutil.GetTestSmartWalletConfig(), nil)

client, _ := ethclient.Dial(smartWalletConfig.EthRpcUrl)

Expand Down
22 changes: 20 additions & 2 deletions core/taskengine/vm_runner_customcode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"strings"
"testing"

"github.com/AvaProtocol/ap-avs/core/testutil"
"github.com/AvaProtocol/ap-avs/model"
avsproto "github.com/AvaProtocol/ap-avs/protobuf"
)

Expand Down Expand Up @@ -34,7 +36,15 @@ func TestRunJavaScript(t *testing.T) {
},
}

vm, err := NewVMWithData("123abc", trigger, nil, nodes, edges)
vm, err := NewVMWithData(&model.Task{
&avsproto.Task{
Id: "123abc",
Nodes: nodes,
Edges: edges,
Trigger: trigger,
},
}, nil, testutil.GetTestSmartWalletConfig(), nil)

n := NewJSProcessor(vm)

step, err := n.Execute("123abc", node)
Expand Down Expand Up @@ -87,7 +97,15 @@ func TestRunJavaScriptComplex(t *testing.T) {
},
}

vm, _ := NewVMWithData("123abc", trigger, nil, nodes, edges)
vm, _ := NewVMWithData(&model.Task{
&avsproto.Task{
Id: "123abc",
Nodes: nodes,
Edges: edges,
Trigger: trigger,
},
}, nil, testutil.GetTestSmartWalletConfig(), nil)

n := NewJSProcessor(vm)

step, _ := n.Execute("123abc", node)
Expand Down
22 changes: 20 additions & 2 deletions core/taskengine/vm_runner_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"strings"
"testing"

"github.com/AvaProtocol/ap-avs/core/testutil"
"github.com/AvaProtocol/ap-avs/model"
avsproto "github.com/AvaProtocol/ap-avs/protobuf"
)

Expand Down Expand Up @@ -40,7 +42,15 @@ func TestFilter(t *testing.T) {
},
}

vm, err := NewVMWithData("abc123", trigger, nil, nodes, edges)
vm, err := NewVMWithData(&model.Task{
&avsproto.Task{
Id: "123abc",
Nodes: nodes,
Edges: edges,
Trigger: trigger,
},
}, nil, testutil.GetTestSmartWalletConfig(), nil)

if err != nil {
t.Errorf("expect vm initialize succesully but failed with error: %v", err)
}
Expand Down Expand Up @@ -100,7 +110,15 @@ func TestFilterComplexLogic(t *testing.T) {
},
}

vm, err := NewVMWithData("abc123", trigger, nil, nodes, edges)
vm, err := NewVMWithData(&model.Task{
&avsproto.Task{
Id: "123abc",
Nodes: nodes,
Edges: edges,
Trigger: trigger,
},
}, nil, testutil.GetTestSmartWalletConfig(), nil)

if err != nil {
t.Errorf("expect vm initialize succesully but failed with error: %v", err)
}
Expand Down
Loading

0 comments on commit a4d6835

Please sign in to comment.