diff --git a/v2/callctx/callctx.go b/v2/callctx/callctx.go index 9aab3d91..f5af5c99 100644 --- a/v2/callctx/callctx.go +++ b/v2/callctx/callctx.go @@ -74,9 +74,27 @@ func SetHeaders(ctx context.Context, keyvals ...string) context.Context { h, ok := ctx.Value(headerKey).(map[string][]string) if !ok { h = make(map[string][]string) + } else { + h = cloneHeaders(h) } + for i := 0; i < len(keyvals); i = i + 2 { h[keyvals[i]] = append(h[keyvals[i]], keyvals[i+1]) } return context.WithValue(ctx, headerKey, h) } + +// cloneHeaders makes a new key-value map while reusing the value slices. +// As such, new values should be appended to the value slice, and modifying +// indexed values is not thread safe. +// +// TODO: Replace this with maps.Clone when Go 1.21 is the minimum version. +func cloneHeaders(h map[string][]string) map[string][]string { + c := make(map[string][]string, len(h)) + for k, v := range h { + vc := make([]string, len(v)) + copy(vc, v) + c[k] = vc + } + return c +} diff --git a/v2/callctx/callctx_test.go b/v2/callctx/callctx_test.go index 46d91b0e..e644d55d 100644 --- a/v2/callctx/callctx_test.go +++ b/v2/callctx/callctx_test.go @@ -31,6 +31,7 @@ package callctx import ( "context" + "sync" "testing" "github.com/google/go-cmp/cmp" @@ -77,3 +78,45 @@ func TestSetHeaders_panics(t *testing.T) { ctx := context.Background() SetHeaders(ctx, "1", "2", "3") } + +func TestSetHeaders_reuse(t *testing.T) { + c := SetHeaders(context.Background(), "key", "value1") + v1 := HeadersFromContext(c) + c = SetHeaders(c, "key", "value2") + v2 := HeadersFromContext(c) + + if cmp.Diff(v2, v1) == "" { + t.Errorf("Second header set did not differ from first header set as expected") + } +} + +func TestSetHeaders_race(t *testing.T) { + key := "key" + value := "value" + want := map[string][]string{ + key: []string{value, value}, + } + + // Init the ctx so a value already exists to be "shared". + cctx := SetHeaders(context.Background(), key, value) + + // Reusing the same cctx and adding to the same header key + // should *not* produce a race condition when run with -race. + var wg sync.WaitGroup + for i := 0; i < 3; i++ { + wg.Add(1) + go func(ctx context.Context) { + defer wg.Done() + c := SetHeaders(ctx, key, value) + h := HeadersFromContext(c) + + // Additionally, if there was a race condition, + // we may see that one instance of these headers + // contains extra values. + if diff := cmp.Diff(h, want); diff != "" { + t.Errorf("got(-),want(+):\n%s", diff) + } + }(cctx) + } + wg.Wait() +}