Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding queryable encryption range support
Browse files Browse the repository at this point in the history
Supports range style queries for encrypted fields
rozza committed Jan 16, 2025
1 parent 14985a9 commit 18050c0
Showing 15 changed files with 1,145 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -19,6 +19,7 @@
import java.util.Optional;
import java.util.function.Function;

import org.bson.conversions.Bson;
import org.springframework.data.mongodb.core.mapping.Field;
import org.springframework.data.mongodb.core.query.Collation;
import org.springframework.data.mongodb.core.schema.MongoJsonSchema;
@@ -51,10 +52,11 @@ public class CollectionOptions {
private ValidationOptions validationOptions;
private @Nullable TimeSeriesOptions timeSeriesOptions;
private @Nullable CollectionChangeStreamOptions changeStreamOptions;
private @Nullable Bson encryptedFields;

private CollectionOptions(@Nullable Long size, @Nullable Long maxDocuments, @Nullable Boolean capped,
@Nullable Collation collation, ValidationOptions validationOptions, @Nullable TimeSeriesOptions timeSeriesOptions,
@Nullable CollectionChangeStreamOptions changeStreamOptions) {
@Nullable CollectionChangeStreamOptions changeStreamOptions, @Nullable Bson encryptedFields) {

this.maxDocuments = maxDocuments;
this.size = size;
@@ -63,6 +65,7 @@ private CollectionOptions(@Nullable Long size, @Nullable Long maxDocuments, @Nul
this.validationOptions = validationOptions;
this.timeSeriesOptions = timeSeriesOptions;
this.changeStreamOptions = changeStreamOptions;
this.encryptedFields = encryptedFields;
}

/**
@@ -76,7 +79,7 @@ public static CollectionOptions just(Collation collation) {

Assert.notNull(collation, "Collation must not be null");

return new CollectionOptions(null, null, null, collation, ValidationOptions.none(), null, null);
return new CollectionOptions(null, null, null, collation, ValidationOptions.none(), null, null, null);
}

/**
@@ -86,7 +89,7 @@ public static CollectionOptions just(Collation collation) {
* @since 2.0
*/
public static CollectionOptions empty() {
return new CollectionOptions(null, null, null, null, ValidationOptions.none(), null, null);
return new CollectionOptions(null, null, null, null, ValidationOptions.none(), null, null, null);
}

/**
@@ -136,7 +139,7 @@ public static CollectionOptions emitChangedRevisions() {
*/
public CollectionOptions capped() {
return new CollectionOptions(size, maxDocuments, true, collation, validationOptions, timeSeriesOptions,
changeStreamOptions);
changeStreamOptions, encryptedFields);
}

/**
@@ -148,7 +151,7 @@ public CollectionOptions capped() {
*/
public CollectionOptions maxDocuments(long maxDocuments) {
return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions,
changeStreamOptions);
changeStreamOptions, encryptedFields);
}

/**
@@ -160,7 +163,7 @@ public CollectionOptions maxDocuments(long maxDocuments) {
*/
public CollectionOptions size(long size) {
return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions,
changeStreamOptions);
changeStreamOptions, encryptedFields);
}

/**
@@ -172,7 +175,7 @@ public CollectionOptions size(long size) {
*/
public CollectionOptions collation(@Nullable Collation collation) {
return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions,
changeStreamOptions);
changeStreamOptions, encryptedFields);
}

/**
@@ -293,7 +296,7 @@ public CollectionOptions validation(ValidationOptions validationOptions) {

Assert.notNull(validationOptions, "ValidationOptions must not be null");
return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions,
changeStreamOptions);
changeStreamOptions, encryptedFields);
}

/**
@@ -307,7 +310,7 @@ public CollectionOptions timeSeries(TimeSeriesOptions timeSeriesOptions) {

Assert.notNull(timeSeriesOptions, "TimeSeriesOptions must not be null");
return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions,
changeStreamOptions);
changeStreamOptions, encryptedFields);
}

/**
@@ -321,7 +324,19 @@ public CollectionOptions changeStream(CollectionChangeStreamOptions changeStream

Assert.notNull(changeStreamOptions, "ChangeStreamOptions must not be null");
return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions,
changeStreamOptions);
changeStreamOptions, encryptedFields);
}

/**
* Create new {@link CollectionOptions} with the given {@code encryptedFields}.
*
* @param encryptedFields can be null
* @return new instance of {@link CollectionOptions}.
* @since QERange
*/
public CollectionOptions encryptedFields(@Nullable Bson encryptedFields) {
return new CollectionOptions(size, maxDocuments, capped, collation, validationOptions, timeSeriesOptions,
changeStreamOptions, encryptedFields);
}

/**
@@ -392,12 +407,22 @@ public Optional<CollectionChangeStreamOptions> getChangeStreamOptions() {
return Optional.ofNullable(changeStreamOptions);
}

/**
* Get the {@code encryptedFields} if available.
*
* @return {@link Optional#empty()} if not specified.
* @since QERange
*/
public Optional<Bson> getEncryptedFields() {
return Optional.ofNullable(encryptedFields);
}

