Skip to content

Commit

Permalink
Normalize spooling encryption headers casing
Browse files Browse the repository at this point in the history
  • Loading branch information
wendigo committed Oct 3, 2024
1 parent 8bf23b4 commit 286e40c
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.spooling.filesystem.encryption.HeadersUtils.getOnlyHeader;

public class AzureEncryptionHeadersTranslator
class AzureEncryptionHeadersTranslator
implements EncryptionHeadersTranslator
{
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.List;
import java.util.Map;

import static io.trino.spooling.filesystem.encryption.HeadersUtils.normalizeHeaders;
import static java.util.Objects.requireNonNull;

public interface EncryptionHeadersTranslator
Expand All @@ -35,14 +36,40 @@ static EncryptionHeadersTranslator encryptionHeadersTranslator(Location location
.orElseThrow(() -> new IllegalArgumentException("Unknown location scheme: " + location));
}

private static EncryptionHeadersTranslator forScheme(String scheme)
static EncryptionHeadersTranslator forScheme(String scheme)
{
// These should match schemes supported in the FileSystemSpoolingModule
return switch (scheme) {
EncryptionHeadersTranslator schemeHeadersTranslator = switch (scheme) {
case "s3" -> new S3EncryptionHeadersTranslator();
case "gs" -> new GcsEncryptionHeadersTranslator();
case "abfs" -> new AzureEncryptionHeadersTranslator();
default -> throw new IllegalArgumentException("Unknown file system scheme: " + scheme);
};

// Normalize header case so it won't matter which case we will get from the client
return new NormalizingHeadersTranslator(schemeHeadersTranslator);
}

class NormalizingHeadersTranslator
implements EncryptionHeadersTranslator
{
private final EncryptionHeadersTranslator delegate;

NormalizingHeadersTranslator(EncryptionHeadersTranslator delegate)
{
this.delegate = requireNonNull(delegate, "delegate is null");
}

@Override
public EncryptionKey extractKey(Map<String, List<String>> headers)
{
return delegate.extractKey(normalizeHeaders(headers));
}

@Override
public Map<String, List<String>> createHeaders(EncryptionKey encryption)
{
return normalizeHeaders(delegate.createHeaders(encryption));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.spooling.filesystem.encryption.HeadersUtils.getOnlyHeader;

public class GcsEncryptionHeadersTranslator
class GcsEncryptionHeadersTranslator
implements EncryptionHeadersTranslator
{
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,33 @@
import java.util.Map;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Locale.ENGLISH;

public class HeadersUtils
{
private HeadersUtils() {}

public static String getOnlyHeader(Map<String, List<String>> headers, String name)
{
List<String> values = headers.get(name);
checkArgument(values != null && !values.isEmpty(), "Required header " + name + " was not found");
checkArgument(values.size() == 1, "Required header " + name + " contains more than one value");
String headerName = normalizeHeaderNameCase(name);
List<String> values = headers.get(headerName);
checkArgument(values != null && !values.isEmpty(), "Required header %s was not found", headerName);
checkArgument(values.size() == 1, "Required header %s contains more than one value", headerName);
return values.getFirst();
}

static Map<String, List<String>> normalizeHeaders(Map<String, List<String>> headers)
{
return headers.entrySet()
.stream()
.collect(toImmutableMap(
entry -> normalizeHeaderNameCase(entry.getKey()),
entry -> List.copyOf(entry.getValue())));
}

static String normalizeHeaderNameCase(String headerName)
{
return headerName.toLowerCase(ENGLISH);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@
import static com.google.common.base.Preconditions.checkArgument;
import static io.trino.spooling.filesystem.encryption.HeadersUtils.getOnlyHeader;

public class S3EncryptionHeadersTranslator
class S3EncryptionHeadersTranslator
implements EncryptionHeadersTranslator
{
@Override
public EncryptionKey extractKey(Map<String, List<String>> headers)
{
byte[] key = Base64.getDecoder().decode(getOnlyHeader(headers, "x-amz-server-side-encryption-customer-key"));
String md5Checksum = getOnlyHeader(headers, "x-amz-server-side-encryption-customer-key-MD5");
String md5Checksum = getOnlyHeader(headers, "x-amz-server-side-encryption-customer-key-md5");
EncryptionKey encryption = new EncryptionKey(key, getOnlyHeader(headers, "x-amz-server-side-encryption-customer-algorithm"));
checkArgument(md5(encryption).equals(md5Checksum), "Key MD5 checksum does not match");
return encryption;
Expand All @@ -44,7 +44,7 @@ public Map<String, List<String>> createHeaders(EncryptionKey encryption)
return ImmutableMap.of(
"x-amz-server-side-encryption-customer-algorithm", ImmutableList.of(encryption.algorithm()),
"x-amz-server-side-encryption-customer-key", ImmutableList.of(encoded(encryption)),
"x-amz-server-side-encryption-customer-key-MD5", ImmutableList.of(md5(encryption)));
"x-amz-server-side-encryption-customer-key-md5", ImmutableList.of(md5(encryption)));
}

public static String encoded(EncryptionKey key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,23 @@
package io.trino.spooling.filesystem.encryption;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterators;
import io.trino.filesystem.encryption.EncryptionKey;
import org.junit.jupiter.api.Test;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

class TestAzureEncryptionHeadersTranslator
{
private static final EncryptionHeadersTranslator SSE = new AzureEncryptionHeadersTranslator();
private static final EncryptionHeadersTranslator SSE = EncryptionHeadersTranslator.forScheme("abfs");

@Test
public void testKnownKey()
Expand All @@ -49,6 +53,14 @@ public void testRoundTrip()
assertThat(SSE.extractKey(SSE.createHeaders(key))).isEqualTo(key);
}

@Test
public void testRoundTripWithMixedCaseHeaders()
{
EncryptionKey key = EncryptionKey.randomAes256();
Map<String, List<String>> headers = mixCase(SSE.createHeaders(key));
assertThat(SSE.extractKey(headers)).isEqualTo(key);
}

@Test
public void testThrowsOnInvalidChecksum()
{
Expand All @@ -61,4 +73,18 @@ public void testThrowsOnInvalidChecksum()
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Key SHA256 checksum does not match");
}

private static Map<String, List<String>> mixCase(Map<String, List<String>> headers)
{
Iterator<Function<String, String>> iterator = Iterators.cycle(
String::toUpperCase,
value -> value.replaceFirst("x-ms-", "X-Ms-"),
value -> value.replaceFirst("x-ms-encryption", "X-ms-Encryption"));

return headers.entrySet()
.stream()
.collect(toImmutableMap(
entry -> iterator.next().apply(entry.getKey()),
Map.Entry::getValue));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,23 @@
package io.trino.spooling.filesystem.encryption;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterators;
import io.trino.filesystem.encryption.EncryptionKey;
import org.assertj.core.api.AssertionsForClassTypes;
import org.junit.jupiter.api.Test;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

class TestGcsEncryptionHeadersTranslator
{
private static final EncryptionHeadersTranslator SSE = new GcsEncryptionHeadersTranslator();
private static final EncryptionHeadersTranslator SSE = EncryptionHeadersTranslator.forScheme("gs");

@Test
public void testKnownKey()
Expand All @@ -47,7 +50,15 @@ public void testKnownKey()
public void testRoundTrip()
{
EncryptionKey key = EncryptionKey.randomAes256();
AssertionsForClassTypes.assertThat(SSE.extractKey(SSE.createHeaders(key))).isEqualTo(key);
assertThat(SSE.extractKey(SSE.createHeaders(key))).isEqualTo(key);
}

@Test
public void testRoundTripWithMixedCaseHeaders()
{
EncryptionKey key = EncryptionKey.randomAes256();
Map<String, List<String>> headers = mixCase(SSE.createHeaders(key));
assertThat(SSE.extractKey(headers)).isEqualTo(key);
}

@Test
Expand All @@ -62,4 +73,18 @@ public void testThrowsOnInvalidChecksum()
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Key SHA256 checksum does not match");
}

private static Map<String, List<String>> mixCase(Map<String, List<String>> headers)
{
Iterator<Function<String, String>> iterator = Iterators.cycle(
String::toUpperCase,
value -> value.replaceFirst("x-goog-", "X-Goog-"),
value -> value.replaceFirst("x-goog-encryption", "X-goog-Encryption"));

return headers.entrySet()
.stream()
.collect(toImmutableMap(
entry -> iterator.next().apply(entry.getKey()),
Map.Entry::getValue));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,23 @@
package io.trino.spooling.filesystem.encryption;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterators;
import io.trino.filesystem.encryption.EncryptionKey;
import org.junit.jupiter.api.Test;

import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;

import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

class TestS3EncryptionHeadersTranslator
{
private static final EncryptionHeadersTranslator SSE = new S3EncryptionHeadersTranslator();
private static final EncryptionHeadersTranslator SSE = EncryptionHeadersTranslator.forScheme("s3");

@Test
public void testKnownKey()
Expand All @@ -38,7 +42,7 @@ public void testKnownKey()
assertThat(headers)
.hasSize(3)
.containsEntry("x-amz-server-side-encryption-customer-key", List.of("VHJpbm9XaWxsRmx5V2l0aFNwb29sZWRQcm90b2NvbCE="))
.containsEntry("x-amz-server-side-encryption-customer-key-MD5", List.of("CX3f4fSIpiyVyQDCzuhDWg=="))
.containsEntry("x-amz-server-side-encryption-customer-key-md5", List.of("CX3f4fSIpiyVyQDCzuhDWg=="))
.containsEntry("x-amz-server-side-encryption-customer-algorithm", List.of("AES256"));
}

Expand All @@ -49,16 +53,38 @@ public void testRoundTrip()
assertThat(SSE.extractKey(SSE.createHeaders(key))).isEqualTo(key);
}

@Test
public void testRoundTripWithMixedCaseHeaders()
{
EncryptionKey key = EncryptionKey.randomAes256();
Map<String, List<String>> headers = mixCase(SSE.createHeaders(key));
assertThat(SSE.extractKey(headers)).isEqualTo(key);
}

@Test
public void testThrowsOnInvalidChecksum()
{
Map<String, List<String>> headers = ImmutableMap.of(
"x-amz-server-side-encryption-customer-key", List.of("VHJpbm9XaWxsRmx5V2l0aFNwb29sZWRQcm90b2NvbCE="),
"x-amz-server-side-encryption-customer-key-MD5", List.of("brokenchecksum"),
"x-amz-server-side-encryption-customer-key-md5", List.of("brokenchecksum"),
"x-amz-server-side-encryption-customer-algorithm", List.of("AES256"));

assertThatThrownBy(() -> SSE.extractKey(headers))
.isInstanceOf(IllegalArgumentException.class)
.hasMessageContaining("Key MD5 checksum does not match");
}

private static Map<String, List<String>> mixCase(Map<String, List<String>> headers)
{
Iterator<Function<String, String>> iterator = Iterators.cycle(
String::toUpperCase,
value -> value.replaceFirst("x-amz-", "X-Amz-"),
value -> value.replaceFirst("x-amz-server-side", "X-amz-Server-Side"));

return headers.entrySet()
.stream()
.collect(toImmutableMap(
entry -> iterator.next().apply(entry.getKey()),
Map.Entry::getValue));
}
}

0 comments on commit 286e40c

Please sign in to comment.