From f3a87486c7dc980034f3487466539424739d80ca Mon Sep 17 00:00:00 2001 From: Bojan Date: Sat, 13 Feb 2021 21:42:27 -0400 Subject: [PATCH] enhance: add preseed and caching for metadata --- runner/data.go | 28 ++++++++++++++-- runner/data_test.go | 80 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+), 2 deletions(-) diff --git a/runner/data.go b/runner/data.go index e521dc22..96374516 100644 --- a/runner/data.go +++ b/runner/data.go @@ -59,6 +59,7 @@ type dataProvider struct { type mdProvider struct { metadata []byte + preseed metadata.MD } func newDataProvider(mtd *desc.MethodDescriptor, @@ -215,11 +216,34 @@ func (dp *dataProvider) getMessages(ctd *CallData, i int, inputData []byte) ([]* return inputs, nil } -func newMetadataProvider(mtd *desc.MethodDescriptor, metadata []byte) (*mdProvider, error) { - return &mdProvider{metadata: metadata}, nil +func newMetadataProvider(mtd *desc.MethodDescriptor, mdData []byte) (*mdProvider, error) { + // Test if we can preseed data + ctd := newCallData(mtd, nil, "", 0) + ha, err := ctd.hasAction(string(mdData)) + if err != nil { + return nil, err + } + + var preseed metadata.MD = nil + if !ha { + mdMap, err := ctd.executeMetadata(string(mdData)) + if err != nil { + return nil, err + } + + if len(mdMap) > 0 { + preseed = metadata.New(mdMap) + } + } + + return &mdProvider{metadata: mdData, preseed: preseed}, nil } func (dp *mdProvider) getMetadataForCall(ctd *CallData) (*metadata.MD, error) { + if dp.preseed != nil { + return &dp.preseed, nil + } + mdMap, err := ctd.executeMetadata(string(dp.metadata)) if err != nil { return nil, err diff --git a/runner/data_test.go b/runner/data_test.go index 7f9e31d8..51783db0 100644 --- a/runner/data_test.go +++ b/runner/data_test.go @@ -265,3 +265,83 @@ func TestData_createPayloads(t *testing.T) { assert.Len(t, inputs, 0) }) } + +func TestMetadata_newMetadataProvider(t *testing.T) { + t.Run("no action", func(t *testing.T) { + mtdUnary, err := protodesc.GetMethodDescFromProto( + "helloworld.Greeter.SayHello", + "../testdata/greeter.proto", + nil) + assert.NoError(t, err) + + mdp, err := newMetadataProvider(mtdUnary, []byte(`{"token":"asdf"}`)) + assert.NoError(t, err) + assert.NotNil(t, mdp) + assert.NotNil(t, mdp.preseed) + assert.Equal(t, []string{"asdf"}, mdp.preseed.Get("token")) + }) + + t.Run("with action", func(t *testing.T) { + mtdUnary, err := protodesc.GetMethodDescFromProto( + "helloworld.Greeter.SayHello", + "../testdata/greeter.proto", + nil) + assert.NoError(t, err) + + mdp, err := newMetadataProvider(mtdUnary, []byte(`{"token":"{{ .RequestNumber }}"}`)) + assert.NoError(t, err) + assert.NotNil(t, mdp) + assert.Nil(t, mdp.preseed) + }) +} + +func TestMetadata_getMetadataForCall(t *testing.T) { + t.Run("no action", func(t *testing.T) { + mtdUnary, err := protodesc.GetMethodDescFromProto( + "helloworld.Greeter.SayHello", + "../testdata/greeter.proto", + nil) + assert.NoError(t, err) + + mdp, err := newMetadataProvider(mtdUnary, []byte(`{"token":"asdf"}`)) + assert.NoError(t, err) + assert.NotNil(t, mdp.preseed) + + cd := newCallData(mtdUnary, nil, "123", 1) + + md, err := mdp.getMetadataForCall(cd) + assert.NoError(t, err) + assert.NotNil(t, md) + assert.Equal(t, []string{"asdf"}, md.Get("token")) + assert.Same(t, &mdp.preseed, md) + }) + + t.Run("with action", func(t *testing.T) { + mtdUnary, err := protodesc.GetMethodDescFromProto( + "helloworld.Greeter.SayHello", + "../testdata/greeter.proto", + nil) + assert.NoError(t, err) + + mdp, err := newMetadataProvider(mtdUnary, []byte(`{"token":"{{ .RequestNumber }}"}`)) + assert.NoError(t, err) + assert.Nil(t, mdp.preseed) + + cd := newCallData(mtdUnary, nil, "123", 1) + + md1, err := mdp.getMetadataForCall(cd) + assert.NoError(t, err) + assert.NotNil(t, md1) + assert.Equal(t, []string{"1"}, md1.Get("token")) + assert.NotSame(t, mdp.preseed, md1) + + cd = newCallData(mtdUnary, nil, "123", 2) + md2, err := mdp.getMetadataForCall(cd) + assert.NoError(t, err) + assert.NotNil(t, md2) + assert.Equal(t, []string{"2"}, md2.Get("token")) + assert.NotSame(t, mdp.preseed, md2) + assert.NotSame(t, md1, md2) + assert.NotEqual(t, md1, md2) + }) +}