From 4006710b9a6e8aaa8e02d7d6d1e6a51253e9257f Mon Sep 17 00:00:00 2001 From: Rob Skillington Date: Mon, 25 Jul 2016 22:29:12 -0400 Subject: [PATCH] Pass context to Thrift post response callback (#465) --- thrift/options.go | 7 +++++-- thrift/server.go | 2 +- thrift/thrift_test.go | 13 ++++++++++++- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/thrift/options.go b/thrift/options.go index 06985bf2..050a3282 100644 --- a/thrift/options.go +++ b/thrift/options.go @@ -20,7 +20,10 @@ package thrift -import "github.com/apache/thrift/lib/go/thrift" +import ( + "github.com/apache/thrift/lib/go/thrift" + "golang.org/x/net/context" +) // RegisterOption is the interface for options to Register. type RegisterOption interface { @@ -30,7 +33,7 @@ type RegisterOption interface { // PostResponseCB registers a callback that is run after a response has been // compeltely processed (e.g. written to the channel). // This gives the server a chance to clean up resources from the response object -type PostResponseCB func(method string, response thrift.TStruct) +type PostResponseCB func(ctx context.Context, method string, response thrift.TStruct) type optPostResponse PostResponseCB diff --git a/thrift/server.go b/thrift/server.go index 789dce63..af314905 100644 --- a/thrift/server.go +++ b/thrift/server.go @@ -168,7 +168,7 @@ func (s *Server) handle(origCtx context.Context, handler handler, method string, err = writer.Close() if handler.postResponseCB != nil { - handler.postResponseCB(method, resp) + handler.postResponseCB(ctx, method, resp) } return err diff --git a/thrift/thrift_test.go b/thrift/thrift_test.go index 317543b4..1be91ca7 100644 --- a/thrift/thrift_test.go +++ b/thrift/thrift_test.go @@ -238,6 +238,15 @@ func TestClientHostPort(t *testing.T) { func TestRegisterPostResponseCB(t *testing.T) { withSetup(t, func(ctx Context, args testArgs) { + var createdCtx Context + ctxKey := "key" + ctxValue := "value" + + args.server.SetContextFn(func(ctx context.Context, method string, headers map[string]string) Context { + createdCtx = WithHeaders(context.WithValue(ctx, ctxKey, ctxValue), headers) + return createdCtx + }) + arg := &gen.Data{ B1: true, S2: "str", @@ -250,8 +259,10 @@ func TestRegisterPostResponseCB(t *testing.T) { } called := make(chan struct{}) - cb := func(method string, response thrift.TStruct) { + cb := func(reqCtx context.Context, method string, response thrift.TStruct) { assert.Equal(t, "Call", method) + assert.Equal(t, createdCtx, reqCtx) + assert.Equal(t, ctxValue, reqCtx.Value(ctxKey)) res, ok := response.(*gen.SimpleServiceCallResult) if assert.True(t, ok, "response type should be Result struct") { assert.Equal(t, ret, res.GetSuccess(), "result should be returned value")