From f54125d65ba1e41ff5f4024235c835d15948c4cb Mon Sep 17 00:00:00 2001
From: Paul Warren <paul_c_warren@yahoo.com>
Date: Tue, 26 May 2020 22:38:51 -0700
Subject: [PATCH] Initial implementation of a current tenant id resolver and
 amazon s3 client provider

#201
---
 .../content/s3/config/S3StoreFactoryBean.java |  2 +-
 .../content/s3/store/DefaultS3StoreImpl.java  | 36 +++++++++--
 .../CurrentTenantIdentifierResolver.java      | 17 ++++++
 .../config/MultiTenantAmazonS3Provider.java   | 20 ++++++
 .../s3/store/DefaultS3StoreImplTest.java      | 61 ++++++++++++++++---
 5 files changed, 124 insertions(+), 12 deletions(-)
 create mode 100644 spring-content-s3/src/main/java/org/springframework/content/s3/config/CurrentTenantIdentifierResolver.java
 create mode 100644 spring-content-s3/src/main/java/org/springframework/content/s3/config/MultiTenantAmazonS3Provider.java

diff --git a/spring-content-s3/src/main/java/internal/org/springframework/content/s3/config/S3StoreFactoryBean.java b/spring-content-s3/src/main/java/internal/org/springframework/content/s3/config/S3StoreFactoryBean.java
index 3973bbed3..89e211d83 100644
--- a/spring-content-s3/src/main/java/internal/org/springframework/content/s3/config/S3StoreFactoryBean.java
+++ b/spring-content-s3/src/main/java/internal/org/springframework/content/s3/config/S3StoreFactoryBean.java
@@ -59,6 +59,6 @@ protected Object getContentStoreImpl() {
 		DefaultResourceLoader loader = new DefaultResourceLoader();
 		loader.addProtocolResolver(s3Protocol);
 
-		return new DefaultS3StoreImpl(loader, s3StorePlacementService, client);
+		return new DefaultS3StoreImpl(loader, s3StorePlacementService, client, null, null);
 	}
 }
diff --git a/spring-content-s3/src/main/java/internal/org/springframework/content/s3/store/DefaultS3StoreImpl.java b/spring-content-s3/src/main/java/internal/org/springframework/content/s3/store/DefaultS3StoreImpl.java
index b81cf6213..02edf2434 100644
--- a/spring-content-s3/src/main/java/internal/org/springframework/content/s3/store/DefaultS3StoreImpl.java
+++ b/spring-content-s3/src/main/java/internal/org/springframework/content/s3/store/DefaultS3StoreImpl.java
@@ -6,6 +6,7 @@
 import org.apache.commons.io.IOUtils;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.springframework.cloud.aws.core.io.s3.SimpleStorageProtocolResolver;
 import org.springframework.content.commons.annotations.ContentId;
 import org.springframework.content.commons.annotations.ContentLength;
 import org.springframework.content.commons.io.DeletableResource;
@@ -16,7 +17,10 @@
 import org.springframework.content.commons.utils.BeanUtils;
 import org.springframework.content.commons.utils.Condition;
 import org.springframework.content.commons.utils.PlacementService;
+import org.springframework.content.s3.config.CurrentTenantIdentifierResolver;
+import org.springframework.content.s3.config.MultiTenantAmazonS3Provider;
 import org.springframework.core.convert.TypeDescriptor;
+import org.springframework.core.io.DefaultResourceLoader;
 import org.springframework.core.io.Resource;
 import org.springframework.core.io.ResourceLoader;
 import org.springframework.core.io.WritableResource;
@@ -42,15 +46,18 @@ public class DefaultS3StoreImpl<S, SID extends Serializable>
 	private ResourceLoader loader;
 	private PlacementService placementService;
 	private AmazonS3 client;
+	private CurrentTenantIdentifierResolver tenantIdResolver;
+	private MultiTenantAmazonS3Provider clientProvider;
 
