From 0ccf90df0d3529ad65ee100580d0ad883eb1b22b Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Mon, 30 Oct 2023 12:51:19 -0400 Subject: [PATCH] metadata: Use strings.EqualFold for ValueFromIncomingContext (#6743) --- metadata/metadata.go | 18 +++++++----- metadata/metadata_test.go | 61 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 67 insertions(+), 12 deletions(-) diff --git a/metadata/metadata.go b/metadata/metadata.go index a2cdcaf12a87..49446825763b 100644 --- a/metadata/metadata.go +++ b/metadata/metadata.go @@ -153,14 +153,16 @@ func Join(mds ...MD) MD { type mdIncomingKey struct{} type mdOutgoingKey struct{} -// NewIncomingContext creates a new context with incoming md attached. +// NewIncomingContext creates a new context with incoming md attached. md must +// not be modified after calling this function. func NewIncomingContext(ctx context.Context, md MD) context.Context { return context.WithValue(ctx, mdIncomingKey{}, md) } // NewOutgoingContext creates a new context with outgoing md attached. If used // in conjunction with AppendToOutgoingContext, NewOutgoingContext will -// overwrite any previously-appended metadata. +// overwrite any previously-appended metadata. md must not be modified after +// calling this function. func NewOutgoingContext(ctx context.Context, md MD) context.Context { return context.WithValue(ctx, mdOutgoingKey{}, rawMD{md: md}) } @@ -203,7 +205,8 @@ func FromIncomingContext(ctx context.Context) (MD, bool) { } // ValueFromIncomingContext returns the metadata value corresponding to the metadata -// key from the incoming metadata if it exists. Key must be lower-case. +// key from the incoming metadata if it exists. Keys are matched in a case insensitive +// manner. // // # Experimental // @@ -219,17 +222,16 @@ func ValueFromIncomingContext(ctx context.Context, key string) []string { return copyOf(v) } for k, v := range md { - // We need to manually convert all keys to lower case, because MD is a - // map, and there's no guarantee that the MD attached to the context is - // created using our helper functions. - if strings.ToLower(k) == key { + // Case insenitive comparison: MD is a map, and there's no guarantee + // that the MD attached to the context is created using our helper + // functions. + if strings.EqualFold(k, key) { return copyOf(v) } } return nil } -// the returned slice must not be modified in place func copyOf(v []string) []string { vals := make([]string, len(v)) copy(vals, v) diff --git a/metadata/metadata_test.go b/metadata/metadata_test.go index 9277f2d6c84f..fbee086fb919 100644 --- a/metadata/metadata_test.go +++ b/metadata/metadata_test.go @@ -198,6 +198,39 @@ func (s) TestDelete(t *testing.T) { } } +func (s) TestFromIncomingContext(t *testing.T) { + md := Pairs( + "X-My-Header-1", "42", + ) + // Verify that we lowercase if callers directly modify md + md["X-INCORRECT-UPPERCASE"] = []string{"foo"} + ctx := NewIncomingContext(context.Background(), md) + + result, found := FromIncomingContext(ctx) + if !found { + t.Fatal("FromIncomingContext must return metadata") + } + expected := MD{ + "x-my-header-1": []string{"42"}, + "x-incorrect-uppercase": []string{"foo"}, + } + if !reflect.DeepEqual(result, expected) { + t.Errorf("FromIncomingContext returned %#v, expected %#v", result, expected) + } + + // ensure modifying result does not modify the value in the context + result["new_key"] = []string{"foo"} + result["x-my-header-1"][0] = "mutated" + + result2, found := FromIncomingContext(ctx) + if !found { + t.Fatal("FromIncomingContext must return metadata") + } + if !reflect.DeepEqual(result2, expected) { + t.Errorf("FromIncomingContext after modifications returned %#v, expected %#v", result2, expected) + } +} + func (s) TestValueFromIncomingContext(t *testing.T) { md := Pairs( "X-My-Header-1", "42", @@ -205,6 +238,8 @@ func (s) TestValueFromIncomingContext(t *testing.T) { "X-My-Header-2", "43-2", "x-my-header-3", "44", ) + // Verify that we lowercase if callers directly modify md + md["X-INCORRECT-UPPERCASE"] = []string{"foo"} ctx := NewIncomingContext(context.Background(), md) for _, test := range []struct { @@ -227,6 +262,10 @@ func (s) TestValueFromIncomingContext(t *testing.T) { key: "x-unknown", want: nil, }, + { + key: "x-incorrect-uppercase", + want: []string{"foo"}, + }, } { v := ValueFromIncomingContext(ctx, test.key) if !reflect.DeepEqual(v, test.want) { @@ -348,8 +387,22 @@ func BenchmarkFromIncomingContext(b *testing.B) { func BenchmarkValueFromIncomingContext(b *testing.B) { md := Pairs("X-My-Header-1", "42") ctx := NewIncomingContext(context.Background(), md) - b.ResetTimer() - for n := 0; n < b.N; n++ { - ValueFromIncomingContext(ctx, "x-my-header-1") - } + + b.Run("key-found", func(b *testing.B) { + for n := 0; n < b.N; n++ { + result := ValueFromIncomingContext(ctx, "x-my-header-1") + if len(result) != 1 { + b.Fatal("ensures not optimized away") + } + } + }) + + b.Run("key-not-found", func(b *testing.B) { + for n := 0; n < b.N; n++ { + result := ValueFromIncomingContext(ctx, "key-not-found") + if len(result) != 0 { + b.Fatal("ensures not optimized away") + } + } + }) }