From b420ee7184b3ad75a2dd1aca0c57b1a689c7483b Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Thu, 30 Sep 2021 15:30:24 -0400 Subject: [PATCH 1/4] Propagate stream context in streaming API Signed-off-by: Micah Hausler --- grpc-server/tinkerbell.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/grpc-server/tinkerbell.go b/grpc-server/tinkerbell.go index 14ed269a0..261691dd1 100644 --- a/grpc-server/tinkerbell.go +++ b/grpc-server/tinkerbell.go @@ -35,11 +35,11 @@ func (s *server) GetWorkflowContexts(req *pb.WorkflowContextRequest, stream pb.W return err } for _, wf := range wfs { - wfContext, err := s.db.GetWorkflowContexts(context.Background(), wf) + wfContext, err := s.db.GetWorkflowContexts(stream.Context(), wf) if err != nil { return status.Errorf(codes.Aborted, err.Error()) } - if isApplicableToSend(context.Background(), s.logger, wfContext, req.WorkerId, s.db) { + if isApplicableToSend(stream.Context(), s.logger, wfContext, req.WorkerId, s.db) { if err := stream.Send(wfContext); err != nil { return err } From 4c2bb157fb3b40a13a7ce2a569aa66cc12c60eb6 Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Thu, 30 Sep 2021 15:33:36 -0400 Subject: [PATCH 2/4] Refactored shadowed context package name Signed-off-by: Micah Hausler --- grpc-server/tinkerbell.go | 42 +++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/grpc-server/tinkerbell.go b/grpc-server/tinkerbell.go index 261691dd1..c01119c12 100644 --- a/grpc-server/tinkerbell.go +++ b/grpc-server/tinkerbell.go @@ -49,8 +49,8 @@ func (s *server) GetWorkflowContexts(req *pb.WorkflowContextRequest, stream pb.W } // GetWorkflowContextList implements tinkerbell.GetWorkflowContextList -func (s *server) GetWorkflowContextList(context context.Context, req *pb.WorkflowContextRequest) (*pb.WorkflowContextList, error) { wfs, err := getWorkflowsForWorker(s.db, req.WorkerId) +func (s *server) GetWorkflowContextList(ctx context.Context, req *pb.WorkflowContextRequest) (*pb.WorkflowContextList, error) { if err != nil { return nil, err } @@ -58,7 +58,7 @@ func (s *server) GetWorkflowContextList(context context.Context, req *pb.Workflo if wfs != nil { wfContexts := []*pb.WorkflowContext{} for _, wf := range wfs { - wfContext, err := s.db.GetWorkflowContexts(context, wf) + wfContext, err := s.db.GetWorkflowContexts(ctx, wf) if err != nil { return nil, status.Errorf(codes.Aborted, err.Error()) } @@ -72,16 +72,16 @@ func (s *server) GetWorkflowContextList(context context.Context, req *pb.Workflo } // GetWorkflowActions implements tinkerbell.GetWorkflowActions -func (s *server) GetWorkflowActions(context context.Context, req *pb.WorkflowActionsRequest) (*pb.WorkflowActionList, error) { +func (s *server) GetWorkflowActions(ctx context.Context, req *pb.WorkflowActionsRequest) (*pb.WorkflowActionList, error) { wfID := req.GetWorkflowId() if wfID == "" { return nil, status.Errorf(codes.InvalidArgument, errInvalidWorkflowId) } - return getWorkflowActions(context, s.db, wfID) + return getWorkflowActions(ctx, s.db, wfID) } // ReportActionStatus implements tinkerbell.ReportActionStatus -func (s *server) ReportActionStatus(context context.Context, req *pb.WorkflowActionStatus) (*pb.Empty, error) { +func (s *server) ReportActionStatus(ctx context.Context, req *pb.WorkflowActionStatus) (*pb.Empty, error) { wfID := req.GetWorkflowId() if wfID == "" { return nil, status.Errorf(codes.InvalidArgument, errInvalidWorkflowId) @@ -96,11 +96,11 @@ func (s *server) ReportActionStatus(context context.Context, req *pb.WorkflowAct l := s.logger.With("actionName", req.GetActionName(), "workflowID", req.GetWorkflowId()) l.Info(fmt.Sprintf(msgReceivedStatus, req.GetActionStatus())) - wfContext, err := s.db.GetWorkflowContexts(context, wfID) + wfContext, err := s.db.GetWorkflowContexts(ctx, wfID) if err != nil { return nil, status.Errorf(codes.Aborted, err.Error()) } - wfActions, err := s.db.GetWorkflowActions(context, wfID) + wfActions, err := s.db.GetWorkflowActions(ctx, wfID) if err != nil { return nil, status.Errorf(codes.Aborted, err.Error()) } @@ -123,14 +123,14 @@ func (s *server) ReportActionStatus(context context.Context, req *pb.WorkflowAct wfContext.CurrentAction = req.GetActionName() wfContext.CurrentActionState = req.GetActionStatus() wfContext.CurrentActionIndex = actionIndex - err = s.db.UpdateWorkflowState(context, wfContext) + err = s.db.UpdateWorkflowState(ctx, wfContext) if err != nil { return &pb.Empty{}, status.Errorf(codes.Aborted, err.Error()) } // TODO the below "time" would be a part of the request which is coming form worker. time := time.Now() - err = s.db.InsertIntoWorkflowEventTable(context, req, time) + err = s.db.InsertIntoWorkflowEventTable(ctx, req, time) if err != nil { return &pb.Empty{}, status.Error(codes.Aborted, err.Error()) } @@ -149,7 +149,7 @@ func (s *server) ReportActionStatus(context context.Context, req *pb.WorkflowAct } // UpdateWorkflowData updates workflow ephemeral data -func (s *server) UpdateWorkflowData(context context.Context, req *pb.UpdateWorkflowDataRequest) (*pb.Empty, error) { +func (s *server) UpdateWorkflowData(ctx context.Context, req *pb.UpdateWorkflowDataRequest) (*pb.Empty, error) { wfID := req.GetWorkflowId() if wfID == "" { return &pb.Empty{}, status.Errorf(codes.InvalidArgument, errInvalidWorkflowId) @@ -158,7 +158,7 @@ func (s *server) UpdateWorkflowData(context context.Context, req *pb.UpdateWorkf if !ok { workflowData[wfID] = 1 } - err := s.db.InsertIntoWfDataTable(context, req) + err := s.db.InsertIntoWfDataTable(ctx, req) if err != nil { return &pb.Empty{}, status.Errorf(codes.Aborted, err.Error()) } @@ -166,12 +166,12 @@ func (s *server) UpdateWorkflowData(context context.Context, req *pb.UpdateWorkf } // GetWorkflowData gets the ephemeral data for a workflow -func (s *server) GetWorkflowData(context context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) { +func (s *server) GetWorkflowData(ctx context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) { wfID := req.GetWorkflowId() if wfID == "" { return &pb.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.InvalidArgument, errInvalidWorkflowId) } - data, err := s.db.GetfromWfDataTable(context, req) + data, err := s.db.GetfromWfDataTable(ctx, req) if err != nil { return &pb.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.Aborted, err.Error()) } @@ -179,8 +179,8 @@ func (s *server) GetWorkflowData(context context.Context, req *pb.GetWorkflowDat } // GetWorkflowMetadata returns metadata wrt to the ephemeral data of a workflow -func (s *server) GetWorkflowMetadata(context context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) { - data, err := s.db.GetWorkflowMetadata(context, req) +func (s *server) GetWorkflowMetadata(ctx context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) { + data, err := s.db.GetWorkflowMetadata(ctx, req) if err != nil { return &pb.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.Aborted, err.Error()) } @@ -188,8 +188,8 @@ func (s *server) GetWorkflowMetadata(context context.Context, req *pb.GetWorkflo } // GetWorkflowDataVersion returns the latest version of data for a workflow -func (s *server) GetWorkflowDataVersion(context context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) { - version, err := s.db.GetWorkflowDataVersion(context, req.WorkflowId) +func (s *server) GetWorkflowDataVersion(ctx context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) { + version, err := s.db.GetWorkflowDataVersion(ctx, req.WorkflowId) if err != nil { return &pb.GetWorkflowDataResponse{Version: version}, status.Errorf(codes.Aborted, err.Error()) } @@ -207,8 +207,8 @@ func getWorkflowsForWorker(db db.Database, id string) ([]string, error) { return wfs, nil } -func getWorkflowActions(context context.Context, db db.Database, wfID string) (*pb.WorkflowActionList, error) { - actions, err := db.GetWorkflowActions(context, wfID) +func getWorkflowActions(ctx context.Context, db db.Database, wfID string) (*pb.WorkflowActionList, error) { + actions, err := db.GetWorkflowActions(ctx, wfID) if err != nil { return nil, status.Errorf(codes.Aborted, errInvalidWorkflowId) } @@ -217,12 +217,12 @@ func getWorkflowActions(context context.Context, db db.Database, wfID string) (* // isApplicableToSend checks if a particular workflow context is applicable or if it is needed to // be sent to a worker based on the state of the current action and the targeted workerID -func isApplicableToSend(context context.Context, logger log.Logger, wfContext *pb.WorkflowContext, workerID string, db db.Database) bool { +func isApplicableToSend(ctx context.Context, logger log.Logger, wfContext *pb.WorkflowContext, workerID string, db db.Database) bool { if wfContext.GetCurrentActionState() == pb.State_STATE_FAILED || wfContext.GetCurrentActionState() == pb.State_STATE_TIMEOUT { return false } - actions, err := getWorkflowActions(context, db, wfContext.GetWorkflowId()) + actions, err := getWorkflowActions(ctx, db, wfContext.GetWorkflowId()) if err != nil { return false } From 01075a821b5242ca73dcb0c535bff700ab1bc46d Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Thu, 30 Sep 2021 15:35:23 -0400 Subject: [PATCH 3/4] Add context to GetWorkflowsForWorker Signed-off-by: Micah Hausler --- db/db.go | 2 +- db/mock/mock.go | 2 +- db/mock/workflow.go | 4 ++-- db/workflow.go | 4 ++-- grpc-server/tinkerbell.go | 8 ++++---- grpc-server/tinkerbell_test.go | 14 +++++++------- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/db/db.go b/db/db.go index e4107bca8..f8179a768 100644 --- a/db/db.go +++ b/db/db.go @@ -47,11 +47,11 @@ type workflow interface { GetfromWfDataTable(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error) GetWorkflowMetadata(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error) GetWorkflowDataVersion(ctx context.Context, workflowID string) (int32, error) - GetWorkflowsForWorker(id string) ([]string, error) GetWorkflow(ctx context.Context, id string) (Workflow, error) DeleteWorkflow(ctx context.Context, id string, state int32) error ListWorkflows(fn func(wf Workflow) error) error UpdateWorkflow(ctx context.Context, wf Workflow, state int32) error + GetWorkflowsForWorker(ctx context.Context, id string) ([]string, error) UpdateWorkflowState(ctx context.Context, wfContext *pb.WorkflowContext) error GetWorkflowContexts(ctx context.Context, wfID string) (*pb.WorkflowContext, error) GetWorkflowActions(ctx context.Context, wfID string) (*pb.WorkflowActionList, error) diff --git a/db/mock/mock.go b/db/mock/mock.go index bccd3976f..f45cc26f3 100644 --- a/db/mock/mock.go +++ b/db/mock/mock.go @@ -19,7 +19,7 @@ type DB struct { InsertIntoWfDataTableFunc func(ctx context.Context, req *pb.UpdateWorkflowDataRequest) error GetWorkflowMetadataFunc func(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error) GetWorkflowDataVersionFunc func(ctx context.Context, workflowID string) (int32, error) - GetWorkflowsForWorkerFunc func(id string) ([]string, error) + GetWorkflowsForWorkerFunc func(ctx context.Context, id string) ([]string, error) GetWorkflowContextsFunc func(ctx context.Context, wfID string) (*pb.WorkflowContext, error) GetWorkflowActionsFunc func(ctx context.Context, wfID string) (*pb.WorkflowActionList, error) UpdateWorkflowStateFunc func(ctx context.Context, wfContext *pb.WorkflowContext) error diff --git a/db/mock/workflow.go b/db/mock/workflow.go index 4a845802f..9f3908b76 100644 --- a/db/mock/workflow.go +++ b/db/mock/workflow.go @@ -35,8 +35,8 @@ func (d DB) GetWorkflowDataVersion(ctx context.Context, workflowID string) (int3 } // GetWorkflowsForWorker : returns the list of workflows for a particular worker -func (d DB) GetWorkflowsForWorker(id string) ([]string, error) { - return d.GetWorkflowsForWorkerFunc(id) +func (d DB) GetWorkflowsForWorker(ctx context.Context, id string) ([]string, error) { + return d.GetWorkflowsForWorkerFunc(ctx, id) } // GetWorkflow returns a workflow diff --git a/db/workflow.go b/db/workflow.go index fa470cd5a..ec7da9483 100644 --- a/db/workflow.go +++ b/db/workflow.go @@ -304,8 +304,8 @@ func (d TinkDB) GetWorkflowDataVersion(ctx context.Context, workflowID string) ( } // GetWorkflowsForWorker : returns the list of workflows for a particular worker -func (d TinkDB) GetWorkflowsForWorker(id string) ([]string, error) { - rows, err := d.instance.Query(` +func (d TinkDB) GetWorkflowsForWorker(ctx context.Context, id string) ([]string, error) { + rows, err := d.instance.QueryContext(ctx, ` SELECT workflow_id FROM workflow_worker_map WHERE diff --git a/grpc-server/tinkerbell.go b/grpc-server/tinkerbell.go index c01119c12..6c3f0a722 100644 --- a/grpc-server/tinkerbell.go +++ b/grpc-server/tinkerbell.go @@ -30,7 +30,7 @@ const ( // GetWorkflowContexts implements tinkerbell.GetWorkflowContexts func (s *server) GetWorkflowContexts(req *pb.WorkflowContextRequest, stream pb.WorkflowService_GetWorkflowContextsServer) error { - wfs, err := getWorkflowsForWorker(s.db, req.WorkerId) + wfs, err := getWorkflowsForWorker(stream.Context(), s.db, req.WorkerId) if err != nil { return err } @@ -49,8 +49,8 @@ func (s *server) GetWorkflowContexts(req *pb.WorkflowContextRequest, stream pb.W } // GetWorkflowContextList implements tinkerbell.GetWorkflowContextList - wfs, err := getWorkflowsForWorker(s.db, req.WorkerId) func (s *server) GetWorkflowContextList(ctx context.Context, req *pb.WorkflowContextRequest) (*pb.WorkflowContextList, error) { + wfs, err := getWorkflowsForWorker(ctx, s.db, req.WorkerId) if err != nil { return nil, err } @@ -196,11 +196,11 @@ func (s *server) GetWorkflowDataVersion(ctx context.Context, req *pb.GetWorkflow return &pb.GetWorkflowDataResponse{Version: version}, nil } -func getWorkflowsForWorker(db db.Database, id string) ([]string, error) { +func getWorkflowsForWorker(ctx context.Context, db db.Database, id string) ([]string, error) { if id == "" { return nil, status.Errorf(codes.InvalidArgument, errInvalidWorkerID) } - wfs, err := db.GetWorkflowsForWorker(id) + wfs, err := db.GetWorkflowsForWorker(ctx, id) if err != nil { return nil, status.Errorf(codes.Aborted, err.Error()) } diff --git a/grpc-server/tinkerbell_test.go b/grpc-server/tinkerbell_test.go index d30fe7425..c7e11b9b9 100644 --- a/grpc-server/tinkerbell_test.go +++ b/grpc-server/tinkerbell_test.go @@ -67,7 +67,7 @@ func TestGetWorkflowContextList(t *testing.T) { "database failure": { args: args{ db: &mock.DB{ - GetWorkflowsForWorkerFunc: func(id string) ([]string, error) { + GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) { return []string{workflowID}, nil }, GetWorkflowContextsFunc: func(ctx context.Context, wfID string) (*pb.WorkflowContext, error) { @@ -83,7 +83,7 @@ func TestGetWorkflowContextList(t *testing.T) { "no workflows found": { args: args{ db: &mock.DB{ - GetWorkflowsForWorkerFunc: func(id string) ([]string, error) { + GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) { return nil, nil }, GetWorkflowContextsFunc: func(ctx context.Context, wfID string) (*pb.WorkflowContext, error) { @@ -99,7 +99,7 @@ func TestGetWorkflowContextList(t *testing.T) { "workflows found": { args: args{ db: &mock.DB{ - GetWorkflowsForWorkerFunc: func(id string) ([]string, error) { + GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) { return []string{workflowID}, nil }, GetWorkflowContextsFunc: func(ctx context.Context, wfID string) (*pb.WorkflowContext, error) { @@ -758,7 +758,7 @@ func TestGetWorkflowsForWorker(t *testing.T) { "database failure": { args: args{ db: &mock.DB{ - GetWorkflowsForWorkerFunc: func(id string) ([]string, error) { + GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) { return nil, errors.New("database failed") }, }, @@ -771,7 +771,7 @@ func TestGetWorkflowsForWorker(t *testing.T) { "no workflows found": { args: args{ db: &mock.DB{ - GetWorkflowsForWorkerFunc: func(id string) ([]string, error) { + GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) { return nil, nil }, }, @@ -784,7 +784,7 @@ func TestGetWorkflowsForWorker(t *testing.T) { "workflows found": { args: args{ db: &mock.DB{ - GetWorkflowsForWorkerFunc: func(id string) ([]string, error) { + GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) { return []string{workflowID}, nil }, }, @@ -799,7 +799,7 @@ func TestGetWorkflowsForWorker(t *testing.T) { for name, tc := range testCases { t.Run(name, func(t *testing.T) { s := testServer(t, tc.args.db) - res, err := getWorkflowsForWorker(s.db, tc.args.workerID) + res, err := getWorkflowsForWorker(context.Background(), s.db, tc.args.workerID) if err != nil { assert.True(t, tc.want.expectedError) assert.Error(t, err) From d932a471178c195e85df56c878f43f0bef2213f5 Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Thu, 30 Sep 2021 15:36:38 -0400 Subject: [PATCH 4/4] Refactor database interface This change adds a separate interface for calls that are only invoked by APIs made by the Tink worker Signed-off-by: Micah Hausler --- db/db.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/db/db.go b/db/db.go index f8179a768..d0e35d643 100644 --- a/db/db.go +++ b/db/db.go @@ -22,6 +22,7 @@ type Database interface { hardware template workflow + WorkerWorkflow } type hardware interface { @@ -43,20 +44,24 @@ type template interface { type workflow interface { CreateWorkflow(ctx context.Context, wf Workflow, data string, id uuid.UUID) error - InsertIntoWfDataTable(ctx context.Context, req *pb.UpdateWorkflowDataRequest) error - GetfromWfDataTable(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error) GetWorkflowMetadata(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error) GetWorkflowDataVersion(ctx context.Context, workflowID string) (int32, error) GetWorkflow(ctx context.Context, id string) (Workflow, error) DeleteWorkflow(ctx context.Context, id string, state int32) error ListWorkflows(fn func(wf Workflow) error) error UpdateWorkflow(ctx context.Context, wf Workflow, state int32) error + InsertIntoWorkflowEventTable(ctx context.Context, wfEvent *pb.WorkflowActionStatus, time time.Time) error + ShowWorkflowEvents(wfID string, fn func(wfs *pb.WorkflowActionStatus) error) error +} + +// WorkerWorkflow is an interface for methods invoked by APIs that the worker calls +type WorkerWorkflow interface { + InsertIntoWfDataTable(ctx context.Context, req *pb.UpdateWorkflowDataRequest) error + GetfromWfDataTable(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error) GetWorkflowsForWorker(ctx context.Context, id string) ([]string, error) UpdateWorkflowState(ctx context.Context, wfContext *pb.WorkflowContext) error GetWorkflowContexts(ctx context.Context, wfID string) (*pb.WorkflowContext, error) GetWorkflowActions(ctx context.Context, wfID string) (*pb.WorkflowActionList, error) - InsertIntoWorkflowEventTable(ctx context.Context, wfEvent *pb.WorkflowActionStatus, time time.Time) error - ShowWorkflowEvents(wfID string, fn func(wfs *pb.WorkflowActionStatus) error) error } // TinkDB implements the Database interface