From ef2d883fb004805f080b14011add8dae667c3fa2 Mon Sep 17 00:00:00 2001 From: Joshua Humphries <2035234+jhump@users.noreply.github.com> Date: Wed, 17 May 2023 09:14:16 -0400 Subject: [PATCH] Expose request method of unary requests to clients and server handlers (#502) Add ability for clients and servers to inspect the HTTP method used for unary RPCs. Combined with #494, this enables support for conditional GET requests. --- client.go | 17 ++++-- client_ext_test.go | 25 ++++---- connect.go | 24 +++++++- duplex_http_call.go | 4 ++ handler.go | 6 ++ interceptor_ext_test.go | 127 ++++++++++++++++++++++++++++++++++++++++ protocol.go | 28 ++++++--- protocol_connect.go | 18 +++++- protocol_grpc.go | 6 +- 9 files changed, 227 insertions(+), 28 deletions(-) diff --git a/client.go b/client.go index 08a1c066..c1179e44 100644 --- a/client.go +++ b/client.go @@ -77,6 +77,9 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien unarySpec := config.newSpec(StreamTypeUnary) unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) { conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header()) + conn.onRequestSend(func(r *http.Request) { + request.setRequestMethod(r.Method) + }) // Send always returns an io.EOF unless the error is from the client-side. // We want the user to continue to call Receive in those cases to get the // full error from the server-side. @@ -132,7 +135,7 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo if c.err != nil { return &ClientStreamForClient[Req, Res]{err: c.err} } - return &ClientStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeClient)} + return &ClientStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeClient, nil)} } // CallServerStream calls a server streaming procedure. @@ -140,7 +143,9 @@ func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Reques if c.err != nil { return nil, c.err } - conn := c.newConn(ctx, StreamTypeServer) + conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) { + request.method = r.Method + }) request.spec = conn.Spec() request.peer = conn.Peer() mergeHeaders(conn.RequestHeader(), request.header) @@ -163,14 +168,16 @@ func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForCli if c.err != nil { return &BidiStreamForClient[Req, Res]{err: c.err} } - return &BidiStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeBidi)} + return &BidiStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeBidi, nil)} } -func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType) StreamingClientConn { +func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, onRequestSend func(r *http.Request)) StreamingClientConn { newConn := func(ctx context.Context, spec Spec) StreamingClientConn { header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing c.protocolClient.WriteRequestHeader(streamType, header) - return c.protocolClient.NewConn(ctx, spec, header) + conn := c.protocolClient.NewConn(ctx, spec, header) + conn.onRequestSend(onRequestSend) + return conn } if interceptor := c.config.Interceptor; interceptor != nil { newConn = interceptor.WrapStreamingClient(newConn) diff --git a/client_ext_test.go b/client_ext_test.go index 5dc85d5d..4ba3bc68 100644 --- a/client_ext_test.go +++ b/client_ext_test.go @@ -80,7 +80,7 @@ func TestClientPeer(t *testing.T) { server.StartTLS() t.Cleanup(server.Close) - run := func(t *testing.T, opts ...connect.ClientOption) { + run := func(t *testing.T, unaryHTTPMethod string, opts ...connect.ClientOption) { t.Helper() client := pingv1connect.NewPingServiceClient( server.Client(), @@ -90,8 +90,10 @@ func TestClientPeer(t *testing.T) { ) ctx := context.Background() // unary - _, err := client.Ping(ctx, connect.NewRequest[pingv1.PingRequest](nil)) + unaryReq := connect.NewRequest[pingv1.PingRequest](nil) + _, err := client.Ping(ctx, unaryReq) assert.Nil(t, err) + assert.Equal(t, unaryHTTPMethod, unaryReq.HTTPMethod()) text := strings.Repeat(".", 256) r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text})) assert.Nil(t, err) @@ -126,22 +128,22 @@ func TestClientPeer(t *testing.T) { t.Run("connect", func(t *testing.T) { t.Parallel() - run(t) + run(t, http.MethodPost) }) t.Run("connect+get", func(t *testing.T) { t.Parallel() - run(t, + run(t, http.MethodGet, connect.WithHTTPGet(), connect.WithSendGzip(), ) }) t.Run("grpc", func(t *testing.T) { t.Parallel() - run(t, connect.WithGRPC()) + run(t, http.MethodPost, connect.WithGRPC()) }) t.Run("grpcweb", func(t *testing.T) { t.Parallel() - run(t, connect.WithGRPCWeb()) + run(t, http.MethodPost, connect.WithGRPCWeb()) }) } @@ -167,14 +169,16 @@ func TestGetNotModified(t *testing.T) { ) ctx := context.Background() // unconditional request - res, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{})) + unaryReq := connect.NewRequest(&pingv1.PingRequest{}) + res, err := client.Ping(ctx, unaryReq) assert.Nil(t, err) assert.Equal(t, res.Header().Get("Etag"), etag) assert.Equal(t, res.Header().Values("Vary"), expectVary) + assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod()) - conditional := connect.NewRequest(&pingv1.PingRequest{}) - conditional.Header().Set("If-None-Match", etag) - _, err = client.Ping(ctx, conditional) + unaryReq = connect.NewRequest(&pingv1.PingRequest{}) + unaryReq.Header().Set("If-None-Match", etag) + _, err = client.Ping(ctx, unaryReq) assert.NotNil(t, err) assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown) assert.True(t, connect.IsNotModifiedError(err)) @@ -182,6 +186,7 @@ func TestGetNotModified(t *testing.T) { assert.True(t, errors.As(err, &connectErr)) assert.Equal(t, connectErr.Meta().Get("Etag"), etag) assert.Equal(t, connectErr.Meta().Values("Vary"), expectVary) + assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod()) } type notModifiedPingServer struct { diff --git a/connect.go b/connect.go index 4c486422..50b1810f 100644 --- a/connect.go +++ b/connect.go @@ -150,6 +150,7 @@ type Request[T any] struct { spec Spec peer Peer header http.Header + method string } // NewRequest wraps a generated request message. @@ -187,9 +188,28 @@ func (r *Request[_]) Header() http.Header { return r.header } +// HTTPMethod returns the HTTP method for this request. This is nearly always +// POST, but side-effect-free unary RPCs could be made via a GET. +// +// On a newly created request, via NewRequest, this will return the empty +// string until the actual request is actually sent and the HTTP method +// determined. This means that client interceptor functions will see the +// empty string until *after* they delegate to the handler they wrapped. It +// is even possible for this to return the empty string after such delegation, +// if the request was never actually sent to the server (and thus no +// determination ever made about the HTTP method). +func (r *Request[_]) HTTPMethod() string { + return r.method +} + // internalOnly implements AnyRequest. func (r *Request[_]) internalOnly() {} +// setRequestMethod sets the request method to the given value. +func (r *Request[_]) setRequestMethod(method string) { + r.method = method +} + // AnyRequest is the common method set of every [Request], regardless of type // parameter. It's used in unary interceptors. // @@ -205,8 +225,10 @@ type AnyRequest interface { Spec() Spec Peer() Peer Header() http.Header + HTTPMethod() string internalOnly() + setRequestMethod(string) } // Response is a wrapper around a generated response message. It provides @@ -322,7 +344,7 @@ func newPeerFromURL(url *url.URL, protocol string) Peer { } } -// handlerConnCloser extends HandlerConn with a method for handlers to +// handlerConnCloser extends StreamingHandlerConn with a method for handlers to // terminate the message exchange (and optionally send an error to the client). type handlerConnCloser interface { StreamingHandlerConn diff --git a/duplex_http_call.go b/duplex_http_call.go index 77a09a4d..efe065bf 100644 --- a/duplex_http_call.go +++ b/duplex_http_call.go @@ -33,6 +33,7 @@ type duplexHTTPCall struct { ctx context.Context httpClient HTTPClient streamType StreamType + onRequestSend func(*http.Request) validateResponse func(*http.Response) *Error // We'll use a pipe as the request body. We hand the read side of the pipe to @@ -255,6 +256,9 @@ func (d *duplexHTTPCall) makeRequest() { // on d.responseReady, so we can't race with them. defer close(d.responseReady) + if d.onRequestSend != nil { + d.onRequestSend(d.request) + } // Once we send a message to the server, they send a message back and // establish the receive side of the stream. response, err := d.httpClient.Do(d.request) //nolint:bodyclose diff --git a/handler.go b/handler.go index 89ceea13..af3d215b 100644 --- a/handler.go +++ b/handler.go @@ -67,11 +67,16 @@ func NewUnaryHandler[Req, Res any]( if err := conn.Receive(&msg); err != nil { return err } + method := http.MethodPost + if hasRequestMethod, ok := conn.(interface{ getHTTPMethod() string }); ok { + method = hasRequestMethod.getHTTPMethod() + } request := &Request[Req]{ Msg: &msg, spec: conn.Spec(), peer: conn.Peer(), header: conn.RequestHeader(), + method: method, } response, err := untyped(ctx, request) if err != nil { @@ -141,6 +146,7 @@ func NewServerStreamHandler[Req, Res any]( spec: conn.Spec(), peer: conn.Peer(), header: conn.RequestHeader(), + method: http.MethodPost, }, &ServerStream[Res]{conn: conn}, ) diff --git a/interceptor_ext_test.go b/interceptor_ext_test.go index af5a1edf..ade9bf1e 100644 --- a/interceptor_ext_test.go +++ b/interceptor_ext_test.go @@ -16,8 +16,10 @@ package connect_test import ( "context" + "fmt" "net/http" "net/http/httptest" + "sync/atomic" "testing" "connectrpc.com/connect" @@ -66,6 +68,8 @@ func TestOnionOrderingEndToEnd(t *testing.T) { } } + var client1, client2, client3, handler1, handler2, handler3 atomic.Int32 + // The client and handler interceptor onions are the meat of the test. The // order of interceptor execution must be the same for unary and streaming // procedures. @@ -79,6 +83,7 @@ func TestOnionOrderingEndToEnd(t *testing.T) { // intended order clear. clientOnion := connect.WithInterceptors( newHeaderInterceptor( + &client1, // 1 (start). request: should see protocol-related headers func(_ connect.Spec, h http.Header) { assert.NotZero(t, h.Get("Content-Type")) @@ -87,24 +92,29 @@ func TestOnionOrderingEndToEnd(t *testing.T) { assertAllPresent, ), newHeaderInterceptor( + &client2, newInspector("", "one"), // 2. request: add header "one" newInspector("three", "four"), // 11. response: check "three", add "four" ), newHeaderInterceptor( + &client3, newInspector("one", "two"), // 3. request: check "one", add "two" newInspector("two", "three"), // 10. response: check "two", add "three" ), ) handlerOnion := connect.WithInterceptors( newHeaderInterceptor( + &handler1, newInspector("two", "three"), // 4. request: check "two", add "three" newInspector("one", "two"), // 9. response: check "one", add "two" ), newHeaderInterceptor( + &handler2, newInspector("three", "four"), // 5. request: check "three", add "four" newInspector("", "one"), // 8. response: add "one" ), newHeaderInterceptor( + &handler3, assertAllPresent, // 6. request: check "one"-"four" nil, // 7. response: no-op ), @@ -129,6 +139,14 @@ func TestOnionOrderingEndToEnd(t *testing.T) { _, err := client.Ping(context.Background(), connect.NewRequest(&pingv1.PingRequest{Number: 10})) assert.Nil(t, err) + // make sure the interceptors were actually invoked + assert.Equal(t, int32(1), client1.Load()) + assert.Equal(t, int32(1), client2.Load()) + assert.Equal(t, int32(1), client3.Load()) + assert.Equal(t, int32(1), handler1.Load()) + assert.Equal(t, int32(1), handler2.Load()) + assert.Equal(t, int32(1), handler3.Load()) + responses, err := client.CountUp(context.Background(), connect.NewRequest(&pingv1.CountUpRequest{Number: 10})) assert.Nil(t, err) var sum int64 @@ -137,6 +155,14 @@ func TestOnionOrderingEndToEnd(t *testing.T) { } assert.Equal(t, sum, 55) assert.Nil(t, responses.Close()) + + // make sure the interceptors were invoked again + assert.Equal(t, int32(2), client1.Load()) + assert.Equal(t, int32(2), client2.Load()) + assert.Equal(t, int32(2), client3.Load()) + assert.Equal(t, int32(2), handler1.Load()) + assert.Equal(t, int32(2), handler2.Load()) + assert.Equal(t, int32(2), handler3.Load()) } func TestEmptyUnaryInterceptorFunc(t *testing.T) { @@ -166,6 +192,54 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { assert.Nil(t, countUpStream.Close()) } +func TestInterceptorFuncAccessingHTTPMethod(t *testing.T) { + t.Parallel() + clientChecker := &httpMethodChecker{client: true} + handlerChecker := &httpMethodChecker{} + + mux := http.NewServeMux() + mux.Handle( + pingv1connect.NewPingServiceHandler( + pingServer{}, + connect.WithInterceptors(handlerChecker), + ), + ) + server := httptest.NewServer(mux) + defer server.Close() + + client := pingv1connect.NewPingServiceClient( + server.Client(), + server.URL, + connect.WithInterceptors(clientChecker), + ) + + pingReq := connect.NewRequest(&pingv1.PingRequest{Number: 10}) + assert.Equal(t, "", pingReq.HTTPMethod()) + _, err := client.Ping(context.Background(), pingReq) + assert.Nil(t, err) + assert.Equal(t, http.MethodPost, pingReq.HTTPMethod()) + + // make sure interceptor was invoked + assert.Equal(t, int32(1), clientChecker.count.Load()) + assert.Equal(t, int32(1), handlerChecker.count.Load()) + + countUpReq := connect.NewRequest(&pingv1.CountUpRequest{Number: 10}) + assert.Equal(t, "", countUpReq.HTTPMethod()) + responses, err := client.CountUp(context.Background(), countUpReq) + assert.Nil(t, err) + var sum int64 + for responses.Receive() { + sum += responses.Msg().Number + } + assert.Equal(t, sum, 55) + assert.Nil(t, responses.Close()) + assert.Equal(t, http.MethodPost, countUpReq.HTTPMethod()) + + // make sure interceptor was invoked again + assert.Equal(t, int32(2), clientChecker.count.Load()) + assert.Equal(t, int32(2), handlerChecker.count.Load()) +} + // headerInterceptor makes it easier to write interceptors that inspect or // mutate HTTP headers. It applies the same logic to unary and streaming // procedures, wrapping the send or receive side of the stream as appropriate. @@ -173,6 +247,7 @@ func TestEmptyUnaryInterceptorFunc(t *testing.T) { // It's useful as a testing harness to make sure that we're chaining // interceptors in the correct order. type headerInterceptor struct { + counter *atomic.Int32 inspectRequestHeader func(connect.Spec, http.Header) inspectResponseHeader func(connect.Spec, http.Header) } @@ -180,10 +255,12 @@ type headerInterceptor struct { // newHeaderInterceptor constructs a headerInterceptor. Nil function pointers // are treated as no-ops. func newHeaderInterceptor( + counter *atomic.Int32, inspectRequestHeader func(connect.Spec, http.Header), inspectResponseHeader func(connect.Spec, http.Header), ) *headerInterceptor { interceptor := headerInterceptor{ + counter: counter, inspectRequestHeader: inspectRequestHeader, inspectResponseHeader: inspectResponseHeader, } @@ -198,6 +275,7 @@ func newHeaderInterceptor( func (h *headerInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + h.counter.Add(1) h.inspectRequestHeader(req.Spec(), req.Header()) res, err := next(ctx, req) if err != nil { @@ -210,6 +288,7 @@ func (h *headerInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc func (h *headerInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { + h.counter.Add(1) return &headerInspectingClientConn{ StreamingClientConn: next(ctx, spec), inspectRequestHeader: h.inspectRequestHeader, @@ -220,6 +299,7 @@ func (h *headerInterceptor) WrapStreamingClient(next connect.StreamingClientFunc func (h *headerInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + h.counter.Add(1) h.inspectRequestHeader(conn.Spec(), conn.RequestHeader()) return next(ctx, &headerInspectingHandlerConn{ StreamingHandlerConn: conn, @@ -268,3 +348,50 @@ func (cc *headerInspectingClientConn) Receive(msg any) error { } return err } + +type httpMethodChecker struct { + client bool + count atomic.Int32 +} + +func (h *httpMethodChecker) WrapUnary(unaryFunc connect.UnaryFunc) connect.UnaryFunc { + return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { + h.count.Add(1) + if h.client { + // should be blank until after we make request + if req.HTTPMethod() != "" { + return nil, fmt.Errorf("expected blank HTTP method but instead got %q", req.HTTPMethod()) + } + } else { + // server interceptors see method from the start + // NB: In theory, the method could also be GET, not just POST. But for the + // configuration under test, it will always be POST. + if req.HTTPMethod() != http.MethodPost { + return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod()) + } + } + resp, err := unaryFunc(ctx, req) + // NB: In theory, the method could also be GET, not just POST. But for the + // configuration under test, it will always be POST. + if req.HTTPMethod() != http.MethodPost { + return nil, fmt.Errorf("expected HTTP method %s but instead got %q", http.MethodPost, req.HTTPMethod()) + } + return resp, err + } +} + +func (h *httpMethodChecker) WrapStreamingClient(clientFunc connect.StreamingClientFunc) connect.StreamingClientFunc { + return func(ctx context.Context, spec connect.Spec) connect.StreamingClientConn { + // method not exposed to streaming interceptor, but that's okay because it's always POST for streams + h.count.Add(1) + return clientFunc(ctx, spec) + } +} + +func (h *httpMethodChecker) WrapStreamingHandler(handlerFunc connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { + return func(ctx context.Context, conn connect.StreamingHandlerConn) error { + // method not exposed to streaming interceptor, but that's okay because it's always POST for streams + h.count.Add(1) + return handlerFunc(ctx, conn) + } +} diff --git a/protocol.go b/protocol.go index 77e899ed..c698f706 100644 --- a/protocol.go +++ b/protocol.go @@ -145,7 +145,15 @@ type protocolClient interface { // been populated by WriteRequestHeader. When constructing a stream for a // unary call, implementations may assume that the Sender's Send and Close // methods return before the Receiver's Receive or Close methods are called. - NewConn(context.Context, Spec, http.Header) StreamingClientConn + NewConn(context.Context, Spec, http.Header) streamingClientConn +} + +// streamingClientConn extends StreamingClientConn with a method for registering +// a hook when the HTTP request is actually sent. +type streamingClientConn interface { + StreamingClientConn + + onRequestSend(fn func(*http.Request)) } // errorTranslatingHandlerConnCloser wraps a handlerConnCloser to ensure that @@ -178,25 +186,29 @@ func (hc *errorTranslatingHandlerConnCloser) Close(err error) error { // // It's used in protocol implementations. type errorTranslatingClientConn struct { - StreamingClientConn + streamingClientConn fromWire func(error) error } func (cc *errorTranslatingClientConn) Send(msg any) error { - return cc.fromWire(cc.StreamingClientConn.Send(msg)) + return cc.fromWire(cc.streamingClientConn.Send(msg)) } func (cc *errorTranslatingClientConn) Receive(msg any) error { - return cc.fromWire(cc.StreamingClientConn.Receive(msg)) + return cc.fromWire(cc.streamingClientConn.Receive(msg)) } func (cc *errorTranslatingClientConn) CloseRequest() error { - return cc.fromWire(cc.StreamingClientConn.CloseRequest()) + return cc.fromWire(cc.streamingClientConn.CloseRequest()) } func (cc *errorTranslatingClientConn) CloseResponse() error { - return cc.fromWire(cc.StreamingClientConn.CloseResponse()) + return cc.fromWire(cc.streamingClientConn.CloseResponse()) +} + +func (cc *errorTranslatingClientConn) onRequestSend(fn func(*http.Request)) { + cc.streamingClientConn.onRequestSend(fn) } // wrapHandlerConnWithCodedErrors ensures that we (1) automatically code @@ -212,9 +224,9 @@ func wrapHandlerConnWithCodedErrors(conn handlerConnCloser) handlerConnCloser { // wrapClientConnWithCodedErrors ensures that we always return *Errors from // public APIs. -func wrapClientConnWithCodedErrors(conn StreamingClientConn) StreamingClientConn { +func wrapClientConnWithCodedErrors(conn streamingClientConn) streamingClientConn { return &errorTranslatingClientConn{ - StreamingClientConn: conn, + streamingClientConn: conn, fromWire: wrapIfUncoded, } } diff --git a/protocol_connect.go b/protocol_connect.go index 0232b1cd..8e6294c7 100644 --- a/protocol_connect.go +++ b/protocol_connect.go @@ -356,7 +356,7 @@ func (c *connectClient) NewConn( ctx context.Context, spec Spec, header http.Header, -) StreamingClientConn { +) streamingClientConn { if deadline, ok := ctx.Deadline(); ok { millis := int64(time.Until(deadline) / time.Millisecond) if millis > 0 { @@ -367,7 +367,7 @@ func (c *connectClient) NewConn( } } duplexCall := newDuplexHTTPCall(ctx, c.HTTPClient, c.URL, spec, header) - var conn StreamingClientConn + var conn streamingClientConn if spec.StreamType == StreamTypeUnary { unaryConn := &connectUnaryClientConn{ spec: spec, @@ -499,6 +499,10 @@ func (cc *connectUnaryClientConn) CloseResponse() error { return cc.duplexCall.CloseRead() } +func (cc *connectUnaryClientConn) onRequestSend(fn func(*http.Request)) { + cc.duplexCall.onRequestSend = fn +} + func (cc *connectUnaryClientConn) validateResponse(response *http.Response) *Error { for k, v := range response.Header { if !strings.HasPrefix(k, connectUnaryTrailerPrefix) { @@ -624,6 +628,10 @@ func (cc *connectStreamingClientConn) CloseResponse() error { return cc.duplexCall.CloseRead() } +func (cc *connectStreamingClientConn) onRequestSend(fn func(*http.Request)) { + cc.duplexCall.onRequestSend = fn +} + func (cc *connectStreamingClientConn) validateResponse(response *http.Response) *Error { if response.StatusCode != http.StatusOK { return errorf(connectHTTPToCode(response.StatusCode), "HTTP status %v", response.Status) @@ -719,6 +727,10 @@ func (hc *connectUnaryHandlerConn) Close(err error) error { return hc.request.Body.Close() } +func (hc *connectUnaryHandlerConn) getHTTPMethod() string { + return hc.request.Method +} + func (hc *connectUnaryHandlerConn) writeResponseHeader(err error) { header := hc.responseWriter.Header() if hc.request.Method == http.MethodGet { @@ -928,7 +940,7 @@ type connectUnaryRequestMarshaler struct { func (m *connectUnaryRequestMarshaler) Marshal(message any) *Error { if m.enableGet { if m.stableCodec == nil && !m.getUseFallback { - return errorf(CodeInternal, "codec %s doesn't support stable marshal; cam't use get", m.codec.Name()) + return errorf(CodeInternal, "codec %s doesn't support stable marshal; can't use get", m.codec.Name()) } if m.stableCodec != nil { return m.marshalWithGet(message) diff --git a/protocol_grpc.go b/protocol_grpc.go index 2b0c7b56..043b13ca 100644 --- a/protocol_grpc.go +++ b/protocol_grpc.go @@ -274,7 +274,7 @@ func (g *grpcClient) NewConn( ctx context.Context, spec Spec, header http.Header, -) StreamingClientConn { +) streamingClientConn { if deadline, ok := ctx.Deadline(); ok { if encodedDeadline, err := grpcEncodeTimeout(time.Until(deadline)); err == nil { // Tests verify that the error in encodeTimeout is unreachable, so we @@ -424,6 +424,10 @@ func (cc *grpcClientConn) CloseResponse() error { return cc.duplexCall.CloseRead() } +func (cc *grpcClientConn) onRequestSend(fn func(*http.Request)) { + cc.duplexCall.onRequestSend = fn +} + func (cc *grpcClientConn) validateResponse(response *http.Response) *Error { if err := grpcValidateResponse( response,