diff --git a/plugin/plugin.go b/plugin/plugin.go index eae981b1..85d2973c 100644 --- a/plugin/plugin.go +++ b/plugin/plugin.go @@ -121,7 +121,7 @@ func (p *Plugin) Execute(req *proto.ExecuteRequest, stream proto.WrapperPlugin_E // 3) Build row spawns goroutines for any required hydrate functions. // 4) When hydrate functions are complete, apply transforms to generate column values. When row is ready, send on rowChan // 5) Range over rowChan - for each row, send on results stream - ctx := context.WithValue(context.Background(), context_key.Logger, p.Logger) + ctx := context.WithValue(stream.Context(), context_key.Logger, p.Logger) var matrixItem []map[string]interface{} var connection *Connection diff --git a/plugin/query_data.go b/plugin/query_data.go index 1c341f50..df10a277 100644 --- a/plugin/query_data.go +++ b/plugin/query_data.go @@ -2,7 +2,6 @@ package plugin import ( "context" - "errors" "fmt" "log" "sync" @@ -55,9 +54,6 @@ type QueryData struct { listWg sync.WaitGroup // when executing parent child list calls, we cache the parent list result in the query data passed to the child list call parentItem interface{} - - // there was an error streaming to the grpc stream - streamingError error } func newQueryData(queryContext *QueryContext, table *Table, stream proto.WrapperPlugin_ExecuteServer, connection *Connection, matrix []map[string]interface{}, connectionManager *connection_manager.Manager) *QueryData { @@ -265,9 +261,11 @@ func (d *QueryData) verifyCallerIsListCall(callingFunction string) bool { } func (d *QueryData) streamLeafListItem(ctx context.Context, item interface{}) { - if d.streamingError != nil { - // if there is streaming error, panic to force exit thread - this will be recovered higher up - panic(d.streamingError) + // if the context is cancelled, panic to break out + select { + case <-d.stream.Context().Done(): + panic(contextCancelledError) + default: } // create rowData, passing matrixItem from context @@ -292,8 +290,6 @@ func (d *QueryData) streamRows(_ context.Context, rowChan chan *proto.Row) error for { // wait for either an item or an error select { - case <-d.stream.Context().Done(): - d.streamingError = errors.New(contextCancelledError) case err := <-d.errorChan: log.Printf("[ERROR] streamRows error chan select: %v\n", err) return err @@ -306,10 +302,6 @@ func (d *QueryData) streamRows(_ context.Context, rowChan chan *proto.Row) error return nil } if err := d.streamRow(row); err != nil { - // if there was an error streaming, store in d.streamingError - // - this is checked by the thread streaming list items and will cause it to terminate - d.streamingError = err - return err } }