Skip to content

Commit

Permalink
Implement processing segments
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pfus committed Dec 20, 2024
1 parent 2634bdd commit 68b11d2
Show file tree
Hide file tree
Showing 22 changed files with 564 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
import net.snowflake.client.jdbc.MatDesc;
import net.snowflake.common.core.RemoteStoreFileEncryptionMaterial;

class GcmEncryptionProvider {
@SnowflakeJdbcInternalApi
public class GcmEncryptionProvider {
private static final int TAG_LENGTH_IN_BITS = 128;
private static final int IV_LENGTH_IN_BYTES = 12;
private static final String AES = "AES";
private static final String FILE_CIPHER = "AES/GCM/NoPadding";
private static final String KEY_CIPHER = "AES/GCM/NoPadding";
private static final int BUFFER_SIZE = 8 * 1024 * 1024; // 2 MB
private static final ThreadLocal<SecureRandom> random =
ThreadLocal.withInitial(SecureRandom::new);
Expand Down Expand Up @@ -85,7 +85,7 @@ private static byte[] encryptKey(byte[] kekBytes, byte[] keyBytes, byte[] keyIvD
BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException {
SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES);
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, keyIvData);
Cipher keyCipher = Cipher.getInstance(KEY_CIPHER);
Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME);
keyCipher.init(Cipher.ENCRYPT_MODE, kek, gcmParameterSpec);
if (aad != null) {
keyCipher.updateAAD(aad);
Expand All @@ -99,7 +99,7 @@ private static CipherInputStream encryptContent(
NoSuchAlgorithmException {
SecretKey fileKey = new SecretKeySpec(keyBytes, 0, keyBytes.length, AES);
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, dataIvBytes);
Cipher fileCipher = Cipher.getInstance(FILE_CIPHER);
Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME);
fileCipher.init(Cipher.ENCRYPT_MODE, fileKey, gcmParameterSpec);
if (aad != null) {
fileCipher.updateAAD(aad);
Expand Down Expand Up @@ -172,7 +172,7 @@ private static CipherInputStream decryptContentFromStream(
NoSuchAlgorithmException {
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes);
SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES);
Cipher fileCipher = Cipher.getInstance(FILE_CIPHER);
Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME);
fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec);
if (aad != null) {
fileCipher.updateAAD(aad);
Expand All @@ -187,7 +187,7 @@ private static void decryptContentFromFile(
SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES);
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, cekIvBytes);
byte[] buffer = new byte[BUFFER_SIZE];
Cipher fileCipher = Cipher.getInstance(FILE_CIPHER);
Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME);
fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec);
if (aad != null) {
fileCipher.updateAAD(aad);
Expand Down Expand Up @@ -215,7 +215,7 @@ private static byte[] decryptKey(byte[] kekBytes, byte[] ivBytes, byte[] keyByte
BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException {
SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES);
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes);
Cipher keyCipher = Cipher.getInstance(KEY_CIPHER);
Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME);
keyCipher.init(Cipher.DECRYPT_MODE, kek, gcmParameterSpec);
if (aad != null) {
keyCipher.updateAAD(aad);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,55 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import net.snowflake.client.jdbc.cloud.storage.floe.aead.Gcm;

public enum Aead {
AES_GCM_128((byte) 0),
AES_GCM_256((byte) 1);
// TODO confirm id
AES_GCM_256((byte) 0, "AES/GCM/NoPadding", 32, 12, 16, new Gcm(16)),
AES_GCM_128((byte) 1, "AES/GCM/NoPadding", 16, 12, 16, new Gcm(16));

private byte id;
private String jceName;
private int keyLength;
private int ivLength;
private int authTagLength;
private AeadProvider aeadProvider;

Aead(byte id) {
Aead(
byte id,
String jceName,
int keyLength,
int ivLength,
int authTagLength,
AeadProvider aeadProvider) {
this.jceName = jceName;
this.keyLength = keyLength;
this.id = id;
this.ivLength = ivLength;
this.authTagLength = authTagLength;
this.aeadProvider = aeadProvider;
}

byte getId() {
return id;
}

String getJceName() {
return jceName;
}

int getKeyLength() {
return keyLength;
}

int getIvLength() {
return ivLength;
}

int getAuthTagLength() {
return authTagLength;
}

AeadProvider getAeadProvider() {
return aeadProvider;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import java.nio.ByteBuffer;

class AeadAad {
private final byte[] bytes;

private AeadAad(long segmentCounter, byte terminalityByte) {
ByteBuffer buf = ByteBuffer.allocate(9);
buf.putLong(segmentCounter);
buf.put(terminalityByte);
this.bytes = buf.array();
}

static AeadAad nonTerminal(long segmentCounter) {
return new AeadAad(segmentCounter, (byte) 0);
}

byte[] getBytes() {
return bytes;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import java.nio.ByteBuffer;

class AeadIv {
private final byte[] bytes;

AeadIv(byte[] bytes) {
this.bytes = bytes;
}

public static AeadIv generateRandom(FloeRandom floeRandom, int ivLength) {
return new AeadIv(floeRandom.ofLength(ivLength));
}

public static AeadIv from(ByteBuffer buffer, int ivLength) {
byte[] bytes = new byte[ivLength];
buffer.get(bytes);
return new AeadIv(bytes);
}

byte[] getBytes() {
return bytes;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import javax.crypto.SecretKey;

class AeadKey {
private final SecretKey key;

AeadKey(SecretKey key) {
this.key = key;
}

SecretKey getKey() {
return key;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import java.security.GeneralSecurityException;
import javax.crypto.SecretKey;

public interface AeadProvider {
byte[] encrypt(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext)
throws GeneralSecurityException;

byte[] decrypt(SecretKey key, byte[] iv, byte[] aad, byte[] ciphertext)
throws GeneralSecurityException;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;

abstract class BaseSegmentProcessor {
protected static final int NON_TERMINAL_SEGMENT_SIZE_MARKER = -1;
protected static final int headerTagLength = 32;

protected final FloeParameterSpec parameterSpec;
protected final FloeKey floeKey;
protected final FloeAad floeAad;

protected final KeyDerivator keyDerivator;

private AeadKey currentAeadKey;

BaseSegmentProcessor(FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad) {
this.parameterSpec = parameterSpec;
this.floeKey = floeKey;
this.floeAad = floeAad;
this.keyDerivator = new KeyDerivator(parameterSpec);
}

protected AeadKey getKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) {
if (currentAeadKey == null || segmentCounter % parameterSpec.getKeyRotationModulo() == 0) {
currentAeadKey = deriveKey(floeKey, floeIv, floeAad, segmentCounter);
}
return currentAeadKey;
}

private AeadKey deriveKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) {
byte[] keyBytes =
keyDerivator.hkdfExpand(
floeKey,
floeIv,
floeAad,
new DekTagFloePurpose(segmentCounter),
parameterSpec.getAead().getKeyLength());
SecretKey key =
new SecretKeySpec(keyBytes, "AES"); // for now it is safe as we use only AES as AEAD
return new AeadKey(key);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

public interface FloeDecryptor {}
public interface FloeDecryptor extends SegmentProcessor {}
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

import java.nio.ByteBuffer;
import java.security.GeneralSecurityException;
import java.util.Arrays;

public class FloeDecryptorImpl extends FloeBase implements FloeDecryptor {
public class FloeDecryptorImpl extends BaseSegmentProcessor implements FloeDecryptor {
private final FloeIv floeIv;
private long segmentCounter;

FloeDecryptorImpl(
FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad, byte[] floeHeaderAsBytes) {
super(parameterSpec, floeKey, floeAad);
validate(floeHeaderAsBytes);
}

public void validate(byte[] floeHeaderAsBytes) {
byte[] encodedParams = parameterSpec.paramEncode();
byte[] encodedParams = this.parameterSpec.paramEncode();
if (floeHeaderAsBytes.length
!= encodedParams.length + parameterSpec.getFloeIvLength().getLength() + headerTagLength) {
!= encodedParams.length
+ this.parameterSpec.getFloeIvLength().getLength()
+ headerTagLength) {
throw new IllegalArgumentException("invalid header length");
}
ByteBuffer floeHeader = ByteBuffer.wrap(floeHeaderAsBytes);
Expand All @@ -24,17 +26,56 @@ public void validate(byte[] floeHeaderAsBytes) {
throw new IllegalArgumentException("invalid parameters header");
}

byte[] floeIvBytes = new byte[parameterSpec.getFloeIvLength().getLength()];
byte[] floeIvBytes = new byte[this.parameterSpec.getFloeIvLength().getLength()];
floeHeader.get(floeIvBytes, 0, floeIvBytes.length);
FloeIv floeIv = new FloeIv(floeIvBytes);
this.floeIv = new FloeIv(floeIvBytes);

byte[] headerTagFromHeader = new byte[headerTagLength];
floeHeader.get(headerTagFromHeader, 0, headerTagFromHeader.length);

byte[] headerTag =
floeKdf.hkdfExpand(floeKey, floeIv, floeAad, FloePurpose.HEADER_TAG, headerTagLength);
keyDerivator.hkdfExpand(
this.floeKey, floeIv, this.floeAad, HeaderTagFloePurpose.INSTANCE, headerTagLength);
if (!Arrays.equals(headerTag, headerTagFromHeader)) {
throw new IllegalArgumentException("invalid header tag");
}
}

@Override
public byte[] processSegment(byte[] input) {
try {
verifySegmentLength(input);
ByteBuffer inputBuf = ByteBuffer.wrap(input);
verifySegmentSizeMarker(inputBuf);
AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter);
AeadIv aeadIv = AeadIv.from(inputBuf, parameterSpec.getAead().getIvLength());
AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter++);
AeadProvider aeadProvider = parameterSpec.getAead().getAeadProvider();
byte[] ciphertext = new byte[inputBuf.remaining()];
inputBuf.get(ciphertext);
return aeadProvider.decrypt(
aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), ciphertext);
} catch (GeneralSecurityException e) {
throw new RuntimeException(e);
}
}

private void verifySegmentLength(byte[] input) {
if (input.length != parameterSpec.getEncryptedSegmentLength()) {
throw new IllegalArgumentException(
String.format(
"segment length mismatch, expected %d, got %d",
parameterSpec.getEncryptedSegmentLength(), input.length));
}
}

private void verifySegmentSizeMarker(ByteBuffer inputBuf) {
int segmentSizeMarker = inputBuf.getInt();
if (segmentSizeMarker != NON_TERMINAL_SEGMENT_SIZE_MARKER) {
throw new IllegalStateException(
String.format(
"segment length marker mismatch, expected: %d, got :%d",
NON_TERMINAL_SEGMENT_SIZE_MARKER, segmentSizeMarker));
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package net.snowflake.client.jdbc.cloud.storage.floe;

public interface FloeEncryptor {
public interface FloeEncryptor extends SegmentProcessor {
byte[] getHeader();
}
Loading

0 comments on commit 68b11d2

Please sign in to comment.