diff --git a/sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/core/MultiTenantDBCosmosFactory.java b/sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/core/MultiTenantDBCosmosFactory.java new file mode 100644 index 0000000000000..d7212896d6e74 --- /dev/null +++ b/sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/core/MultiTenantDBCosmosFactory.java @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.spring.data.cosmos.core; + +import com.azure.cosmos.CosmosAsyncClient; +import com.azure.spring.data.cosmos.CosmosFactory; + +/** + * Example for extending CosmosFactory for Mutli-Tenancy at the database level + */ +public class MultiTenantDBCosmosFactory extends CosmosFactory { + + public String manuallySetDatabaseName; + + /** + * Validate config and initialization + * + * @param cosmosAsyncClient cosmosAsyncClient + * @param databaseName databaseName + */ + public MultiTenantDBCosmosFactory(CosmosAsyncClient cosmosAsyncClient, String databaseName) { + super(cosmosAsyncClient, databaseName); + + this.manuallySetDatabaseName = databaseName; + } + + @Override + public String getDatabaseName() { + return this.manuallySetDatabaseName; + } +} diff --git a/sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/core/MultiTenantDBCosmosFactoryIT.java b/sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/core/MultiTenantDBCosmosFactoryIT.java new file mode 100644 index 0000000000000..dc16252d651a3 --- /dev/null +++ b/sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/core/MultiTenantDBCosmosFactoryIT.java @@ -0,0 +1,133 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.spring.data.cosmos.core; + +import com.azure.cosmos.CosmosAsyncClient; +import com.azure.cosmos.CosmosAsyncDatabase; +import com.azure.cosmos.CosmosClientBuilder; +import com.azure.cosmos.CosmosException; +import com.azure.cosmos.models.PartitionKey; +import com.azure.spring.data.cosmos.CosmosFactory; +import com.azure.spring.data.cosmos.IntegrationTestCollectionManager; +import com.azure.spring.data.cosmos.config.CosmosConfig; +import com.azure.spring.data.cosmos.core.convert.MappingCosmosConverter; +import com.azure.spring.data.cosmos.core.mapping.CosmosMappingContext; +import com.azure.spring.data.cosmos.domain.Person; +import com.azure.spring.data.cosmos.repository.MultiTenantTestRepositoryConfig; +import com.azure.spring.data.cosmos.repository.support.CosmosEntityInformation; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.autoconfigure.domain.EntityScanner; +import org.springframework.context.ApplicationContext; +import org.springframework.data.annotation.Persistent; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; + +import java.util.ArrayList; +import java.util.List; + +import static com.azure.spring.data.cosmos.common.TestConstants.ADDRESSES; +import static com.azure.spring.data.cosmos.common.TestConstants.AGE; +import static com.azure.spring.data.cosmos.common.TestConstants.FIRST_NAME; +import static com.azure.spring.data.cosmos.common.TestConstants.HOBBIES; +import static com.azure.spring.data.cosmos.common.TestConstants.ID_1; +import static com.azure.spring.data.cosmos.common.TestConstants.ID_2; +import static com.azure.spring.data.cosmos.common.TestConstants.LAST_NAME; +import static com.azure.spring.data.cosmos.common.TestConstants.PASSPORT_IDS_BY_COUNTRY; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertEquals; + +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration(classes = MultiTenantTestRepositoryConfig.class) +public class MultiTenantDBCosmosFactoryIT { + + private final String testDB1 = "Database1"; + private final String testDB2 = "Database2"; + + private final Person TEST_PERSON_1 = new Person(ID_1, FIRST_NAME, LAST_NAME, HOBBIES, ADDRESSES, AGE, PASSPORT_IDS_BY_COUNTRY); + private final Person TEST_PERSON_2 = new Person(ID_2, FIRST_NAME, LAST_NAME, HOBBIES, ADDRESSES, AGE, PASSPORT_IDS_BY_COUNTRY); + + @ClassRule + public static final IntegrationTestCollectionManager collectionManager = new IntegrationTestCollectionManager(); + + @Autowired + private ApplicationContext applicationContext; + @Autowired + private CosmosConfig cosmosConfig; + @Autowired + private CosmosClientBuilder cosmosClientBuilder; + + private MultiTenantDBCosmosFactory cosmosFactory; + private CosmosTemplate cosmosTemplate; + private CosmosAsyncClient client; + private CosmosEntityInformation personInfo; + + @Before + public void setUp() throws ClassNotFoundException { + /// Setup + client = CosmosFactory.createCosmosAsyncClient(cosmosClientBuilder); + cosmosFactory = new MultiTenantDBCosmosFactory(client, testDB1); + final CosmosMappingContext mappingContext = new CosmosMappingContext(); + + try { + mappingContext.setInitialEntitySet(new EntityScanner(this.applicationContext).scan(Persistent.class)); + } catch (Exception e) { + Assert.fail(); + } + + final MappingCosmosConverter cosmosConverter = new MappingCosmosConverter(mappingContext, null); + cosmosTemplate = new CosmosTemplate(cosmosFactory, cosmosConfig, cosmosConverter, null); + personInfo = new CosmosEntityInformation<>(Person.class); + } + + @Test + public void testGetDatabaseFunctionality() { + // Create DB1 and add TEST_PERSON_1 to it + cosmosTemplate.createContainerIfNotExists(personInfo); + cosmosTemplate.deleteAll(personInfo.getContainerName(), Person.class); + assertThat(cosmosFactory.getDatabaseName()).isEqualTo(testDB1); + cosmosTemplate.insert(TEST_PERSON_1, new PartitionKey(personInfo.getPartitionKeyFieldValue(TEST_PERSON_1))); + + // Create DB2 and add TEST_PERSON_2 to it + cosmosFactory.manuallySetDatabaseName = testDB2; + cosmosTemplate.createContainerIfNotExists(personInfo); + cosmosTemplate.deleteAll(personInfo.getContainerName(), Person.class); + assertThat(cosmosFactory.getDatabaseName()).isEqualTo(testDB2); + cosmosTemplate.insert(TEST_PERSON_2, new PartitionKey(personInfo.getPartitionKeyFieldValue(TEST_PERSON_2))); + + // Check that DB2 has the correct contents + List expectedResultsDB2 = new ArrayList<>(); + expectedResultsDB2.add(TEST_PERSON_2); + Iterable iterableDB2 = cosmosTemplate.findAll(personInfo.getContainerName(), Person.class); + List resultDB2 = new ArrayList<>(); + iterableDB2.forEach(resultDB2::add); + Assert.assertEquals(expectedResultsDB2, resultDB2); + + // Check that DB1 has the correct contents + cosmosFactory.manuallySetDatabaseName = testDB1; + List expectedResultsDB1 = new ArrayList<>(); + expectedResultsDB1.add(TEST_PERSON_1); + Iterable iterableDB1 = cosmosTemplate.findAll(personInfo.getContainerName(), Person.class); + List resultDB1 = new ArrayList<>(); + iterableDB1.forEach(resultDB1::add); + Assert.assertEquals(expectedResultsDB1, resultDB1); + + //Cleanup + deleteDatabaseIfExists(testDB1); + deleteDatabaseIfExists(testDB2); + } + + private void deleteDatabaseIfExists(String dbName) { + CosmosAsyncDatabase database = client.getDatabase(dbName); + try { + database.delete().block(); + } catch (CosmosException ex) { + assertEquals(ex.getStatusCode(), 404); + } + } +} diff --git a/sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/repository/MultiCosmosTemplateIT.java b/sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/repository/MultiCosmosTemplateIT.java index 0ae1bb6c320a8..02b00ca05833f 100644 --- a/sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/repository/MultiCosmosTemplateIT.java +++ b/sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/repository/MultiCosmosTemplateIT.java @@ -4,6 +4,7 @@ import com.azure.cosmos.CosmosAsyncClient; import com.azure.cosmos.models.PartitionKey; +import com.azure.spring.data.cosmos.CosmosFactory; import com.azure.spring.data.cosmos.ReactiveIntegrationTestCollectionManager; import com.azure.spring.data.cosmos.common.TestConstants; import com.azure.spring.data.cosmos.core.ReactiveCosmosTemplate; @@ -80,10 +81,12 @@ public void testSecondaryTemplateWithDiffDatabase() { @Test public void testSingleCosmosClientForMultipleCosmosTemplate() throws IllegalAccessException { - final Field cosmosAsyncClient = FieldUtils.getDeclaredField(ReactiveCosmosTemplate.class, - "cosmosAsyncClient", true); - CosmosAsyncClient client1 = (CosmosAsyncClient) cosmosAsyncClient.get(secondaryReactiveCosmosTemplate); - CosmosAsyncClient client2 = (CosmosAsyncClient) cosmosAsyncClient.get(secondaryDiffDatabaseReactiveCosmosTemplate); + final Field cosmosFactory = FieldUtils.getDeclaredField(ReactiveCosmosTemplate.class, + "cosmosFactory", true); + CosmosFactory cosmosFactory1 = (CosmosFactory) cosmosFactory.get(secondaryReactiveCosmosTemplate); + CosmosAsyncClient client1 = cosmosFactory1.getCosmosAsyncClient(); + CosmosFactory cosmosFactory2 = (CosmosFactory) cosmosFactory.get(secondaryDiffDatabaseReactiveCosmosTemplate); + CosmosAsyncClient client2 = cosmosFactory2.getCosmosAsyncClient(); Assertions.assertThat(client1).isEqualTo(client2); } } diff --git a/sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/repository/MultiTenantTestRepositoryConfig.java b/sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/repository/MultiTenantTestRepositoryConfig.java new file mode 100644 index 0000000000000..e28533a40997f --- /dev/null +++ b/sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/repository/MultiTenantTestRepositoryConfig.java @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +package com.azure.spring.data.cosmos.repository; + +import com.azure.cosmos.CosmosAsyncClient; +import com.azure.cosmos.CosmosClientBuilder; +import com.azure.spring.data.cosmos.common.ResponseDiagnosticsTestUtils; +import com.azure.spring.data.cosmos.common.TestConstants; +import com.azure.spring.data.cosmos.config.AbstractCosmosConfiguration; +import com.azure.spring.data.cosmos.config.CosmosConfig; +import com.azure.spring.data.cosmos.core.MultiTenantDBCosmosFactory; +import com.azure.spring.data.cosmos.core.mapping.event.SimpleCosmosMappingEventListener; +import com.azure.spring.data.cosmos.repository.config.EnableCosmosRepositories; +import com.azure.spring.data.cosmos.repository.config.EnableReactiveCosmosRepositories; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.PropertySource; +import org.springframework.util.StringUtils; + +import java.util.Arrays; +import java.util.Collection; + +@Configuration +@PropertySource(value = { "classpath:application.properties" }) +@EnableCosmosRepositories +@EnableReactiveCosmosRepositories +public class MultiTenantTestRepositoryConfig extends AbstractCosmosConfiguration { + @Value("${cosmos.uri:}") + private String cosmosDbUri; + + @Value("${cosmos.key:}") + private String cosmosDbKey; + + @Value("${cosmos.database:}") + private String database; + + @Value("${cosmos.queryMetricsEnabled}") + private boolean queryMetricsEnabled; + + @Value("${cosmos.maxDegreeOfParallelism}") + private int maxDegreeOfParallelism; + + @Value("${cosmos.maxBufferedItemCount}") + private int maxBufferedItemCount; + + @Value("${cosmos.responseContinuationTokenLimitInKb}") + private int responseContinuationTokenLimitInKb; + + @Bean + public ResponseDiagnosticsTestUtils responseDiagnosticsTestUtils() { + return new ResponseDiagnosticsTestUtils(); + } + + @Bean + public CosmosClientBuilder cosmosClientBuilder() { + return new CosmosClientBuilder() + .key(cosmosDbKey) + .endpoint(cosmosDbUri) + .contentResponseOnWriteEnabled(true); + } + + @Bean + public MultiTenantDBCosmosFactory cosmosFactory(CosmosAsyncClient cosmosAsyncClient) { + return new MultiTenantDBCosmosFactory(cosmosAsyncClient, getDatabaseName()); + } + + @Bean + @Override + public CosmosConfig cosmosConfig() { + return CosmosConfig.builder() + .enableQueryMetrics(queryMetricsEnabled) + .maxDegreeOfParallelism(maxDegreeOfParallelism) + .maxBufferedItemCount(maxBufferedItemCount) + .responseContinuationTokenLimitInKb(responseContinuationTokenLimitInKb) + .responseDiagnosticsProcessor(responseDiagnosticsTestUtils().getResponseDiagnosticsProcessor()) + .build(); + } + + @Override + protected String getDatabaseName() { + return StringUtils.hasText(this.database) ? this.database : TestConstants.DB_NAME; + } + + @Override + protected Collection getMappingBasePackages() { + final Package mappingBasePackage = getClass().getPackage(); + final String entityPackage = "com.azure.spring.data.cosmos.domain"; + return Arrays.asList(mappingBasePackage.getName(), entityPackage); + } + + @Bean + SimpleCosmosMappingEventListener simpleMappingEventListener() { + return new SimpleCosmosMappingEventListener(); + } +} diff --git a/sdk/cosmos/azure-spring-data-cosmos/CHANGELOG.md b/sdk/cosmos/azure-spring-data-cosmos/CHANGELOG.md index 6a9bfcfc05a06..2d039b74398cf 100644 --- a/sdk/cosmos/azure-spring-data-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-spring-data-cosmos/CHANGELOG.md @@ -3,6 +3,7 @@ ### 3.31.0-beta.1 (Unreleased) #### Features Added +* Added support for multi-tenancy at the Database level via `CosmosFactory` - See [PR 32516](https://github.com/Azure/azure-sdk-for-java/pull/32516) #### Breaking Changes diff --git a/sdk/cosmos/azure-spring-data-cosmos/README.md b/sdk/cosmos/azure-spring-data-cosmos/README.md index b76af386b06ff..e68f6de593fb7 100644 --- a/sdk/cosmos/azure-spring-data-cosmos/README.md +++ b/sdk/cosmos/azure-spring-data-cosmos/README.md @@ -934,6 +934,32 @@ public class MultiDatabaseApplication implements CommandLineRunner { } ``` +### Multi-Tenancy at the Database Level +- Azure-spring-data-cosmos supports multi-tenancy at the database level configuration by extending `CosmosFactory` and overriding the getDatabaseName() function. +```java readme-sample-MultiTenantDBCosmosFactory +public class MultiTenantDBCosmosFactory extends CosmosFactory { + + private String tenantId; + + /** + * Validate config and initialization + * + * @param cosmosAsyncClient cosmosAsyncClient + * @param databaseName databaseName + */ + public MultiTenantDBCosmosFactory(CosmosAsyncClient cosmosAsyncClient, String databaseName) { + super(cosmosAsyncClient, databaseName); + + this.tenantId = databaseName; + } + + @Override + public String getDatabaseName() { + return this.getCosmosAsyncClient().getDatabase(this.tenantId).toString(); + } +} +``` + ## Beta version package Beta version built from `main` branch are available, you can refer to the [instruction](https://github.com/Azure/azure-sdk-for-java/blob/main/CONTRIBUTING.md#nightly-package-builds) to use beta version packages. diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/CosmosFactory.java b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/CosmosFactory.java index 51721280b8471..dfd05f94d4a02 100644 --- a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/CosmosFactory.java +++ b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/CosmosFactory.java @@ -20,7 +20,10 @@ public class CosmosFactory { private final CosmosAsyncClient cosmosAsyncClient; - private final String databaseName; + /** + * Database Name to be used for operations. + */ + protected String databaseName; private static final String USER_AGENT_SUFFIX = Constants.USER_AGENT_SUFFIX + PropertyLoader.getProjectVersion(); diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/CosmosTemplate.java b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/CosmosTemplate.java index f1b8d79d8a09c..4768cd768e9e4 100644 --- a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/CosmosTemplate.java +++ b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/CosmosTemplate.java @@ -74,13 +74,12 @@ public class CosmosTemplate implements CosmosOperations, ApplicationContextAware private final MappingCosmosConverter mappingCosmosConverter; private final IsNewAwareAuditingHandler cosmosAuditingHandler; - private final String databaseName; + private final CosmosFactory cosmosFactory; private final ResponseDiagnosticsProcessor responseDiagnosticsProcessor; private final boolean queryMetricsEnabled; private final int maxDegreeOfParallelism; private final int maxBufferedItemCount; private final int responseContinuationTokenLimitInKb; - private final CosmosAsyncClient cosmosAsyncClient; private final DatabaseThroughputConfig databaseThroughputConfig; private ApplicationContext applicationContext; @@ -129,8 +128,7 @@ public CosmosTemplate(CosmosFactory cosmosFactory, Assert.notNull(mappingCosmosConverter, "MappingCosmosConverter must not be null!"); this.mappingCosmosConverter = mappingCosmosConverter; this.cosmosAuditingHandler = cosmosAuditingHandler; - this.cosmosAsyncClient = cosmosFactory.getCosmosAsyncClient(); - this.databaseName = cosmosFactory.getDatabaseName(); + this.cosmosFactory = cosmosFactory; this.responseDiagnosticsProcessor = cosmosConfig.getResponseDiagnosticsProcessor(); this.queryMetricsEnabled = cosmosConfig.isQueryMetricsEnabled(); this.maxDegreeOfParallelism = cosmosConfig.getMaxDegreeOfParallelism(); @@ -152,6 +150,14 @@ public CosmosTemplate(CosmosFactory cosmosFactory, this(cosmosFactory, cosmosConfig, mappingCosmosConverter, null); } + private String getDatabaseName() { + return this.cosmosFactory.getDatabaseName(); + } + + private CosmosAsyncClient getCosmosAsyncClient() { + return this.cosmosFactory.getCosmosAsyncClient(); + } + /** * Sets the application context * @@ -207,14 +213,14 @@ public T insert(String containerName, T objectToSave, PartitionKey partition final JsonNode originalItem = mappingCosmosConverter.writeJsonNode(objectToSave); - LOGGER.debug("execute createItem in database {} container {}", this.databaseName, + LOGGER.debug("execute createItem in database {} container {}", this.getDatabaseName(), containerName); final CosmosItemRequestOptions options = new CosmosItemRequestOptions(); // if the partition key is null, SDK will get the partitionKey from the object - final CosmosItemResponse response = cosmosAsyncClient - .getDatabase(this.databaseName) + final CosmosItemResponse response = this.getCosmosAsyncClient() + .getDatabase(this.getDatabaseName()) .getContainer(containerName) .createItem(originalItem, partitionKey, options) .publishOn(Schedulers.parallel()) @@ -258,8 +264,8 @@ public T findById(Object id, Class domainType, PartitionKey partitionKey) Assert.notNull(partitionKey, "partitionKey should not be null"); String idToQuery = CosmosUtils.getStringIDValue(id); final String containerName = getContainerName(domainType); - return cosmosAsyncClient - .getDatabase(this.databaseName) + return this.getCosmosAsyncClient() + .getDatabase(this.getDatabaseName()) .getContainer(containerName) .readItem(idToQuery, partitionKey, JsonNode.class) .publishOn(Schedulers.parallel()) @@ -295,8 +301,8 @@ public T findById(String containerName, Object id, Class domainType) { options.setMaxDegreeOfParallelism(this.maxDegreeOfParallelism); options.setMaxBufferedItemCount(this.maxBufferedItemCount); options.setResponseContinuationTokenLimitInKb(this.responseContinuationTokenLimitInKb); - return cosmosAsyncClient - .getDatabase(this.databaseName) + return this.getCosmosAsyncClient() + .getDatabase(this.getDatabaseName()) .getContainer(containerName) .queryItems(sqlQuerySpec, options, JsonNode.class) .byPage() @@ -355,7 +361,7 @@ public T upsertAndReturnEntity(String containerName, T object) { final JsonNode originalItem = mappingCosmosConverter.writeJsonNode(object); - LOGGER.debug("execute upsert item in database {} container {}", this.databaseName, + LOGGER.debug("execute upsert item in database {} container {}", this.getDatabaseName(), containerName); @SuppressWarnings("unchecked") final Class domainType = (Class) object.getClass(); @@ -363,8 +369,8 @@ public T upsertAndReturnEntity(String containerName, T object) { final CosmosItemRequestOptions options = new CosmosItemRequestOptions(); applyVersioning(domainType, originalItem, options); - final CosmosItemResponse cosmosItemResponse = cosmosAsyncClient - .getDatabase(this.databaseName) + final CosmosItemResponse cosmosItemResponse = this.getCosmosAsyncClient() + .getDatabase(this.getDatabaseName()) .getContainer(containerName) .upsertItem(originalItem, options) .publishOn(Schedulers.parallel()) @@ -423,8 +429,8 @@ public Iterable findAll(PartitionKey partitionKey, final Class domainT cosmosQueryRequestOptions.setMaxBufferedItemCount(this.maxBufferedItemCount); cosmosQueryRequestOptions.setResponseContinuationTokenLimitInKb(this.responseContinuationTokenLimitInKb); - return cosmosAsyncClient - .getDatabase(this.databaseName) + return this.getCosmosAsyncClient() + .getDatabase(this.getDatabaseName()) .getContainer(containerName) .queryItems("SELECT * FROM r", cosmosQueryRequestOptions, JsonNode.class) .byPage() @@ -458,7 +464,7 @@ public void deleteAll(@NonNull String containerName, @NonNull Class domainTyp @Override public void deleteContainer(@NonNull String containerName) { Assert.hasText(containerName, "containerName should have text."); - cosmosAsyncClient.getDatabase(this.databaseName) + this.getCosmosAsyncClient().getDatabase(this.getDatabaseName()) .getContainer(containerName) .delete() .publishOn(Schedulers.parallel()) @@ -499,7 +505,7 @@ public CosmosContainerProperties createContainerIfNotExists(CosmosEntityInformat cosmosContainerProperties.setUniqueKeyPolicy(uniqueKeyPolicy); } - CosmosAsyncDatabase cosmosAsyncDatabase = cosmosAsyncClient + CosmosAsyncDatabase cosmosAsyncDatabase = this.getCosmosAsyncClient() .getDatabase(cosmosDatabaseResponse.getProperties().getId()); Mono cosmosContainerResponseMono; @@ -530,20 +536,21 @@ public CosmosContainerProperties createContainerIfNotExists(CosmosEntityInformat private Mono createDatabaseIfNotExists() { if (databaseThroughputConfig == null) { - return cosmosAsyncClient - .createDatabaseIfNotExists(this.databaseName); + return this.getCosmosAsyncClient() + .createDatabaseIfNotExists(this.getDatabaseName()); } else { ThroughputProperties throughputProperties = databaseThroughputConfig.isAutoScale() ? ThroughputProperties.createAutoscaledThroughput(databaseThroughputConfig.getRequestUnits()) : ThroughputProperties.createManualThroughput(databaseThroughputConfig.getRequestUnits()); - return cosmosAsyncClient - .createDatabaseIfNotExists(this.databaseName, throughputProperties); + return this.getCosmosAsyncClient() + .createDatabaseIfNotExists(this.getDatabaseName(), throughputProperties); } } @Override public CosmosContainerProperties getContainerProperties(String containerName) { - final CosmosContainerResponse response = cosmosAsyncClient.getDatabase(this.databaseName) + final CosmosContainerResponse response = this.getCosmosAsyncClient() + .getDatabase(this.getDatabaseName()) .getContainer(containerName) .read() .block(); @@ -554,7 +561,8 @@ public CosmosContainerProperties getContainerProperties(String containerName) { @Override public CosmosContainerProperties replaceContainerProperties(String containerName, CosmosContainerProperties properties) { - CosmosContainerResponse response = this.cosmosAsyncClient.getDatabase(this.databaseName) + CosmosContainerResponse response = this.getCosmosAsyncClient() + .getDatabase(this.getDatabaseName()) .getContainer(containerName) .replace(properties) .block(); @@ -595,14 +603,14 @@ private void deleteById(String containerName, Object id, PartitionKey partitionK CosmosItemRequestOptions options) { Assert.hasText(containerName, "containerName should not be null, empty or only whitespaces"); String idToDelete = CosmosUtils.getStringIDValue(id); - LOGGER.debug("execute deleteById in database {} container {}", this.databaseName, + LOGGER.debug("execute deleteById in database {} container {}", this.getDatabaseName(), containerName); if (partitionKey == null) { partitionKey = PartitionKey.NONE; } - cosmosAsyncClient.getDatabase(this.databaseName) + this.getCosmosAsyncClient().getDatabase(this.getDatabaseName()) .getContainer(containerName) .deleteItem(idToDelete, partitionKey, options) .publishOn(Schedulers.parallel()) @@ -756,7 +764,7 @@ private Slice sliceQuery(SqlQuerySpec querySpec, }); CosmosAsyncContainer container = - cosmosAsyncClient.getDatabase(this.databaseName).getContainer(containerName); + this.getCosmosAsyncClient().getDatabase(this.getDatabaseName()).getContainer(containerName); Flux> feedResponseFlux; /* @@ -914,7 +922,7 @@ private Long getCountValue(SqlQuerySpec querySpec, String containerName) { private Flux> executeQuery(SqlQuerySpec sqlQuerySpec, String containerName, CosmosQueryRequestOptions options) { - return cosmosAsyncClient.getDatabase(this.databaseName) + return this.getCosmosAsyncClient().getDatabase(this.getDatabaseName()) .getContainer(containerName) .queryItems(sqlQuerySpec, options, JsonNode.class) .byPage(); @@ -935,8 +943,8 @@ private Flux findItemsAsFlux(@NonNull CosmosQuery query, cosmosQueryRequestOptions.setPartitionKey(new PartitionKey(o)); }); - return cosmosAsyncClient - .getDatabase(this.databaseName) + return this.getCosmosAsyncClient() + .getDatabase(this.getDatabaseName()) .getContainer(containerName) .queryItems(sqlQuerySpec, cosmosQueryRequestOptions, JsonNode.class) .byPage() @@ -960,8 +968,8 @@ private Flux getJsonNodeFluxFromQuerySpec( cosmosQueryRequestOptions.setMaxBufferedItemCount(this.maxBufferedItemCount); cosmosQueryRequestOptions.setResponseContinuationTokenLimitInKb(this.responseContinuationTokenLimitInKb); - return cosmosAsyncClient - .getDatabase(this.databaseName) + return this.getCosmosAsyncClient() + .getDatabase(this.getDatabaseName()) .getContainer(containerName) .queryItems(sqlQuerySpec, cosmosQueryRequestOptions, JsonNode.class) .byPage() @@ -992,8 +1000,8 @@ private T deleteItem(@NonNull JsonNode jsonNode, final CosmosItemRequestOptions options = new CosmosItemRequestOptions(); applyVersioning(domainType, jsonNode, options); - return cosmosAsyncClient - .getDatabase(this.databaseName) + return this.getCosmosAsyncClient() + .getDatabase(this.getDatabaseName()) .getContainer(containerName) .deleteItem(jsonNode, options) .publishOn(Schedulers.parallel()) diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/ReactiveCosmosTemplate.java b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/ReactiveCosmosTemplate.java index a533050586953..9becf4409696c 100644 --- a/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/ReactiveCosmosTemplate.java +++ b/sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/ReactiveCosmosTemplate.java @@ -58,14 +58,13 @@ public class ReactiveCosmosTemplate implements ReactiveCosmosOperations, Applica private static final Logger LOGGER = LoggerFactory.getLogger(ReactiveCosmosTemplate.class); + private final CosmosFactory cosmosFactory; private final MappingCosmosConverter mappingCosmosConverter; - private final String databaseName; private final ResponseDiagnosticsProcessor responseDiagnosticsProcessor; private final boolean queryMetricsEnabled; private final int maxDegreeOfParallelism; private final int maxBufferedItemCount; private final int responseContinuationTokenLimitInKb; - private final CosmosAsyncClient cosmosAsyncClient; private final IsNewAwareAuditingHandler cosmosAuditingHandler; private final DatabaseThroughputConfig databaseThroughputConfig; @@ -116,8 +115,7 @@ public ReactiveCosmosTemplate(CosmosFactory cosmosFactory, Assert.notNull(mappingCosmosConverter, "MappingCosmosConverter must not be null!"); this.mappingCosmosConverter = mappingCosmosConverter; - this.cosmosAsyncClient = cosmosFactory.getCosmosAsyncClient(); - this.databaseName = cosmosFactory.getDatabaseName(); + this.cosmosFactory = cosmosFactory; this.responseDiagnosticsProcessor = cosmosConfig.getResponseDiagnosticsProcessor(); this.queryMetricsEnabled = cosmosConfig.isQueryMetricsEnabled(); this.maxDegreeOfParallelism = cosmosConfig.getMaxDegreeOfParallelism(); @@ -140,6 +138,14 @@ public ReactiveCosmosTemplate(CosmosFactory cosmosFactory, this(cosmosFactory, cosmosConfig, mappingCosmosConverter, null); } + private String getDatabaseName() { + return this.cosmosFactory.getDatabaseName(); + } + + private CosmosAsyncClient getCosmosAsyncClient() { + return this.cosmosFactory.getCosmosAsyncClient(); + } + /** * @param applicationContext the application context * @throws BeansException the bean exception @@ -175,7 +181,7 @@ public Mono createContainerIfNotExists(CosmosEntityInfo } CosmosAsyncDatabase database = - cosmosAsyncClient.getDatabase(cosmosDatabaseResponse.getProperties().getId()); + this.getCosmosAsyncClient().getDatabase(cosmosDatabaseResponse.getProperties().getId()); Mono cosmosContainerResponseMono; if (information.getRequestUnit() == null) { @@ -205,20 +211,20 @@ public Mono createContainerIfNotExists(CosmosEntityInfo private Mono createDatabaseIfNotExists() { if (databaseThroughputConfig == null) { - return cosmosAsyncClient - .createDatabaseIfNotExists(this.databaseName); + return this.getCosmosAsyncClient() + .createDatabaseIfNotExists(this.getDatabaseName()); } else { ThroughputProperties throughputProperties = databaseThroughputConfig.isAutoScale() ? ThroughputProperties.createAutoscaledThroughput(databaseThroughputConfig.getRequestUnits()) : ThroughputProperties.createManualThroughput(databaseThroughputConfig.getRequestUnits()); - return cosmosAsyncClient - .createDatabaseIfNotExists(this.databaseName, throughputProperties); + return this.getCosmosAsyncClient() + .createDatabaseIfNotExists(this.getDatabaseName(), throughputProperties); } } @Override public Mono getContainerProperties(String containerName) { - return cosmosAsyncClient.getDatabase(this.databaseName) + return this.getCosmosAsyncClient().getDatabase(this.getDatabaseName()) .getContainer(containerName) .read() .map(CosmosContainerResponse::getProperties); @@ -227,7 +233,7 @@ public Mono getContainerProperties(String containerNa @Override public Mono replaceContainerProperties(String containerName, CosmosContainerProperties properties) { - return this.cosmosAsyncClient.getDatabase(this.databaseName) + return this.getCosmosAsyncClient().getDatabase(this.getDatabaseName()) .getContainer(containerName) .replace(properties) .map(CosmosContainerResponse::getProperties); @@ -273,8 +279,8 @@ public Flux findAll(PartitionKey partitionKey, Class domainType) { cosmosQueryRequestOptions.setMaxBufferedItemCount(this.maxBufferedItemCount); cosmosQueryRequestOptions.setResponseContinuationTokenLimitInKb(this.responseContinuationTokenLimitInKb); - return cosmosAsyncClient - .getDatabase(this.databaseName) + return this.getCosmosAsyncClient() + .getDatabase(this.getDatabaseName()) .getContainer(containerName) .queryItems("SELECT * FROM r", cosmosQueryRequestOptions, JsonNode.class) .byPage() @@ -325,7 +331,7 @@ public Mono findById(String containerName, Object id, Class domainType options.setMaxBufferedItemCount(this.maxBufferedItemCount); options.setResponseContinuationTokenLimitInKb(this.responseContinuationTokenLimitInKb); - return cosmosAsyncClient.getDatabase(this.databaseName) + return this.getCosmosAsyncClient().getDatabase(this.getDatabaseName()) .getContainer(containerName) .queryItems(sqlQuerySpec, options, JsonNode.class) .byPage() @@ -360,7 +366,7 @@ public Mono findById(Object id, Class domainType, PartitionKey partiti String idToFind = CosmosUtils.getStringIDValue(id); final String containerName = getContainerName(domainType); - return cosmosAsyncClient.getDatabase(this.databaseName) + return this.getCosmosAsyncClient().getDatabase(this.getDatabaseName()) .getContainer(containerName) .readItem(idToFind, partitionKey, JsonNode.class) .publishOn(Schedulers.parallel()) @@ -418,8 +424,8 @@ public Mono insert(String containerName, T objectToSave, final JsonNode originalItem = mappingCosmosConverter.writeJsonNode(objectToSave); final CosmosItemRequestOptions options = new CosmosItemRequestOptions(); // if the partition key is null, SDK will get the partitionKey from the object - return cosmosAsyncClient - .getDatabase(this.databaseName) + return this.getCosmosAsyncClient() + .getDatabase(this.getDatabaseName()) .getContainer(containerName) .createItem(originalItem, partitionKey, options) .publishOn(Schedulers.parallel()) @@ -481,7 +487,7 @@ public Mono upsert(String containerName, T object) { applyVersioning(object.getClass(), originalItem, options); - return cosmosAsyncClient.getDatabase(this.databaseName) + return this.getCosmosAsyncClient().getDatabase(this.getDatabaseName()) .getContainer(containerName) .upsertItem(originalItem, options) .publishOn(Schedulers.parallel()) @@ -517,7 +523,7 @@ private Mono deleteById(String containerName, Object id, PartitionKey part partitionKey = PartitionKey.NONE; } - return cosmosAsyncClient.getDatabase(this.databaseName) + return this.getCosmosAsyncClient().getDatabase(this.getDatabaseName()) .getContainer(containerName) .deleteItem(idToDelete, partitionKey, cosmosItemRequestOptions) .publishOn(Schedulers.parallel()) @@ -683,7 +689,7 @@ private Flux runQuery(SqlQuerySpec querySpec, Class domainType) { options.setMaxDegreeOfParallelism(this.maxDegreeOfParallelism); options.setMaxBufferedItemCount(this.maxBufferedItemCount); options.setResponseContinuationTokenLimitInKb(this.responseContinuationTokenLimitInKb); - return cosmosAsyncClient.getDatabase(this.databaseName) + return this.getCosmosAsyncClient().getDatabase(this.getDatabaseName()) .getContainer(containerName) .queryItems(querySpec, options, JsonNode.class) .byPage() @@ -721,7 +727,7 @@ private Flux> executeQuery(SqlQuerySpec sqlQuerySpec, String containerName, CosmosQueryRequestOptions options) { - return cosmosAsyncClient.getDatabase(this.databaseName) + return this.getCosmosAsyncClient().getDatabase(this.getDatabaseName()) .getContainer(containerName) .queryItems(sqlQuerySpec, options, JsonNode.class) .byPage() @@ -738,7 +744,7 @@ private Flux> executeQuery(SqlQuerySpec sqlQuerySpec, @Override public void deleteContainer(@NonNull String containerName) { Assert.hasText(containerName, "containerName should have text."); - cosmosAsyncClient.getDatabase(this.databaseName) + this.getCosmosAsyncClient().getDatabase(this.getDatabaseName()) .getContainer(containerName) .delete() .doOnNext(cosmosContainerResponse -> @@ -782,8 +788,8 @@ private Flux findItems(@NonNull CosmosQuery query, cosmosQueryRequestOptions.setPartitionKey(new PartitionKey(o)); }); - return cosmosAsyncClient - .getDatabase(this.databaseName) + return this.getCosmosAsyncClient() + .getDatabase(this.getDatabaseName()) .getContainer(containerName) .queryItems(sqlQuerySpec, cosmosQueryRequestOptions, JsonNode.class) .byPage() @@ -804,7 +810,7 @@ private Mono deleteItem(@NonNull JsonNode jsonNode, final CosmosItemRequestOptions options = new CosmosItemRequestOptions(); applyVersioning(domainType, jsonNode, options); - return cosmosAsyncClient.getDatabase(this.databaseName) + return this.getCosmosAsyncClient().getDatabase(this.getDatabaseName()) .getContainer(containerName) .deleteItem(jsonNode, options) .publishOn(Schedulers.parallel()) diff --git a/sdk/cosmos/azure-spring-data-cosmos/src/samples/java/com/azure/spring/data/cosmos/MultiTenantDBCosmosFactory.java b/sdk/cosmos/azure-spring-data-cosmos/src/samples/java/com/azure/spring/data/cosmos/MultiTenantDBCosmosFactory.java new file mode 100644 index 0000000000000..fe951fff0b515 --- /dev/null +++ b/sdk/cosmos/azure-spring-data-cosmos/src/samples/java/com/azure/spring/data/cosmos/MultiTenantDBCosmosFactory.java @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package com.azure.spring.data.cosmos; + +import com.azure.cosmos.CosmosAsyncClient; + +/** + * Example for extending CosmosFactory for Mutli-Tenancy at the database level + */ +// BEGIN: readme-sample-MultiTenantDBCosmosFactory +public class MultiTenantDBCosmosFactory extends CosmosFactory { + + private String tenantId; + + /** + * Validate config and initialization + * + * @param cosmosAsyncClient cosmosAsyncClient + * @param databaseName databaseName + */ + public MultiTenantDBCosmosFactory(CosmosAsyncClient cosmosAsyncClient, String databaseName) { + super(cosmosAsyncClient, databaseName); + + this.tenantId = databaseName; + } + + @Override + public String getDatabaseName() { + return this.getCosmosAsyncClient().getDatabase(this.tenantId).toString(); + } +} +// END: readme-sample-MultiTenantDBCosmosFactory