-	public DefaultS3StoreImpl(ResourceLoader loader, PlacementService placementService,
-			AmazonS3 client/*, S3ObjectIdResolver idResolver, String defaultBucket*/) {
+	public DefaultS3StoreImpl(ResourceLoader loader, PlacementService placementService, AmazonS3 client, CurrentTenantIdentifierResolver tenantIdResolver, MultiTenantAmazonS3Provider provider) {
 		Assert.notNull(loader, "loader must be specified");
 		Assert.notNull(placementService, "placementService must be specified");
 		Assert.notNull(client, "client must be specified");
 		this.loader = loader;
 		this.placementService = placementService;
 		this.client = client;
+		this.tenantIdResolver = tenantIdResolver;
+		this.clientProvider = provider;
 	}
 
 	@Override
@@ -101,8 +108,29 @@ protected Resource getResourceInternal(S3ObjectId id) {
             location = placementService.convert(objectId, String.class);
             location = absolutify(bucket, location);
         }
-		Resource resource = loader.getResource(location);
-		return new S3StoreResource(client, bucket, resource);
+
+        AmazonS3 clientToUse = client;
+        ResourceLoader loaderToUse = loader;
+        if (tenantIdResolver != null && clientProvider != null) {
+			String tenantId = tenantIdResolver.resolveCurrentTenantIdentifier();
+			if (tenantId != null) {
+				AmazonS3 client = clientProvider.getAmazonS3(tenantId);
+
+				if (client != null) {
+					SimpleStorageProtocolResolver s3Protocol = new SimpleStorageProtocolResolver(client);
+					s3Protocol.afterPropertiesSet();
+
+					DefaultResourceLoader loader = new DefaultResourceLoader();
+					loader.addProtocolResolver(s3Protocol);
+
+					clientToUse = client;
+					loaderToUse = loader;
+				}
+			}
+		}
+
+		Resource resource = loaderToUse.getResource(location);
+		return new S3StoreResource(clientToUse, bucket, resource);
 	}
 
 	@Override
diff --git a/spring-content-s3/src/main/java/org/springframework/content/s3/config/CurrentTenantIdentifierResolver.java b/spring-content-s3/src/main/java/org/springframework/content/s3/config/CurrentTenantIdentifierResolver.java
new file mode 100644
index 000000000..cf01d3e81
--- /dev/null
+++ b/spring-content-s3/src/main/java/org/springframework/content/s3/config/CurrentTenantIdentifierResolver.java
@@ -0,0 +1,17 @@
+package org.springframework.content.s3.config;
+
+
+/**
+ * A callback, used by the S3 default store implementation when returning a Resource, for identifying the current
+ * tenant identifier.  Subsequently, used by the {@link MultiTenantAmazonS3Provider} to get a specific AmazonS3 client
+ * to use for the Resource being returned.
+ */
+public interface CurrentTenantIdentifierResolver {
+
+    /**
+     * Return the current tenant identifier, or null if one cannot be established
+     *
+     * @return current tenant identifer, or null
+     */
+    String resolveCurrentTenantIdentifier();
+}
diff --git a/spring-content-s3/src/main/java/org/springframework/content/s3/config/MultiTenantAmazonS3Provider.java b/spring-content-s3/src/main/java/org/springframework/content/s3/config/MultiTenantAmazonS3Provider.java
new file mode 100644
index 000000000..6acb0c26e
--- /dev/null
+++ b/spring-content-s3/src/main/java/org/springframework/content/s3/config/MultiTenantAmazonS3Provider.java
@@ -0,0 +1,20 @@
+package org.springframework.content.s3.config;
+
+import com.amazonaws.services.s3.AmazonS3;
+
+/**
+ * A callback, used by the S3 default store implementation when returning a resource, in order to establish which
+ * AmazonS3 client to provide to the Resource being returned.
+ *
+ * The tenantId is resolved by the {@link CurrentTenantIdentifierResolver}
+ */
+public interface MultiTenantAmazonS3Provider {
+
+    /**
+     * The AmazonS3 client to use, or null
+     *
+     * @param tenantId the current tenant identifier
+     * @return  the AmazonS3 client to use, or null
+     */
+    AmazonS3 getAmazonS3(String tenantId);
+}
diff --git a/spring-content-s3/src/test/java/internal/org/springframework/content/s3/store/DefaultS3StoreImplTest.java b/spring-content-s3/src/test/java/internal/org/springframework/content/s3/store/DefaultS3StoreImplTest.java
index 1006f7c9a..84602d233 100644
--- a/spring-content-s3/src/test/java/internal/org/springframework/content/s3/store/DefaultS3StoreImplTest.java
+++ b/spring-content-s3/src/test/java/internal/org/springframework/content/s3/store/DefaultS3StoreImplTest.java
@@ -18,6 +18,8 @@
 import org.springframework.content.commons.utils.PlacementServiceImpl;
 import org.springframework.content.s3.Bucket;
 import org.springframework.content.s3.S3ObjectIdResolver;
+import org.springframework.content.s3.config.CurrentTenantIdentifierResolver;
+import org.springframework.content.s3.config.MultiTenantAmazonS3Provider;
 import org.springframework.core.convert.ConversionFailedException;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.core.io.DefaultResourceLoader;
