diff --git a/runner/data.go b/runner/data.go index 68ac1337..85ef4e92 100644 --- a/runner/data.go +++ b/runner/data.go @@ -44,6 +44,16 @@ type StreamMessageProviderFunc func(*CallData) (*dynamic.Message, error) // Clients can return ErrEndStream to end the call early type StreamRecvMsgInterceptFunc func(*dynamic.Message, error) error +// StreamInterceptorProviderFunc is an interface for a function invoked to generate a stream interceptor +type StreamInterceptorProviderFunc func(*CallData) StreamInterceptor + +// StreamInterceptor is an interface for sending and receiving stream messages. +// The interceptor can keep shared state for the send and receive calls. +type StreamInterceptor interface { + Recv(*dynamic.Message, error) error + Send(*CallData) (*dynamic.Message, error) +} + type dataProvider struct { binary bool data []byte diff --git a/runner/options.go b/runner/options.go index d2d39925..2292aa48 100644 --- a/runner/options.go +++ b/runner/options.go @@ -129,12 +129,13 @@ type RunConfig struct { disableTemplateData bool // misc - name string - cpus int - tags []byte - skipFirst int - countErrors bool - recvMsgFunc StreamRecvMsgInterceptFunc + name string + cpus int + tags []byte + skipFirst int + countErrors bool + recvMsgFunc StreamRecvMsgInterceptFunc + streamInterceptorProviderFunc StreamInterceptorProviderFunc } // Option controls some aspect of run @@ -1034,6 +1035,15 @@ func WithStreamRecvMsgIntercept(fn StreamRecvMsgInterceptFunc) Option { } } +// WithStreamInterceptor specifies the stream interceptor provider function +func WithStreamInterceptorProviderFunc(interceptor StreamInterceptorProviderFunc) Option { + return func(o *RunConfig) error { + o.streamInterceptorProviderFunc = interceptor + + return nil + } +} + // WithDataProvider provides custom data provider // // WithDataProvider(func(*CallData) ([]*dynamic.Message, error) { diff --git a/runner/requester.go b/runner/requester.go index 53a9f093..cbbaadea 100644 --- a/runner/requester.go +++ b/runner/requester.go @@ -389,17 +389,18 @@ func (b *Requester) runWorkers(wt load.WorkerTicker, p load.Pacer) error { } w := Worker{ - ticks: ticks, - active: true, - stub: b.stubs[n], - mtd: b.mtd, - config: b.config, - stopCh: make(chan bool), - workerID: wID, - dataProvider: b.dataProvider, - metadataProvider: b.metadataProvider, - streamRecv: b.config.recvMsgFunc, - msgProvider: b.config.dataStreamFunc, + ticks: ticks, + active: true, + stub: b.stubs[n], + mtd: b.mtd, + config: b.config, + stopCh: make(chan bool), + workerID: wID, + dataProvider: b.dataProvider, + metadataProvider: b.metadataProvider, + streamRecv: b.config.recvMsgFunc, + msgProvider: b.config.dataStreamFunc, + streamInterceptorProviderFunc: b.config.streamInterceptorProviderFunc, } wc++ // increment worker id diff --git a/runner/worker.go b/runner/worker.go index 58ee2bb1..2bf3c425 100644 --- a/runner/worker.go +++ b/runner/worker.go @@ -40,7 +40,8 @@ type Worker struct { metadataProvider MetadataProviderFunc msgProvider StreamMessageProviderFunc - streamRecv StreamRecvMsgInterceptFunc + streamRecv StreamRecvMsgInterceptFunc + streamInterceptorProviderFunc StreamInterceptorProviderFunc } func (w *Worker) runWorker() error { @@ -83,6 +84,13 @@ func (w *Worker) makeRequest(tv TickValue) error { ctd := newCallData(w.mtd, w.workerID, reqNum, !w.config.disableTemplateFuncs, !w.config.disableTemplateData, w.config.funcs) + var streamInterceptor StreamInterceptor + if w.mtd.IsClientStreaming() || w.mtd.IsServerStreaming() { + if w.streamInterceptorProviderFunc != nil { + streamInterceptor = w.streamInterceptorProviderFunc(ctd) + } + } + reqMD, err := w.metadataProvider(ctd) if err != nil { return err @@ -115,6 +123,8 @@ func (w *Worker) makeRequest(tv TickValue) error { var msgProvider StreamMessageProviderFunc if w.msgProvider != nil { msgProvider = w.msgProvider + } else if streamInterceptor != nil { + msgProvider = streamInterceptor.Send } else if w.mtd.IsClientStreaming() { if w.config.streamDynamicMessages { mp, err := newDynamicMessageProvider(w.mtd, w.config.data, w.config.streamCallCount, !w.config.disableTemplateFuncs, !w.config.disableTemplateData) @@ -155,11 +165,11 @@ func (w *Worker) makeRequest(tv TickValue) error { // RPC errors are handled via stats handler if w.mtd.IsClientStreaming() && w.mtd.IsServerStreaming() { - _ = w.makeBidiRequest(&ctx, ctd, msgProvider) + _ = w.makeBidiRequest(&ctx, ctd, msgProvider, streamInterceptor) } else if w.mtd.IsClientStreaming() { _ = w.makeClientStreamingRequest(&ctx, ctd, msgProvider) } else if w.mtd.IsServerStreaming() { - _ = w.makeServerStreamingRequest(&ctx, inputs[0]) + _ = w.makeServerStreamingRequest(&ctx, inputs[0], streamInterceptor) } else { _ = w.makeUnaryRequest(&ctx, reqMD, inputs[0]) } @@ -314,7 +324,7 @@ func (w *Worker) makeClientStreamingRequest(ctx *context.Context, return nil } -func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic.Message) error { +func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic.Message, streamInterceptor StreamInterceptor) error { var callOptions = []grpc.CallOption{} if w.config.enableCompression { callOptions = append(callOptions, grpc.UseCompressor(gzip.Name)) @@ -388,6 +398,18 @@ func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic } } + if streamInterceptor != nil { + if converted, ok := res.(*dynamic.Message); ok { + err = streamInterceptor.Recv(converted, err) + if errors.Is(err, ErrEndStream) && !interceptCanceled { + interceptCanceled = true + err = nil + + callCancel() + } + } + } + if err != nil { if err == io.EOF { err = nil @@ -415,7 +437,7 @@ func (w *Worker) makeServerStreamingRequest(ctx *context.Context, input *dynamic } func (w *Worker) makeBidiRequest(ctx *context.Context, - ctd *CallData, messageProvider StreamMessageProviderFunc) error { + ctd *CallData, messageProvider StreamMessageProviderFunc, streamInterceptor StreamInterceptor) error { var callOptions = []grpc.CallOption{} @@ -494,6 +516,19 @@ func (w *Worker) makeBidiRequest(ctx *context.Context, } } + if streamInterceptor != nil { + if converted, ok := res.(*dynamic.Message); ok { + iErr := streamInterceptor.Recv(converted, recvErr) + if errors.Is(iErr, ErrEndStream) && !interceptCanceled { + interceptCanceled = true + if len(cancel) == 0 { + cancel <- struct{}{} + } + recvErr = nil + } + } + } + if recvErr != nil { close(recvDone) break