Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SNOW-1858529 Implement FLOE #2006

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ public class EncryptionProvider {
private static final String FILE_CIPHER = "AES/CBC/PKCS5Padding";
private static final String KEY_CIPHER = "AES/ECB/PKCS5Padding";
private static final int BUFFER_SIZE = 2 * 1024 * 1024; // 2 MB
private static ThreadLocal<SecureRandom> secRnd =
new ThreadLocal<>().withInitial(SecureRandom::new);
private static ThreadLocal<SecureRandom> secRnd = ThreadLocal.withInitial(SecureRandom::new);

/**
* Decrypt a InputStream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,18 @@
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.MatDesc;
import net.snowflake.common.core.RemoteStoreFileEncryptionMaterial;

class GcmEncryptionProvider {
@SnowflakeJdbcInternalApi
public class GcmEncryptionProvider {
private static final int TAG_LENGTH_IN_BITS = 128;
private static final int IV_LENGTH_IN_BYTES = 12;
private static final String AES = "AES";
private static final String FILE_CIPHER = "AES/GCM/NoPadding";
private static final String KEY_CIPHER = "AES/GCM/NoPadding";
private static final int BUFFER_SIZE = 8 * 1024 * 1024; // 2 MB
private static final ThreadLocal<SecureRandom> random =
new ThreadLocal<>().withInitial(SecureRandom::new);
ThreadLocal.withInitial(SecureRandom::new);
private static final Base64.Decoder base64Decoder = Base64.getDecoder();

static InputStream encrypt(
Expand Down Expand Up @@ -85,7 +85,7 @@ private static byte[] encryptKey(byte[] kekBytes, byte[] keyBytes, byte[] keyIvD
BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException {
SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES);
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, keyIvData);
Cipher keyCipher = Cipher.getInstance(KEY_CIPHER);
Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME);
keyCipher.init(Cipher.ENCRYPT_MODE, kek, gcmParameterSpec);
if (aad != null) {
keyCipher.updateAAD(aad);
Expand All @@ -99,7 +99,7 @@ private static CipherInputStream encryptContent(
NoSuchAlgorithmException {
SecretKey fileKey = new SecretKeySpec(keyBytes, 0, keyBytes.length, AES);
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, dataIvBytes);
Cipher fileCipher = Cipher.getInstance(FILE_CIPHER);
Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME);
fileCipher.init(Cipher.ENCRYPT_MODE, fileKey, gcmParameterSpec);
if (aad != null) {
fileCipher.updateAAD(aad);
Expand Down Expand Up @@ -172,7 +172,7 @@ private static CipherInputStream decryptContentFromStream(
NoSuchAlgorithmException {
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes);
SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES);
Cipher fileCipher = Cipher.getInstance(FILE_CIPHER);
Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME);
fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec);
if (aad != null) {
fileCipher.updateAAD(aad);
Expand All @@ -187,7 +187,7 @@ private static void decryptContentFromFile(
SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES);
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, cekIvBytes);
byte[] buffer = new byte[BUFFER_SIZE];
Cipher fileCipher = Cipher.getInstance(FILE_CIPHER);
Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME);
fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec);
if (aad != null) {
fileCipher.updateAAD(aad);
Expand Down Expand Up @@ -215,7 +215,7 @@ private static byte[] decryptKey(byte[] kekBytes, byte[] ivBytes, byte[] keyByte
BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException {
SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES);
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes);
Cipher keyCipher = Cipher.getInstance(KEY_CIPHER);
Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME);
keyCipher.init(Cipher.DECRYPT_MODE, kek, gcmParameterSpec);
if (aad != null) {
keyCipher.updateAAD(aad);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import net.snowflake.client.jdbc.cloud.storage.floe.aead.Gcm;

public enum Aead {
// TODO confirm id
AES_GCM_256((byte) 0, "AES/GCM/NoPadding", 32, 12, 16, new Gcm(16)),
AES_GCM_128((byte) 1, "AES/GCM/NoPadding", 16, 12, 16, new Gcm(16));

private byte id;
private String jceName;
private int keyLength;
private int ivLength;
private int authTagLength;
private AeadProvider aeadProvider;

Aead(
byte id,
String jceName,
int keyLength,
int ivLength,
int authTagLength,
AeadProvider aeadProvider) {
this.jceName = jceName;
this.keyLength = keyLength;
this.id = id;
this.ivLength = ivLength;
this.authTagLength = authTagLength;
this.aeadProvider = aeadProvider;
}

byte getId() {
return id;
}

String getJceName() {
return jceName;
}

int getKeyLength() {
return keyLength;
}

int getIvLength() {
return ivLength;
}

int getAuthTagLength() {
return authTagLength;
}

AeadProvider getAeadProvider() {
return aeadProvider;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import java.nio.ByteBuffer;

class AeadAad {
private final byte[] bytes;

private AeadAad(long segmentCounter, byte terminalityByte) {
ByteBuffer buf = ByteBuffer.allocate(9);
buf.putLong(segmentCounter);
buf.put(terminalityByte);
this.bytes = buf.array();
}

static AeadAad nonTerminal(long segmentCounter) {
return new AeadAad(segmentCounter, (byte) 0);
}

byte[] getBytes() {
return bytes;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import java.nio.ByteBuffer;

class AeadIv {
private final byte[] bytes;

AeadIv(byte[] bytes) {
this.bytes = bytes;
}

public static AeadIv generateRandom(FloeRandom floeRandom, int ivLength) {
return new AeadIv(floeRandom.ofLength(ivLength));
}

public static AeadIv from(ByteBuffer buffer, int ivLength) {
byte[] bytes = new byte[ivLength];
buffer.get(bytes);
return new AeadIv(bytes);
}

byte[] getBytes() {
return bytes;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import javax.crypto.SecretKey;

class AeadKey {
private final SecretKey key;

AeadKey(SecretKey key) {
this.key = key;
}

SecretKey getKey() {
return key;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import java.security.GeneralSecurityException;
import javax.crypto.SecretKey;

public interface AeadProvider {
byte[] encrypt(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext)
throws GeneralSecurityException;

byte[] decrypt(SecretKey key, byte[] iv, byte[] aad, byte[] ciphertext)
throws GeneralSecurityException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;

abstract class BaseSegmentProcessor {
protected static final int NON_TERMINAL_SEGMENT_SIZE_MARKER = -1;
protected static final int headerTagLength = 32;

protected final FloeParameterSpec parameterSpec;
protected final FloeKey floeKey;
protected final FloeAad floeAad;

protected final KeyDerivator keyDerivator;

private AeadKey currentAeadKey;

BaseSegmentProcessor(FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad) {
this.parameterSpec = parameterSpec;
this.floeKey = floeKey;
this.floeAad = floeAad;
this.keyDerivator = new KeyDerivator(parameterSpec);
}

protected AeadKey getKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) {
if (currentAeadKey == null || segmentCounter % parameterSpec.getKeyRotationModulo() == 0) {
currentAeadKey = deriveKey(floeKey, floeIv, floeAad, segmentCounter);
}
return currentAeadKey;
}

private AeadKey deriveKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) {
byte[] keyBytes =
keyDerivator.hkdfExpand(
floeKey,
floeIv,
floeAad,
new DekTagFloePurpose(segmentCounter),
parameterSpec.getAead().getKeyLength());
SecretKey key =
new SecretKeySpec(keyBytes, "AES"); // for now it is safe as we use only AES as AEAD
return new AeadKey(key);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import javax.crypto.SecretKey;

public class Floe {
private final FloeParameterSpec parameterSpec;

private Floe(FloeParameterSpec parameterSpec) {
this.parameterSpec = parameterSpec;
}

public static Floe getInstance(FloeParameterSpec parameterSpec) {
return new Floe(parameterSpec);
}

public FloeEncryptor createEncryptor(SecretKey key, byte[] aad) {
return new FloeEncryptorImpl(parameterSpec, new FloeKey(key), new FloeAad(aad));
}

public FloeDecryptor createDecryptor(SecretKey key, byte[] aad, byte[] floeHeader) {
return new FloeDecryptorImpl(parameterSpec, new FloeKey(key), new FloeAad(aad), floeHeader);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import java.util.Optional;

class FloeAad {
private final byte[] aad;

FloeAad(byte[] aad) {
this.aad = Optional.ofNullable(aad).orElse(new byte[0]);
}

byte[] getBytes() {
return aad;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

public interface FloeDecryptor extends SegmentProcessor {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.Arrays;

public class FloeDecryptorImpl extends BaseSegmentProcessor implements FloeDecryptor {
private final FloeIv floeIv;
private long segmentCounter;

FloeDecryptorImpl(
FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad, byte[] floeHeaderAsBytes) {
super(parameterSpec, floeKey, floeAad);
byte[] encodedParams = this.parameterSpec.paramEncode();
if (floeHeaderAsBytes.length
!= encodedParams.length
+ this.parameterSpec.getFloeIvLength().getLength()
+ headerTagLength) {
throw new IllegalArgumentException("invalid header length");
}
ByteBuffer floeHeader = ByteBuffer.wrap(floeHeaderAsBytes);

byte[] encodedParamsFromHeader = new byte[10];
floeHeader.get(encodedParamsFromHeader, 0, encodedParamsFromHeader.length);
if (!Arrays.equals(encodedParams, encodedParamsFromHeader)) {
throw new IllegalArgumentException("invalid parameters header");
}

byte[] floeIvBytes = new byte[this.parameterSpec.getFloeIvLength().getLength()];
floeHeader.get(floeIvBytes, 0, floeIvBytes.length);
this.floeIv = new FloeIv(floeIvBytes);

byte[] headerTagFromHeader = new byte[headerTagLength];
floeHeader.get(headerTagFromHeader, 0, headerTagFromHeader.length);

byte[] headerTag =
keyDerivator.hkdfExpand(
this.floeKey, floeIv, this.floeAad, HeaderTagFloePurpose.INSTANCE, headerTagLength);
if (!Arrays.equals(headerTag, headerTagFromHeader)) {
throw new IllegalArgumentException("invalid header tag");
}
}

@Override
public byte[] processSegment(byte[] input) {
try {
verifySegmentLength(input);
ByteBuffer inputBuf = ByteBuffer.wrap(input);
verifySegmentSizeMarker(inputBuf);
AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter);
AeadIv aeadIv = AeadIv.from(inputBuf, parameterSpec.getAead().getIvLength());
AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter++);
AeadProvider aeadProvider = parameterSpec.getAead().getAeadProvider();
byte[] ciphertext = new byte[inputBuf.remaining()];
inputBuf.get(ciphertext);
return aeadProvider.decrypt(
aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), ciphertext);
} catch (GeneralSecurityException e) {
throw new RuntimeException(e);
}
}

private void verifySegmentLength(byte[] input) {
if (input.length != parameterSpec.getEncryptedSegmentLength()) {
throw new IllegalArgumentException(
String.format(
"segment length mismatch, expected %d, got %d",
parameterSpec.getEncryptedSegmentLength(), input.length));
}
}

private void verifySegmentSizeMarker(ByteBuffer inputBuf) {
int segmentSizeMarker = inputBuf.getInt();
if (segmentSizeMarker != NON_TERMINAL_SEGMENT_SIZE_MARKER) {
throw new IllegalStateException(
String.format(
"segment length marker mismatch, expected: %d, got :%d",
NON_TERMINAL_SEGMENT_SIZE_MARKER, segmentSizeMarker));
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

public interface FloeEncryptor extends SegmentProcessor {
byte[] getHeader();
}
Loading
Loading