Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Lazy load gRPC plugin #353

Merged
merged 7 commits into from
Jun 5, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/athena v1.0.0
github.com/bstadlbauer/dask-k8s-operator-go-client v0.1.0
github.com/coocood/freecache v1.1.1
github.com/flyteorg/flyteidl v1.5.2
github.com/flyteorg/flyteidl v1.5.9
github.com/flyteorg/flytestdlib v1.0.15
github.com/go-test/deep v1.0.7
github.com/golang/protobuf v1.5.2
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ github.com/evanphx/json-patch v4.12.0+incompatible/go.mod h1:50XU6AFN0ol/bzJsmQL
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/flyteorg/flyteidl v1.5.2 h1:DZPzYkTg92qA4e17fd0ZW1M+gh1gJKh/VOK+F4bYgM8=
github.com/flyteorg/flyteidl v1.5.2/go.mod h1:ckLjB51moX4L0oQml+WTCrPK50zrJf6IZJ6LPC0RB4I=
github.com/flyteorg/flyteidl v1.5.9 h1:jqoenDx6p1Uncja1LMSzWmq3mBrMQ6vOdzN7/Ma3P28=
github.com/flyteorg/flyteidl v1.5.9/go.mod h1:EtE/muM2lHHgBabjYcxqe9TWeJSL0kXwbI0RgVwI4Og=
github.com/flyteorg/flytestdlib v1.0.15 h1:kv9jDQmytbE84caY+pkZN8trJU2ouSAmESzpTEhfTt0=
github.com/flyteorg/flytestdlib v1.0.15/go.mod h1:ghw/cjY0sEWIIbyCtcJnL/Gt7ZS7gf9SUi0CCPhbz3s=
github.com/flyteorg/stow v0.3.6 h1:jt50ciM14qhKBaIrB+ppXXY+SXB59FNREFgTJqCyqIk=
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grpc
package agent

import (
"time"
Expand Down Expand Up @@ -39,22 +39,22 @@ var (
Value: 50,
},
},
DefaultGrpcEndpoint: "dns:///external-plugin-service.flyte.svc.cluster.local:80",
DefaultGrpcEndpoint: "dns:///flyte-agent.flyte.svc.cluster.local:80",
SupportedTaskTypes: []string{"task_type_1", "task_type_2"},
}

configSection = pluginsConfig.MustRegisterSubSection("external-plugin-service", &defaultConfig)
configSection = pluginsConfig.MustRegisterSubSection("agent-service", &defaultConfig)
)