@Override
public String toString() {
return "CollectionOptions{" + "maxDocuments=" + maxDocuments + ", size=" + size + ", capped=" + capped
+ ", collation=" + collation + ", validationOptions=" + validationOptions + ", timeSeriesOptions="
+ timeSeriesOptions + ", changeStreamOptions=" + changeStreamOptions + ", disableValidation="
+ disableValidation() + ", strictValidation=" + strictValidation() + ", moderateValidation="
+ timeSeriesOptions + ", changeStreamOptions=" + changeStreamOptions + ", encryptedFields=" + encryptedFields
+ ", disableValidation=" + disableValidation() + ", strictValidation=" + strictValidation() + ", moderateValidation="
+ moderateValidation() + ", warnOnValidationError=" + warnOnValidationError() + ", failOnValidationError="
+ failOnValidationError() + '}';
}
@@ -431,7 +456,10 @@ public boolean equals(@Nullable Object o) {
if (!ObjectUtils.nullSafeEquals(timeSeriesOptions, that.timeSeriesOptions)) {
return false;
}
return ObjectUtils.nullSafeEquals(changeStreamOptions, that.changeStreamOptions);
if (!ObjectUtils.nullSafeEquals(changeStreamOptions, that.changeStreamOptions)) {
return false;
}
return ObjectUtils.nullSafeEquals(encryptedFields, that.encryptedFields);
}

@Override
@@ -443,6 +471,7 @@ public int hashCode() {
result = 31 * result + ObjectUtils.nullSafeHashCode(validationOptions);
result = 31 * result + ObjectUtils.nullSafeHashCode(timeSeriesOptions);
result = 31 * result + ObjectUtils.nullSafeHashCode(changeStreamOptions);
result = 31 * result + ObjectUtils.nullSafeHashCode(encryptedFields);
return result;
}

Original file line number Diff line number Diff line change
@@ -26,4 +26,6 @@ public final class EncryptionAlgorithms {
public static final String AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic = "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic";
public static final String AEAD_AES_256_CBC_HMAC_SHA_512_Random = "AEAD_AES_256_CBC_HMAC_SHA_512-Random";

public static final String RANGE = "Range";

}
Original file line number Diff line number Diff line change
@@ -378,6 +378,7 @@ public CreateCollectionOptions convertToCreateCollectionOptions(@Nullable Collec
collectionOptions.getChangeStreamOptions().ifPresent(it -> result
.changeStreamPreAndPostImagesOptions(new ChangeStreamPreAndPostImagesOptions(it.getPreAndPostImages())));

collectionOptions.getEncryptedFields().ifPresent(result::encryptedFields);
return result;
}

