Skip to content

Commit

Permalink
Database interface refactor (#544)
Browse files Browse the repository at this point in the history
Signed-off-by: Micah Hausler <[email protected]>

## Description

This includes 4 changes that are preparatory for future work to migrate to the Kubernetes data model.

* Propagate stream context in streaming API. Previously `context.Background()` was used, but the stream provides a context object
* Refactored shadowed "context" package name. By naming a method variable "context", no `context` package calls could be made in the method
* Added `context.Context` to `GetWorkflowsForWorker()` database API. This plumbs down the context from the API call into the `d.instance.QueryContext()` call.
* Refactor the database interface. I added a new interface `WorkerWorkflow` with the methods that get used by APIs the Tink Worker invokes. This is essentially a no-op for now. 

## Why is this needed

See tinkerbell/proposals#46

## How Has This Been Tested?

Locally ran tests.

## How are existing users impacted? What migration steps/scripts do we need?

No impact to existing users

## Checklist:

I have:

- [ ] updated the documentation and/or roadmap (if required)
- [ ] added unit or e2e tests
- [ ] provided instructions on how to upgrade
  • Loading branch information
mergify[bot] authored Oct 1, 2021
2 parents 3743d31 + d932a47 commit 7ef02f8
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 44 deletions.
15 changes: 10 additions & 5 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type Database interface {
hardware
template
workflow
WorkerWorkflow
}

type hardware interface {
Expand All @@ -43,20 +44,24 @@ type template interface {

type workflow interface {
CreateWorkflow(ctx context.Context, wf Workflow, data string, id uuid.UUID) error
InsertIntoWfDataTable(ctx context.Context, req *pb.UpdateWorkflowDataRequest) error
GetfromWfDataTable(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error)
GetWorkflowMetadata(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error)
GetWorkflowDataVersion(ctx context.Context, workflowID string) (int32, error)
GetWorkflowsForWorker(id string) ([]string, error)
GetWorkflow(ctx context.Context, id string) (Workflow, error)
DeleteWorkflow(ctx context.Context, id string, state int32) error
ListWorkflows(fn func(wf Workflow) error) error
UpdateWorkflow(ctx context.Context, wf Workflow, state int32) error
InsertIntoWorkflowEventTable(ctx context.Context, wfEvent *pb.WorkflowActionStatus, time time.Time) error
ShowWorkflowEvents(wfID string, fn func(wfs *pb.WorkflowActionStatus) error) error
}

// WorkerWorkflow is an interface for methods invoked by APIs that the worker calls
type WorkerWorkflow interface {
InsertIntoWfDataTable(ctx context.Context, req *pb.UpdateWorkflowDataRequest) error
GetfromWfDataTable(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error)
GetWorkflowsForWorker(ctx context.Context, id string) ([]string, error)
UpdateWorkflowState(ctx context.Context, wfContext *pb.WorkflowContext) error
GetWorkflowContexts(ctx context.Context, wfID string) (*pb.WorkflowContext, error)
GetWorkflowActions(ctx context.Context, wfID string) (*pb.WorkflowActionList, error)
InsertIntoWorkflowEventTable(ctx context.Context, wfEvent *pb.WorkflowActionStatus, time time.Time) error
ShowWorkflowEvents(wfID string, fn func(wfs *pb.WorkflowActionStatus) error) error
}

// TinkDB implements the Database interface
Expand Down
2 changes: 1 addition & 1 deletion db/mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type DB struct {
InsertIntoWfDataTableFunc func(ctx context.Context, req *pb.UpdateWorkflowDataRequest) error
GetWorkflowMetadataFunc func(ctx context.Context, req *pb.GetWorkflowDataRequest) ([]byte, error)
GetWorkflowDataVersionFunc func(ctx context.Context, workflowID string) (int32, error)
GetWorkflowsForWorkerFunc func(id string) ([]string, error)
GetWorkflowsForWorkerFunc func(ctx context.Context, id string) ([]string, error)
GetWorkflowContextsFunc func(ctx context.Context, wfID string) (*pb.WorkflowContext, error)
GetWorkflowActionsFunc func(ctx context.Context, wfID string) (*pb.WorkflowActionList, error)
UpdateWorkflowStateFunc func(ctx context.Context, wfContext *pb.WorkflowContext) error
Expand Down
4 changes: 2 additions & 2 deletions db/mock/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func (d DB) GetWorkflowDataVersion(ctx context.Context, workflowID string) (int3
}

// GetWorkflowsForWorker : returns the list of workflows for a particular worker
func (d DB) GetWorkflowsForWorker(id string) ([]string, error) {
return d.GetWorkflowsForWorkerFunc(id)
func (d DB) GetWorkflowsForWorker(ctx context.Context, id string) ([]string, error) {
return d.GetWorkflowsForWorkerFunc(ctx, id)
}

// GetWorkflow returns a workflow
Expand Down
4 changes: 2 additions & 2 deletions db/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ func (d TinkDB) GetWorkflowDataVersion(ctx context.Context, workflowID string) (
}

// GetWorkflowsForWorker : returns the list of workflows for a particular worker
func (d TinkDB) GetWorkflowsForWorker(id string) ([]string, error) {
rows, err := d.instance.Query(`
func (d TinkDB) GetWorkflowsForWorker(ctx context.Context, id string) ([]string, error) {
rows, err := d.instance.QueryContext(ctx, `
SELECT workflow_id
FROM workflow_worker_map
WHERE
Expand Down
54 changes: 27 additions & 27 deletions grpc-server/tinkerbell.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ const (

// GetWorkflowContexts implements tinkerbell.GetWorkflowContexts
func (s *server) GetWorkflowContexts(req *pb.WorkflowContextRequest, stream pb.WorkflowService_GetWorkflowContextsServer) error {
wfs, err := getWorkflowsForWorker(s.db, req.WorkerId)
wfs, err := getWorkflowsForWorker(stream.Context(), s.db, req.WorkerId)
if err != nil {
return err
}
for _, wf := range wfs {
wfContext, err := s.db.GetWorkflowContexts(context.Background(), wf)
wfContext, err := s.db.GetWorkflowContexts(stream.Context(), wf)
if err != nil {
return status.Errorf(codes.Aborted, err.Error())
}
if isApplicableToSend(context.Background(), s.logger, wfContext, req.WorkerId, s.db) {
if isApplicableToSend(stream.Context(), s.logger, wfContext, req.WorkerId, s.db) {
if err := stream.Send(wfContext); err != nil {
return err
}
Expand All @@ -49,16 +49,16 @@ func (s *server) GetWorkflowContexts(req *pb.WorkflowContextRequest, stream pb.W
}

// GetWorkflowContextList implements tinkerbell.GetWorkflowContextList
func (s *server) GetWorkflowContextList(context context.Context, req *pb.WorkflowContextRequest) (*pb.WorkflowContextList, error) {
wfs, err := getWorkflowsForWorker(s.db, req.WorkerId)
func (s *server) GetWorkflowContextList(ctx context.Context, req *pb.WorkflowContextRequest) (*pb.WorkflowContextList, error) {
wfs, err := getWorkflowsForWorker(ctx, s.db, req.WorkerId)
if err != nil {
return nil, err
}

if wfs != nil {
wfContexts := []*pb.WorkflowContext{}
for _, wf := range wfs {
wfContext, err := s.db.GetWorkflowContexts(context, wf)
wfContext, err := s.db.GetWorkflowContexts(ctx, wf)
if err != nil {
return nil, status.Errorf(codes.Aborted, err.Error())
}
Expand All @@ -72,16 +72,16 @@ func (s *server) GetWorkflowContextList(context context.Context, req *pb.Workflo
}

// GetWorkflowActions implements tinkerbell.GetWorkflowActions
func (s *server) GetWorkflowActions(context context.Context, req *pb.WorkflowActionsRequest) (*pb.WorkflowActionList, error) {
func (s *server) GetWorkflowActions(ctx context.Context, req *pb.WorkflowActionsRequest) (*pb.WorkflowActionList, error) {
wfID := req.GetWorkflowId()
if wfID == "" {
return nil, status.Errorf(codes.InvalidArgument, errInvalidWorkflowId)
}
return getWorkflowActions(context, s.db, wfID)
return getWorkflowActions(ctx, s.db, wfID)
}

// ReportActionStatus implements tinkerbell.ReportActionStatus
func (s *server) ReportActionStatus(context context.Context, req *pb.WorkflowActionStatus) (*pb.Empty, error) {
func (s *server) ReportActionStatus(ctx context.Context, req *pb.WorkflowActionStatus) (*pb.Empty, error) {
wfID := req.GetWorkflowId()
if wfID == "" {
return nil, status.Errorf(codes.InvalidArgument, errInvalidWorkflowId)
Expand All @@ -96,11 +96,11 @@ func (s *server) ReportActionStatus(context context.Context, req *pb.WorkflowAct
l := s.logger.With("actionName", req.GetActionName(), "workflowID", req.GetWorkflowId())
l.Info(fmt.Sprintf(msgReceivedStatus, req.GetActionStatus()))

wfContext, err := s.db.GetWorkflowContexts(context, wfID)
wfContext, err := s.db.GetWorkflowContexts(ctx, wfID)
if err != nil {
return nil, status.Errorf(codes.Aborted, err.Error())
}
wfActions, err := s.db.GetWorkflowActions(context, wfID)
wfActions, err := s.db.GetWorkflowActions(ctx, wfID)
if err != nil {
return nil, status.Errorf(codes.Aborted, err.Error())
}
Expand All @@ -123,14 +123,14 @@ func (s *server) ReportActionStatus(context context.Context, req *pb.WorkflowAct
wfContext.CurrentAction = req.GetActionName()
wfContext.CurrentActionState = req.GetActionStatus()
wfContext.CurrentActionIndex = actionIndex
err = s.db.UpdateWorkflowState(context, wfContext)
err = s.db.UpdateWorkflowState(ctx, wfContext)
if err != nil {
return &pb.Empty{}, status.Errorf(codes.Aborted, err.Error())
}

// TODO the below "time" would be a part of the request which is coming form worker.
time := time.Now()
err = s.db.InsertIntoWorkflowEventTable(context, req, time)
err = s.db.InsertIntoWorkflowEventTable(ctx, req, time)
if err != nil {
return &pb.Empty{}, status.Error(codes.Aborted, err.Error())
}
Expand All @@ -149,7 +149,7 @@ func (s *server) ReportActionStatus(context context.Context, req *pb.WorkflowAct
}

// UpdateWorkflowData updates workflow ephemeral data
func (s *server) UpdateWorkflowData(context context.Context, req *pb.UpdateWorkflowDataRequest) (*pb.Empty, error) {
func (s *server) UpdateWorkflowData(ctx context.Context, req *pb.UpdateWorkflowDataRequest) (*pb.Empty, error) {
wfID := req.GetWorkflowId()
if wfID == "" {
return &pb.Empty{}, status.Errorf(codes.InvalidArgument, errInvalidWorkflowId)
Expand All @@ -158,57 +158,57 @@ func (s *server) UpdateWorkflowData(context context.Context, req *pb.UpdateWorkf
if !ok {
workflowData[wfID] = 1
}
err := s.db.InsertIntoWfDataTable(context, req)
err := s.db.InsertIntoWfDataTable(ctx, req)
if err != nil {
return &pb.Empty{}, status.Errorf(codes.Aborted, err.Error())
}
return &pb.Empty{}, nil
}

// GetWorkflowData gets the ephemeral data for a workflow
func (s *server) GetWorkflowData(context context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) {
func (s *server) GetWorkflowData(ctx context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) {
wfID := req.GetWorkflowId()
if wfID == "" {
return &pb.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.InvalidArgument, errInvalidWorkflowId)
}
data, err := s.db.GetfromWfDataTable(context, req)
data, err := s.db.GetfromWfDataTable(ctx, req)
if err != nil {
return &pb.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.Aborted, err.Error())
}
return &pb.GetWorkflowDataResponse{Data: data}, nil
}

// GetWorkflowMetadata returns metadata wrt to the ephemeral data of a workflow
func (s *server) GetWorkflowMetadata(context context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) {
data, err := s.db.GetWorkflowMetadata(context, req)
func (s *server) GetWorkflowMetadata(ctx context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) {
data, err := s.db.GetWorkflowMetadata(ctx, req)
if err != nil {
return &pb.GetWorkflowDataResponse{Data: []byte("")}, status.Errorf(codes.Aborted, err.Error())
}
return &pb.GetWorkflowDataResponse{Data: data}, nil
}

// GetWorkflowDataVersion returns the latest version of data for a workflow
func (s *server) GetWorkflowDataVersion(context context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) {
version, err := s.db.GetWorkflowDataVersion(context, req.WorkflowId)
func (s *server) GetWorkflowDataVersion(ctx context.Context, req *pb.GetWorkflowDataRequest) (*pb.GetWorkflowDataResponse, error) {
version, err := s.db.GetWorkflowDataVersion(ctx, req.WorkflowId)
if err != nil {
return &pb.GetWorkflowDataResponse{Version: version}, status.Errorf(codes.Aborted, err.Error())
}
return &pb.GetWorkflowDataResponse{Version: version}, nil
}

func getWorkflowsForWorker(db db.Database, id string) ([]string, error) {
func getWorkflowsForWorker(ctx context.Context, db db.Database, id string) ([]string, error) {
if id == "" {
return nil, status.Errorf(codes.InvalidArgument, errInvalidWorkerID)
}
wfs, err := db.GetWorkflowsForWorker(id)
wfs, err := db.GetWorkflowsForWorker(ctx, id)
if err != nil {
return nil, status.Errorf(codes.Aborted, err.Error())
}
return wfs, nil
}

func getWorkflowActions(context context.Context, db db.Database, wfID string) (*pb.WorkflowActionList, error) {
actions, err := db.GetWorkflowActions(context, wfID)
func getWorkflowActions(ctx context.Context, db db.Database, wfID string) (*pb.WorkflowActionList, error) {
actions, err := db.GetWorkflowActions(ctx, wfID)
if err != nil {
return nil, status.Errorf(codes.Aborted, errInvalidWorkflowId)
}
Expand All @@ -217,12 +217,12 @@ func getWorkflowActions(context context.Context, db db.Database, wfID string) (*

// isApplicableToSend checks if a particular workflow context is applicable or if it is needed to
// be sent to a worker based on the state of the current action and the targeted workerID
func isApplicableToSend(context context.Context, logger log.Logger, wfContext *pb.WorkflowContext, workerID string, db db.Database) bool {
func isApplicableToSend(ctx context.Context, logger log.Logger, wfContext *pb.WorkflowContext, workerID string, db db.Database) bool {
if wfContext.GetCurrentActionState() == pb.State_STATE_FAILED ||
wfContext.GetCurrentActionState() == pb.State_STATE_TIMEOUT {
return false
}
actions, err := getWorkflowActions(context, db, wfContext.GetWorkflowId())
actions, err := getWorkflowActions(ctx, db, wfContext.GetWorkflowId())
if err != nil {
return false
}
Expand Down
14 changes: 7 additions & 7 deletions grpc-server/tinkerbell_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestGetWorkflowContextList(t *testing.T) {
"database failure": {
args: args{
db: &mock.DB{
GetWorkflowsForWorkerFunc: func(id string) ([]string, error) {
GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) {
return []string{workflowID}, nil
},
GetWorkflowContextsFunc: func(ctx context.Context, wfID string) (*pb.WorkflowContext, error) {
Expand All @@ -83,7 +83,7 @@ func TestGetWorkflowContextList(t *testing.T) {
"no workflows found": {
args: args{
db: &mock.DB{
GetWorkflowsForWorkerFunc: func(id string) ([]string, error) {
GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) {
return nil, nil
},
GetWorkflowContextsFunc: func(ctx context.Context, wfID string) (*pb.WorkflowContext, error) {
Expand All @@ -99,7 +99,7 @@ func TestGetWorkflowContextList(t *testing.T) {
"workflows found": {
args: args{
db: &mock.DB{
GetWorkflowsForWorkerFunc: func(id string) ([]string, error) {
GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) {
return []string{workflowID}, nil
},
GetWorkflowContextsFunc: func(ctx context.Context, wfID string) (*pb.WorkflowContext, error) {
Expand Down Expand Up @@ -758,7 +758,7 @@ func TestGetWorkflowsForWorker(t *testing.T) {
"database failure": {
args: args{
db: &mock.DB{
GetWorkflowsForWorkerFunc: func(id string) ([]string, error) {
GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) {
return nil, errors.New("database failed")
},
},
Expand All @@ -771,7 +771,7 @@ func TestGetWorkflowsForWorker(t *testing.T) {
"no workflows found": {
args: args{
db: &mock.DB{
GetWorkflowsForWorkerFunc: func(id string) ([]string, error) {
GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) {
return nil, nil
},
},
Expand All @@ -784,7 +784,7 @@ func TestGetWorkflowsForWorker(t *testing.T) {
"workflows found": {
args: args{
db: &mock.DB{
GetWorkflowsForWorkerFunc: func(id string) ([]string, error) {
GetWorkflowsForWorkerFunc: func(ctx context.Context, id string) ([]string, error) {
return []string{workflowID}, nil
},
},
Expand All @@ -799,7 +799,7 @@ func TestGetWorkflowsForWorker(t *testing.T) {
for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
s := testServer(t, tc.args.db)
res, err := getWorkflowsForWorker(s.db, tc.args.workerID)
res, err := getWorkflowsForWorker(context.Background(), s.db, tc.args.workerID)
if err != nil {
assert.True(t, tc.want.expectedError)
assert.Error(t, err)
Expand Down

0 comments on commit 7ef02f8

Please sign in to comment.