Skip to content

Commit

Permalink
Expose request method of unary requests to clients and server handlers (
Browse files Browse the repository at this point in the history
#502)

Add ability for clients and servers to inspect the HTTP method
used for unary RPCs. Combined with #494, this enables support
for conditional GET requests.
  • Loading branch information
jhump authored May 17, 2023
1 parent 5ae0544 commit ef2d883
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 28 deletions.
17 changes: 12 additions & 5 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ func NewClient[Req, Res any](httpClient HTTPClient, url string, options ...Clien
unarySpec := config.newSpec(StreamTypeUnary)
unaryFunc := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
conn := client.protocolClient.NewConn(ctx, unarySpec, request.Header())
conn.onRequestSend(func(r *http.Request) {
request.setRequestMethod(r.Method)
})
// Send always returns an io.EOF unless the error is from the client-side.
// We want the user to continue to call Receive in those cases to get the
// full error from the server-side.
Expand Down Expand Up @@ -132,15 +135,17 @@ func (c *Client[Req, Res]) CallClientStream(ctx context.Context) *ClientStreamFo
if c.err != nil {
return &ClientStreamForClient[Req, Res]{err: c.err}
}
return &ClientStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeClient)}
return &ClientStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeClient, nil)}
}

// CallServerStream calls a server streaming procedure.
func (c *Client[Req, Res]) CallServerStream(ctx context.Context, request *Request[Req]) (*ServerStreamForClient[Res], error) {
if c.err != nil {
return nil, c.err
}
conn := c.newConn(ctx, StreamTypeServer)
conn := c.newConn(ctx, StreamTypeServer, func(r *http.Request) {
request.method = r.Method
})
request.spec = conn.Spec()
request.peer = conn.Peer()
mergeHeaders(conn.RequestHeader(), request.header)
Expand All @@ -163,14 +168,16 @@ func (c *Client[Req, Res]) CallBidiStream(ctx context.Context) *BidiStreamForCli
if c.err != nil {
return &BidiStreamForClient[Req, Res]{err: c.err}
}
return &BidiStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeBidi)}
return &BidiStreamForClient[Req, Res]{conn: c.newConn(ctx, StreamTypeBidi, nil)}
}

