diff --git a/server/http.go b/server/http.go index 2c0f85840..59bf49cd2 100644 --- a/server/http.go +++ b/server/http.go @@ -212,7 +212,7 @@ func (h *httpCache) CacheHandler(w http.ResponseWriter, r *http.Request) { return } - if h.mangleACKeys && kind == cache.AC { + if h.mangleACKeys && (kind == cache.AC || kind == cache.RAW) { hash = cache.TransformActionCacheKey(hash, instance, h.accessLogger) } diff --git a/server/http_test.go b/server/http_test.go index baf1153a7..142d19c09 100644 --- a/server/http_test.go +++ b/server/http_test.go @@ -2,6 +2,7 @@ package server import ( "bytes" + "context" "crypto/sha256" "encoding/hex" "encoding/json" @@ -506,3 +507,59 @@ func TestRemoteReturnsNotFound(t *testing.T) { t.Errorf("Wrong status code, expected %d, got %d", http.StatusNotFound, statusCode) } } + +func TestManglingACKeys(t *testing.T) { + cacheDir, err := os.MkdirTemp("", "bazel-remote") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(cacheDir) + + blobSize := int64(1024) + cacheSize := blobSize*2 + disk.BlockSize + diskCache, err := disk.New(cacheDir, cacheSize, disk.WithAccessLogger(testutils.NewSilentLogger())) + if err != nil { + t.Fatal(err) + } + + h := NewHTTPCache(diskCache, testutils.NewSilentLogger(), testutils.NewSilentLogger(), false, true, false, false, "") + // create a fake http.Request + data, hash := testutils.RandomDataAndHash(blobSize) + err = diskCache.Put(context.Background(), cache.RAW, hash, int64(len(data)), bytes.NewReader(data)) + if err != nil { + t.Fatal(err) + } + + url, _ := url.Parse(fmt.Sprintf("http://localhost:8080/ac/%s", hash)) + reader := bytes.NewReader([]byte{}) + body := io.NopCloser(reader) + req := &http.Request{ + Method: "GET", + URL: url, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Body: body, + } + statusCode := 0 + respWriter := &fakeResponseWriter{ + statusCode: &statusCode, + } + h.CacheHandler(respWriter, req) + if statusCode != 0 { + t.Errorf("Wrong status code, expected %d, got %d", 0, statusCode) + } + + url, _ = url.Parse(fmt.Sprintf("http://localhost:8080/test-instance/ac/%s", hash)) + reader.Reset([]byte{}) + body.Close() + body = io.NopCloser(reader) + req.URL = url + req.Body = body + statusCode = 0 + + h.CacheHandler(respWriter, req) + if statusCode != http.StatusNotFound { + t.Errorf("Wrong status code, expected %d, got %d", http.StatusNotFound, statusCode) + } +}