diff --git a/cache/async_cache.go b/cache/async_cache.go index 6ecd3d67..1b9ccdd2 100644 --- a/cache/async_cache.go +++ b/cache/async_cache.go @@ -20,6 +20,8 @@ type AsyncCache struct { TransactionRegistry graceTime time.Duration + + MaxPayloadSize int64 } func (c *AsyncCache) Close() error { @@ -103,9 +105,12 @@ func NewAsyncCache(cfg config.Cache, maxExecutionTime time.Duration) (*AsyncCach return nil, err } + maxPayloadSize := int64(cfg.MaxPayloadSize) + return &AsyncCache{ Cache: cache, TransactionRegistry: transaction, graceTime: graceTime, + MaxPayloadSize: maxPayloadSize, }, nil } diff --git a/cache/async_cache_test.go b/cache/async_cache_test.go index ed059cb1..e6a07b1e 100644 --- a/cache/async_cache_test.go +++ b/cache/async_cache_test.go @@ -214,6 +214,7 @@ func TestAsyncCache_FilesystemCache_instantiation(t *testing.T) { MaxSize: 8192, }, Expire: config.Duration(time.Minute), + MaxPayloadSize: config.ByteSize(100000000), } if err := os.RemoveAll(testDirAsync); err != nil { log.Fatalf("cannot remove %q: %s", testDirAsync, err) @@ -249,6 +250,7 @@ func TestAsyncCache_RedisCache_instantiation(t *testing.T) { Addresses: []string{s.Addr()}, }, Expire: config.Duration(cacheTTL), + MaxPayloadSize: config.ByteSize(100000000), } _, err := NewAsyncCache(redisCfg, 1*time.Second) diff --git a/config/config.go b/config/config.go index 85656940..ecc84649 100644 --- a/config/config.go +++ b/config/config.go @@ -32,6 +32,8 @@ var ( } defaultExecutionTime = Duration(30 * time.Second) + + defaultMaxPayloadSize = ByteSize(100000000) ) // Config describes server configuration, access and proxy rules @@ -609,6 +611,9 @@ type Cache struct { // Catches all undefined fields XXX map[string]interface{} `yaml:",inline"` + + // Maximum total size of request payload for caching + MaxPayloadSize ByteSize `yaml:"max_payload_size,omitempty"` } type FileSystemCacheConfig struct { @@ -820,6 +825,13 @@ func LoadFile(filename string) (*Config, error) { } } + for i := range cfg.Caches { + c := &cfg.Caches[i] + if c.MaxPayloadSize <= 0 { + c.MaxPayloadSize = defaultMaxPayloadSize + } + } + if maxResponseTime < 0 { maxResponseTime = 0 } diff --git a/config/config_test.go b/config/config_test.go index fa976e50..9a2f6706 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -21,8 +21,9 @@ var fullConfig = Config{ Dir: "/path/to/longterm/cachedir", MaxSize: ByteSize(100 << 30), }, - Expire: Duration(time.Hour), - GraceTime: Duration(20 * time.Second), + Expire: Duration(time.Hour), + GraceTime: Duration(20 * time.Second), + MaxPayloadSize: ByteSize(100 << 30), }, { Name: "shortterm", @@ -31,7 +32,8 @@ var fullConfig = Config{ Dir: "/path/to/shortterm/cachedir", MaxSize: ByteSize(100 << 20), }, - Expire: Duration(10 * time.Second), + Expire: Duration(10 * time.Second), + MaxPayloadSize: ByteSize(100 << 20), }, }, HackMePlease: true, @@ -452,6 +454,11 @@ func TestBadConfig(t *testing.T) { "testdata/bad.heartbeat_section.empty2.yml", "cannot be use `heartbeat_interval` with `heartbeat` section", }, + { + "max payload size to cache", + "testdata/bad.max_payload_size.yml", + "cannot parse byte size \"-10B\": it must be positive float followed by optional units. For example, 1Gb, 100Mb", + }, } for _, tc := range testCases { @@ -827,12 +834,14 @@ caches: file_system: dir: /path/to/longterm/cachedir max_size: 107374182400 + max_payload_size: 107374182400 - mode: file_system name: shortterm expire: 10s file_system: dir: /path/to/shortterm/cachedir max_size: 104857600 + max_payload_size: 104857600 param_groups: - name: cron-job params: diff --git a/config/testdata/bad.max_payload_size.yml b/config/testdata/bad.max_payload_size.yml new file mode 100644 index 00000000..3607e06c --- /dev/null +++ b/config/testdata/bad.max_payload_size.yml @@ -0,0 +1,21 @@ +caches: + - name: "longterm" + mode: "file_system" + max_payload_size: "-10B" + file_system: + dir: "cache_dir" + max_size: 100Gb + +server: + http: + listen_addr: ":8080" + +users: + - name: "dummy" + allowed_networks: ["1.2.3.4"] + to_cluster: "cluster" + to_user: "default" + +clusters: + - name: "cluster" + nodes: ["127.0.1.1:8123"] diff --git a/config/testdata/full.yml b/config/testdata/full.yml index 450d66ad..bdfb4e35 100644 --- a/config/testdata/full.yml +++ b/config/testdata/full.yml @@ -26,6 +26,8 @@ caches: # Path to directory where cached responses will be stored. dir: "/path/to/longterm/cachedir" + max_payload_size: 100Gb + # Expiration time for cached responses. expire: 1h @@ -44,6 +46,7 @@ caches: file_system: max_size: 100Mb dir: "/path/to/shortterm/cachedir" + max_payload_size: 100Mb expire: 10s # Optional network lists, might be used as values for `allowed_networks`. diff --git a/proxy.go b/proxy.go index 2d89fad4..d4cdd212 100644 --- a/proxy.go +++ b/proxy.go @@ -339,25 +339,32 @@ func (rp *reverseProxy) serveFromCache(s *scope, srw *statResponseWriter, req *h contentLength := bufferedRespWriter.GetCapturedContentLength() reader := bufferedRespWriter.Reader() - // we create this buffer to be able to stream data both to cache as well as to an end user - var buf bytes.Buffer - tee := io.TeeReader(reader, &buf) - contentMetadata := cache.ContentMetadata{Length: contentLength, Encoding: contentEncoding, Type: contentType} - expiration, err := userCache.Put(tee, contentMetadata, key) - if err != nil { - log.Errorf("%s: %s; query: %q - failed to put response in the cache", s, err, q) - } + if isToCache(contentLength, s) { + // we create this buffer to be able to stream data both to cache as well as to an end user + var buf bytes.Buffer + tee := io.TeeReader(reader, &buf) + contentMetadata := cache.ContentMetadata{Length: contentLength, Encoding: contentEncoding, Type: contentType} + expiration, err := userCache.Put(tee, contentMetadata, key) + if err != nil { + log.Errorf("%s: %s; query: %q - failed to put response in the cache", s, err, q) + } - // mark transaction as completed - if err = userCache.Complete(key); err != nil { - log.Errorf("%s: %s; query: %q", s, err, q) - } + // mark transaction as completed + if err = userCache.Complete(key); err != nil { + log.Errorf("%s: %s; query: %q", s, err, q) + } - err = RespondWithData(srw, &buf, contentMetadata, expiration, bufferedRespWriter.StatusCode()) - if err != nil { - err = fmt.Errorf("%s: %w; query: %q", s, err, q) - respondWith(srw, err, http.StatusInternalServerError) - return + err = RespondWithData(srw, &buf, contentMetadata, expiration, bufferedRespWriter.StatusCode()) + if err != nil { + err = fmt.Errorf("%s: %w; query: %q", s, err, q) + respondWith(srw, err, http.StatusInternalServerError) + return + } + } else { + err = RespondWithoutData(srw) + if err != nil { + log.Errorf("%s: %s; query: %q - failed to put response in the cache", s, err, q) + } } } } diff --git a/utils.go b/utils.go index 6083c223..67978b19 100644 --- a/utils.go +++ b/utils.go @@ -288,3 +288,8 @@ func calcMapHash(m map[string]string) (uint32, error) { } return h.Sum32(), nil } + +func isToCache(length int64, s *scope) bool { + maxPayloadSize := s.user.cache.MaxPayloadSize + return length <= maxPayloadSize +} diff --git a/utils_test.go b/utils_test.go index 39024b8c..9fbbc31d 100644 --- a/utils_test.go +++ b/utils_test.go @@ -354,13 +354,13 @@ func TestCalcMapHash(t *testing.T) { }, { "map with multiple value", - map[string]string{"param_table_name": "clients","param_database":"default"}, + map[string]string{"param_table_name": "clients", "param_database": "default"}, 0x6fddf04d, nil, }, { "map with exchange value", - map[string]string{"param_database":"default","param_table_name":"clients"}, + map[string]string{"param_database": "default", "param_table_name": "clients"}, 0x6fddf04d, nil, },