func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType) StreamingClientConn {
func (c *Client[Req, Res]) newConn(ctx context.Context, streamType StreamType, onRequestSend func(r *http.Request)) StreamingClientConn {
newConn := func(ctx context.Context, spec Spec) StreamingClientConn {
header := make(http.Header, 8) // arbitrary power of two, prevent immediate resizing
c.protocolClient.WriteRequestHeader(streamType, header)
return c.protocolClient.NewConn(ctx, spec, header)
conn := c.protocolClient.NewConn(ctx, spec, header)
conn.onRequestSend(onRequestSend)
return conn
}
if interceptor := c.config.Interceptor; interceptor != nil {
newConn = interceptor.WrapStreamingClient(newConn)
Expand Down
25 changes: 15 additions & 10 deletions client_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func TestClientPeer(t *testing.T) {
server.StartTLS()
t.Cleanup(server.Close)

run := func(t *testing.T, opts ...connect.ClientOption) {
run := func(t *testing.T, unaryHTTPMethod string, opts ...connect.ClientOption) {
t.Helper()
client := pingv1connect.NewPingServiceClient(
server.Client(),
Expand All @@ -90,8 +90,10 @@ func TestClientPeer(t *testing.T) {
)
ctx := context.Background()
// unary
_, err := client.Ping(ctx, connect.NewRequest[pingv1.PingRequest](nil))
unaryReq := connect.NewRequest[pingv1.PingRequest](nil)
_, err := client.Ping(ctx, unaryReq)
assert.Nil(t, err)
assert.Equal(t, unaryHTTPMethod, unaryReq.HTTPMethod())
text := strings.Repeat(".", 256)
r, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{Text: text}))
assert.Nil(t, err)
Expand Down Expand Up @@ -126,22 +128,22 @@ func TestClientPeer(t *testing.T) {

t.Run("connect", func(t *testing.T) {
t.Parallel()
run(t)
run(t, http.MethodPost)
})
t.Run("connect+get", func(t *testing.T) {
t.Parallel()
run(t,
run(t, http.MethodGet,
connect.WithHTTPGet(),
connect.WithSendGzip(),
)
})
t.Run("grpc", func(t *testing.T) {
t.Parallel()
run(t, connect.WithGRPC())
run(t, http.MethodPost, connect.WithGRPC())
})
t.Run("grpcweb", func(t *testing.T) {
t.Parallel()
run(t, connect.WithGRPCWeb())
run(t, http.MethodPost, connect.WithGRPCWeb())
})
}

Expand All @@ -167,21 +169,24 @@ func TestGetNotModified(t *testing.T) {
)
ctx := context.Background()
// unconditional request
res, err := client.Ping(ctx, connect.NewRequest(&pingv1.PingRequest{}))
unaryReq := connect.NewRequest(&pingv1.PingRequest{})
res, err := client.Ping(ctx, unaryReq)
assert.Nil(t, err)
assert.Equal(t, res.Header().Get("Etag"), etag)
assert.Equal(t, res.Header().Values("Vary"), expectVary)
assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod())

conditional := connect.NewRequest(&pingv1.PingRequest{})
conditional.Header().Set("If-None-Match", etag)
_, err = client.Ping(ctx, conditional)
unaryReq = connect.NewRequest(&pingv1.PingRequest{})
unaryReq.Header().Set("If-None-Match", etag)
_, err = client.Ping(ctx, unaryReq)
assert.NotNil(t, err)
assert.Equal(t, connect.CodeOf(err), connect.CodeUnknown)
assert.True(t, connect.IsNotModifiedError(err))
var connectErr *connect.Error
assert.True(t, errors.As(err, &connectErr))
assert.Equal(t, connectErr.Meta().Get("Etag"), etag)
assert.Equal(t, connectErr.Meta().Values("Vary"), expectVary)
assert.Equal(t, http.MethodGet, unaryReq.HTTPMethod())
}

type notModifiedPingServer struct {
Expand Down
24 changes: 23 additions & 1 deletion connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ type Request[T any] struct {
spec Spec
peer Peer
header http.Header
method string
}

// NewRequest wraps a generated request message.
Expand Down Expand Up @@ -187,9 +188,28 @@ func (r *Request[_]) Header() http.Header {
return r.header
}

// HTTPMethod returns the HTTP method for this request. This is nearly always
// POST, but side-effect-free unary RPCs could be made via a GET.
//
// On a newly created request, via NewRequest, this will return the empty
// string until the actual request is actually sent and the HTTP method
// determined. This means that client interceptor functions will see the
// empty string until *after* they delegate to the handler they wrapped. It
// is even possible for this to return the empty string after such delegation,
// if the request was never actually sent to the server (and thus no
// determination ever made about the HTTP method).
func (r *Request[_]) HTTPMethod() string {
return r.method
}

// internalOnly implements AnyRequest.
func (r *Request[_]) internalOnly() {}

// setRequestMethod sets the request method to the given value.
func (r *Request[_]) setRequestMethod(method string) {
r.method = method
}

// AnyRequest is the common method set of every [Request], regardless of type
// parameter. It's used in unary interceptors.
//
Expand All @@ -205,8 +225,10 @@ type AnyRequest interface {
Spec() Spec
Peer() Peer
Header() http.Header
HTTPMethod() string

internalOnly()
setRequestMethod(string)
}

// Response is a wrapper around a generated response message. It provides
Expand Down Expand Up @@ -322,7 +344,7 @@ func newPeerFromURL(url *url.URL, protocol string) Peer {
}
}

// handlerConnCloser extends HandlerConn with a method for handlers to
// handlerConnCloser extends StreamingHandlerConn with a method for handlers to
// terminate the message exchange (and optionally send an error to the client).
type handlerConnCloser interface {
StreamingHandlerConn
Expand Down
4 changes: 4 additions & 0 deletions duplex_http_call.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type duplexHTTPCall struct {
ctx context.Context
httpClient HTTPClient
streamType StreamType
onRequestSend func(*http.Request)
validateResponse func(*http.Response) *Error

// We'll use a pipe as the request body. We hand the read side of the pipe to
Expand Down Expand Up @@ -255,6 +256,9 @@ func (d *duplexHTTPCall) makeRequest() {
// on d.responseReady, so we can't race with them.
defer close(d.responseReady)

if d.onRequestSend != nil {
d.onRequestSend(d.request)
}
// Once we send a message to the server, they send a message back and
// establish the receive side of the stream.
response, err := d.httpClient.Do(d.request) //nolint:bodyclose
Expand Down
6 changes: 6 additions & 0 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,16 @@ func NewUnaryHandler[Req, Res any](
if err := conn.Receive(&msg); err != nil {
return err
}
method := http.MethodPost
if hasRequestMethod, ok := conn.(interface{ getHTTPMethod() string }); ok {
method = hasRequestMethod.getHTTPMethod()
}
request := &Request[Req]{
Msg: &msg,
spec: conn.Spec(),
peer: conn.Peer(),
header: conn.RequestHeader(),
method: method,
}
response, err := untyped(ctx, request)
if err != nil {
Expand Down Expand Up @@ -141,6 +146,7 @@ func NewServerStreamHandler[Req, Res any](
spec: conn.Spec(),
peer: conn.Peer(),
header: conn.RequestHeader(),
method: http.MethodPost,
},
&ServerStream[Res]{conn: conn},
)
Expand Down
Loading

0 comments on commit ef2d883

Please sign in to comment.