Skip to content

Commit

Permalink
Pass context to Thrift post response callback (#465)
Browse files Browse the repository at this point in the history
  • Loading branch information
robskillington authored and prashantv committed Jul 26, 2016
1 parent e52924f commit 4006710
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 4 deletions.
7 changes: 5 additions & 2 deletions thrift/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

package thrift

import "github.com/apache/thrift/lib/go/thrift"
import (
"github.com/apache/thrift/lib/go/thrift"
"golang.org/x/net/context"
)

// RegisterOption is the interface for options to Register.
type RegisterOption interface {
Expand All @@ -30,7 +33,7 @@ type RegisterOption interface {
// PostResponseCB registers a callback that is run after a response has been
// compeltely processed (e.g. written to the channel).
// This gives the server a chance to clean up resources from the response object
type PostResponseCB func(method string, response thrift.TStruct)
type PostResponseCB func(ctx context.Context, method string, response thrift.TStruct)

type optPostResponse PostResponseCB

Expand Down
2 changes: 1 addition & 1 deletion thrift/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func (s *Server) handle(origCtx context.Context, handler handler, method string,
err = writer.Close()

if handler.postResponseCB != nil {
handler.postResponseCB(method, resp)
handler.postResponseCB(ctx, method, resp)
}

return err
Expand Down
13 changes: 12 additions & 1 deletion thrift/thrift_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,15 @@ func TestClientHostPort(t *testing.T) {

func TestRegisterPostResponseCB(t *testing.T) {
withSetup(t, func(ctx Context, args testArgs) {
var createdCtx Context
ctxKey := "key"
ctxValue := "value"

args.server.SetContextFn(func(ctx context.Context, method string, headers map[string]string) Context {
createdCtx = WithHeaders(context.WithValue(ctx, ctxKey, ctxValue), headers)
return createdCtx
})

arg := &gen.Data{
B1: true,
S2: "str",
Expand All @@ -250,8 +259,10 @@ func TestRegisterPostResponseCB(t *testing.T) {
}

called := make(chan struct{})
cb := func(method string, response thrift.TStruct) {
cb := func(reqCtx context.Context, method string, response thrift.TStruct) {
assert.Equal(t, "Call", method)
assert.Equal(t, createdCtx, reqCtx)
assert.Equal(t, ctxValue, reqCtx.Value(ctxKey))
res, ok := response.(*gen.SimpleServiceCallResult)
if assert.True(t, ok, "response type should be Result struct") {
assert.Equal(t, ret, res.GetSuccess(), "result should be returned value")
Expand Down

0 comments on commit 4006710

Please sign in to comment.