@@ -49,7 +51,11 @@ public class DefaultS3StoreImplTest {
 
 	private ResourceLoader loader;
 	private PlacementService placementService;
-	private AmazonS3 client;
+	private AmazonS3 client, client2;
+
+	private CurrentTenantIdentifierResolver tenantIdResolver;
+	private MultiTenantAmazonS3Provider clientProvider;
+
 	private S3ObjectIdResolver resolver;
 	private String defaultBucket;
 
@@ -65,6 +71,7 @@ public class DefaultS3StoreImplTest {
 	private InputStream result;
 	private Exception e;
 
+
 	{
 		Describe("DefaultS3StoreImpl", () -> {
 			BeforeEach(() -> {
@@ -92,8 +99,7 @@ public String convert(String source) {
 							loader = new DefaultResourceLoader();
 							((DefaultResourceLoader)loader).addProtocolResolver(s3Protocol);
 
-							s3ObjectIdBasedStore = new DefaultS3StoreImpl<ContentProperty, S3ObjectId>(
-									loader, placementService, client);
+							s3ObjectIdBasedStore = new DefaultS3StoreImpl<>(loader, placementService, client, null, null);
 						});
 						JustBeforeEach(() -> {
 							try {
@@ -111,8 +117,7 @@ public String convert(String source) {
 					});
 					Context("given the store's ID is a custom ID type", () -> {
 						JustBeforeEach(() -> {
-							customS3ContentIdBasedStore = new DefaultS3StoreImpl<ContentProperty, CustomContentId>(
-									loader, placementService, client/*, resolver, defaultBucket*/);
+							customS3ContentIdBasedStore = new DefaultS3StoreImpl<>(loader, placementService, client, null, null);
 
 							try {
 								r = customS3ContentIdBasedStore.getResource(customId);
@@ -241,12 +246,54 @@ public String getKey(CustomContentId idOrEntity) {
 							});
 						});
 					});
+					Context("given a multi tenant configuration", () -> {
+						JustBeforeEach(() -> {
+							placementService = new PlacementServiceImpl();
+							S3StoreConfiguration.addDefaultS3ObjectIdConverters(placementService, defaultBucket);
+							s3ObjectIdBasedStore = new DefaultS3StoreImpl<>(loader, placementService, client, tenantIdResolver, clientProvider);
+
+							try {
+								r = s3ObjectIdBasedStore.getResource(new S3ObjectId("some-bucket", "some-object-id"));
+							}
+							catch (Exception e) {
+								this.e = e;
+							}
+						});
+
+						BeforeEach(() -> {
+							client2 = mock(AmazonS3.class);
+							tenantIdResolver = new CurrentTenantIdentifierResolver() {
+								@Override
+								public String resolveCurrentTenantIdentifier() {
+									return "client2";
+								}
+							};
+							clientProvider = new MultiTenantAmazonS3Provider(){
+								@Override
+								public AmazonS3 getAmazonS3(String tenantId) {
+									if ("client1".equals(tenantId)) {
+										return client;
+									} else if ("client2".equals(tenantId)) {
+										return client2;
+									}
+									throw new IllegalArgumentException("not a valid tenant id");
+								};
+							};
+						});
+
+						It("should fetch the resource using the correct client", () -> {
+							assertThat(e, is(nullValue()));
+							assertThat(r, is(instanceOf(S3StoreResource.class)));
+							assertThat(((S3StoreResource)r).getClient(), is(client2));
+							assertThat(r.getDescription(), is(format("Amazon s3 resource [bucket='%s' and object='%s']","some-bucket", "some-object-id")));
+						});
+					});
 				});
 			});
 
 			Describe("AssociativeStore", () -> {
 				JustBeforeEach(() -> {
-					s3StoreImpl = new DefaultS3StoreImpl<ContentProperty, String>(loader,placementService,client/*, resolver, defaultBucket*/);
+					s3StoreImpl = new DefaultS3StoreImpl<ContentProperty, String>(loader,placementService,client,null,null);
 				});
 				Context("#getResource", () -> {
 					JustBeforeEach(() -> {
@@ -416,7 +463,7 @@ public S3ObjectId convert(TestEntity source) {
 
 			Describe("ContentStore", () -> {
 				JustBeforeEach(() -> {
-					s3StoreImpl = new DefaultS3StoreImpl<ContentProperty, String>(loader,placementService,client);
+					s3StoreImpl = new DefaultS3StoreImpl<ContentProperty, String>(loader,placementService,client,null,null);
 				});
 				Context("#setContent", () -> {
 					BeforeEach(() -> {