Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancing CosmosTemplate to Support Multi-Tenancy at a DB Level #32516

Merged
merged 25 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e7e0165
Proof of concept that we can write to two databases from the same ses…
trande4884 Dec 1, 2022
7516e97
Improving the changes to CosmosTemplate and the test case.
trande4884 Dec 5, 2022
0125a30
Moving default setNameAndCreateDatabase() logic into CosmosTemplate.
trande4884 Dec 5, 2022
644697b
Improving unit test.
trande4884 Dec 5, 2022
b9c6510
Changing function name to be a more accurate description of the funct…
trande4884 Dec 5, 2022
4b974b9
Updating changelog
trande4884 Dec 8, 2022
ccb755e
Removing unused imports.
trande4884 Dec 8, 2022
f4751ab
Code cleanup.
trande4884 Dec 8, 2022
4fdcc45
Refactoring CosmosTemplate to now store the CosmosFactory on the temp…
trande4884 Dec 12, 2022
7933806
Updating changelog.
trande4884 Dec 13, 2022
8af75dc
Making the requested updates in the PR. Adding CosmosFactory to React…
trande4884 Dec 13, 2022
cbea46d
Making updates for PR comments.
trande4884 Dec 16, 2022
6b1d18b
Fixing updates to unit test.
trande4884 Dec 16, 2022
948a0f2
Fixing readme
trande4884 Dec 16, 2022
7b82f12
Adding file needed for readme.
trande4884 Dec 16, 2022
b3339d8
Fixing snippet for readme.
trande4884 Dec 16, 2022
6e78dab
Fixing snippet for readme.
trande4884 Dec 16, 2022
40af243
Updating readme.
trande4884 Dec 16, 2022
f7b58dc
Adding javadoc.
trande4884 Dec 16, 2022
e6b9dfc
Fixing unit test.
trande4884 Dec 20, 2022
6674719
Testing.
trande4884 Dec 20, 2022
55e793d
Testing breaking out setup to be before unit test runs.
trande4884 Dec 20, 2022
52ba220
Renaming file.
trande4884 Dec 20, 2022
63777fa
Adding new test config for MultiTenantDB test.
trande4884 Dec 20, 2022
466e5db
Adding cleanup to unit test.
trande4884 Dec 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.spring.data.cosmos.core;

import com.azure.cosmos.CosmosAsyncClient;
import com.azure.spring.data.cosmos.CosmosFactory;
import com.azure.spring.data.cosmos.config.CosmosConfig;
import com.azure.spring.data.cosmos.core.convert.MappingCosmosConverter;
import org.springframework.data.auditing.IsNewAwareAuditingHandler;