Original file line number Diff line number Diff line change
@@ -2172,8 +2172,9 @@ protected <O> AggregationResults<O> doAggregate(Aggregation aggregation, String

List<Document> pipeline = aggregationUtil.createPipeline(aggregation, context);

if (LOGGER.isDebugEnabled()) {
LOGGER.debug(
// TODO revert to DEBUG
if (LOGGER.isErrorEnabled()) {
LOGGER.error(
String.format("Executing aggregation: %s in collection %s", serializeToJsonSafely(pipeline), collectionName));
}

@@ -2594,10 +2595,10 @@ protected <S, T> List<T> doFind(String collectionName,
Document mappedFields = queryContext.getMappedFields(entity, EntityProjection.nonProjecting(entityClass));
Document mappedQuery = queryContext.getMappedQuery(entity);

if (LOGGER.isDebugEnabled()) {

// TODO revert to DEBUG
if (LOGGER.isErrorEnabled()) {
Document mappedSort = getMappedSortObject(query, entityClass);
LOGGER.debug(String.format("find using query: %s fields: %s sort: %s for class: %s in collection: %s",
LOGGER.error(String.format("find using query: %s fields: %s sort: %s for class: %s in collection: %s",
serializeToJsonSafely(mappedQuery), mappedFields, serializeToJsonSafely(mappedSort), entityClass,
collectionName));
}
@@ -2623,8 +2624,9 @@ <S, T> List<T> doFind(CollectionPreparer<MongoCollection<Document>> collectionPr
Document mappedQuery = queryContext.getMappedQuery(entity);
Document mappedSort = getMappedSortObject(query, sourceClass);

if (LOGGER.isDebugEnabled()) {
LOGGER.debug(String.format("find using query: %s fields: %s sort: %s for class: %s in collection: %s",
// TODO revert to DEBUG
if (LOGGER.isErrorEnabled()) {
LOGGER.error(String.format("find using query: %s fields: %s sort: %s for class: %s in collection: %s",
serializeToJsonSafely(mappedQuery), mappedFields, serializeToJsonSafely(mappedSort), sourceClass,
collectionName));
}
Original file line number Diff line number Diff line change
@@ -33,24 +33,39 @@
public class MongoConversionContext implements ValueConversionContext<MongoPersistentProperty> {

private final PropertyValueProvider<MongoPersistentProperty> accessor; // TODO: generics
private final @Nullable MongoPersistentProperty persistentProperty;
private final MongoConverter mongoConverter;

@Nullable private final MongoPersistentProperty persistentProperty;
@Nullable private final SpELContext spELContext;
@Nullable private final String queryFieldPath;

public MongoConversionContext(PropertyValueProvider<MongoPersistentProperty> accessor,
@Nullable MongoPersistentProperty persistentProperty, MongoConverter mongoConverter) {
this(accessor, persistentProperty, mongoConverter, null);
this(accessor, mongoConverter, persistentProperty, null);
}

public MongoConversionContext(PropertyValueProvider<MongoPersistentProperty> accessor,
@Nullable MongoPersistentProperty persistentProperty, MongoConverter mongoConverter,
@Nullable SpELContext spELContext) {
this(accessor, mongoConverter, persistentProperty, spELContext, null);
}

public MongoConversionContext(PropertyValueProvider<MongoPersistentProperty> accessor, MongoConverter mongoConverter,
@Nullable MongoPersistentProperty persistentProperty, @Nullable String queryFieldPath) {
this(accessor, mongoConverter, persistentProperty, null, queryFieldPath);
}

public MongoConversionContext(PropertyValueProvider<MongoPersistentProperty> accessor,
MongoConverter mongoConverter,
@Nullable MongoPersistentProperty persistentProperty,
@Nullable SpELContext spELContext,
@Nullable String queryFieldPath) {

this.accessor = accessor;
this.persistentProperty = persistentProperty;
this.mongoConverter = mongoConverter;
this.spELContext = spELContext;
this.queryFieldPath = queryFieldPath;
}

@Override
@@ -84,4 +99,9 @@ public <T> T read(@Nullable Object value, TypeInformation<T> target) {
public SpELContext getSpELContext() {
return spELContext;
}

@Nullable
public String getQueryFieldPath() {
return queryFieldPath;
}
}
Original file line number Diff line number Diff line change
@@ -59,6 +59,7 @@
import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext;
import org.springframework.data.mongodb.core.convert.MappingMongoConverter.NestedDocument;
import org.springframework.data.mongodb.core.mapping.FieldName;
import org.springframework.data.mongodb.core.mapping.MongoField;
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty.PropertyToFieldNameConverter;
@@ -356,9 +357,10 @@ protected Entry<String, Object> getMappedObjectForField(Field field, Object rawV
return createMapEntry(key, getMappedObject(mongoExpression.toDocument(), field.getEntity()));
}

if (isNestedKeyword(rawValue) && !field.isIdField()) {
if (isNestedKeyword(rawValue)) {
Keyword keyword = new Keyword((Document) rawValue);
value = getMappedKeyword(field, keyword);
field = field.with(keyword.getKey());
value = field.isIdField() ? getMappedValue(field, rawValue) : getMappedKeyword(field, keyword);
} else {
value = getMappedValue(field, rawValue);
}
@@ -455,10 +457,19 @@ protected Document getMappedKeyword(Field property, Keyword keyword) {
@Nullable
@SuppressWarnings("unchecked")
protected Object getMappedValue(Field documentField, Object sourceValue) {

Object value = applyFieldTargetTypeHintToValue(documentField, sourceValue);

if (documentField.getProperty() != null
MongoPersistentProperty property = documentField.getProperty();

String queryPath = property != null && !property.getFieldName().equals(documentField.name) ?
property.getFieldName() + "." + documentField.name : documentField.name;

// TODO add flattened path to convert value and remove logging
if (LOGGER.isErrorEnabled()) {
LOGGER.error(" >-|-> " + queryPath);
}

if (property != null
&& converter.getCustomConversions().hasValueConverter(documentField.getProperty())) {

PropertyValueConverter<Object, Object, ValueConversionContext<MongoPersistentProperty>> valueConverter = converter
@@ -668,8 +679,17 @@ private Object convertValue(Field documentField, Object sourceValue, Object valu
PropertyValueConverter<Object, Object, ValueConversionContext<MongoPersistentProperty>> valueConverter) {

MongoPersistentProperty property = documentField.getProperty();

String queryPath = property != null && !property.getFieldName().equals(documentField.name) ?
property.getFieldName() + "." + documentField.name : documentField.name;

// TODO add flattened path to convert value and remove logging
if (LOGGER.isErrorEnabled()) {
LOGGER.error(" >--> " + queryPath);
}

MongoConversionContext conversionContext = new MongoConversionContext(NoPropertyPropertyValueProvider.INSTANCE,
property, converter);
converter, property, queryPath);

/* might be an $in clause with multiple entries */
if (property != null && !property.isCollectionLike() && sourceValue instanceof Collection<?> collection) {
Original file line number Diff line number Diff line change
@@ -66,4 +66,10 @@ public <T> T read(@Nullable Object value, TypeInformation<T> target) {
public <T> T write(@Nullable Object value, TypeInformation<T> target) {
return conversionContext.write(value, target);
}

// TODO QE - add to interface
@Nullable
public String getQueryFieldPath() {
return conversionContext.getQueryFieldPath();
}
}
Original file line number Diff line number Diff line change
@@ -16,9 +16,12 @@
package org.springframework.data.mongodb.core.convert.encryption;

import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import com.mongodb.client.model.vault.RangeOptions;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.bson.BsonArray;
@@ -27,18 +30,25 @@
import org.bson.BsonValue;
import org.bson.Document;
import org.bson.types.Binary;
import org.jetbrains.annotations.NotNull;
import org.springframework.core.CollectionFactory;
import org.springframework.data.mongodb.core.convert.MongoConversionContext;
import org.springframework.data.mongodb.core.encryption.Encryption;
import org.springframework.data.mongodb.core.encryption.EncryptionContext;
import org.springframework.data.mongodb.core.encryption.EncryptionKey;
import org.springframework.data.mongodb.core.encryption.EncryptionKeyResolver;
import org.springframework.data.mongodb.core.encryption.EncryptionOptions;
import org.springframework.data.mongodb.core.mapping.Encrypted;
import org.springframework.data.mongodb.core.mapping.ExplicitEncrypted;
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
import org.springframework.data.mongodb.util.BsonUtils;
import org.springframework.lang.Nullable;
import org.springframework.util.ObjectUtils;

import static java.util.Arrays.asList;
import static java.util.Collections.singletonList;
import static org.springframework.data.mongodb.core.EncryptionAlgorithms.RANGE;

/**
* Default implementation of {@link EncryptingConverter}. Properties used with this converter must be annotated with
* {@link Encrypted @Encrypted} to provide key and algorithm metadata.
@@ -49,12 +59,14 @@
public class MongoEncryptionConverter implements EncryptingConverter<Object, Object> {

private static final Log LOGGER = LogFactory.getLog(MongoEncryptionConverter.class);
private static final String EQUALITY_OPERATOR = "$eq";
private static final List<String> RANGE_OPERATORS = asList("$gt", "$gte", "$lt", "$lte");


private final Encryption<BsonValue, BsonBinary> encryption;
private final EncryptionKeyResolver keyResolver;

public MongoEncryptionConverter(Encryption<BsonValue, BsonBinary> encryption, EncryptionKeyResolver keyResolver) {

this.encryption = encryption;
this.keyResolver = keyResolver;
}
@@ -143,9 +155,9 @@ public Object decrypt(Object encryptedValue, EncryptionContext context) {

@Override
public Object encrypt(Object value, EncryptionContext context) {

if (LOGGER.isDebugEnabled()) {
LOGGER.debug(String.format("Encrypting %s.%s.", getProperty(context).getOwner().getName(),
// TODO revert to DEBUG
if (LOGGER.isErrorEnabled()) {
LOGGER.error(String.format("Encrypting %s.%s.", getProperty(context).getOwner().getName(),
getProperty(context).getName()));
}

@@ -161,8 +173,45 @@ public Object encrypt(Object value, EncryptionContext context) {
getProperty(context).getOwner().getName(), getProperty(context).getName()));
}

EncryptionOptions encryptionOptions = new EncryptionOptions(annotation.algorithm(), keyResolver.getKey(context));
boolean encryptValue = true;
String algorithm = annotation.algorithm();
EncryptionKey key = keyResolver.getKey(context);
EncryptionOptions encryptionOptions;
encryptionOptions = new EncryptionOptions(algorithm, key);

String queryFieldPath = context instanceof ExplicitEncryptionContext explicitEncryptionContext ?
explicitEncryptionContext.getQueryFieldPath() : null;

ExplicitEncrypted explicitEncryptedAnnotation = persistentProperty.findAnnotation(ExplicitEncrypted.class);
if (explicitEncryptedAnnotation != null) {
EncryptionOptions.QueryableEncryptionOptions queryableEncryptionOptions = EncryptionOptions.QueryableEncryptionOptions.none();
String rangeOptions = explicitEncryptedAnnotation.rangeOptions();
if (!rangeOptions.trim().isEmpty()) {
queryableEncryptionOptions = queryableEncryptionOptions.rangeOptions(Document.parse(rangeOptions));
}

if (explicitEncryptedAnnotation.contentionFactor() >= 0) {
queryableEncryptionOptions = queryableEncryptionOptions.contentionFactor(explicitEncryptedAnnotation.contentionFactor());
}

boolean isRangeQuery = algorithm.equalsIgnoreCase(RANGE) && queryFieldPath != null;
if (isRangeQuery) {
encryptValue = false;
queryableEncryptionOptions = queryableEncryptionOptions.queryType("range");
}
encryptionOptions = new EncryptionOptions(algorithm, key, queryableEncryptionOptions);

}

if (encryptValue) {
return encryptValue(value, context, persistentProperty, encryptionOptions);
} else {
return encryptExpression(queryFieldPath, value, encryptionOptions);
}
}

private @NotNull BsonBinary encryptValue(Object value, EncryptionContext context, MongoPersistentProperty persistentProperty,
EncryptionOptions encryptionOptions) {
if (!persistentProperty.isEntity()) {

if (persistentProperty.isCollectionLike()) {
@@ -187,6 +236,32 @@ public Object encrypt(Object value, EncryptionContext context) {
return encryption.encrypt(BsonUtils.simpleToBsonValue(write), encryptionOptions);
}


private BsonValue encryptExpression(String queryFieldPath, Object value, EncryptionOptions encryptionOptions) {
BsonValue doc = BsonUtils.simpleToBsonValue(value);

String fieldName = queryFieldPath;
String queryOperator = EQUALITY_OPERATOR;

int pos = queryFieldPath.lastIndexOf(".$");
if (pos > -1) {
fieldName = queryFieldPath.substring(0, pos);
queryOperator = queryFieldPath.substring(pos + 1);
}

if (!RANGE_OPERATORS.contains(queryOperator)) {
throw new AssertionError(String.format("Not a valid range query. Querying a range encrypted field but the " +
"query operator '%s' for field path '%s' is not a range query.",
queryOperator, queryFieldPath));
}

BsonDocument encryptExpression = new BsonDocument("$and", new BsonArray(
singletonList(new BsonDocument(fieldName, new BsonDocument(queryOperator, doc)))));

BsonDocument result = encryption.encryptExpression(encryptExpression, encryptionOptions);
return result.getArray("$and").get(0).asDocument().getDocument(fieldName).getBinary(queryOperator);
}

private BsonValue collectionLikeToBsonValue(Object value, MongoPersistentProperty property,
EncryptionContext context) {

Original file line number Diff line number Diff line change
@@ -15,6 +15,8 @@
*/
package org.springframework.data.mongodb.core.encryption;

import org.bson.BsonDocument;

/**
* Component responsible for encrypting and decrypting values.
*
@@ -40,4 +42,8 @@ public interface Encryption<S, T> {
*/
S decrypt(T value);

default BsonDocument encryptExpression(BsonDocument value, EncryptionOptions options) {
throw new UnsupportedOperationException("Unsupported encryption method");
}

}
Original file line number Diff line number Diff line change
@@ -15,9 +15,18 @@
*/
package org.springframework.data.mongodb.core.encryption;

import com.mongodb.client.model.vault.RangeOptions;
import org.bson.Document;
import org.springframework.data.mongodb.MongoTransactionManager;
import org.springframework.data.mongodb.util.BsonUtils;
import org.springframework.data.util.Optionals;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;

import java.util.Objects;
import java.util.Optional;

/**
* Options, like the {@link #algorithm()}, to apply when encrypting values.
*
@@ -28,14 +37,20 @@ public class EncryptionOptions {

private final String algorithm;
private final EncryptionKey key;
private final QueryableEncryptionOptions queryableEncryptionOptions;

public EncryptionOptions(String algorithm, EncryptionKey key) {
this(algorithm, key, QueryableEncryptionOptions.NONE);
}

public EncryptionOptions(String algorithm, EncryptionKey key, QueryableEncryptionOptions queryableEncryptionOptions) {
Assert.hasText(algorithm, "Algorithm must not be empty");
Assert.notNull(key, "EncryptionKey must not be empty");
Assert.notNull(key, "QueryableEncryptionOptions must not be empty");

this.key = key;
this.algorithm = algorithm;
this.queryableEncryptionOptions = queryableEncryptionOptions;
}

public EncryptionKey key() {
@@ -46,6 +61,10 @@ public String algorithm() {
return algorithm;
}

public QueryableEncryptionOptions queryableEncryptionOptions() {
return queryableEncryptionOptions;
}

@Override
public boolean equals(Object o) {

@@ -61,19 +80,186 @@ public boolean equals(Object o) {
if (!ObjectUtils.nullSafeEquals(algorithm, that.algorithm)) {
return false;
}
return ObjectUtils.nullSafeEquals(key, that.key);
if (!ObjectUtils.nullSafeEquals(key, that.key)) {
return false;
}

return ObjectUtils.nullSafeEquals(queryableEncryptionOptions, that.queryableEncryptionOptions);
}

@Override
public int hashCode() {

int result = ObjectUtils.nullSafeHashCode(algorithm);
result = 31 * result + ObjectUtils.nullSafeHashCode(key);
result = 31 * result + ObjectUtils.nullSafeHashCode(queryableEncryptionOptions);
return result;
}

@Override
public String toString() {
return "EncryptionOptions{" + "algorithm='" + algorithm + '\'' + ", key=" + key + '}';
return "EncryptionOptions{" + "algorithm='" + algorithm + '\'' + ", key=" + key +
", queryableEncryptionOptions='" + queryableEncryptionOptions + "'}";
}

/**
* Options, like the {@link #getQueryType()}, to apply when encrypting queryable values.
*
* @author Ross Lawley
*/
public static class QueryableEncryptionOptions {

private static final QueryableEncryptionOptions NONE = new QueryableEncryptionOptions(null, null, null);

private final @Nullable String queryType;
private final @Nullable Long contentionFactor;
private final @Nullable Document rangeOptions;

private QueryableEncryptionOptions(@Nullable String queryType, @Nullable Long contentionFactor, @Nullable Document rangeOptions) {
this.queryType = queryType;
this.contentionFactor = contentionFactor;
this.rangeOptions = rangeOptions;
}

/**
* Create an empty {@link QueryableEncryptionOptions}.
*
* @return none {@literal null}.
*/
public static QueryableEncryptionOptions none() {
return NONE;
}

/**
* Define the {@code queryType} to be used for queryable document encryption.
*
* @param queryType can be {@literal null}.
* @return new instance of {@link QueryableEncryptionOptions}.
*/
public QueryableEncryptionOptions queryType(@Nullable String queryType) {
return new QueryableEncryptionOptions(queryType, contentionFactor, rangeOptions);
}

/**
* Define the {@code contentionFactor} to be used for queryable document encryption.
*
* @param contentionFactor can be {@literal null}.
* @return new instance of {@link QueryableEncryptionOptions}.
*/
public QueryableEncryptionOptions contentionFactor(@Nullable Long contentionFactor) {
return new QueryableEncryptionOptions(queryType, contentionFactor, rangeOptions);
}

/**
* Define the {@code rangeOptions} to be used for queryable document encryption.
*
* @param rangeOptions can be {@literal null}.
* @return new instance of {@link QueryableEncryptionOptions}.
*/
public QueryableEncryptionOptions rangeOptions(@Nullable Document rangeOptions) {
return new QueryableEncryptionOptions(queryType, contentionFactor, rangeOptions);
}

/**
* Get the {@code queryType} to apply.
*
* @return {@link Optional#empty()} if not set.
*/
public Optional<String> getQueryType() {
return Optional.ofNullable(queryType);
}

/**
* Get the {@code contentionFactor} to apply.
*
* @return {@link Optional#empty()} if not set.
*/
public Optional<Long> getContentionFactor() {
return Optional.ofNullable(contentionFactor);
}

/**
* Get the {@code rangeOptions} to apply.
*
* @return {@link Optional#empty()} if not set.
*/
public Optional<RangeOptions> getRangeOptions() {
if (rangeOptions == null) {
return Optional.empty();
}
RangeOptions encryptionRangeOptions = new RangeOptions();

if (rangeOptions.containsKey("min")) {
encryptionRangeOptions.min(BsonUtils.simpleToBsonValue(rangeOptions.get("min")));
}
if (rangeOptions.containsKey("max")) {
encryptionRangeOptions.max(BsonUtils.simpleToBsonValue(rangeOptions.get("max")));
}
if (rangeOptions.containsKey("trimFactor")) {
Object trimFactor = rangeOptions.get("trimFactor");
Assert.isInstanceOf(Integer.class, trimFactor,
() -> String.format("Expected to find a %s but it turned out to be %s.", Integer.class,
trimFactor.getClass()));
encryptionRangeOptions.trimFactor((Integer) trimFactor);
}

if (rangeOptions.containsKey("sparsity")) {
Object sparsity = rangeOptions.get("sparsity");
Assert.isInstanceOf(Number.class, sparsity,
() -> String.format("Expected to find a %s but it turned out to be %s.", Long.class,
sparsity.getClass()));
encryptionRangeOptions.sparsity(((Number) sparsity).longValue());
}

if (rangeOptions.containsKey("precision")) {
Object precision = rangeOptions.get("precision");
Assert.isInstanceOf(Number.class, precision,
() -> String.format("Expected to find a %s but it turned out to be %s.", Integer.class,
precision.getClass()));
encryptionRangeOptions.precision(((Number) precision).intValue());
}
return Optional.of(encryptionRangeOptions);
}

/**
* @return {@literal true} if no arguments set.
*/
boolean isEmpty() {
return !Optionals.isAnyPresent(getQueryType(), getContentionFactor(), getRangeOptions());
}

@Override
public String toString() {
return "QueryableEncryptionOptions{" +
"queryType='" + queryType + '\'' +
", contentionFactor=" + contentionFactor +
", rangeOptions=" + rangeOptions +
'}';
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
QueryableEncryptionOptions that = (QueryableEncryptionOptions) o;

if (!ObjectUtils.nullSafeEquals(queryType, that.queryType)) {
return false;
}

if (!ObjectUtils.nullSafeEquals(contentionFactor, that.contentionFactor)) {
return false;
}
return ObjectUtils.nullSafeEquals(rangeOptions, that.rangeOptions);
}

@Override
public int hashCode() {
return Objects.hash(queryType, contentionFactor, rangeOptions);
}
}
}
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
import java.util.function.Supplier;

import org.bson.BsonBinary;
import org.bson.BsonDocument;
import org.bson.BsonValue;
import org.springframework.data.mongodb.core.encryption.EncryptionKey.Type;
import org.springframework.util.Assert;
@@ -59,7 +60,19 @@ public BsonValue decrypt(BsonBinary value) {

@Override
public BsonBinary encrypt(BsonValue value, EncryptionOptions options) {
return getClientEncryption().encrypt(value, createEncryptOptions(options));
}

@Override
public BsonDocument encryptExpression(BsonDocument value, EncryptionOptions options) {
return getClientEncryption().encryptExpression(value, createEncryptOptions(options));
}

public ClientEncryption getClientEncryption() {
return source.get();
}

private EncryptOptions createEncryptOptions(EncryptionOptions options) {
EncryptOptions encryptOptions = new EncryptOptions(options.algorithm());

if (Type.ALT.equals(options.key().type())) {
@@ -68,11 +81,10 @@ public BsonBinary encrypt(BsonValue value, EncryptionOptions options) {
encryptOptions = encryptOptions.keyId((BsonBinary) options.key().value());
}

return getClientEncryption().encrypt(value, encryptOptions);
}

public ClientEncryption getClientEncryption() {
return source.get();
options.queryableEncryptionOptions().getQueryType().map(encryptOptions::queryType);
options.queryableEncryptionOptions().getContentionFactor().map(encryptOptions::contentionFactor);
options.queryableEncryptionOptions().getRangeOptions().map(encryptOptions::rangeOptions);
return encryptOptions;
}

}
Original file line number Diff line number Diff line change
@@ -84,11 +84,18 @@
*/
String keyAltName() default "";

// TODO QE - update docs as well as algorithm.
long contentionFactor() default -1;

// TODO QE - update docs as well as algorithm.
String rangeOptions() default "";

/**
* The {@link EncryptingConverter} type handling the {@literal en-/decryption} of the annotated property.
*
* @return the configured {@link EncryptingConverter}. A {@link MongoEncryptionConverter} by default.
*/
@AliasFor(annotation = ValueConverter.class, value = "value")
Class<? extends PropertyValueConverter> value() default MongoEncryptionConverter.class;

}
Original file line number Diff line number Diff line change
@@ -76,6 +76,18 @@ public abstract class AbstractEncryptionTestBase {

@Autowired MongoTemplate template;

@Test
void canQueryDeterministicallyEncryptedWithQueryScope() {
Person source = new Person();
source.id = "id-1";
source.ssn = "mySecretSSN";

template.save(source);

Person loaded = template.query(Person.class).matching(where("ssn").gte(source.ssn)).firstValue();
assertThat(loaded).isEqualTo(source);
}

@Test // GH-4284
void encryptAndDecryptSimpleValue() {

Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.data.mongodb.core.encryption;

import com.mongodb.ClientEncryptionSettings;
import com.mongodb.ConnectionString;
import com.mongodb.MongoClientSettings;
import com.mongodb.MongoNamespace;
import com.mongodb.client.MongoClient;
import com.mongodb.client.MongoCollection;
import com.mongodb.client.MongoDatabase;
import com.mongodb.client.model.CreateCollectionOptions;
import com.mongodb.client.model.CreateEncryptedCollectionParams;
import com.mongodb.client.model.Filters;
import com.mongodb.client.model.IndexOptions;
import com.mongodb.client.model.Indexes;
import com.mongodb.client.model.vault.DataKeyOptions;
import com.mongodb.client.vault.ClientEncryption;
import com.mongodb.client.vault.ClientEncryptions;
import org.bson.BsonArray;
import org.bson.BsonBinary;
import org.bson.BsonDocument;
import org.bson.BsonInt32;
import org.bson.BsonInt64;
import org.bson.BsonNull;
import org.bson.BsonString;
import org.bson.Document;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.data.convert.PropertyValueConverterFactory;
import org.springframework.data.mongodb.config.AbstractMongoClientConfiguration;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.convert.MongoCustomConversions.MongoConverterConfigurationAdapter;
import org.springframework.data.mongodb.core.convert.encryption.MongoEncryptionConverter;
import org.springframework.data.mongodb.core.mapping.ExplicitEncrypted;
import org.springframework.data.mongodb.test.util.MongoClientExtension;
import org.springframework.data.util.Lazy;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit.jupiter.SpringExtension;

import java.security.SecureRandom;
import java.util.Collections;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

import static java.util.Collections.singletonList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.springframework.data.mongodb.core.EncryptionAlgorithms.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic;
import static org.springframework.data.mongodb.core.EncryptionAlgorithms.RANGE;
import static org.springframework.data.mongodb.core.query.Criteria.where;

/**
* @author Ross Lawley
*/
@ExtendWith(MongoClientExtension.class)
@ExtendWith(SpringExtension.class)
@ContextConfiguration(classes = RangeEncryptionTest.EncryptionConfig.class)
public class RangeEncryptionTest {

@Autowired MongoTemplate template;
// TODO
/*
Todo:
- [X] Add {{encryptedFields}} support to {{CreateCollectionOptions}}
- [X] Add {{contentionFactor}} to {{EncryptOptions}}
- [X] Add {{queryType}} to {{EncryptOptions}}
- [X] Add {{RangeOptions}} to {{EncryptOptions}}
- [X] Add {{rangeOptions}} (String / JSON) to {{ExplicitEncrypted}} annotation
- [X] Add {{Range}} to encryption algorithms.
- [ ] Add test cases from the Test Plan
// TODO - add support for Indexed
Test Plan
Setup:
- Create a POJO with the valid range bson data types, annotate the fields with @ExplicitEncrypted.
- Insert test data
- Validate the data has been encrypted in the db.
Single range tests:
- Perform a Range query for each of the encrypted fields
- Validate the expected POJO(s) is turned
Multiple field range tests:
- Perform a Range query on multiple the encrypted fields at once
- Validate the expected POJO(s) is turned
Multiple field tests:
- Perform a Range query on an encrypted fields as well as a non encrypted field
- Validate the expected POJO(s) is turned
*/

@Test
void canEqualityMatchRangeEncryptedField() {
Person source = new Person();
source.id = "id-1";
source.ssn = 101;
template.save(source);

assertThatThrownBy(() -> template.query(Person.class).matching(where("ssn").is(source.ssn)).firstValue())
.isInstanceOf(AssertionError.class)
.hasMessageStartingWith("Not a valid range query. Querying a range encrypted field but " +
"the query operator '$eq' for field path 'ssn' is not a range query.");
}

@Test
void canGreaterThanMatchRangeEncryptedField() {
Person source = new Person();
source.id = "id-1";
source.ssn = 101;
template.save(source);

Person loaded = template.query(Person.class).matching(where("ssn").gte(source.ssn)).firstValue();
assertThat(loaded).isEqualTo(source);
}

protected static class EncryptionConfig extends AbstractMongoClientConfiguration {

@Autowired ApplicationContext applicationContext;

@Override
protected String getDatabaseName() {
return "qe-test";
}

@Bean
public MongoClient mongoClient() {
return super.mongoClient();
}

@Override
protected void configureConverters(MongoConverterConfigurationAdapter converterConfigurationAdapter) {
converterConfigurationAdapter
.registerPropertyValueConverterFactory(PropertyValueConverterFactory.beanFactoryAware(applicationContext))
.useNativeDriverJavaTimeCodecs();
}

@Bean
MongoEncryptionConverter encryptingConverter(MongoClientEncryption mongoClientEncryption) {
Lazy<BsonBinary> lazyDataKey = Lazy.of(() -> {
BsonDocument encryptedFields = new BsonDocument()
.append(
"fields",
new BsonArray(singletonList(new BsonDocument("keyId", BsonNull.VALUE)
.append("path", new BsonString("sid"))
.append("bsonType", new BsonString("int"))
.append(
"queries",
new BsonDocument("queryType", new BsonString("range"))
.append("contention", new BsonInt64(0L))
.append("trimFactor", new BsonInt32(1))
.append("sparsity", new BsonInt64(1))
.append("min", new BsonInt32(0))
.append("max", new BsonInt32(200))))));

try (MongoClient client = mongoClient()) {
MongoDatabase database = client.getDatabase(getDatabaseName());
database.getCollection("test").drop();
BsonDocument local = mongoClientEncryption.getClientEncryption()
.createEncryptedCollection(database, "test",
new CreateCollectionOptions().encryptedFields(encryptedFields),
new CreateEncryptedCollectionParams("local"));
return local.getArray("fields").get(0).asDocument().getBinary("keyId");
}
});
return new MongoEncryptionConverter(mongoClientEncryption,
EncryptionKeyResolver.annotated((ctx) -> EncryptionKey.keyId(lazyDataKey.get())));
}

@Bean
CachingMongoClientEncryption clientEncryption(ClientEncryptionSettings encryptionSettings) {
return new CachingMongoClientEncryption(() -> ClientEncryptions.create(encryptionSettings));
}

@Bean
ClientEncryptionSettings encryptionSettings(MongoClient mongoClient) {

MongoNamespace keyVaultNamespace = new MongoNamespace("encryption.testKeyVault");
MongoCollection<Document> keyVaultCollection = mongoClient.getDatabase(keyVaultNamespace.getDatabaseName())
.getCollection(keyVaultNamespace.getCollectionName());
keyVaultCollection.drop();
// Ensure that two data keys cannot share the same keyAltName.
keyVaultCollection.createIndex(Indexes.ascending("keyAltNames"),
new IndexOptions().unique(true).partialFilterExpression(Filters.exists("keyAltNames")));

MongoCollection<Document> collection = mongoClient.getDatabase(getDatabaseName()).getCollection("test");
collection.drop(); // Clear old data

byte[] localMasterKey = new byte[96];
new SecureRandom().nextBytes(localMasterKey);
Map<String, Map<String, Object>> kmsProviders = Map.of("local", Map.of("key", localMasterKey));

// Create the ClientEncryption instance
return ClientEncryptionSettings.builder() //
.keyVaultMongoClientSettings(
MongoClientSettings.builder().applyConnectionString(new ConnectionString("mongodb://localhost")).build()) //
.keyVaultNamespace(keyVaultNamespace.getFullName()) //
.kmsProviders(kmsProviders) //
.build();
}
}

static class CachingMongoClientEncryption extends MongoClientEncryption implements DisposableBean {

static final AtomicReference<ClientEncryption> cache = new AtomicReference<>();

CachingMongoClientEncryption(Supplier<ClientEncryption> source) {
super(() -> {
ClientEncryption clientEncryption = cache.get();
if (clientEncryption == null) {
clientEncryption = source.get();
cache.set(clientEncryption);
}

return clientEncryption;
});
}

@Override
public void destroy() {
ClientEncryption clientEncryption = cache.get();
if (clientEncryption != null) {
clientEncryption.close();
cache.set(null);
}
}
}

@org.springframework.data.mongodb.core.mapping.Document("test")
static class Person {

String id;
String name;

@ExplicitEncrypted(algorithm = RANGE, contentionFactor = 0L, rangeOptions = "{min: 0, max: 200, trimFactor: 1, sparsity: 1}")
Integer ssn;

public String getId() {
return this.id;
}

public String getName() {
return this.name;
}

public Integer getSsn() {
return this.ssn;
}

public void setId(String id) {
this.id = id;
}

public void setName(String name) {
this.name = name;
}

public void setSsn(Integer ssn) {
this.ssn = ssn;
}


@Override
public boolean equals(Object o) {
if (o == this) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Person person = (Person) o;
return Objects.equals(id, person.id) && Objects.equals(name, person.name) && Objects.equals(ssn, person.ssn);
}

@Override
public int hashCode() {
return Objects.hash(id, name, ssn);
}

public String toString() {
return "RangeEncryptionTest.Person(id=" + this.getId() + ", name=" + this.getName() + ", ssn=" + this.getSsn() + ")";
}
}

}

Large diffs are not rendered by default.

0 comments on commit 18050c0

Please sign in to comment.