From 4956ce1c21b4c45ceacc7c916c91685a3677182c Mon Sep 17 00:00:00 2001 From: "Mateusz \"Serafin\" Gajewski" Date: Thu, 3 Oct 2024 14:17:01 +0200 Subject: [PATCH] Normalize spooling encryption headers casing --- .../AzureEncryptionHeadersTranslator.java | 2 +- .../EncryptionHeadersTranslator.java | 31 ++++++++++++++++-- .../GcsEncryptionHeadersTranslator.java | 2 +- .../filesystem/encryption/HeadersUtils.java | 23 +++++++++++-- .../S3EncryptionHeadersTranslator.java | 6 ++-- .../TestAzureEncryptionHeadersTranslator.java | 28 +++++++++++++++- .../TestGcsEncryptionHeadersTranslator.java | 31 ++++++++++++++++-- .../TestS3EncryptionHeadersTranslator.java | 32 +++++++++++++++++-- 8 files changed, 138 insertions(+), 17 deletions(-) diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/AzureEncryptionHeadersTranslator.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/AzureEncryptionHeadersTranslator.java index ec638e0ecf41..98f9bc5c64a2 100644 --- a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/AzureEncryptionHeadersTranslator.java +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/AzureEncryptionHeadersTranslator.java @@ -24,7 +24,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spooling.filesystem.encryption.HeadersUtils.getOnlyHeader; -public class AzureEncryptionHeadersTranslator +class AzureEncryptionHeadersTranslator implements EncryptionHeadersTranslator { @Override diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/EncryptionHeadersTranslator.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/EncryptionHeadersTranslator.java index 854089641fcb..5254d35dcbd3 100644 --- a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/EncryptionHeadersTranslator.java +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/EncryptionHeadersTranslator.java @@ -19,6 +19,7 @@ import java.util.List; import java.util.Map; +import static io.trino.spooling.filesystem.encryption.HeadersUtils.normalizeHeaders; import static java.util.Objects.requireNonNull; public interface EncryptionHeadersTranslator @@ -35,14 +36,40 @@ static EncryptionHeadersTranslator encryptionHeadersTranslator(Location location .orElseThrow(() -> new IllegalArgumentException("Unknown location scheme: " + location)); } - private static EncryptionHeadersTranslator forScheme(String scheme) + static EncryptionHeadersTranslator forScheme(String scheme) { // These should match schemes supported in the FileSystemSpoolingModule - return switch (scheme) { + EncryptionHeadersTranslator schemeHeadersTranslator = switch (scheme) { case "s3" -> new S3EncryptionHeadersTranslator(); case "gs" -> new GcsEncryptionHeadersTranslator(); case "abfs" -> new AzureEncryptionHeadersTranslator(); default -> throw new IllegalArgumentException("Unknown file system scheme: " + scheme); }; + + // Normalize header case so it won't matter which case we will get from the client + return new NormalizingHeadersTranslator(schemeHeadersTranslator); + } + + class NormalizingHeadersTranslator + implements EncryptionHeadersTranslator + { + private final EncryptionHeadersTranslator delegate; + + NormalizingHeadersTranslator(EncryptionHeadersTranslator delegate) + { + this.delegate = requireNonNull(delegate, "delegate is null"); + } + + @Override + public EncryptionKey extractKey(Map> headers) + { + return delegate.extractKey(normalizeHeaders(headers)); + } + + @Override + public Map> createHeaders(EncryptionKey encryption) + { + return normalizeHeaders(delegate.createHeaders(encryption)); + } } } diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/GcsEncryptionHeadersTranslator.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/GcsEncryptionHeadersTranslator.java index f4fa875b3791..440b53980a03 100644 --- a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/GcsEncryptionHeadersTranslator.java +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/GcsEncryptionHeadersTranslator.java @@ -25,7 +25,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spooling.filesystem.encryption.HeadersUtils.getOnlyHeader; -public class GcsEncryptionHeadersTranslator +class GcsEncryptionHeadersTranslator implements EncryptionHeadersTranslator { @Override diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/HeadersUtils.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/HeadersUtils.java index 40532a76f1c7..96e8a96d4706 100644 --- a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/HeadersUtils.java +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/HeadersUtils.java @@ -17,6 +17,8 @@ import java.util.Map; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Locale.ENGLISH; public class HeadersUtils { @@ -24,9 +26,24 @@ private HeadersUtils() {} public static String getOnlyHeader(Map> headers, String name) { - List values = headers.get(name); - checkArgument(values != null && !values.isEmpty(), "Required header " + name + " was not found"); - checkArgument(values.size() == 1, "Required header " + name + " contains more than one value"); + String headerName = normalizeHeaderNameCase(name); + List values = headers.get(headerName); + checkArgument(values != null && !values.isEmpty(), "Required header %s was not found", headerName); + checkArgument(values.size() == 1, "Required header %s contains more than one value", headerName); return values.getFirst(); } + + static Map> normalizeHeaders(Map> headers) + { + return headers.entrySet() + .stream() + .collect(toImmutableMap( + entry -> normalizeHeaderNameCase(entry.getKey()), + entry -> List.copyOf(entry.getValue()))); + } + + static String normalizeHeaderNameCase(String headerName) + { + return headerName.toLowerCase(ENGLISH); + } } diff --git a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/S3EncryptionHeadersTranslator.java b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/S3EncryptionHeadersTranslator.java index de902cabe4e2..86f461e24d78 100644 --- a/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/S3EncryptionHeadersTranslator.java +++ b/plugin/trino-spooling-filesystem/src/main/java/io/trino/spooling/filesystem/encryption/S3EncryptionHeadersTranslator.java @@ -25,14 +25,14 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.trino.spooling.filesystem.encryption.HeadersUtils.getOnlyHeader; -public class S3EncryptionHeadersTranslator +class S3EncryptionHeadersTranslator implements EncryptionHeadersTranslator { @Override public EncryptionKey extractKey(Map> headers) { byte[] key = Base64.getDecoder().decode(getOnlyHeader(headers, "x-amz-server-side-encryption-customer-key")); - String md5Checksum = getOnlyHeader(headers, "x-amz-server-side-encryption-customer-key-MD5"); + String md5Checksum = getOnlyHeader(headers, "x-amz-server-side-encryption-customer-key-md5"); EncryptionKey encryption = new EncryptionKey(key, getOnlyHeader(headers, "x-amz-server-side-encryption-customer-algorithm")); checkArgument(md5(encryption).equals(md5Checksum), "Key MD5 checksum does not match"); return encryption; @@ -44,7 +44,7 @@ public Map> createHeaders(EncryptionKey encryption) return ImmutableMap.of( "x-amz-server-side-encryption-customer-algorithm", ImmutableList.of(encryption.algorithm()), "x-amz-server-side-encryption-customer-key", ImmutableList.of(encoded(encryption)), - "x-amz-server-side-encryption-customer-key-MD5", ImmutableList.of(md5(encryption))); + "x-amz-server-side-encryption-customer-key-md5", ImmutableList.of(md5(encryption))); } public static String encoded(EncryptionKey key) diff --git a/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestAzureEncryptionHeadersTranslator.java b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestAzureEncryptionHeadersTranslator.java index a7fbd295d9ed..f46f6daa8c7b 100644 --- a/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestAzureEncryptionHeadersTranslator.java +++ b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestAzureEncryptionHeadersTranslator.java @@ -14,19 +14,23 @@ package io.trino.spooling.filesystem.encryption; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterators; import io.trino.filesystem.encryption.EncryptionKey; import org.junit.jupiter.api.Test; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.function.Function; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; class TestAzureEncryptionHeadersTranslator { - private static final EncryptionHeadersTranslator SSE = new AzureEncryptionHeadersTranslator(); + private static final EncryptionHeadersTranslator SSE = EncryptionHeadersTranslator.forScheme("abfs"); @Test public void testKnownKey() @@ -49,6 +53,14 @@ public void testRoundTrip() assertThat(SSE.extractKey(SSE.createHeaders(key))).isEqualTo(key); } + @Test + public void testRoundTripWithMixedCaseHeaders() + { + EncryptionKey key = EncryptionKey.randomAes256(); + Map> headers = mixCase(SSE.createHeaders(key)); + assertThat(SSE.extractKey(headers)).isEqualTo(key); + } + @Test public void testThrowsOnInvalidChecksum() { @@ -61,4 +73,18 @@ public void testThrowsOnInvalidChecksum() .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Key SHA256 checksum does not match"); } + + private static Map> mixCase(Map> headers) + { + Iterator> iterator = Iterators.cycle( + String::toUpperCase, + value -> value.replaceFirst("x-ms-", "X-Ms-"), + value -> value.replaceFirst("x-ms-encryption", "X-ms-Encryption")); + + return headers.entrySet() + .stream() + .collect(toImmutableMap( + entry -> iterator.next().apply(entry.getKey()), + Map.Entry::getValue)); + } } diff --git a/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestGcsEncryptionHeadersTranslator.java b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestGcsEncryptionHeadersTranslator.java index 2b58a85527ad..a0d77ce57cdf 100644 --- a/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestGcsEncryptionHeadersTranslator.java +++ b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestGcsEncryptionHeadersTranslator.java @@ -14,20 +14,23 @@ package io.trino.spooling.filesystem.encryption; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterators; import io.trino.filesystem.encryption.EncryptionKey; -import org.assertj.core.api.AssertionsForClassTypes; import org.junit.jupiter.api.Test; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.function.Function; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; class TestGcsEncryptionHeadersTranslator { - private static final EncryptionHeadersTranslator SSE = new GcsEncryptionHeadersTranslator(); + private static final EncryptionHeadersTranslator SSE = EncryptionHeadersTranslator.forScheme("gs"); @Test public void testKnownKey() @@ -47,7 +50,15 @@ public void testKnownKey() public void testRoundTrip() { EncryptionKey key = EncryptionKey.randomAes256(); - AssertionsForClassTypes.assertThat(SSE.extractKey(SSE.createHeaders(key))).isEqualTo(key); + assertThat(SSE.extractKey(SSE.createHeaders(key))).isEqualTo(key); + } + + @Test + public void testRoundTripWithMixedCaseHeaders() + { + EncryptionKey key = EncryptionKey.randomAes256(); + Map> headers = mixCase(SSE.createHeaders(key)); + assertThat(SSE.extractKey(headers)).isEqualTo(key); } @Test @@ -62,4 +73,18 @@ public void testThrowsOnInvalidChecksum() .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Key SHA256 checksum does not match"); } + + private static Map> mixCase(Map> headers) + { + Iterator> iterator = Iterators.cycle( + String::toUpperCase, + value -> value.replaceFirst("x-goog-", "X-Goog-"), + value -> value.replaceFirst("x-goog-encryption", "X-goog-Encryption")); + + return headers.entrySet() + .stream() + .collect(toImmutableMap( + entry -> iterator.next().apply(entry.getKey()), + Map.Entry::getValue)); + } } diff --git a/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestS3EncryptionHeadersTranslator.java b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestS3EncryptionHeadersTranslator.java index 71267cd4cb33..a7a71974f0bd 100644 --- a/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestS3EncryptionHeadersTranslator.java +++ b/plugin/trino-spooling-filesystem/src/test/java/io/trino/spooling/filesystem/encryption/TestS3EncryptionHeadersTranslator.java @@ -14,19 +14,23 @@ package io.trino.spooling.filesystem.encryption; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Iterators; import io.trino.filesystem.encryption.EncryptionKey; import org.junit.jupiter.api.Test; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.function.Function; +import static com.google.common.collect.ImmutableMap.toImmutableMap; import static java.nio.charset.StandardCharsets.UTF_8; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; class TestS3EncryptionHeadersTranslator { - private static final EncryptionHeadersTranslator SSE = new S3EncryptionHeadersTranslator(); + private static final EncryptionHeadersTranslator SSE = EncryptionHeadersTranslator.forScheme("s3"); @Test public void testKnownKey() @@ -38,7 +42,7 @@ public void testKnownKey() assertThat(headers) .hasSize(3) .containsEntry("x-amz-server-side-encryption-customer-key", List.of("VHJpbm9XaWxsRmx5V2l0aFNwb29sZWRQcm90b2NvbCE=")) - .containsEntry("x-amz-server-side-encryption-customer-key-MD5", List.of("CX3f4fSIpiyVyQDCzuhDWg==")) + .containsEntry("x-amz-server-side-encryption-customer-key-md5", List.of("CX3f4fSIpiyVyQDCzuhDWg==")) .containsEntry("x-amz-server-side-encryption-customer-algorithm", List.of("AES256")); } @@ -49,16 +53,38 @@ public void testRoundTrip() assertThat(SSE.extractKey(SSE.createHeaders(key))).isEqualTo(key); } + @Test + public void testRoundTripWithMixedCaseHeaders() + { + EncryptionKey key = EncryptionKey.randomAes256(); + Map> headers = mixCase(SSE.createHeaders(key)); + assertThat(SSE.extractKey(headers)).isEqualTo(key); + } + @Test public void testThrowsOnInvalidChecksum() { Map> headers = ImmutableMap.of( "x-amz-server-side-encryption-customer-key", List.of("VHJpbm9XaWxsRmx5V2l0aFNwb29sZWRQcm90b2NvbCE="), - "x-amz-server-side-encryption-customer-key-MD5", List.of("brokenchecksum"), + "x-amz-server-side-encryption-customer-key-md5", List.of("brokenchecksum"), "x-amz-server-side-encryption-customer-algorithm", List.of("AES256")); assertThatThrownBy(() -> SSE.extractKey(headers)) .isInstanceOf(IllegalArgumentException.class) .hasMessageContaining("Key MD5 checksum does not match"); } + + private static Map> mixCase(Map> headers) + { + Iterator> iterator = Iterators.cycle( + String::toUpperCase, + value -> value.replaceFirst("x-amz-", "X-Amz-"), + value -> value.replaceFirst("x-amz-server-side", "X-amz-Server-Side")); + + return headers.entrySet() + .stream() + .collect(toImmutableMap( + entry -> iterator.next().apply(entry.getKey()), + Map.Entry::getValue)); + } }