diff --git a/v2/pkg/astjson/astjson.go b/v2/pkg/astjson/astjson.go index b06b304ec..f3f451da0 100644 --- a/v2/pkg/astjson/astjson.go +++ b/v2/pkg/astjson/astjson.go @@ -50,6 +50,10 @@ type JSON struct { _intSlicePos int } +func (j *JSON) Size() int { + return len(j.storage) +} + func (j *JSON) Get(nodeRef int, path []string) int { if len(path) == 0 { return nodeRef diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 1d44f650f..ccc5c0db7 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "io" "net/http" "regexp" "slices" @@ -1697,14 +1696,14 @@ func (s *Source) replaceEmptyObject(variables []byte) ([]byte, bool) { return variables, false } -func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, writer io.Writer) (err error) { +func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, out *bytes.Buffer) (err error) { input = s.compactAndUnNullVariables(input) - return httpclient.DoMultipartForm(s.httpClient, ctx, input, files, writer) + return httpclient.DoMultipartForm(s.httpClient, ctx, input, files, out) } -func (s *Source) Load(ctx context.Context, input []byte, writer io.Writer) (err error) { +func (s *Source) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { input = s.compactAndUnNullVariables(input) - return httpclient.Do(s.httpClient, ctx, input, writer) + return httpclient.Do(s.httpClient, ctx, input, out) } type GraphQLSubscriptionClient interface { diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index 14c911d50..7d5648083 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -132,7 +132,7 @@ func releaseBuffer(buf *bytes.Buffer) { requestBufferPool.Put(buf) } -func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, out io.Writer, contentType string) (err error) { +func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, headers, queryParams []byte, body io.Reader, enableTrace bool, out *bytes.Buffer, contentType string) (err error) { request, err := http.NewRequestWithContext(ctx, string(method), string(url), body) if err != nil { return err @@ -204,18 +204,14 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head return err } - buf := getBuffer() - defer releaseBuffer(buf) - if !enableTrace { - _, err = buf.ReadFrom(respReader) - if err != nil { - return err - } - _, err = buf.WriteTo(out) + _, err = out.ReadFrom(respReader) return } + buf := getBuffer() + defer releaseBuffer(buf) + _, err = buf.ReadFrom(respReader) if err != nil { return err @@ -245,14 +241,14 @@ func makeHTTPRequest(client *http.Client, ctx context.Context, url, method, head return err } -func Do(client *http.Client, ctx context.Context, requestInput []byte, out io.Writer) (err error) { +func Do(client *http.Client, ctx context.Context, requestInput []byte, out *bytes.Buffer) (err error) { url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput) return makeHTTPRequest(client, ctx, url, method, headers, queryParams, bytes.NewReader(body), enableTrace, out, ContentTypeJSON) } func DoMultipartForm( - client *http.Client, ctx context.Context, requestInput []byte, files []File, out io.Writer, + client *http.Client, ctx context.Context, requestInput []byte, files []File, out *bytes.Buffer, ) (err error) { if len(files) == 0 { return errors.New("no files provided") diff --git a/v2/pkg/engine/datasource/introspection_datasource/source.go b/v2/pkg/engine/datasource/introspection_datasource/source.go index 3f6b20608..2149c3ba8 100644 --- a/v2/pkg/engine/datasource/introspection_datasource/source.go +++ b/v2/pkg/engine/datasource/introspection_datasource/source.go @@ -1,11 +1,13 @@ package introspection_datasource import ( + "bytes" "context" "encoding/json" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "io" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" + "github.com/wundergraph/graphql-go-tools/v2/pkg/introspection" ) @@ -17,7 +19,7 @@ type Source struct { introspectionData *introspection.Data } -func (s *Source) Load(ctx context.Context, input []byte, w io.Writer) (err error) { +func (s *Source) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { var req introspectionInput if err := json.Unmarshal(input, &req); err != nil { return err @@ -25,17 +27,17 @@ func (s *Source) Load(ctx context.Context, input []byte, w io.Writer) (err error switch req.RequestType { case TypeRequestType: - return s.singleType(w, req.TypeName) + return s.singleType(out, req.TypeName) case TypeEnumValuesRequestType: - return s.enumValuesForType(w, req.OnTypeName, req.IncludeDeprecated) + return s.enumValuesForType(out, req.OnTypeName, req.IncludeDeprecated) case TypeFieldsRequestType: - return s.fieldsForType(w, req.OnTypeName, req.IncludeDeprecated) + return s.fieldsForType(out, req.OnTypeName, req.IncludeDeprecated) } - return json.NewEncoder(w).Encode(s.schemaWithoutTypeInfo()) + return json.NewEncoder(out).Encode(s.schemaWithoutTypeInfo()) } -func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) { +func (s *Source) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, out *bytes.Buffer) (err error) { panic("not implemented") } diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go index 9f4c4fe82..1020f727f 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_kafka.go @@ -1,13 +1,15 @@ package pubsub_datasource import ( + "bytes" "context" "encoding/json" + "io" + "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" - "io" ) type KafkaEventConfiguration struct { @@ -65,7 +67,7 @@ type KafkaPublishDataSource struct { pubSub KafkaPubSub } -func (s *KafkaPublishDataSource) Load(ctx context.Context, input []byte, w io.Writer) error { +func (s *KafkaPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { var publishConfiguration KafkaPublishEventConfiguration err := json.Unmarshal(input, &publishConfiguration) if err != nil { @@ -73,13 +75,13 @@ func (s *KafkaPublishDataSource) Load(ctx context.Context, input []byte, w io.Wr } if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(w, `{"success": false}`) + _, err = io.WriteString(out, `{"success": false}`) return err } - _, err = io.WriteString(w, `{"success": true}`) + _, err = io.WriteString(out, `{"success": true}`) return err } -func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) { +func (s *KafkaPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, out *bytes.Buffer) (err error) { panic("not implemented") } diff --git a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go index 1088cd810..b6b5e5049 100644 --- a/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go +++ b/v2/pkg/engine/datasource/pubsub_datasource/pubsub_nats.go @@ -1,13 +1,15 @@ package pubsub_datasource import ( + "bytes" "context" "encoding/json" + "io" + "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" - "io" ) type NatsStreamConfiguration struct { @@ -73,7 +75,7 @@ type NatsPublishDataSource struct { pubSub NatsPubSub } -func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, w io.Writer) error { +func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { var publishConfiguration NatsPublishAndRequestEventConfiguration err := json.Unmarshal(input, &publishConfiguration) if err != nil { @@ -81,15 +83,15 @@ func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, w io.Wri } if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(w, `{"success": false}`) + _, err = io.WriteString(out, `{"success": false}`) return err } - _, err = io.WriteString(w, `{"success": true}`) + _, err = io.WriteString(out, `{"success": true}`) return err } -func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) error { +func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, out *bytes.Buffer) error { panic("not implemented") } @@ -97,16 +99,16 @@ type NatsRequestDataSource struct { pubSub NatsPubSub } -func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, w io.Writer) error { +func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { var subscriptionConfiguration NatsPublishAndRequestEventConfiguration err := json.Unmarshal(input, &subscriptionConfiguration) if err != nil { return err } - return s.pubSub.Request(ctx, subscriptionConfiguration, w) + return s.pubSub.Request(ctx, subscriptionConfiguration, out) } -func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) error { +func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, out *bytes.Buffer) error { panic("not implemented") } diff --git a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go index 0e4a53112..86325713c 100644 --- a/v2/pkg/engine/datasource/staticdatasource/static_datasource.go +++ b/v2/pkg/engine/datasource/staticdatasource/static_datasource.go @@ -1,8 +1,8 @@ package staticdatasource import ( + "bytes" "context" - "io" "github.com/jensneuse/abstractlogger" @@ -66,11 +66,11 @@ func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration { type Source struct{} -func (Source) Load(ctx context.Context, input []byte, w io.Writer) (err error) { - _, err = w.Write(input) +func (Source) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { + _, err = out.Write(input) return } -func (Source) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) { +func (Source) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, out *bytes.Buffer) (err error) { panic("not implemented") } diff --git a/v2/pkg/engine/plan/schemausageinfo_test.go b/v2/pkg/engine/plan/schemausageinfo_test.go index 4e8f260bf..462e468b9 100644 --- a/v2/pkg/engine/plan/schemausageinfo_test.go +++ b/v2/pkg/engine/plan/schemausageinfo_test.go @@ -3,7 +3,6 @@ package plan import ( "bytes" "context" - "io" "testing" "github.com/jensneuse/abstractlogger" @@ -487,10 +486,10 @@ type FakeDataSource struct { source *StatefulSource } -func (f *FakeDataSource) Load(ctx context.Context, input []byte, w io.Writer) (err error) { +func (f *FakeDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { return } -func (f *FakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) { +func (f *FakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, out *bytes.Buffer) (err error) { return } diff --git a/v2/pkg/engine/resolve/datasource.go b/v2/pkg/engine/resolve/datasource.go index 64a36cec0..bba3295b1 100644 --- a/v2/pkg/engine/resolve/datasource.go +++ b/v2/pkg/engine/resolve/datasource.go @@ -1,16 +1,16 @@ package resolve import ( + "bytes" "context" - "io" "github.com/cespare/xxhash/v2" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" ) type DataSource interface { - Load(ctx context.Context, input []byte, w io.Writer) (err error) - LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) + Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) + LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, out *bytes.Buffer) (err error) } type SubscriptionDataSource interface { diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 58a8f2cfd..6d5a31038 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -11,17 +11,19 @@ import ( "net/http/httptrace" "slices" "strings" + "sync" "time" "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/pkg/errors" "github.com/tidwall/gjson" + "go.uber.org/atomic" "golang.org/x/sync/errgroup" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" - "github.com/wundergraph/graphql-go-tools/v2/pkg/pool" ) const ( @@ -41,6 +43,25 @@ func IsIntrospectionDataSource(dataSourceID string) bool { return dataSourceID == IntrospectionSchemaTypeDataSourceID || dataSourceID == IntrospectionTypeFieldsDataSourceID || dataSourceID == IntrospectionTypeEnumValuesDataSourceID } +var ( + loaderBufPool = sync.Pool{} + loaderBufSize = atomic.NewInt32(128) +) + +func acquireLoaderBuf() *bytes.Buffer { + v := loaderBufPool.Get() + if v == nil { + return bytes.NewBuffer(make([]byte, 0, loaderBufSize.Load())) + } + return v.(*bytes.Buffer) +} + +func releaseLoaderBuf(buf *bytes.Buffer) { + loaderBufSize.Store(int32(buf.Cap())) + buf.Reset() + loaderBufPool.Put(buf) +} + type Loader struct { data *astjson.JSON dataRoot int @@ -210,9 +231,8 @@ func (l *Loader) resolveAndMergeFetch(fetch Fetch, items []int) error { switch f := fetch.(type) { case *SingleFetch: res := &result{ - out: pool.BytesBuffer.Get(), + out: acquireLoaderBuf(), } - err := l.loadSingleFetch(l.ctx.ctx, f, items, res) if err != nil { return err @@ -288,7 +308,7 @@ func (l *Loader) resolveAndMergeFetch(fetch Fetch, items []int) error { for i := range items { i := i results[i] = &result{ - out: pool.BytesBuffer.Get(), + out: acquireLoaderBuf(), } g.Go(func() error { return l.loadFetch(ctx, f.Fetch, items[i:i+1], results[i]) @@ -309,7 +329,7 @@ func (l *Loader) resolveAndMergeFetch(fetch Fetch, items []int) error { } case *EntityFetch: res := &result{ - out: pool.BytesBuffer.Get(), + out: acquireLoaderBuf(), } err := l.loadEntityFetch(l.ctx.ctx, f, items, res) if err != nil { @@ -322,7 +342,7 @@ func (l *Loader) resolveAndMergeFetch(fetch Fetch, items []int) error { return err case *BatchEntityFetch: res := &result{ - out: pool.BytesBuffer.Get(), + out: acquireLoaderBuf(), } err := l.loadBatchEntityFetch(l.ctx.ctx, f, items, res) if err != nil { @@ -340,7 +360,7 @@ func (l *Loader) resolveAndMergeFetch(fetch Fetch, items []int) error { func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, items []int, res *result) error { switch f := fetch.(type) { case *SingleFetch: - res.out = pool.BytesBuffer.Get() + res.out = acquireLoaderBuf() return l.loadSingleFetch(ctx, f, items, res) case *SerialFetch: return fmt.Errorf("serial fetch must not be nested") @@ -355,7 +375,7 @@ func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, items []int, res *r for i := range items { i := i results[i] = &result{ - out: pool.BytesBuffer.Get(), + out: acquireLoaderBuf(), } if l.ctx.TracingOptions.Enable { f.Traces[i] = new(SingleFetch) @@ -376,17 +396,17 @@ func (l *Loader) loadFetch(ctx context.Context, fetch Fetch, items []int, res *r res.nestedMergeItems = results return nil case *EntityFetch: - res.out = pool.BytesBuffer.Get() + res.out = acquireLoaderBuf() return l.loadEntityFetch(ctx, f, items, res) case *BatchEntityFetch: - res.out = pool.BytesBuffer.Get() + res.out = acquireLoaderBuf() return l.loadBatchEntityFetch(ctx, f, items, res) } return nil } func (l *Loader) mergeResult(res *result, items []int) error { - defer pool.BytesBuffer.Put(res.out) + defer releaseLoaderBuf(res.out) if res.err != nil { return l.renderErrorsFailedToFetch(res, failedToFetchNoReason) } @@ -467,8 +487,8 @@ func (l *Loader) mergeResult(res *result, items []int) error { withPostProcessing := res.postProcessing.ResponseTemplate != nil if withPostProcessing && len(items) <= 1 { - postProcessed := pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(postProcessed) + postProcessed := acquireLoaderBuf() + defer releaseLoaderBuf(postProcessed) res.out.Reset() err = l.data.PrintNode(l.data.Nodes[node], res.out) if err != nil { @@ -497,10 +517,10 @@ func (l *Loader) mergeResult(res *result, items []int) error { rendered *bytes.Buffer ) if withPostProcessing { - postProcessed = pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(postProcessed) - rendered = pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(rendered) + postProcessed = acquireLoaderBuf() + defer releaseLoaderBuf(postProcessed) + rendered = acquireLoaderBuf() + defer releaseLoaderBuf(rendered) for i, stats := range res.batchStats { postProcessed.Reset() rendered.Reset() @@ -608,8 +628,8 @@ func (l *Loader) mergeErrors(res *result, ref int) error { path := l.renderPath() - responseErrorsBuf := pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(responseErrorsBuf) + responseErrorsBuf := acquireLoaderBuf() + defer releaseLoaderBuf(responseErrorsBuf) // print them into the buffer to be able to parse them err := l.data.PrintNode(l.data.Nodes[ref], responseErrorsBuf) @@ -946,29 +966,57 @@ func (l *Loader) validatePreFetch(input []byte, info *FetchInfo, res *result) (a return l.rateLimitFetch(input, info, res) } +var ( + singleFetchPool = sync.Pool{} + singleFetchInputSize = atomic.NewInt32(32) + singleFetchPreparedInputSize = atomic.NewInt32(32) +) + +type singleFetchBuffer struct { + input *bytes.Buffer + preparedInput *bytes.Buffer +} + +func acquireSingleFetchBuffer() *singleFetchBuffer { + buf := singleFetchPool.Get() + if buf == nil { + return &singleFetchBuffer{ + input: bytes.NewBuffer(make([]byte, 0, int(singleFetchInputSize.Load()))), + preparedInput: bytes.NewBuffer(make([]byte, 0, int(singleFetchPreparedInputSize.Load()))), + } + } + return buf.(*singleFetchBuffer) +} + +func releaseSingleFetchBuffer(buf *singleFetchBuffer) { + singleFetchInputSize.Store(int32(buf.input.Cap())) + singleFetchPreparedInputSize.Store(int32(buf.preparedInput.Cap())) + buf.input.Reset() + buf.preparedInput.Reset() + singleFetchPool.Put(buf) +} + func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, items []int, res *result) error { res.init(fetch.PostProcessing, fetch.Info) - input := pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(input) - preparedInput := pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(preparedInput) - err := l.itemsData(items, input) + buf := acquireSingleFetchBuffer() + defer releaseSingleFetchBuffer(buf) + err := l.itemsData(items, buf.input) if err != nil { return errors.WithStack(err) } if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData { - inputCopy := make([]byte, input.Len()) - copy(inputCopy, input.Bytes()) + inputCopy := make([]byte, buf.input.Len()) + copy(inputCopy, buf.input.Bytes()) fetch.Trace.RawInputData = inputCopy } } - err = fetch.InputTemplate.Render(l.ctx, input.Bytes(), preparedInput) + err = fetch.InputTemplate.Render(l.ctx, buf.input.Bytes(), buf.preparedInput) if err != nil { return l.renderErrorsInvalidInput(res.out) } - fetchInput := preparedInput.Bytes() + fetchInput := buf.preparedInput.Bytes() allowed, err := l.validatePreFetch(fetchInput, fetch.Info, res) if err != nil { return err @@ -980,15 +1028,46 @@ func (l *Loader) loadSingleFetch(ctx context.Context, fetch *SingleFetch, items return nil } +var ( + entityFetchPool = sync.Pool{} + entityFetchItemDataSize = atomic.NewInt32(32) + entityFetchPreparedInputSize = atomic.NewInt32(32) + entityFetchItemSize = atomic.NewInt32(32) +) + +type entityFetchBuffer struct { + itemData *bytes.Buffer + preparedInput *bytes.Buffer + item *bytes.Buffer +} + +func acquireEntityFetchBuffer() *entityFetchBuffer { + buf := entityFetchPool.Get() + if buf == nil { + return &entityFetchBuffer{ + itemData: bytes.NewBuffer(make([]byte, 0, int(entityFetchItemDataSize.Load()))), + preparedInput: bytes.NewBuffer(make([]byte, 0, int(entityFetchPreparedInputSize.Load()))), + item: bytes.NewBuffer(make([]byte, 0, int(entityFetchItemSize.Load()))), + } + } + return buf.(*entityFetchBuffer) +} + +func releaseEntityFetchBuffer(buf *entityFetchBuffer) { + entityFetchItemDataSize.Store(int32(buf.itemData.Cap())) + entityFetchPreparedInputSize.Store(int32(buf.preparedInput.Cap())) + entityFetchItemSize.Store(int32(buf.item.Cap())) + buf.itemData.Reset() + buf.preparedInput.Reset() + buf.item.Reset() + entityFetchPool.Put(buf) +} + func (l *Loader) loadEntityFetch(ctx context.Context, fetch *EntityFetch, items []int, res *result) error { res.init(fetch.PostProcessing, fetch.Info) - itemData := pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(itemData) - preparedInput := pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(preparedInput) - item := pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(item) - err := l.itemsData(items, itemData) + buf := acquireEntityFetchBuffer() + defer releaseEntityFetchBuffer(buf) + err := l.itemsData(items, buf.itemData) if err != nil { return errors.WithStack(err) } @@ -996,20 +1075,20 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetch *EntityFetch, items if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData { - itemDataCopy := make([]byte, itemData.Len()) - copy(itemDataCopy, itemData.Bytes()) + itemDataCopy := make([]byte, buf.itemData.Len()) + copy(itemDataCopy, buf.itemData.Bytes()) fetch.Trace.RawInputData = itemDataCopy } } var undefinedVariables []string - err = fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) + err = fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - err = fetch.Input.Item.Render(l.ctx, itemData.Bytes(), item) + err = fetch.Input.Item.Render(l.ctx, buf.itemData.Bytes(), buf.item) if err != nil { if fetch.Input.SkipErrItem { err = nil // nolint:ineffassign @@ -1021,7 +1100,7 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetch *EntityFetch, items } return errors.WithStack(err) } - renderedItem := item.Bytes() + renderedItem := buf.item.Bytes() if bytes.Equal(renderedItem, null) { // skip fetch if item is null res.fetchSkipped = true @@ -1040,17 +1119,17 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetch *EntityFetch, items return nil } } - _, _ = item.WriteTo(preparedInput) - err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) + _, _ = buf.item.WriteTo(buf.preparedInput) + err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - err = SetInputUndefinedVariables(preparedInput, undefinedVariables) + err = SetInputUndefinedVariables(buf.preparedInput, undefinedVariables) if err != nil { return errors.WithStack(err) } - fetchInput := preparedInput.Bytes() + fetchInput := buf.preparedInput.Bytes() if l.ctx.TracingOptions.Enable && res.fetchSkipped { l.setTracingInput(fetchInput, fetch.Trace) @@ -1068,9 +1147,50 @@ func (l *Loader) loadEntityFetch(ctx context.Context, fetch *EntityFetch, items return nil } +var ( + batchEntityFetchPool = sync.Pool{} + batchEntityPreparedInputSize = atomic.NewInt32(32) + batchEntityItemDataSize = atomic.NewInt32(32) + batchEntityItemInputSize = atomic.NewInt32(32) +) + +type batchEntityFetchBuffer struct { + preparedInput *bytes.Buffer + itemData *bytes.Buffer + itemInput *bytes.Buffer + keyGen *xxhash.Digest +} + +func acquireBatchEntityFetchBuffer() *batchEntityFetchBuffer { + buf := batchEntityFetchPool.Get() + if buf == nil { + return &batchEntityFetchBuffer{ + preparedInput: bytes.NewBuffer(make([]byte, 0, int(batchEntityPreparedInputSize.Load()))), + itemData: bytes.NewBuffer(make([]byte, 0, int(batchEntityItemDataSize.Load()))), + itemInput: bytes.NewBuffer(make([]byte, 0, int(batchEntityItemInputSize.Load()))), + keyGen: xxhash.New(), + } + } + return buf.(*batchEntityFetchBuffer) +} + +func releaseBatchEntityFetchBuffer(buf *batchEntityFetchBuffer) { + batchEntityPreparedInputSize.Store(int32(buf.preparedInput.Cap())) + batchEntityItemDataSize.Store(int32(buf.itemData.Cap())) + batchEntityItemInputSize.Store(int32(buf.itemInput.Cap())) + buf.preparedInput.Reset() + buf.itemData.Reset() + buf.itemInput.Reset() + buf.keyGen.Reset() + batchEntityFetchPool.Put(buf) +} + func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetch *BatchEntityFetch, items []int, res *result) error { res.init(fetch.PostProcessing, fetch.Info) + buf := acquireBatchEntityFetchBuffer() + defer releaseBatchEntityFetchBuffer(buf) + if l.ctx.TracingOptions.Enable { fetch.Trace = &DataSourceLoadTrace{} if !l.ctx.TracingOptions.ExcludeRawInputData { @@ -1083,12 +1203,9 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetch *BatchEntityFet } } - preparedInput := pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(preparedInput) - var undefinedVariables []string - err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) + err := fetch.Input.Header.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } @@ -1097,25 +1214,16 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetch *BatchEntityFet batchItemIndex := 0 addSeparator := false - keyGen := pool.Hash64.Get() - defer pool.Hash64.Put(keyGen) - - itemData := pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(itemData) - - itemInput := pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(itemInput) - WithNextItem: for i, item := range items { - itemData.Reset() - err = l.data.PrintNode(l.data.Nodes[item], itemData) + buf.itemData.Reset() + err = l.data.PrintNode(l.data.Nodes[item], buf.itemData) if err != nil { return errors.WithStack(err) } for j := range fetch.Input.Items { - itemInput.Reset() - err = fetch.Input.Items[j].Render(l.ctx, itemData.Bytes(), itemInput) + buf.itemInput.Reset() + err = fetch.Input.Items[j].Render(l.ctx, buf.itemData.Bytes(), buf.itemInput) if err != nil { if fetch.Input.SkipErrItems { err = nil // nolint:ineffassign @@ -1127,18 +1235,18 @@ WithNextItem: } return errors.WithStack(err) } - if fetch.Input.SkipNullItems && itemInput.Len() == 4 && bytes.Equal(itemInput.Bytes(), null) { + if fetch.Input.SkipNullItems && buf.itemInput.Len() == 4 && bytes.Equal(buf.itemInput.Bytes(), null) { res.batchStats[i] = append(res.batchStats[i], -1) continue } - if fetch.Input.SkipEmptyObjectItems && itemInput.Len() == 2 && bytes.Equal(itemInput.Bytes(), emptyObject) { + if fetch.Input.SkipEmptyObjectItems && buf.itemInput.Len() == 2 && bytes.Equal(buf.itemInput.Bytes(), emptyObject) { res.batchStats[i] = append(res.batchStats[i], -1) continue } - keyGen.Reset() - _, _ = keyGen.Write(itemInput.Bytes()) - itemHash := keyGen.Sum64() + buf.keyGen.Reset() + _, _ = buf.keyGen.Write(buf.itemInput.Bytes()) + itemHash := buf.keyGen.Sum64() for k := range itemHashes { if itemHashes[k] == itemHash { res.batchStats[i] = append(res.batchStats[i], k) @@ -1147,12 +1255,12 @@ WithNextItem: } itemHashes = append(itemHashes, itemHash) if addSeparator { - err = fetch.Input.Separator.Render(l.ctx, nil, preparedInput) + err = fetch.Input.Separator.Render(l.ctx, nil, buf.preparedInput) if err != nil { return errors.WithStack(err) } } - _, _ = itemInput.WriteTo(preparedInput) + _, _ = buf.itemInput.WriteTo(buf.preparedInput) res.batchStats[i] = append(res.batchStats[i], batchItemIndex) batchItemIndex++ addSeparator = true @@ -1169,16 +1277,16 @@ WithNextItem: } } - err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, preparedInput, &undefinedVariables) + err = fetch.Input.Footer.RenderAndCollectUndefinedVariables(l.ctx, nil, buf.preparedInput, &undefinedVariables) if err != nil { return errors.WithStack(err) } - err = SetInputUndefinedVariables(preparedInput, undefinedVariables) + err = SetInputUndefinedVariables(buf.preparedInput, undefinedVariables) if err != nil { return errors.WithStack(err) } - fetchInput := preparedInput.Bytes() + fetchInput := buf.preparedInput.Bytes() if l.ctx.TracingOptions.Enable && res.fetchSkipped { l.setTracingInput(fetchInput, fetch.Trace) @@ -1275,7 +1383,6 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, input []b if l.ctx.Files != nil { return source.LoadWithFiles(ctx, input, l.ctx.Files, res.out) } - return source.Load(ctx, input, res.out) } diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 8926f535d..34adaa283 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -34,15 +34,14 @@ type Reporter interface { } type AsyncErrorWriter interface { - WriteError(ctx *Context, err error, res *GraphQLResponse, w io.Writer, buf *bytes.Buffer) + WriteError(ctx *Context, err error, res *GraphQLResponse, w io.Writer) } type Resolver struct { - ctx context.Context - options ResolverOptions - toolPool sync.Pool - limitMaxConcurrency bool - maxConcurrency chan struct{} + ctx context.Context + options ResolverOptions + bufPool sync.Pool + maxConcurrency chan struct{} triggers map[uint64]*trigger events chan subscriptionEvent @@ -56,6 +55,8 @@ type Resolver struct { propagateSubgraphErrors bool propagateSubgraphStatusCodes bool + + tools *sync.Pool } func (r *Resolver) SetAsyncErrorWriter(w AsyncErrorWriter) { @@ -110,13 +111,21 @@ type ResolverOptions struct { // New returns a new Resolver, ctx.Done() is used to cancel all active subscriptions & streams func New(ctx context.Context, options ResolverOptions) *Resolver { // options.Debug = true + if options.MaxConcurrency <= 0 { + options.MaxConcurrency = 32 + } resolver := &Resolver{ ctx: ctx, options: options, propagateSubgraphErrors: options.PropagateSubgraphErrors, propagateSubgraphStatusCodes: options.PropagateSubgraphStatusCodes, - toolPool: sync.Pool{ - New: func() interface{} { + events: make(chan subscriptionEvent), + triggers: make(map[uint64]*trigger), + reporter: options.Reporter, + asyncErrorWriter: options.AsyncErrorWriter, + triggerUpdateBuf: bytes.NewBuffer(make([]byte, 0, 1024)), + tools: &sync.Pool{ + New: func() any { return &tools{ resolvable: NewResolvable(), loader: &Loader{ @@ -130,19 +139,10 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { } }, }, - events: make(chan subscriptionEvent), - triggers: make(map[uint64]*trigger), - reporter: options.Reporter, - asyncErrorWriter: options.AsyncErrorWriter, - triggerUpdateBuf: bytes.NewBuffer(make([]byte, 0, 1024)), - } - if options.MaxConcurrency > 0 { - semaphore := make(chan struct{}, options.MaxConcurrency) - for i := 0; i < options.MaxConcurrency; i++ { - semaphore <- struct{}{} - } - resolver.limitMaxConcurrency = true - resolver.maxConcurrency = semaphore + } + resolver.maxConcurrency = make(chan struct{}, options.MaxConcurrency) + for i := 0; i < options.MaxConcurrency; i++ { + resolver.maxConcurrency <- struct{}{} } if options.MaxSubscriptionWorkers == 0 { options.MaxSubscriptionWorkers = 1024 @@ -160,20 +160,28 @@ func New(ctx context.Context, options ResolverOptions) *Resolver { } func (r *Resolver) getTools() *tools { - if r.limitMaxConcurrency { - <-r.maxConcurrency - } - t := r.toolPool.Get().(*tools) - return t + <-r.maxConcurrency + return r.tools.Get().(*tools) } func (r *Resolver) putTools(t *tools) { t.loader.Free() t.resolvable.Reset() - r.toolPool.Put(t) - if r.limitMaxConcurrency { - r.maxConcurrency <- struct{}{} + r.tools.Put(t) + r.maxConcurrency <- struct{}{} +} + +func (r *Resolver) getBuffer(preferredSize int) *bytes.Buffer { + maybeBuffer := r.bufPool.Get() + if maybeBuffer == nil { + return bytes.NewBuffer(make([]byte, 0, preferredSize)) } + return maybeBuffer.(*bytes.Buffer) +} + +func (r *Resolver) releaseBuffer(buf *bytes.Buffer) { + buf.Reset() + r.bufPool.Put(buf) } func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLResponse, data []byte, writer io.Writer) (err error) { @@ -184,15 +192,16 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons } t := r.getTools() - defer r.putTools(t) err = t.resolvable.Init(ctx, data, response.Info.OperationType) if err != nil { + r.putTools(t) return err } err = t.loader.LoadGraphQLResponseData(ctx, response, t.resolvable) if err != nil { + r.putTools(t) return err } @@ -202,7 +211,15 @@ func (r *Resolver) ResolveGraphQLResponse(ctx *Context, response *GraphQLRespons fetchTree = response.Data } - return t.resolvable.Resolve(ctx.ctx, response.Data, fetchTree, writer) + buf := r.getBuffer(t.resolvable.storage.Size()) + defer r.releaseBuffer(buf) + err = t.resolvable.Resolve(ctx.ctx, response.Data, fetchTree, buf) + r.putTools(t) + if err != nil { + return err + } + _, err = buf.WriteTo(writer) + return err } type trigger struct { @@ -245,10 +262,8 @@ func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput input := make([]byte, len(sharedInput)) copy(input, sharedInput) if err := t.resolvable.InitSubscription(ctx, input, sub.resolve.Trigger.PostProcessing); err != nil { - buf := pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(buf) sub.mux.Lock() - r.asyncErrorWriter.WriteError(ctx, err, sub.resolve.Response, sub.writer, buf) + r.asyncErrorWriter.WriteError(ctx, err, sub.resolve.Response, sub.writer) sub.pendingUpdates-- sub.mux.Unlock() if r.options.Debug { @@ -260,10 +275,8 @@ func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput return } if err := t.loader.LoadGraphQLResponseData(ctx, sub.resolve.Response, t.resolvable); err != nil { - buf := pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(buf) sub.mux.Lock() - r.asyncErrorWriter.WriteError(ctx, err, sub.resolve.Response, sub.writer, buf) + r.asyncErrorWriter.WriteError(ctx, err, sub.resolve.Response, sub.writer) sub.pendingUpdates-- sub.mux.Unlock() if r.options.Debug { @@ -284,9 +297,7 @@ func (r *Resolver) executeSubscriptionUpdate(ctx *Context, sub *sub, sharedInput return // subscription was already closed by the client } if err := t.resolvable.Resolve(ctx.ctx, sub.resolve.Response.Data, sub.resolve.Response.FetchTree, sub.writer); err != nil { - buf := pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(buf) - r.asyncErrorWriter.WriteError(ctx, err, sub.resolve.Response, sub.writer, buf) + r.asyncErrorWriter.WriteError(ctx, err, sub.resolve.Response, sub.writer) if r.options.Debug { fmt.Printf("resolver:trigger:subscription:resolve:failed:%d\n", sub.id.SubscriptionID) } @@ -452,11 +463,7 @@ func (r *Resolver) handleAddSubscription(triggerID uint64, add *addSubscription) if r.options.Debug { fmt.Printf("resolver:trigger:failed:%d\n", triggerID) } - buf := pool.BytesBuffer.Get() - defer pool.BytesBuffer.Put(buf) - - r.asyncErrorWriter.WriteError(add.ctx, err, add.resolve.Response, add.writer, buf) - + r.asyncErrorWriter.WriteError(add.ctx, err, add.resolve.Response, add.writer) r.emitTriggerShutdown(triggerID) return } @@ -579,9 +586,7 @@ func (r *Resolver) handleTriggerUpdate(id uint64, data []byte) { c, s := c, s skip, err := s.resolve.Filter.SkipEvent(c, data, r.triggerUpdateBuf) if err != nil { - buf := pool.BytesBuffer.Get() - r.asyncErrorWriter.WriteError(c, err, s.resolve.Response, s.writer, buf) - pool.BytesBuffer.Put(buf) + r.asyncErrorWriter.WriteError(c, err, s.resolve.Response, s.writer) continue } if skip { diff --git a/v2/pkg/engine/resolve/resolve_mock_test.go b/v2/pkg/engine/resolve/resolve_mock_test.go index 1c011bfda..25abcc3a3 100644 --- a/v2/pkg/engine/resolve/resolve_mock_test.go +++ b/v2/pkg/engine/resolve/resolve_mock_test.go @@ -5,8 +5,8 @@ package resolve import ( + bytes "bytes" context "context" - io "io" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -37,7 +37,7 @@ func (m *MockDataSource) EXPECT() *MockDataSourceMockRecorder { } // Load mocks base method. -func (m *MockDataSource) Load(arg0 context.Context, arg1 []byte, arg2 io.Writer) error { +func (m *MockDataSource) Load(arg0 context.Context, arg1 []byte, arg2 *bytes.Buffer) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "Load", arg0, arg1, arg2) ret0, _ := ret[0].(error) @@ -51,7 +51,7 @@ func (mr *MockDataSourceMockRecorder) Load(arg0, arg1, arg2 interface{}) *gomock } // LoadWithFiles mocks base method. -func (m *MockDataSource) LoadWithFiles(arg0 context.Context, arg1 []byte, arg2 []httpclient.File, arg3 io.Writer) error { +func (m *MockDataSource) LoadWithFiles(arg0 context.Context, arg1 []byte, arg2 []httpclient.File, arg3 *bytes.Buffer) error { m.ctrl.T.Helper() ret := m.ctrl.Call(m, "LoadWithFiles", arg0, arg1, arg2, arg3) ret0, _ := ret[0].(error) diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 745169959..79aabad74 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -29,7 +29,7 @@ type _fakeDataSource struct { artificialLatency time.Duration } -func (f *_fakeDataSource) Load(ctx context.Context, input []byte, w io.Writer) (err error) { +func (f *_fakeDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { if f.artificialLatency != 0 { time.Sleep(f.artificialLatency) } @@ -38,11 +38,11 @@ func (f *_fakeDataSource) Load(ctx context.Context, input []byte, w io.Writer) ( require.Equal(f.t, string(f.input), string(input), "input mismatch") } } - _, err = w.Write(f.data) + _, err = out.Write(f.data) return } -func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) { +func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files []httpclient.File, out *bytes.Buffer) (err error) { if f.artificialLatency != 0 { time.Sleep(f.artificialLatency) } @@ -51,7 +51,7 @@ func (f *_fakeDataSource) LoadWithFiles(ctx context.Context, input []byte, files require.Equal(f.t, string(f.input), string(input), "input mismatch") } } - _, err = w.Write(f.data) + _, err = out.Write(f.data) return } @@ -72,7 +72,7 @@ func fakeDataSourceWithInputCheck(t TestingTB, input []byte, data []byte) *_fake type TestErrorWriter struct { } -func (t *TestErrorWriter) WriteError(ctx *Context, err error, res *GraphQLResponse, w io.Writer, buf *bytes.Buffer) { +func (t *TestErrorWriter) WriteError(ctx *Context, err error, res *GraphQLResponse, w io.Writer) { _, err = w.Write([]byte(fmt.Sprintf(`{"errors":[{"message":"%s"}],"data":null}`, err.Error()))) if err != nil { panic(err)