/**
* Template class for cosmos db
*/
public class MultiTenantDBCosmosFactory extends CosmosFactory {

public String databaseName;

/**
* Validate config and initialization
*
* @param cosmosAsyncClient cosmosAsyncClient
* @param databaseName databaseName
*/
public MultiTenantDBCosmosFactory(CosmosAsyncClient cosmosAsyncClient, String databaseName) {
super(cosmosAsyncClient, databaseName);

this.databaseName = databaseName;
}

public String getDatabaseName() {
return this.databaseName;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.spring.data.cosmos.core;

import com.azure.cosmos.CosmosAsyncClient;
import com.azure.cosmos.CosmosClientBuilder;
import com.azure.cosmos.models.PartitionKey;
import com.azure.spring.data.cosmos.CosmosFactory;
import com.azure.spring.data.cosmos.IntegrationTestCollectionManager;
import com.azure.spring.data.cosmos.config.CosmosConfig;
import com.azure.spring.data.cosmos.core.convert.MappingCosmosConverter;
import com.azure.spring.data.cosmos.core.mapping.CosmosMappingContext;
import com.azure.spring.data.cosmos.domain.Person;
import com.azure.spring.data.cosmos.repository.TestRepositoryConfig;
import com.azure.spring.data.cosmos.repository.support.CosmosEntityInformation;
import org.junit.Assert;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.domain.EntityScanner;
import org.springframework.context.ApplicationContext;
import org.springframework.data.annotation.Persistent;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;

import java.util.ArrayList;
import java.util.List;

import static com.azure.spring.data.cosmos.common.TestConstants.ADDRESSES;
import static com.azure.spring.data.cosmos.common.TestConstants.AGE;
import static com.azure.spring.data.cosmos.common.TestConstants.FIRST_NAME;
import static com.azure.spring.data.cosmos.common.TestConstants.HOBBIES;
import static com.azure.spring.data.cosmos.common.TestConstants.ID_1;
import static com.azure.spring.data.cosmos.common.TestConstants.ID_2;
import static com.azure.spring.data.cosmos.common.TestConstants.LAST_NAME;
import static com.azure.spring.data.cosmos.common.TestConstants.PASSPORT_IDS_BY_COUNTRY;
import static org.assertj.core.api.Assertions.assertThat;

@RunWith(SpringJUnit4ClassRunner.class)
@ContextConfiguration(classes = TestRepositoryConfig.class)
public class MultiTenantDBCosmosFactoryUnitTest {

private final String testDB1 = "Database1";
private final String testDB2 = "Database2";

private final Person TEST_PERSON_1 = new Person(ID_1, FIRST_NAME, LAST_NAME, HOBBIES, ADDRESSES, AGE, PASSPORT_IDS_BY_COUNTRY);
private final Person TEST_PERSON_2 = new Person(ID_2, FIRST_NAME, LAST_NAME, HOBBIES, ADDRESSES, AGE, PASSPORT_IDS_BY_COUNTRY);

@ClassRule
public static final IntegrationTestCollectionManager collectionManager = new IntegrationTestCollectionManager();

@Autowired
private ApplicationContext applicationContext;
@Autowired
private CosmosConfig cosmosConfig;
@Autowired
private CosmosClientBuilder cosmosClientBuilder;

@Test
public void testGetDatabaseFunctionality() {
/// Setup
CosmosAsyncClient client = CosmosFactory.createCosmosAsyncClient(cosmosClientBuilder);
MultiTenantDBCosmosFactory cosmosFactory = new MultiTenantDBCosmosFactory(client, testDB1);
final CosmosMappingContext mappingContext = new CosmosMappingContext();

try {
mappingContext.setInitialEntitySet(new EntityScanner(this.applicationContext).scan(Persistent.class));
} catch (Exception e) {
Assert.fail();
}

final MappingCosmosConverter cosmosConverter = new MappingCosmosConverter(mappingContext, null);
CosmosTemplate cosmosTemplate = new CosmosTemplate(cosmosFactory, cosmosConfig, cosmosConverter, null);
CosmosEntityInformation<Person, String> personInfo = new CosmosEntityInformation<>(Person.class);

// Create DB1 and add TEST_PERSON_1 to it
cosmosTemplate.createContainerIfNotExists(personInfo);
cosmosTemplate.deleteAll(personInfo.getContainerName(), Person.class);
assertThat(cosmosTemplate.getDatabaseName()).isEqualTo(testDB1);
cosmosTemplate.insert(TEST_PERSON_1, new PartitionKey(personInfo.getPartitionKeyFieldValue(TEST_PERSON_1)));

// Create DB2 and add TEST_PERSON_2 to it
cosmosFactory.databaseName = testDB2;
cosmosTemplate.createContainerIfNotExists(personInfo);
cosmosTemplate.deleteAll(personInfo.getContainerName(), Person.class);
assertThat(cosmosTemplate.getDatabaseName()).isEqualTo(testDB2);
cosmosTemplate.insert(TEST_PERSON_2, new PartitionKey(personInfo.getPartitionKeyFieldValue(TEST_PERSON_2)));

// Check that DB2 has the correct contents
List<Person> expectedResultsDB2 = new ArrayList<>();
expectedResultsDB2.add(TEST_PERSON_2);
Iterable<Person> iterableDB2 = cosmosTemplate.findAll(personInfo.getContainerName(), Person.class);
List<Person> resultDB2 = new ArrayList<>();
iterableDB2.forEach(resultDB2::add);
Assert.assertEquals(expectedResultsDB2, resultDB2);

// Check that DB1 has the correct contents
cosmosFactory.databaseName = testDB1;
List<Person> expectedResultsDB1 = new ArrayList<>();
expectedResultsDB1.add(TEST_PERSON_1);
Iterable<Person> iterableDB1 = cosmosTemplate.findAll(personInfo.getContainerName(), Person.class);
List<Person> resultDB1 = new ArrayList<>();
iterableDB1.forEach(resultDB1::add);
Assert.assertEquals(expectedResultsDB1, resultDB1);
}
}
1 change: 1 addition & 0 deletions sdk/cosmos/azure-spring-data-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### 3.31.0-beta.1 (Unreleased)

#### Features Added
* Added support for multi-tenancy at the Database level via `CosmosFactory` - See [PR 32516](https://github.com/Azure/azure-sdk-for-java/pull/32516)

#### Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class CosmosFactory {

private final CosmosAsyncClient cosmosAsyncClient;

private final String databaseName;
protected String databaseName;

private static final String USER_AGENT_SUFFIX =
Constants.USER_AGENT_SUFFIX + PropertyLoader.getProjectVersion();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,12 @@ public class CosmosTemplate implements CosmosOperations, ApplicationContextAware
private final MappingCosmosConverter mappingCosmosConverter;
private final IsNewAwareAuditingHandler cosmosAuditingHandler;

private final String databaseName;
private final CosmosFactory cosmosFactory;
private final ResponseDiagnosticsProcessor responseDiagnosticsProcessor;
private final boolean queryMetricsEnabled;
private final int maxDegreeOfParallelism;
private final int maxBufferedItemCount;
private final int responseContinuationTokenLimitInKb;
private final CosmosAsyncClient cosmosAsyncClient;
private final DatabaseThroughputConfig databaseThroughputConfig;

private ApplicationContext applicationContext;
Expand Down Expand Up @@ -129,8 +128,7 @@ public CosmosTemplate(CosmosFactory cosmosFactory,
Assert.notNull(mappingCosmosConverter, "MappingCosmosConverter must not be null!");
this.mappingCosmosConverter = mappingCosmosConverter;
this.cosmosAuditingHandler = cosmosAuditingHandler;
this.cosmosAsyncClient = cosmosFactory.getCosmosAsyncClient();
this.databaseName = cosmosFactory.getDatabaseName();
this.cosmosFactory = cosmosFactory;
this.responseDiagnosticsProcessor = cosmosConfig.getResponseDiagnosticsProcessor();
this.queryMetricsEnabled = cosmosConfig.isQueryMetricsEnabled();
this.maxDegreeOfParallelism = cosmosConfig.getMaxDegreeOfParallelism();
Expand All @@ -152,6 +150,14 @@ public CosmosTemplate(CosmosFactory cosmosFactory,
this(cosmosFactory, cosmosConfig, mappingCosmosConverter, null);
}

public String getDatabaseName() {
return this.cosmosFactory.getDatabaseName();
}

public CosmosAsyncClient getCosmosAsyncClient() {
return this.cosmosFactory.getCosmosAsyncClient();
}

/**
* Sets the application context
*
Expand Down Expand Up @@ -207,14 +213,14 @@ public <T> T insert(String containerName, T objectToSave, PartitionKey partition

final JsonNode originalItem = mappingCosmosConverter.writeJsonNode(objectToSave);

LOGGER.debug("execute createItem in database {} container {}", this.databaseName,
LOGGER.debug("execute createItem in database {} container {}", this.getDatabaseName(),
containerName);

final CosmosItemRequestOptions options = new CosmosItemRequestOptions();

// if the partition key is null, SDK will get the partitionKey from the object
final CosmosItemResponse<JsonNode> response = cosmosAsyncClient
.getDatabase(this.databaseName)
final CosmosItemResponse<JsonNode> response = this.getCosmosAsyncClient()
.getDatabase(this.getDatabaseName())
.getContainer(containerName)
.createItem(originalItem, partitionKey, options)
.publishOn(Schedulers.parallel())
Expand Down Expand Up @@ -258,8 +264,8 @@ public <T> T findById(Object id, Class<T> domainType, PartitionKey partitionKey)
Assert.notNull(partitionKey, "partitionKey should not be null");
String idToQuery = CosmosUtils.getStringIDValue(id);
final String containerName = getContainerName(domainType);
return cosmosAsyncClient
.getDatabase(this.databaseName)
return this.getCosmosAsyncClient()
.getDatabase(this.getDatabaseName())
.getContainer(containerName)
.readItem(idToQuery, partitionKey, JsonNode.class)
.publishOn(Schedulers.parallel())
Expand Down Expand Up @@ -295,8 +301,8 @@ public <T> T findById(String containerName, Object id, Class<T> domainType) {
options.setMaxDegreeOfParallelism(this.maxDegreeOfParallelism);
options.setMaxBufferedItemCount(this.maxBufferedItemCount);
options.setResponseContinuationTokenLimitInKb(this.responseContinuationTokenLimitInKb);
return cosmosAsyncClient
.getDatabase(this.databaseName)
return this.getCosmosAsyncClient()
.getDatabase(this.getDatabaseName())
.getContainer(containerName)
.queryItems(sqlQuerySpec, options, JsonNode.class)
.byPage()
Expand Down Expand Up @@ -355,16 +361,16 @@ public <T> T upsertAndReturnEntity(String containerName, T object) {

final JsonNode originalItem = mappingCosmosConverter.writeJsonNode(object);

LOGGER.debug("execute upsert item in database {} container {}", this.databaseName,
LOGGER.debug("execute upsert item in database {} container {}", this.getDatabaseName(),
containerName);

@SuppressWarnings("unchecked") final Class<T> domainType = (Class<T>) object.getClass();

final CosmosItemRequestOptions options = new CosmosItemRequestOptions();
applyVersioning(domainType, originalItem, options);

final CosmosItemResponse<JsonNode> cosmosItemResponse = cosmosAsyncClient
.getDatabase(this.databaseName)
final CosmosItemResponse<JsonNode> cosmosItemResponse = this.getCosmosAsyncClient()
.getDatabase(this.getDatabaseName())
.getContainer(containerName)
.upsertItem(originalItem, options)
.publishOn(Schedulers.parallel())
Expand Down Expand Up @@ -423,8 +429,8 @@ public <T> Iterable<T> findAll(PartitionKey partitionKey, final Class<T> domainT
cosmosQueryRequestOptions.setMaxBufferedItemCount(this.maxBufferedItemCount);
cosmosQueryRequestOptions.setResponseContinuationTokenLimitInKb(this.responseContinuationTokenLimitInKb);

return cosmosAsyncClient
.getDatabase(this.databaseName)
return this.getCosmosAsyncClient()
.getDatabase(this.getDatabaseName())
.getContainer(containerName)
.queryItems("SELECT * FROM r", cosmosQueryRequestOptions, JsonNode.class)
.byPage()
Expand Down Expand Up @@ -458,7 +464,7 @@ public void deleteAll(@NonNull String containerName, @NonNull Class<?> domainTyp
@Override
public void deleteContainer(@NonNull String containerName) {
Assert.hasText(containerName, "containerName should have text.");
cosmosAsyncClient.getDatabase(this.databaseName)
this.getCosmosAsyncClient().getDatabase(this.getDatabaseName())
.getContainer(containerName)
.delete()
.publishOn(Schedulers.parallel())
Expand Down Expand Up @@ -499,7 +505,7 @@ public CosmosContainerProperties createContainerIfNotExists(CosmosEntityInformat
cosmosContainerProperties.setUniqueKeyPolicy(uniqueKeyPolicy);
}

CosmosAsyncDatabase cosmosAsyncDatabase = cosmosAsyncClient
CosmosAsyncDatabase cosmosAsyncDatabase = this.getCosmosAsyncClient()
.getDatabase(cosmosDatabaseResponse.getProperties().getId());
Mono<CosmosContainerResponse> cosmosContainerResponseMono;

Expand Down Expand Up @@ -530,20 +536,21 @@ public CosmosContainerProperties createContainerIfNotExists(CosmosEntityInformat

private Mono<CosmosDatabaseResponse> createDatabaseIfNotExists() {
if (databaseThroughputConfig == null) {
return cosmosAsyncClient
.createDatabaseIfNotExists(this.databaseName);
return this.getCosmosAsyncClient()
.createDatabaseIfNotExists(this.getDatabaseName());
} else {
ThroughputProperties throughputProperties = databaseThroughputConfig.isAutoScale()
? ThroughputProperties.createAutoscaledThroughput(databaseThroughputConfig.getRequestUnits())
: ThroughputProperties.createManualThroughput(databaseThroughputConfig.getRequestUnits());
return cosmosAsyncClient
.createDatabaseIfNotExists(this.databaseName, throughputProperties);
return this.getCosmosAsyncClient()
.createDatabaseIfNotExists(this.getDatabaseName(), throughputProperties);
}
}

@Override
public CosmosContainerProperties getContainerProperties(String containerName) {
final CosmosContainerResponse response = cosmosAsyncClient.getDatabase(this.databaseName)
final CosmosContainerResponse response = this.getCosmosAsyncClient()
.getDatabase(this.getDatabaseName())
.getContainer(containerName)
.read()
.block();
Expand All @@ -554,7 +561,8 @@ public CosmosContainerProperties getContainerProperties(String containerName) {
@Override
public CosmosContainerProperties replaceContainerProperties(String containerName,
CosmosContainerProperties properties) {
CosmosContainerResponse response = this.cosmosAsyncClient.getDatabase(this.databaseName)
CosmosContainerResponse response = this.getCosmosAsyncClient()
.getDatabase(this.getDatabaseName())
.getContainer(containerName)
.replace(properties)
.block();
Expand Down Expand Up @@ -595,14 +603,14 @@ private void deleteById(String containerName, Object id, PartitionKey partitionK
CosmosItemRequestOptions options) {
Assert.hasText(containerName, "containerName should not be null, empty or only whitespaces");
String idToDelete = CosmosUtils.getStringIDValue(id);
LOGGER.debug("execute deleteById in database {} container {}", this.databaseName,
LOGGER.debug("execute deleteById in database {} container {}", this.getDatabaseName(),
containerName);

if (partitionKey == null) {
partitionKey = PartitionKey.NONE;
}

cosmosAsyncClient.getDatabase(this.databaseName)
this.getCosmosAsyncClient().getDatabase(this.getDatabaseName())
.getContainer(containerName)
.deleteItem(idToDelete, partitionKey, options)
.publishOn(Schedulers.parallel())
Expand Down Expand Up @@ -756,7 +764,7 @@ private <T> Slice<T> sliceQuery(SqlQuerySpec querySpec,
});

CosmosAsyncContainer container =
cosmosAsyncClient.getDatabase(this.databaseName).getContainer(containerName);
this.getCosmosAsyncClient().getDatabase(this.getDatabaseName()).getContainer(containerName);

Flux<FeedResponse<JsonNode>> feedResponseFlux;
/*
Expand Down Expand Up @@ -914,7 +922,7 @@ private Long getCountValue(SqlQuerySpec querySpec, String containerName) {
private Flux<FeedResponse<JsonNode>> executeQuery(SqlQuerySpec sqlQuerySpec,
String containerName,
CosmosQueryRequestOptions options) {
return cosmosAsyncClient.getDatabase(this.databaseName)
return this.getCosmosAsyncClient().getDatabase(this.getDatabaseName())
.getContainer(containerName)
.queryItems(sqlQuerySpec, options, JsonNode.class)
.byPage();
Expand All @@ -935,8 +943,8 @@ private <T> Flux<JsonNode> findItemsAsFlux(@NonNull CosmosQuery query,
cosmosQueryRequestOptions.setPartitionKey(new PartitionKey(o));
});

return cosmosAsyncClient
.getDatabase(this.databaseName)
return this.getCosmosAsyncClient()
.getDatabase(this.getDatabaseName())
.getContainer(containerName)
.queryItems(sqlQuerySpec, cosmosQueryRequestOptions, JsonNode.class)
.byPage()
Expand All @@ -960,8 +968,8 @@ private Flux<JsonNode> getJsonNodeFluxFromQuerySpec(
cosmosQueryRequestOptions.setMaxBufferedItemCount(this.maxBufferedItemCount);
cosmosQueryRequestOptions.setResponseContinuationTokenLimitInKb(this.responseContinuationTokenLimitInKb);

return cosmosAsyncClient
.getDatabase(this.databaseName)
return this.getCosmosAsyncClient()
.getDatabase(this.getDatabaseName())
.getContainer(containerName)
.queryItems(sqlQuerySpec, cosmosQueryRequestOptions, JsonNode.class)
.byPage()
Expand Down Expand Up @@ -992,8 +1000,8 @@ private <T> T deleteItem(@NonNull JsonNode jsonNode,
final CosmosItemRequestOptions options = new CosmosItemRequestOptions();
applyVersioning(domainType, jsonNode, options);

return cosmosAsyncClient
.getDatabase(this.databaseName)
return this.getCosmosAsyncClient()
.getDatabase(this.getDatabaseName())
.getContainer(containerName)
.deleteItem(jsonNode, options)
.publishOn(Schedulers.parallel())
Expand Down