// Config is config for 'databricks' plugin
// Config is config for 'agent' plugin
type Config struct {
// WebAPI defines config for the base WebAPI plugin
WebAPI webapi.PluginConfig `json:"webApi" pflag:",Defines config for the base WebAPI plugin."`

// ResourceConstraints defines resource constraints on how many executions to be created per project/overall at any given time
ResourceConstraints core.ResourceConstraintsSpec `json:"resourceConstraints" pflag:"-,Defines resource constraints on how many executions to be created per project/overall at any given time."`

DefaultGrpcEndpoint string `json:"defaultGrpcEndpoint" pflag:",The default grpc endpoint of external plugin service."`
DefaultGrpcEndpoint string `json:"defaultGrpcEndpoint" pflag:",The default grpc endpoint of agent service."`

// Maps endpoint to their plugin handler. {TaskType: Endpoint}
EndpointForTaskTypes map[string]string `json:"endpointForTaskTypes" pflag:"-,"`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grpc
package agent

import (
"testing"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grpc
package agent

import (
"context"
Expand Down Expand Up @@ -54,11 +54,11 @@ func (m *MockClient) DeleteTask(_ context.Context, _ *service.TaskDeleteRequest,
return &service.TaskDeleteResponse{}, nil
}

func mockGetClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.ExternalPluginServiceClient, error) {
func mockGetClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.AgentServiceClient, error) {
return &MockClient{}, nil
}

func mockGetBadClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.ExternalPluginServiceClient, error) {
func mockGetBadClientFunc(_ context.Context, _ string, _ map[string]*grpc.ClientConn) (service.AgentServiceClient, error) {
return nil, fmt.Errorf("error")
}

Expand Down Expand Up @@ -98,7 +98,7 @@ func TestEndToEnd(t *testing.T) {
basePrefix := storage.DataReference("fake://bucket/prefix/")

t.Run("run a job", func(t *testing.T) {
pluginEntry := pluginmachinery.CreateRemotePlugin(newMockGrpcPlugin())
pluginEntry := pluginmachinery.CreateRemotePlugin(newMockAgentPlugin())
plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("test1"))
assert.NoError(t, err)

Expand All @@ -107,8 +107,8 @@ func TestEndToEnd(t *testing.T) {
})

t.Run("failed to create a job", func(t *testing.T) {
grpcPlugin := newMockGrpcPlugin()
grpcPlugin.PluginLoader = func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
agentPlugin := newMockAgentPlugin()
agentPlugin.PluginLoader = func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return &MockPlugin{
Plugin{
metricScope: iCtx.MetricsScope(),
Expand All @@ -117,7 +117,7 @@ func TestEndToEnd(t *testing.T) {
},
}, nil
}
pluginEntry := pluginmachinery.CreateRemotePlugin(grpcPlugin)
pluginEntry := pluginmachinery.CreateRemotePlugin(agentPlugin)
plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("test2"))
assert.NoError(t, err)

Expand All @@ -144,8 +144,8 @@ func TestEndToEnd(t *testing.T) {
tr.OnRead(context.Background()).Return(nil, fmt.Errorf("read fail"))
tCtx.OnTaskReader().Return(tr)

grpcPlugin := newMockGrpcPlugin()
pluginEntry := pluginmachinery.CreateRemotePlugin(grpcPlugin)
agentPlugin := newAgentPlugin()
pluginEntry := pluginmachinery.CreateRemotePlugin(agentPlugin)
plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("test3"))
assert.NoError(t, err)

Expand All @@ -165,8 +165,8 @@ func TestEndToEnd(t *testing.T) {
inputReader.OnGetMatch(mock.Anything).Return(nil, fmt.Errorf("read fail"))
tCtx.OnInputReader().Return(inputReader)

grpcPlugin := newMockGrpcPlugin()
pluginEntry := pluginmachinery.CreateRemotePlugin(grpcPlugin)
agentPlugin := newMockAgentPlugin()
pluginEntry := pluginmachinery.CreateRemotePlugin(agentPlugin)
plugin, err := pluginEntry.LoadPlugin(context.TODO(), newFakeSetupContext("test4"))
assert.NoError(t, err)

Expand Down Expand Up @@ -239,7 +239,7 @@ func getTaskContext(t *testing.T) *pluginCoreMocks.TaskExecutionContext {
return tCtx
}

func newMockGrpcPlugin() webapi.PluginEntry {
func newMockAgentPlugin() webapi.PluginEntry {
return webapi.PluginEntry{
ID: "external-plugin-service",
SupportedTaskTypes: []core.TaskType{"bigquery_query_job_task"},
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grpc
package agent

import (
"context"
Expand All @@ -19,7 +19,7 @@ import (
"google.golang.org/grpc"
)

type GetClientFunc func(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.ExternalPluginServiceClient, error)
type GetClientFunc func(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.AgentServiceClient, error)

type Plugin struct {
metricScope promutils.Scope
Expand Down Expand Up @@ -67,7 +67,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
endpoint := getFinalEndpoint(taskTemplate.Type, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes)
client, err := p.getClient(ctx, endpoint, p.connectionCache)
if err != nil {
return nil, nil, fmt.Errorf("failed to connect external plugin service with error: %v", err)
return nil, nil, fmt.Errorf("failed to connect to agent with error: %v", err)
}

res, err := client.CreateTask(ctx, &service.TaskCreateRequest{Inputs: inputs, Template: taskTemplate, OutputPrefix: outputPrefix})
Expand All @@ -89,7 +89,7 @@ func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest weba
endpoint := getFinalEndpoint(metadata.TaskType, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes)
client, err := p.getClient(ctx, endpoint, p.connectionCache)
if err != nil {
return nil, fmt.Errorf("failed to connect external plugin service with error: %v", err)
return nil, fmt.Errorf("failed to connect to agent with error: %v", err)
}

res, err := client.GetTask(ctx, &service.TaskGetRequest{TaskType: metadata.TaskType, JobId: metadata.JobID})
Expand All @@ -112,7 +112,7 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error
endpoint := getFinalEndpoint(metadata.TaskType, p.cfg.DefaultGrpcEndpoint, p.cfg.EndpointForTaskTypes)
client, err := p.getClient(ctx, endpoint, p.connectionCache)
if err != nil {
return fmt.Errorf("failed to connect external plugin service with error: %v", err)
return fmt.Errorf("failed to connect to agent with error: %v", err)
}

_, err = client.DeleteTask(ctx, &service.TaskDeleteRequest{TaskType: metadata.TaskType, JobId: metadata.JobID})
Expand Down Expand Up @@ -150,10 +150,10 @@ func getFinalEndpoint(taskType, defaultEndpoint string, endpointForTaskTypes map
return defaultEndpoint
}

func getClientFunc(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.ExternalPluginServiceClient, error) {
func getClientFunc(ctx context.Context, endpoint string, connectionCache map[string]*grpc.ClientConn) (service.AgentServiceClient, error) {
conn, ok := connectionCache[endpoint]
if ok {
return service.NewExternalPluginServiceClient(conn), nil
return service.NewAgentServiceClient(conn), nil
}
var opts []grpc.DialOption
var err error
Expand All @@ -178,14 +178,14 @@ func getClientFunc(ctx context.Context, endpoint string, connectionCache map[str
}
}()
}()
return service.NewExternalPluginServiceClient(conn), nil
return service.NewAgentServiceClient(conn), nil
}

func newGrpcPlugin() webapi.PluginEntry {
func newAgentPlugin() webapi.PluginEntry {
supportedTaskTypes := GetConfig().SupportedTaskTypes

return webapi.PluginEntry{
ID: "external-plugin-service",
ID: "agent-service",
SupportedTaskTypes: supportedTaskTypes,
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return &Plugin{
Expand All @@ -198,9 +198,9 @@ func newGrpcPlugin() webapi.PluginEntry {
}
}

func init() {
func RegisterAgentPlugin() {
gob.Register(ResourceMetaWrapper{})
gob.Register(ResourceWrapper{})

pluginmachinery.PluginRegistry().RegisterRemotePlugin(newGrpcPlugin())
pluginmachinery.PluginRegistry().RegisterRemotePlugin(newAgentPlugin())
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package grpc
package agent

import (
"context"
Expand All @@ -25,7 +25,7 @@ func TestPlugin(t *testing.T) {
cfg := defaultConfig
cfg.WebAPI.Caching.Workers = 1
cfg.WebAPI.Caching.ResyncInterval.Duration = 5 * time.Second
cfg.DefaultGrpcEndpoint = "test-service.flyte.svc.cluster.local:80"
cfg.DefaultGrpcEndpoint = "test-agent.flyte.svc.cluster.local:80"
cfg.EndpointForTaskTypes = map[string]string{"spark": "localhost:80"}
err := SetConfig(&cfg)
assert.NoError(t, err)
Expand All @@ -38,10 +38,10 @@ func TestPlugin(t *testing.T) {
assert.Equal(t, plugin.cfg.ResourceConstraints, constraints)
})

t.Run("tet newGrpcPlugin", func(t *testing.T) {
p := newGrpcPlugin()
t.Run("tet newAgentPlugin", func(t *testing.T) {
p := newAgentPlugin()
assert.NotNil(t, p)
assert.Equal(t, p.ID, "external-plugin-service")
assert.Equal(t, p.ID, "agent-service")
assert.NotNil(t, p.PluginLoader)
})

Expand Down