diff --git a/pkg/runner/grpc.go b/pkg/runner/grpc.go index 24fedf923..8053e5080 100644 --- a/pkg/runner/grpc.go +++ b/pkg/runner/grpc.go @@ -90,17 +90,15 @@ func (r *gRPCTestCaseRunner) RunTestCase(testcase *testing.TestCase, dataContext func invokeRequest(ctx context.Context, md protoreflect.MethodDescriptor, payload string, conn *grpc.ClientConn) (respones []string, err error) { resps := make([]*dynamicpb.Message, 0) - if md.IsStreamingClient() || md.IsStreamingServer() { - result := gjson.Parse(payload) - if !result.IsArray() { + gpayload := gjson.Parse(payload) + if !gpayload.IsArray() { return nil, fmt.Errorf("payload is not a JSON array") } - reqs := make([]*dynamicpb.Message, len(result.Array())) - for i, v := range result.Array() { - req := dynamicpb.NewMessage(md.Input()) - err := protojson.Unmarshal([]byte(v.Raw), req) + reqs := make([]*dynamicpb.Message, len(gpayload.Array())) + for i, v := range gpayload.Array() { + req,err:=getReqMessagePb(md,v.Raw) if err != nil { return nil, err } @@ -113,12 +111,9 @@ func invokeRequest(ctx context.Context, md protoreflect.MethodDescriptor, payloa } } else { - request := dynamicpb.NewMessage(md.Input()) - if payload != "" { - err = protojson.Unmarshal([]byte(payload), request) - if err != nil { - return nil, err - } + request, err := getReqMessagePb(md, payload) + if err != nil { + return nil, err } resp, err := invokeRPC(ctx, conn, md, request) @@ -128,6 +123,21 @@ func invokeRequest(ctx context.Context, md protoreflect.MethodDescriptor, payloa resps = append(resps, resp) } + return buildResponses(resps) +} + +func getReqMessagePb(md protoreflect.MethodDescriptor, message string) (messagepb *dynamicpb.Message, err error) { + request := dynamicpb.NewMessage(md.Input()) + if message != "" { + err := protojson.Unmarshal([]byte(message), request) + if err != nil { + return nil, err + } + } + return request, nil +} + +func buildResponses(resps []*dynamicpb.Message) ([]string, error) { respsStr := make([]string, 0) for i := range resps { respbR, err := protojson.Marshal(resps[i]) @@ -289,7 +299,7 @@ func payloadFieldsVerify(caseName string, expect testing.Response, jsonPayload [ msg += err.Error() } } - + if msg != "" { return fmt.Errorf(msg) }