Skip to content

Commit

Permalink
Watch agent metadata service (#5017)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Jun 8, 2024
1 parent 38883c7 commit cd37d1b
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 71 deletions.
71 changes: 41 additions & 30 deletions flyteplugins/go/tasks/plugins/webapi/agent/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package agent
import (
"context"
"crypto/x509"
"fmt"

"golang.org/x/exp/maps"
"google.golang.org/grpc"
Expand Down Expand Up @@ -98,8 +97,7 @@ func getFinalContext(ctx context.Context, operation string, agent *Deployment) (
return context.WithTimeout(ctx, timeout)
}

func initializeAgentRegistry(cs *ClientSet) (Registry, error) {
logger.Infof(context.Background(), "Initializing agent registry")
func updateAgentRegistry(ctx context.Context, cs *ClientSet) {
agentRegistry := make(Registry)
cfg := GetConfig()
var agentDeployments []*Deployment
Expand All @@ -115,25 +113,31 @@ func initializeAgentRegistry(cs *ClientSet) (Registry, error) {
}
agentDeployments = append(agentDeployments, maps.Values(cfg.AgentDeployments)...)
for _, agentDeployment := range agentDeployments {
client := cs.agentMetadataClients[agentDeployment.Endpoint]
client, ok := cs.agentMetadataClients[agentDeployment.Endpoint]
if !ok {
logger.Warningf(ctx, "Agent client not found in the clientSet for the endpoint: %v", agentDeployment.Endpoint)
continue
}

finalCtx, cancel := getFinalContext(context.Background(), "ListAgents", agentDeployment)
finalCtx, cancel := getFinalContext(ctx, "ListAgents", agentDeployment)
defer cancel()

res, err := client.ListAgents(finalCtx, &admin.ListAgentsRequest{})
if err != nil {
grpcStatus, ok := status.FromError(err)
if grpcStatus.Code() == codes.Unimplemented {
// we should not panic here, as we want to continue to support old agent settings
logger.Infof(context.Background(), "list agent method not implemented for agent: [%v]", agentDeployment)
logger.Warningf(finalCtx, "list agent method not implemented for agent: [%v]", agentDeployment.Endpoint)
continue
}

if !ok {
return nil, fmt.Errorf("failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment, err)
logger.Errorf(finalCtx, "failed to list agent: [%v] with a non-gRPC error: [%v]", agentDeployment.Endpoint, err)
continue
}

return nil, fmt.Errorf("failed to list agent: [%v] with error: [%v]", agentDeployment, err)
logger.Errorf(finalCtx, "failed to list agent: [%v] with error: [%v]", agentDeployment.Endpoint, err)
continue
}

for _, agent := range res.GetAgents() {
Expand All @@ -148,20 +152,27 @@ func initializeAgentRegistry(cs *ClientSet) (Registry, error) {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: agent.IsSync}
agentRegistry[supportedCategory.GetName()] = map[int32]*Agent{supportedCategory.GetVersion(): agent}
}
logger.Infof(context.Background(), "[%v] is a sync agent: [%v]", agent.Name, agent.IsSync)
logger.Infof(context.Background(), "[%v] supports task category: [%v]", agent.Name, supportedTaskCategories)
}
// If the agent doesn't implement the metadata service, we construct the registry based on the configuration
for taskType, agentDeploymentID := range cfg.AgentForTaskTypes {
if agentDeployment, ok := cfg.AgentDeployments[agentDeploymentID]; ok {
if _, ok := agentRegistry[taskType]; !ok {
agent := &Agent{AgentDeployment: agentDeployment, IsSync: false}
agentRegistry[taskType] = map[int32]*Agent{defaultTaskTypeVersion: agent}
}
}
}
}

return agentRegistry, nil
logger.Debugf(ctx, "AgentDeployment service supports task types: %v", maps.Keys(agentRegistry))
setAgentRegistry(agentRegistry)
}

func initializeClients(ctx context.Context) (*ClientSet, error) {
logger.Infof(ctx, "Initializing agent clients")

asyncAgentClients := make(map[string]service.AsyncAgentServiceClient)
syncAgentClients := make(map[string]service.SyncAgentServiceClient)
agentMetadataClients := make(map[string]service.AgentMetadataServiceClient)
func getAgentClientSets(ctx context.Context) *ClientSet {
clientSet := &ClientSet{
asyncAgentClients: make(map[string]service.AsyncAgentServiceClient),
syncAgentClients: make(map[string]service.SyncAgentServiceClient),
agentMetadataClients: make(map[string]service.AgentMetadataServiceClient),
}

var agentDeployments []*Deployment
cfg := GetConfig()
Expand All @@ -170,19 +181,19 @@ func initializeClients(ctx context.Context) (*ClientSet, error) {
agentDeployments = append(agentDeployments, &cfg.DefaultAgent)
}
agentDeployments = append(agentDeployments, maps.Values(cfg.AgentDeployments)...)
for _, agentService := range agentDeployments {
conn, err := getGrpcConnection(ctx, agentService)
for _, agentDeployment := range agentDeployments {
if _, ok := clientSet.agentMetadataClients[agentDeployment.Endpoint]; ok {
logger.Infof(ctx, "Agent client already initialized for [%v]", agentDeployment.Endpoint)
continue
}
conn, err := getGrpcConnection(ctx, agentDeployment)
if err != nil {
return nil, err
logger.Errorf(ctx, "failed to create connection to agent: [%v] with error: [%v]", agentDeployment, err)
continue
}
syncAgentClients[agentService.Endpoint] = service.NewSyncAgentServiceClient(conn)
asyncAgentClients[agentService.Endpoint] = service.NewAsyncAgentServiceClient(conn)
agentMetadataClients[agentService.Endpoint] = service.NewAgentMetadataServiceClient(conn)
clientSet.syncAgentClients[agentDeployment.Endpoint] = service.NewSyncAgentServiceClient(conn)
clientSet.asyncAgentClients[agentDeployment.Endpoint] = service.NewAsyncAgentServiceClient(conn)
clientSet.agentMetadataClients[agentDeployment.Endpoint] = service.NewAgentMetadataServiceClient(conn)
}

return &ClientSet{
syncAgentClients: syncAgentClients,
asyncAgentClients: asyncAgentClients,
agentMetadataClients: agentMetadataClients,
}, nil
return clientSet
}
4 changes: 1 addition & 3 deletions flyteplugins/go/tasks/plugins/webapi/agent/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ func TestInitializeClients(t *testing.T) {
ctx := context.Background()
err := SetConfig(&cfg)
assert.NoError(t, err)
cs, err := initializeClients(ctx)
assert.NoError(t, err)
assert.NotNil(t, cs)
cs := getAgentClientSets(ctx)
_, ok := cs.syncAgentClients["y"]
assert.True(t, ok)
_, ok = cs.asyncAgentClients["x"]
Expand Down
4 changes: 4 additions & 0 deletions flyteplugins/go/tasks/plugins/webapi/agent/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ var (
// AsyncPlugin should be registered to at least one task type.
// Reference: https://github.com/flyteorg/flyte/blob/master/flyteplugins/go/tasks/pluginmachinery/registry.go#L27
SupportedTaskTypes: []string{"task_type_1", "task_type_2"},
PollInterval: config.Duration{Duration: 10 * time.Second},
}

configSection = pluginsConfig.MustRegisterSubSection("agent-service", &defaultConfig)
Expand All @@ -71,6 +72,9 @@ type Config struct {

// SupportedTaskTypes is a list of task types that are supported by this plugin.
SupportedTaskTypes []string `json:"supportedTaskTypes" pflag:"-,Defines a list of task types that are supported by this plugin."`

// PollInterval is the interval at which the plugin should poll the agent for metadata updates
PollInterval config.Duration `json:"pollInterval" pflag:",The interval at which the plugin should poll the agent for metadata updates."`
}

type Deployment struct {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ import (
)

func TestEndToEnd(t *testing.T) {
agentRegistry = Registry{
"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}},
"spark": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: false}},
}
iter := func(ctx context.Context, tCtx pluginCore.TaskExecutionContext) error {
return nil
}
Expand Down Expand Up @@ -118,6 +122,7 @@ func TestEndToEnd(t *testing.T) {
cfg: GetConfig(),
cs: &ClientSet{
asyncAgentClients: map[string]service.AsyncAgentServiceClient{},
syncAgentClients: map[string]service.SyncAgentServiceClient{},
agentMetadataClients: map[string]service.AgentMetadataServiceClient{},
},
}, nil
Expand Down Expand Up @@ -326,7 +331,6 @@ func newMockSyncAgentPlugin() webapi.PluginEntry {
defaultAgentEndpoint: syncAgentClient,
},
},
agentRegistry: Registry{"openai": {defaultTaskTypeVersion: {AgentDeployment: &Deployment{Endpoint: defaultAgentEndpoint}, IsSync: true}}},
}, nil
},
}
Expand Down
77 changes: 47 additions & 30 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import (
"context"
"encoding/gob"
"fmt"
"sync"
"time"

"golang.org/x/exp/maps"
"k8s.io/apimachinery/pkg/util/wait"

"github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/admin"
flyteIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
Expand All @@ -24,11 +26,27 @@ import (

type Registry map[string]map[int32]*Agent // map[taskTypeName][taskTypeVersion] => Agent

type Plugin struct {
metricScope promutils.Scope
cfg *Config
cs *ClientSet
var (
agentRegistry Registry
mu sync.RWMutex
)

func getAgentRegistry() Registry {
mu.Lock()
defer mu.Unlock()
return agentRegistry
}

func setAgentRegistry(r Registry) {
mu.Lock()
defer mu.Unlock()
agentRegistry = r
}

type Plugin struct {
metricScope promutils.Scope
cfg *Config
cs *ClientSet
}

type ResourceWrapper struct {
Expand Down Expand Up @@ -95,7 +113,7 @@ func (p Plugin) Create(ctx context.Context, taskCtx webapi.TaskExecutionContextR
outputPrefix := taskCtx.OutputWriter().GetOutputPrefixPath().String()

taskCategory := admin.TaskCategory{Name: taskTemplate.Type, Version: taskTemplate.TaskTypeVersion}
agent, isSync := getFinalAgent(&taskCategory, p.cfg, p.agentRegistry)
agent, isSync := getFinalAgent(&taskCategory, p.cfg)

taskExecutionMetadata := buildTaskExecutionMetadata(taskCtx.TaskExecutionMetadata())

Expand Down Expand Up @@ -193,7 +211,7 @@ func (p Plugin) ExecuteTaskSync(

func (p Plugin) Get(ctx context.Context, taskCtx webapi.GetContext) (latest webapi.Resource, err error) {
metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper)
agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg, p.agentRegistry)
agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg)

client, err := p.getAsyncAgentClient(ctx, agent)
if err != nil {
Expand Down Expand Up @@ -226,7 +244,7 @@ func (p Plugin) Delete(ctx context.Context, taskCtx webapi.DeleteContext) error
return nil
}
metadata := taskCtx.ResourceMeta().(ResourceMetaWrapper)
agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg, p.agentRegistry)
agent, _ := getFinalAgent(&metadata.TaskCategory, p.cfg)

client, err := p.getAsyncAgentClient(ctx, agent)
if err != nil {
Expand Down Expand Up @@ -322,6 +340,13 @@ func (p Plugin) getAsyncAgentClient(ctx context.Context, agent *Deployment) (ser
return client, nil
}

func (p Plugin) watchAgents(ctx context.Context) {
go wait.Until(func() {
clientSet := getAgentClientSets(ctx)
updateAgentRegistry(ctx, clientSet)
}, p.cfg.PollInterval.Duration, ctx.Done())
}

func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *flyteIdl.LiteralMap) error {
taskTemplate, err := taskCtx.TaskReader().Read(ctx)
if err != nil {
Expand All @@ -344,11 +369,11 @@ func writeOutput(ctx context.Context, taskCtx webapi.StatusContext, outputs *fly
return taskCtx.OutputWriter().Put(ctx, opReader)
}

func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config, agentRegistry Registry) (*Deployment, bool) {
if agent, exists := agentRegistry[taskCategory.Name][taskCategory.Version]; exists {
func getFinalAgent(taskCategory *admin.TaskCategory, cfg *Config) (*Deployment, bool) {
r := getAgentRegistry()
if agent, exists := r[taskCategory.Name][taskCategory.Version]; exists {
return agent.AgentDeployment, agent.IsSync
}

return &cfg.DefaultAgent, false
}

Expand All @@ -367,38 +392,30 @@ func buildTaskExecutionMetadata(taskExecutionMetadata core.TaskExecutionMetadata
}

func newAgentPlugin() webapi.PluginEntry {
cs, err := initializeClients(context.Background())
if err != nil {
// We should wait for all agents to be up and running before starting the server
panic(fmt.Sprintf("failed to initialize clients with error: %v", err))
}

agentRegistry, err := initializeAgentRegistry(cs)
if err != nil {
panic(fmt.Sprintf("failed to initialize agent registry with error: %v", err))
}

ctx := context.Background()
cfg := GetConfig()
supportedTaskTypes := append(maps.Keys(agentRegistry), cfg.SupportedTaskTypes...)
logger.Infof(context.Background(), "AgentDeployment service supports task types: %v", supportedTaskTypes)

clientSet := getAgentClientSets(ctx)
updateAgentRegistry(ctx, clientSet)
supportedTaskTypes := append(maps.Keys(getAgentRegistry()), cfg.SupportedTaskTypes...)

return webapi.PluginEntry{
ID: "agent-service",
SupportedTaskTypes: supportedTaskTypes,
PluginLoader: func(ctx context.Context, iCtx webapi.PluginSetupContext) (webapi.AsyncPlugin, error) {
return &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: cfg,
cs: cs,
agentRegistry: agentRegistry,
}, nil
plugin := &Plugin{
metricScope: iCtx.MetricsScope(),
cfg: cfg,
cs: clientSet,
}
plugin.watchAgents(ctx)
return plugin, nil
},
}
}

func RegisterAgentPlugin() {
gob.Register(ResourceMetaWrapper{})
gob.Register(ResourceWrapper{})

pluginmachinery.PluginRegistry().RegisterRemotePlugin(newAgentPlugin())
}
13 changes: 6 additions & 7 deletions flyteplugins/go/tasks/plugins/webapi/agent/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,15 @@ func TestPlugin(t *testing.T) {

t.Run("test getFinalAgent", func(t *testing.T) {
agent := &Agent{AgentDeployment: &Deployment{Endpoint: "localhost:80"}}
agentRegistry := Registry{"spark": {defaultTaskTypeVersion: agent}}
agentRegistry = Registry{"spark": {defaultTaskTypeVersion: agent}}
spark := &admin.TaskCategory{Name: "spark", Version: defaultTaskTypeVersion}
foo := &admin.TaskCategory{Name: "foo", Version: defaultTaskTypeVersion}
bar := &admin.TaskCategory{Name: "bar", Version: defaultTaskTypeVersion}
agentDeployment, _ := getFinalAgent(spark, &cfg, agentRegistry)
agentDeployment, _ := getFinalAgent(spark, &cfg)
assert.Equal(t, agentDeployment.Endpoint, "localhost:80")
agentDeployment, _ = getFinalAgent(foo, &cfg, agentRegistry)
agentDeployment, _ = getFinalAgent(foo, &cfg)
assert.Equal(t, agentDeployment.Endpoint, cfg.DefaultAgent.Endpoint)
agentDeployment, _ = getFinalAgent(bar, &cfg, agentRegistry)
agentDeployment, _ = getFinalAgent(bar, &cfg)
assert.Equal(t, agentDeployment.Endpoint, cfg.DefaultAgent.Endpoint)
})

Expand Down Expand Up @@ -318,11 +318,10 @@ func TestInitializeAgentRegistry(t *testing.T) {
cfg.AgentForTaskTypes = map[string]string{"task1": "agent-deployment-1", "task2": "agent-deployment-2"}
err := SetConfig(&cfg)
assert.NoError(t, err)
agentRegistry, err := initializeAgentRegistry(cs)
assert.NoError(t, err)
updateAgentRegistry(context.Background(), cs)

// In golang, the order of keys in a map is random. So, we sort the keys before asserting.
agentRegistryKeys := maps.Keys(agentRegistry)
agentRegistryKeys := maps.Keys(getAgentRegistry())
sort.Strings(agentRegistryKeys)

assert.Equal(t, agentRegistryKeys, []string{"task1", "task2", "task3"})
Expand Down

0 comments on commit cd37d1b

Please sign in to comment.