Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: avoid payload limitation [MD-266] #9164

Merged
merged 18 commits into from
Jun 10, 2024
44 changes: 33 additions & 11 deletions master/internal/stream/model_versions.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,27 @@ type ModelVersionMsg struct {
}

// SeqNum gets the SeqNum from a ModelVersionMsg.
func (pm *ModelVersionMsg) SeqNum() int64 {
return pm.Seq
func (mm *ModelVersionMsg) SeqNum() int64 {
return mm.Seq
}

// GetID gets the ID from a ModelVersionMsg.
func (mm *ModelVersionMsg) GetID() int {
return mm.ID
}

// UpsertMsg creates a ModelVersion stream upsert message.
func (pm *ModelVersionMsg) UpsertMsg() stream.UpsertMsg {
return stream.UpsertMsg{
func (mm *ModelVersionMsg) UpsertMsg() *stream.UpsertMsg {
return &stream.UpsertMsg{
JSONKey: ModelVersionsUpsertKey,
Msg: pm,
Msg: mm,
}
}

// DeleteMsg creates a ModelVersion stream delete message.
func (pm *ModelVersionMsg) DeleteMsg() stream.DeleteMsg {
deleted := strconv.FormatInt(int64(pm.ID), 10)
return stream.DeleteMsg{
func (mm *ModelVersionMsg) DeleteMsg() *stream.DeleteMsg {
deleted := strconv.Itoa(mm.ID)
return &stream.DeleteMsg{
Key: ModelVersionsDeleteKey,
Deleted: deleted,
}
Expand Down Expand Up @@ -158,7 +163,7 @@ func ModelVersionCollectStartupMsgs(
}
missing, appeared, err := processQuery(ctx, createQuery, spec.Since, known, "m")
if err != nil {
return nil, err
return nil, fmt.Errorf("processing known: %w", err)
}

// step 2: hydrate appeared IDs into full ModelVersionMsgs
Expand All @@ -171,14 +176,14 @@ func ModelVersionCollectStartupMsgs(
query = modelVersionPermFilterQuery(query, accessScopes)
}
err := query.Scan(ctx, &mvMsgs)
if err != nil && errors.Cause(err) != sql.ErrNoRows {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
log.Errorf("error: %v\n", err)
return nil, err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is there a reason we aren't wrapping this error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will wrap it.

}
}

// step 3: emit deletions and updates to the client
out = append(out, stream.DeleteMsg{
out = append(out, &stream.DeleteMsg{
Key: ModelVersionsDeleteKey,
Deleted: missing,
})
Expand Down Expand Up @@ -258,3 +263,20 @@ func ModelVersionMakePermissionFilter(ctx context.Context, user model.User) (fun
}, nil
}
}

// ModelVersionMakeHydrator returns a function that gets properties of a model version by
// its id.
func ModelVersionMakeHydrator() func(*ModelVersionMsg) (*ModelVersionMsg, error) {
return func(msg *ModelVersionMsg) (*ModelVersionMsg, error) {
var saturatedMsg ModelVersionMsg
query := db.Bun().NewSelect().Model(&saturatedMsg).Where("id = ?", msg.GetID()).ExcludeColumn("workspace_id")
corban-beaird marked this conversation as resolved.
Show resolved Hide resolved
err := query.Scan(context.Background(), &saturatedMsg)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, err
} else if err != nil {
return nil, fmt.Errorf("error in model version hydrator: %w", err)
}
saturatedMsg.WorkspaceID = msg.WorkspaceID
return &saturatedMsg, nil
}
}
43 changes: 32 additions & 11 deletions master/internal/stream/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,27 @@ type ModelMsg struct {
}

// SeqNum gets the SeqNum from a ModelMsg.
func (pm *ModelMsg) SeqNum() int64 {
return pm.Seq
func (mm *ModelMsg) SeqNum() int64 {
return mm.Seq
}

// GetID gets the ID from a ModelMsg.
func (mm *ModelMsg) GetID() int {
return mm.ID
}

// UpsertMsg creates a model stream upsert message.
func (pm *ModelMsg) UpsertMsg() stream.UpsertMsg {
return stream.UpsertMsg{
func (mm *ModelMsg) UpsertMsg() *stream.UpsertMsg {
return &stream.UpsertMsg{
JSONKey: ModelsUpsertKey,
Msg: pm,
Msg: mm,
}
}

// DeleteMsg creates a model stream delete message.
func (pm *ModelMsg) DeleteMsg() stream.DeleteMsg {
deleted := strconv.FormatInt(int64(pm.ID), 10)
return stream.DeleteMsg{
func (mm *ModelMsg) DeleteMsg() *stream.DeleteMsg {
deleted := strconv.Itoa(mm.ID)
return &stream.DeleteMsg{
Key: ModelsDeleteKey,
Deleted: deleted,
}
Expand Down Expand Up @@ -152,7 +157,7 @@ func ModelCollectStartupMsgs(
}
missing, appeared, err := processQuery(ctx, createQuery, spec.Since, known, "m")
if err != nil {
return nil, err
return nil, fmt.Errorf("processing known: %w", err)
}

// step 2: hydrate appeared IDs into full ModelMsgs
Expand All @@ -163,14 +168,14 @@ func ModelCollectStartupMsgs(
query = permFilterQuery(query, accessScopes)
}
err := query.Scan(ctx, &modelMsgs)
if err != nil && errors.Cause(err) != sql.ErrNoRows {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
log.Errorf("error: %v\n", err)
return nil, err
}
}

// step 3: emit deletions and updates to the client
out = append(out, stream.DeleteMsg{
out = append(out, &stream.DeleteMsg{
Key: ModelsDeleteKey,
Deleted: missing,
})
Expand Down Expand Up @@ -246,3 +251,19 @@ func ModelMakePermissionFilter(ctx context.Context, user model.User) (func(*Mode
}, nil
}
}

// ModelMakeHydrator returns a function that gets properties of a model by
// its id.
func ModelMakeHydrator() func(*ModelMsg) (*ModelMsg, error) {
return func(msg *ModelMsg) (*ModelMsg, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments here as the other hydrators

var saturatedMsg ModelMsg
query := db.Bun().NewSelect().Model(&saturatedMsg).Where("id = ?", msg.GetID())
err := query.Scan(context.Background(), &saturatedMsg)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, err
} else if err != nil {
return nil, fmt.Errorf("error in model hydrator: %w", err)
}
return &saturatedMsg, nil
}
}
37 changes: 29 additions & 8 deletions master/internal/stream/projects.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,23 @@ func (pm *ProjectMsg) SeqNum() int64 {
return pm.Seq
}

// GetID gets the ID from a ProjectMsg.
func (pm *ProjectMsg) GetID() int {
return pm.ID
}

// UpsertMsg creates a Project stream upsert message.
func (pm *ProjectMsg) UpsertMsg() stream.UpsertMsg {
return stream.UpsertMsg{
func (pm *ProjectMsg) UpsertMsg() *stream.UpsertMsg {
return &stream.UpsertMsg{
JSONKey: ProjectsUpsertKey,
Msg: pm,
}
}

// DeleteMsg creates a Project stream delete message.
func (pm *ProjectMsg) DeleteMsg() stream.DeleteMsg {
deleted := strconv.FormatInt(int64(pm.ID), 10)
return stream.DeleteMsg{
func (pm *ProjectMsg) DeleteMsg() *stream.DeleteMsg {
deleted := strconv.Itoa(pm.ID)
return &stream.DeleteMsg{
Key: ProjectsDeleteKey,
Deleted: deleted,
}
Expand Down Expand Up @@ -150,7 +155,7 @@ func ProjectCollectStartupMsgs(
}
missing, appeared, err := processQuery(ctx, createQuery, spec.Since, known, "p")
if err != nil {
return nil, err
return nil, fmt.Errorf("processing known: %w", err)
}

// step 2: hydrate appeared IDs into full ProjectMsgs
Expand All @@ -161,14 +166,14 @@ func ProjectCollectStartupMsgs(
query = permFilterQuery(query, accessScopes)
}
err := query.Scan(ctx, &projMsgs)
if err != nil && errors.Cause(err) != sql.ErrNoRows {
if err != nil && !errors.Is(err, sql.ErrNoRows) {
log.Errorf("error: %v\n", err)
return nil, err
}
}

// step 3: emit deletions and updates to the client
out = append(out, stream.DeleteMsg{
out = append(out, &stream.DeleteMsg{
Key: ProjectsDeleteKey,
Deleted: missing,
})
Expand Down Expand Up @@ -233,3 +238,19 @@ func ProjectMakePermissionFilter(ctx context.Context, user model.User) (func(*Pr
}, nil
}
}

// ProjectMakeHydrator returns a function that gets properties of a project by
// its id.
func ProjectMakeHydrator() func(*ProjectMsg) (*ProjectMsg, error) {
return func(msg *ProjectMsg) (*ProjectMsg, error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comments here as the others

var saturatedMsg ProjectMsg
query := db.Bun().NewSelect().Model(&saturatedMsg).Where("project_msg.id = ?", msg.GetID())
err := query.Scan(context.Background(), &saturatedMsg)
if err != nil && errors.Is(err, sql.ErrNoRows) {
return nil, err
} else if err != nil {
return nil, fmt.Errorf("error in project hydrator: %w", err)
}
return &saturatedMsg, nil
}
}
19 changes: 15 additions & 4 deletions master/internal/stream/publisher.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"github.com/determined-ai/determined/master/pkg/syncx/errgroupx"
)

const maxEventCount = 120

// PublisherSet contains all publishers, and handles all websockets. It will connect each websocket
// with the appropriate set of publishers, based on that websocket's subscriptions.
//
Expand All @@ -38,9 +40,9 @@ func NewPublisherSet(dbAddress string) *PublisherSet {
lock := sync.Mutex{}
return &PublisherSet{
DBAddress: dbAddress,
Projects: stream.NewPublisher[*ProjectMsg](),
Models: stream.NewPublisher[*ModelMsg](),
ModelVersions: stream.NewPublisher[*ModelVersionMsg](),
Projects: stream.NewPublisher[*ProjectMsg](ProjectMakeHydrator()),
Models: stream.NewPublisher[*ModelMsg](ModelMakeHydrator()),
ModelVersions: stream.NewPublisher[*ModelVersionMsg](ModelVersionMakeHydrator()),
bootemChan: make(chan struct{}),
readyCond: *sync.NewCond(&lock),
}
Expand Down Expand Up @@ -414,6 +416,7 @@ func publishLoop[T stream.Msg](
events = append(events, event)
// Collect all available notifications before proceeding.
keepGoing := true
eventCount := 0
for keepGoing {
select {
case notification = <-listener.Notify:
Expand All @@ -423,12 +426,20 @@ func publishLoop[T stream.Msg](
return err
}
events = append(events, event)
eventCount++
keepGoing = eventCount < maxEventCount
default:
keepGoing = false
}
}

idToSaturatedMsg := map[int]*stream.UpsertMsg{}
// TODO: MD-434 improve performance by batch hydrating the messages.
for _, ev := range events {
publisher.HydrateMsg(ev.After, idToSaturatedMsg)
}
// Broadcast all the events.
publisher.Broadcast(events)
publisher.Broadcast(events, idToSaturatedMsg)
}
}
}
Expand Down
4 changes: 1 addition & 3 deletions master/internal/stream/subscription.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package stream
import (
"context"

"github.com/labstack/echo/v4"

"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/stream"
)
Expand Down Expand Up @@ -156,7 +154,7 @@ func startup[T stream.Msg, S any](
// Scan for historical msgs matching newly-added subscriptions.
newmsgs, err := state.CollectStartupMsgs(ctx, user, known, *spec)
if err != nil {
return echo.ErrCookieNotFound
return err
}
for _, msg := range newmsgs {
*msgs = append(*msgs, prepare(msg))
Expand Down
4 changes: 2 additions & 2 deletions master/internal/stream/test_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ const (
// otherwise, returns the MarshallableMsg that the streamer sends.
func testPrepareFunc(i stream.MarshallableMsg) interface{} {
switch msg := i.(type) {
case stream.UpsertMsg:
case *stream.UpsertMsg:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, why do we need to change to pointer types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

testPrepareFunc is called in Broadcast() https://github.com/determined-ai/determined/pull/9164/files#diff-ab55b57cea1409ab6f97b1352891adb036d83ee04a82f8428eee733d70351cf6R247:

msg = sub.Streamer.PrepareFn(recordCache.UpsertMsg)

The type of recordCache.UpsertMsg is *stream.UpsertMsg.

switch typedMsg := msg.Msg.(type) {
case *ProjectMsg:
return fmt.Sprintf(
Expand All @@ -50,7 +50,7 @@ func testPrepareFunc(i stream.MarshallableMsg) interface{} {
typedMsg.WorkspaceID,
)
}
case stream.DeleteMsg:
case *stream.DeleteMsg:
return fmt.Sprintf("key: %s, deleted: %s", msg.Key, msg.Deleted)
case stream.SyncMsg:
return fmt.Sprintf("key: %s, sync_id: %s, complete: %t", syncKey, msg.SyncID, msg.Complete)
Expand Down
Loading
Loading