diff --git a/pom.xml b/pom.xml
index 381af0fa..fc65fe7a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -23,6 +23,18 @@
+
+
+
+
+ software.amazon.awssdk
+ bom
+ ${amazon.sdk.v2.version}
+ pom
+ import
+
+
+
@@ -225,6 +237,23 @@
guava
31.1-jre
+
+ org.opensearch.client
+ opensearch-java
+ 2.8.1
+
+
+ software.amazon.awssdk
+ opensearchserverless
+
+
+ software.amazon.awssdk
+ apache-client
+
+
+ software.amazon.awssdk
+ bedrockagent
+
@@ -275,6 +304,7 @@
1.78.1
1.12.296
+ 2.29.34
2.17.1
5.4.1
5.4.1
diff --git a/src/main/java/org/sagebionetworks/template/CloudFormationClient.java b/src/main/java/org/sagebionetworks/template/CloudFormationClient.java
index d4ddde1a..56cd2021 100644
--- a/src/main/java/org/sagebionetworks/template/CloudFormationClient.java
+++ b/src/main/java/org/sagebionetworks/template/CloudFormationClient.java
@@ -1,6 +1,7 @@
package org.sagebionetworks.template;
import java.util.Optional;
+import java.util.Set;
import java.util.stream.Stream;
import com.amazonaws.services.cloudformation.model.AmazonCloudFormationException;
@@ -20,7 +21,7 @@ public interface CloudFormationClient {
* @param stackName
* @return
*/
- public boolean doesStackNameExist(String stackName);
+ boolean doesStackNameExist(String stackName);
/**
* Describe a stack given its name.
@@ -29,7 +30,7 @@ public interface CloudFormationClient {
* @return
* @throws AmazonCloudFormationException When the stack does not exist.
*/
- public Optional describeStack(String stackName);
+ Optional describeStack(String stackName);
/**
* Update a stack with the given name using the provided template body.
@@ -38,7 +39,7 @@ public interface CloudFormationClient {
* @param templateBody
* @return StackId
*/
- public void updateStack(CreateOrUpdateStackRequest request);
+ void updateStack(CreateOrUpdateStackRequest request);
/**
* Create a stack with the given name using the provided template body.
@@ -47,7 +48,7 @@ public interface CloudFormationClient {
* @param templateBody
* @return StackId
*/
- public void createStack(CreateOrUpdateStackRequest request);
+ void createStack(CreateOrUpdateStackRequest request);
/**
* If a stack does not exist the stack will be created else the stack will be
@@ -57,7 +58,7 @@ public interface CloudFormationClient {
* @param templateBody
* @return StackId
*/
- public void createOrUpdateStack(CreateOrUpdateStackRequest request);
+ void createOrUpdateStack(CreateOrUpdateStackRequest request);
/**
* Wait for the given stack to complete.
@@ -65,25 +66,34 @@ public interface CloudFormationClient {
* @return
* @throws InterruptedException
*/
- public Optional waitForStackToComplete(String stackName) throws InterruptedException;
+ Optional waitForStackToComplete(String stackName) throws InterruptedException;
+
+ /**
+ * Wait for the given stack to complete and handles any wait condition that is provided in the map (where the key is the logical id of the wait condition)
+ * @param stackName
+ * @param waitConditionHandlers
+ * @return
+ * @throws InterruptedException
+ */
+ Optional waitForStackToComplete(String stackName, Set waitConditionHandlers) throws InterruptedException;
/**
*
* @param stackName
* @return
*/
- public String getOutput(String stackName, String outputKey);
+ String getOutput(String stackName, String outputKey);
/**
* Stream over all stacks.
* @return
*/
- public Stream streamOverAllStacks();
+ Stream streamOverAllStacks();
/**
* Delete a stack by name
* @param stackName
*/
- public void deleteStack(String stackName);
+ void deleteStack(String stackName);
}
diff --git a/src/main/java/org/sagebionetworks/template/CloudFormationClientImpl.java b/src/main/java/org/sagebionetworks/template/CloudFormationClientImpl.java
index b1ac6f01..a82cb9ca 100644
--- a/src/main/java/org/sagebionetworks/template/CloudFormationClientImpl.java
+++ b/src/main/java/org/sagebionetworks/template/CloudFormationClientImpl.java
@@ -3,12 +3,16 @@
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Collections;
+import java.util.HashSet;
import java.util.List;
+import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.UUID;
import java.util.function.Function;
+import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
@@ -21,11 +25,15 @@
import com.amazonaws.services.cloudformation.model.CreateStackRequest;
import com.amazonaws.services.cloudformation.model.CreateStackResult;
import com.amazonaws.services.cloudformation.model.DeleteStackRequest;
-import com.amazonaws.services.cloudformation.model.DeleteStackResult;
+import com.amazonaws.services.cloudformation.model.DescribeStackEventsRequest;
import com.amazonaws.services.cloudformation.model.DescribeStacksRequest;
import com.amazonaws.services.cloudformation.model.DescribeStacksResult;
import com.amazonaws.services.cloudformation.model.Output;
+import com.amazonaws.services.cloudformation.model.ResourceSignalStatus;
+import com.amazonaws.services.cloudformation.model.ResourceStatus;
+import com.amazonaws.services.cloudformation.model.SignalResourceRequest;
import com.amazonaws.services.cloudformation.model.Stack;
+import com.amazonaws.services.cloudformation.model.StackEvent;
import com.amazonaws.services.cloudformation.model.StackStatus;
import com.amazonaws.services.cloudformation.model.UpdateStackRequest;
import com.amazonaws.services.cloudformation.model.UpdateStackResult;
@@ -225,7 +233,19 @@ public boolean isStartedInUpdateRollbackComplete(String stackName) {
@Override
public Optional waitForStackToComplete(String stackName) throws InterruptedException {
+ return waitForStackToComplete(stackName, Collections.emptySet());
+ }
+
+ @Override
+ public Optional waitForStackToComplete(String stackName, Set waitConditionHandlers) throws InterruptedException {
boolean startedInUpdateRollbackComplete = isStartedInUpdateRollbackComplete(stackName); // Initial state
+
+ Map waitConditionHandlerMap = waitConditionHandlers.stream()
+ .collect(Collectors.toMap(WaitConditionHandler::getWaitConditionId, Function.identity()));
+
+ // To avoid re-processing the same wait condition multiple times we need to keep track of them
+ Set processedWaitConditionSet = new HashSet<>();
+
long start = threadProvider.currentTimeMillis();
while (true) {
long elapse = threadProvider.currentTimeMillis() - start;
@@ -246,10 +266,10 @@ public Optional waitForStackToComplete(String stackName) throws Interrupt
return optional;
case CREATE_IN_PROGRESS:
case UPDATE_IN_PROGRESS:
+ handleWaitConditions(stackName, waitConditionHandlerMap, processedWaitConditionSet);
case DELETE_IN_PROGRESS:
case UPDATE_COMPLETE_CLEANUP_IN_PROGRESS:
- logger.info("Waiting for stack: '" + stackName + "' to complete. Current status: " + status.name()
- + "...");
+ logger.info("Waiting for stack: '" + stackName + "' to complete. Current status: " + status.name() + "...");
threadProvider.sleep(SLEEP_TIME);
break;
case UPDATE_ROLLBACK_COMPLETE:
@@ -262,6 +282,84 @@ public Optional waitForStackToComplete(String stackName) throws Interrupt
}
}
}
+
+ void handleWaitConditions(String stackName, Map waitConditionHandlers, Set processedWaitConditionSet) {
+ if (waitConditionHandlers.isEmpty()) {
+ return;
+ }
+
+ Set waitConditionEventIds = new HashSet<>();
+
+ List waitConditionEvents = cloudFormationClient.describeStackEvents(
+ new DescribeStackEventsRequest().withStackName(stackName)
+ )
+ .getStackEvents()
+ .stream()
+ .filter(event -> "AWS::CloudFormation::WaitCondition".equals(event.getResourceType()))
+ // We only need the latest event for each wait condition
+ .filter(event -> waitConditionEventIds.add(event.getLogicalResourceId()))
+ .filter(event -> ResourceStatus.CREATE_IN_PROGRESS.equals(ResourceStatus.fromValue(event.getResourceStatus())))
+ .collect(Collectors.toList());
+
+ for (StackEvent waitConditionEvent : waitConditionEvents) {
+ String waitConditionId = waitConditionEvent.getLogicalResourceId();
+
+ if (processedWaitConditionSet.contains(waitConditionId)) {
+ logger.warn("Wait condition {} already processed, skipping.", waitConditionId);
+ continue;
+ }
+
+ logger.info("Processing wait condition {} (Status: {}, Reason: {})...", waitConditionId, waitConditionEvent.getResourceStatus(), waitConditionEvent.getResourceStatusReason());
+
+ WaitConditionHandler waitConditionHandler = waitConditionHandlers.get(waitConditionId);
+
+ if (waitConditionHandler == null) {
+
+ cloudFormationClient.signalResource(new SignalResourceRequest()
+ .withStackName(stackName)
+ .withLogicalResourceId(waitConditionId)
+ .withStatus(ResourceSignalStatus.FAILURE)
+ .withUniqueId("handler-not-found")
+ );
+
+ throw new IllegalStateException("Processing wait condition " + waitConditionId + " failed: could not find an handler.");
+
+ } else {
+ logger.info("Processing wait condition {} started...", waitConditionId);
+
+ try {
+ waitConditionHandler.handle(waitConditionEvent).ifPresentOrElse(signalId -> {
+ logger.info("Processing wait condition {} completed with signal {}.", waitConditionId, signalId);
+
+ cloudFormationClient.signalResource(new SignalResourceRequest()
+ .withStackName(stackName)
+ .withLogicalResourceId(waitConditionId)
+ .withStatus(ResourceSignalStatus.SUCCESS)
+ .withUniqueId(signalId)
+ );
+
+ processedWaitConditionSet.add(waitConditionId);
+ }, () -> {
+ logger.info("Processing wait condition {} didn't return a signal, will process later.", waitConditionId);
+ });
+
+ } catch (Exception e) {
+ logger.error("Processing wait condition {} failed exceptionally: ", waitConditionId, e);
+
+ cloudFormationClient.signalResource(new SignalResourceRequest()
+ .withStackName(stackName)
+ .withLogicalResourceId(waitConditionId)
+ .withStatus(ResourceSignalStatus.FAILURE)
+ .withUniqueId("handler-failed")
+ );
+
+ throw new IllegalStateException("Processing wait condition " + waitConditionId + " failed.", e);
+ }
+ }
+
+ }
+
+ }
@Override
public String getOutput(String stackName, String outputKey) {
diff --git a/src/main/java/org/sagebionetworks/template/Constants.java b/src/main/java/org/sagebionetworks/template/Constants.java
index f261ff88..aab02c19 100644
--- a/src/main/java/org/sagebionetworks/template/Constants.java
+++ b/src/main/java/org/sagebionetworks/template/Constants.java
@@ -221,6 +221,7 @@ public class Constants {
public static final String HOSTED_ZONE = "hostedZone";
public static final String VPC_ENDPOINTS_COLOR = "vpcEndpointsColor";
public static final String VPC_ENDPOINTS_AZ = "vpcEndpointsAz";
+ public static final String IDENTITY_ARN = "identityArn";
public static final String CAPABILITY_NAMED_IAM = "CAPABILITY_NAMED_IAM";
public static final String OUTPUT_NAME_SUFFIX_REPOSITORY_DB_ENDPOINT = "RepositoryDBEndpoint";
diff --git a/src/main/java/org/sagebionetworks/template/OpenSearchClientFactory.java b/src/main/java/org/sagebionetworks/template/OpenSearchClientFactory.java
new file mode 100644
index 00000000..a8763190
--- /dev/null
+++ b/src/main/java/org/sagebionetworks/template/OpenSearchClientFactory.java
@@ -0,0 +1,9 @@
+package org.sagebionetworks.template;
+
+import org.opensearch.client.opensearch.indices.OpenSearchIndicesClient;
+
+public interface OpenSearchClientFactory {
+
+ OpenSearchIndicesClient getIndicesClient(String collectionEndpoint);
+
+}
diff --git a/src/main/java/org/sagebionetworks/template/OpenSearchClientFactoryImpl.java b/src/main/java/org/sagebionetworks/template/OpenSearchClientFactoryImpl.java
new file mode 100644
index 00000000..b85743ce
--- /dev/null
+++ b/src/main/java/org/sagebionetworks/template/OpenSearchClientFactoryImpl.java
@@ -0,0 +1,33 @@
+package org.sagebionetworks.template;
+
+import org.opensearch.client.opensearch.indices.OpenSearchIndicesClient;
+import org.opensearch.client.transport.aws.AwsSdk2Transport;
+import org.opensearch.client.transport.aws.AwsSdk2TransportOptions;
+
+import com.google.inject.Inject;
+
+import software.amazon.awssdk.http.SdkHttpClient;
+import software.amazon.awssdk.regions.Region;
+
+public class OpenSearchClientFactoryImpl implements OpenSearchClientFactory {
+
+ private SdkHttpClient httpClient;
+
+ @Inject
+ public OpenSearchClientFactoryImpl(SdkHttpClient httpClient) {
+ this.httpClient = httpClient;
+ }
+
+ public OpenSearchIndicesClient getIndicesClient(String collectionEndpoint) {
+ return new OpenSearchIndicesClient(
+ new AwsSdk2Transport(
+ httpClient,
+ collectionEndpoint.replace("https://", ""),
+ "aoss",
+ Region.US_EAST_1,
+ AwsSdk2TransportOptions.builder().build()
+ )
+ );
+ }
+
+}
diff --git a/src/main/java/org/sagebionetworks/template/TemplateGuiceModule.java b/src/main/java/org/sagebionetworks/template/TemplateGuiceModule.java
index 366c2ad4..d081361a 100644
--- a/src/main/java/org/sagebionetworks/template/TemplateGuiceModule.java
+++ b/src/main/java/org/sagebionetworks/template/TemplateGuiceModule.java
@@ -1,35 +1,17 @@
package org.sagebionetworks.template;
-import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
-import com.amazonaws.regions.Regions;
-import com.amazonaws.services.athena.AmazonAthena;
-import com.amazonaws.services.athena.AmazonAthenaClientBuilder;
-import com.amazonaws.services.cloudformation.AmazonCloudFormation;
-import com.amazonaws.services.cloudformation.AmazonCloudFormationClientBuilder;
-import com.amazonaws.services.ec2.AmazonEC2;
-import com.amazonaws.services.ec2.AmazonEC2ClientBuilder;
-import com.amazonaws.services.elasticbeanstalk.AWSElasticBeanstalk;
-import com.amazonaws.services.elasticbeanstalk.AWSElasticBeanstalkClientBuilder;
-import com.amazonaws.services.elasticloadbalancingv2.AmazonElasticLoadBalancing;
-import com.amazonaws.services.elasticloadbalancingv2.AmazonElasticLoadBalancingClientBuilder;
-import com.amazonaws.services.glue.AWSGlue;
-import com.amazonaws.services.glue.AWSGlueClientBuilder;
-import com.amazonaws.services.kms.AWSKMS;
-import com.amazonaws.services.kms.AWSKMSAsyncClientBuilder;
-import com.amazonaws.services.lambda.AWSLambda;
-import com.amazonaws.services.lambda.AWSLambdaClientBuilder;
-import com.amazonaws.services.route53.AmazonRoute53;
-import com.amazonaws.services.route53.AmazonRoute53ClientBuilder;
-import com.amazonaws.services.s3.AmazonS3;
-import com.amazonaws.services.s3.AmazonS3ClientBuilder;
-import com.amazonaws.services.secretsmanager.AWSSecretsManager;
-import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder;
-import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
-import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
-import com.amazonaws.services.simpleemail.AmazonSimpleEmailService;
-import com.amazonaws.services.simpleemail.AmazonSimpleEmailServiceClientBuilder;
-import com.google.inject.Provides;
-import com.google.inject.multibindings.Multibinder;
+import static org.sagebionetworks.template.Constants.APPCONFIG_CONFIG_FILE;
+import static org.sagebionetworks.template.Constants.ATHENA_QUERIES_CONFIG_FILE;
+import static org.sagebionetworks.template.Constants.CLOUDWATCH_LOGS_CONFIG_FILE;
+import static org.sagebionetworks.template.Constants.DATAWAREHOUSE_CONFIG_FILE;
+import static org.sagebionetworks.template.Constants.KINESIS_CONFIG_FILE;
+import static org.sagebionetworks.template.Constants.LOAD_BALANCER_ALARM_CONFIG_FILE;
+import static org.sagebionetworks.template.Constants.S3_CONFIG_FILE;
+import static org.sagebionetworks.template.Constants.SNS_AND_SQS_CONFIG_FILE;
+import static org.sagebionetworks.template.TemplateUtils.loadFromJsonFile;
+
+import java.io.IOException;
+
import org.apache.http.client.HttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.velocity.app.VelocityEngine;
@@ -97,6 +79,8 @@
import org.sagebionetworks.template.repo.beanstalk.ssl.CertificateBuilderImpl;
import org.sagebionetworks.template.repo.beanstalk.ssl.ElasticBeanstalkExtentionBuilder;
import org.sagebionetworks.template.repo.beanstalk.ssl.ElasticBeanstalkExtentionBuilderImpl;
+import org.sagebionetworks.template.repo.bedrock.SynapseHelpCollectionIndexCreation;
+import org.sagebionetworks.template.repo.bedrock.SynapseHelpKnowledgeBaseDataSourceSync;
import org.sagebionetworks.template.repo.cloudwatchlogs.CloudwatchLogsConfig;
import org.sagebionetworks.template.repo.cloudwatchlogs.CloudwatchLogsConfigValidator;
import org.sagebionetworks.template.repo.cloudwatchlogs.CloudwatchLogsVelocityContextProvider;
@@ -123,17 +107,41 @@
import org.sagebionetworks.war.WarAppender;
import org.sagebionetworks.war.WarAppenderImpl;
-import java.io.IOException;
-import static org.sagebionetworks.template.Constants.ATHENA_QUERIES_CONFIG_FILE;
-import static org.sagebionetworks.template.Constants.CLOUDWATCH_LOGS_CONFIG_FILE;
-import static org.sagebionetworks.template.Constants.DATAWAREHOUSE_CONFIG_FILE;
-import static org.sagebionetworks.template.Constants.KINESIS_CONFIG_FILE;
-import static org.sagebionetworks.template.Constants.LOAD_BALANCER_ALARM_CONFIG_FILE;
-import static org.sagebionetworks.template.Constants.S3_CONFIG_FILE;
-import static org.sagebionetworks.template.Constants.SNS_AND_SQS_CONFIG_FILE;
-import static org.sagebionetworks.template.Constants.APPCONFIG_CONFIG_FILE;
+import com.amazonaws.auth.DefaultAWSCredentialsProviderChain;
+import com.amazonaws.regions.Regions;
+import com.amazonaws.services.athena.AmazonAthena;
+import com.amazonaws.services.athena.AmazonAthenaClientBuilder;
+import com.amazonaws.services.cloudformation.AmazonCloudFormation;
+import com.amazonaws.services.cloudformation.AmazonCloudFormationClientBuilder;
+import com.amazonaws.services.ec2.AmazonEC2;
+import com.amazonaws.services.ec2.AmazonEC2ClientBuilder;
+import com.amazonaws.services.elasticbeanstalk.AWSElasticBeanstalk;
+import com.amazonaws.services.elasticbeanstalk.AWSElasticBeanstalkClientBuilder;
+import com.amazonaws.services.elasticloadbalancingv2.AmazonElasticLoadBalancing;
+import com.amazonaws.services.elasticloadbalancingv2.AmazonElasticLoadBalancingClientBuilder;
+import com.amazonaws.services.glue.AWSGlue;
+import com.amazonaws.services.glue.AWSGlueClientBuilder;
+import com.amazonaws.services.kms.AWSKMS;
+import com.amazonaws.services.kms.AWSKMSAsyncClientBuilder;
+import com.amazonaws.services.lambda.AWSLambda;
+import com.amazonaws.services.lambda.AWSLambdaClientBuilder;
+import com.amazonaws.services.route53.AmazonRoute53;
+import com.amazonaws.services.route53.AmazonRoute53ClientBuilder;
+import com.amazonaws.services.s3.AmazonS3;
+import com.amazonaws.services.s3.AmazonS3ClientBuilder;
+import com.amazonaws.services.secretsmanager.AWSSecretsManager;
+import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder;
+import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
+import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
+import com.amazonaws.services.simpleemail.AmazonSimpleEmailService;
+import com.amazonaws.services.simpleemail.AmazonSimpleEmailServiceClientBuilder;
+import com.google.inject.Provides;
+import com.google.inject.multibindings.Multibinder;
-import static org.sagebionetworks.template.TemplateUtils.loadFromJsonFile;
+import software.amazon.awssdk.http.apache.ApacheHttpClient;
+import software.amazon.awssdk.regions.Region;
+import software.amazon.awssdk.services.bedrockagent.BedrockAgentClient;
+import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient;
public class TemplateGuiceModule extends com.google.inject.AbstractModule {
@@ -191,6 +199,11 @@ protected void configure() {
velocityContextProviderMultibinder.addBinding().to(KinesisFirehoseVelocityContextProvider.class);
velocityContextProviderMultibinder.addBinding().to(RecurrentAthenaQueryContextProvider.class);
velocityContextProviderMultibinder.addBinding().to(BedrockAgentContextProvider.class);
+
+ Multibinder waitConditionHandlerBinder = Multibinder.newSetBinder(binder(), WaitConditionHandler.class);
+
+ waitConditionHandlerBinder.addBinding().to(SynapseHelpCollectionIndexCreation.class);
+ waitConditionHandlerBinder.addBinding().to(SynapseHelpKnowledgeBaseDataSourceSync.class);
}
/**
@@ -366,5 +379,20 @@ public S3TransferManagerFactory provideS3TransferManagerFactory(AmazonS3 s3Clien
public DataWarehouseConfig dataWarehouseConfigProvider() throws IOException {
return new DataWarehouseConfigValidator(loadFromJsonFile(DATAWAREHOUSE_CONFIG_FILE, DataWarehouseConfig.class)).validate();
}
-
+
+ @Provides
+ public OpenSearchServerlessClient ossManagementClientProvider() {
+ return OpenSearchServerlessClient.builder().region(Region.US_EAST_1).build();
+ }
+
+ @Provides
+ public BedrockAgentClient bedrockAgentClientProvider() {
+ return BedrockAgentClient.builder().region(Region.US_EAST_1).build();
+ }
+
+ @Provides
+ public OpenSearchClientFactory openSearchClientFactoryProvider() {
+ return new OpenSearchClientFactoryImpl(ApacheHttpClient.builder().build());
+ }
+
}
diff --git a/src/main/java/org/sagebionetworks/template/WaitConditionHandler.java b/src/main/java/org/sagebionetworks/template/WaitConditionHandler.java
new file mode 100644
index 00000000..d560858d
--- /dev/null
+++ b/src/main/java/org/sagebionetworks/template/WaitConditionHandler.java
@@ -0,0 +1,26 @@
+package org.sagebionetworks.template;
+
+import java.util.Optional;
+
+import com.amazonaws.services.cloudformation.model.StackEvent;
+
+/**
+ * Interface for an handler of a wait condition defined in a cloud formation template. Note that wait conditions updates are not supported and will be invoked only when the stack is created.
+ */
+public interface WaitConditionHandler {
+
+ /**
+ * @return The logical resource id for the wait condition defined in the cloud formation template, this handler will be mapped to the returned id
+ */
+ String getWaitConditionId();
+
+ /**
+ * When the last wait condition event matches the {@link #getWaitConditionId()} and the event is in the CREATE_IN_PROGRESS status this handler will
+ * be invoked. The handle implementation should be idempotent since it is possible that it is invoked multiple times during the stack creation.
+ *
+ * @param stackEvent
+ * @return An optional signal id to send back to cloud formation if the condition could be processed
+ */
+ Optional handle(StackEvent stackEvent) throws InterruptedException;
+
+}
diff --git a/src/main/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImpl.java b/src/main/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImpl.java
index d70c86a9..aba02ee5 100644
--- a/src/main/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImpl.java
+++ b/src/main/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImpl.java
@@ -1,47 +1,6 @@
package org.sagebionetworks.template.repo;
-import com.amazonaws.services.cloudformation.model.Output;
-import com.amazonaws.services.cloudformation.model.Parameter;
-import com.amazonaws.services.cloudformation.model.Stack;
-import com.amazonaws.services.cloudformation.model.Tag;
-import com.amazonaws.services.elasticbeanstalk.AWSElasticBeanstalk;
-import com.amazonaws.services.elasticbeanstalk.model.ListPlatformVersionsRequest;
-import com.amazonaws.services.elasticbeanstalk.model.ListPlatformVersionsResult;
-import com.amazonaws.services.elasticbeanstalk.model.PlatformSummary;
-import com.google.inject.Inject;
-import org.apache.logging.log4j.Logger;
-import org.apache.velocity.Template;
-import org.apache.velocity.VelocityContext;
-import org.apache.velocity.app.VelocityEngine;
-import org.json.JSONObject;
-import org.sagebionetworks.template.CloudFormationClient;
-import org.sagebionetworks.template.ConfigurationPropertyNotFound;
-import org.sagebionetworks.template.Constants;
-import org.sagebionetworks.template.CreateOrUpdateStackRequest;
-import org.sagebionetworks.template.Ec2Client;
-import org.sagebionetworks.template.LoggerFactory;
-import org.sagebionetworks.template.StackTagsProvider;
-import org.sagebionetworks.template.config.RepoConfiguration;
-import org.sagebionetworks.template.config.TimeToLive;
-import org.sagebionetworks.template.repo.beanstalk.ArtifactCopy;
-import org.sagebionetworks.template.repo.beanstalk.BeanstalkUtils;
-import org.sagebionetworks.template.repo.beanstalk.ElasticBeanstalkSolutionStackNameProvider;
-import org.sagebionetworks.template.repo.beanstalk.EnvironmentDescriptor;
-import org.sagebionetworks.template.repo.beanstalk.EnvironmentType;
-import org.sagebionetworks.template.repo.beanstalk.SecretBuilder;
-import org.sagebionetworks.template.repo.beanstalk.SourceBundle;
-import org.sagebionetworks.template.repo.cloudwatchlogs.CloudwatchLogsVelocityContextProvider;
-
-import java.io.StringWriter;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Set;
-import java.util.StringJoiner;
-import java.util.stream.Collectors;
-
import static org.sagebionetworks.template.Constants.ADMIN_RULE_ACTION;
import static org.sagebionetworks.template.Constants.BEANSTALK_INSTANCES_SUBNETS;
import static org.sagebionetworks.template.Constants.CAPABILITY_NAMED_IAM;
@@ -53,11 +12,12 @@
import static org.sagebionetworks.template.Constants.DATA_CDN_DOMAIN_NAME_FMT;
import static org.sagebionetworks.template.Constants.DB_ENDPOINT_SUFFIX;
import static org.sagebionetworks.template.Constants.DELETION_POLICY;
-import static org.sagebionetworks.template.Constants.EC2_INSTANCE_TYPE;
import static org.sagebionetworks.template.Constants.EC2_INSTANCE_MEMORY;
+import static org.sagebionetworks.template.Constants.EC2_INSTANCE_TYPE;
import static org.sagebionetworks.template.Constants.ENVIRONMENT;
import static org.sagebionetworks.template.Constants.EXCEPTION_THROWER;
import static org.sagebionetworks.template.Constants.GLOBAL_RESOURCES_EXPORT_PREFIX;
+import static org.sagebionetworks.template.Constants.IDENTITY_ARN;
import static org.sagebionetworks.template.Constants.INSTANCE;
import static org.sagebionetworks.template.Constants.JSON_INDENT;
import static org.sagebionetworks.template.Constants.MACHINE_TYPES;
@@ -73,8 +33,8 @@
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_BEANSTALK_SSL_ARN;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_BEANSTALK_VERSION;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_DATA_CDN_KEYPAIR_ID;
-import static org.sagebionetworks.template.Constants.PROPERTY_KEY_EC2_INSTANCE_TYPE;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_EC2_INSTANCE_MEMORY;
+import static org.sagebionetworks.template.Constants.PROPERTY_KEY_EC2_INSTANCE_TYPE;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ELASTICBEANSTALK_IMAGE_VERSION_AMAZONLINUX;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ELASTICBEANSTALK_IMAGE_VERSION_JAVA;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ELASTICBEANSTALK_IMAGE_VERSION_TOMCAT;
@@ -86,19 +46,19 @@
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_ALLOCATED_STORAGE;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_INSTANCE_CLASS;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_IOPS;
-import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_THROUGHPUT;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_MAX_ALLOCATED_STORAGE;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_MULTI_AZ;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_STORAGE_TYPE;
+import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_THROUGHPUT;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ROUTE_53_HOSTED_ZONE;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_STACK;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_INSTANCE_COUNT;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_ALLOCATED_STORAGE;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_INSTANCE_CLASS;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_IOPS;
-import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_THROUGHPUT;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_MAX_ALLOCATED_STORAGE;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_STORAGE_TYPE;
+import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_THROUGHPUT;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_VPC_SUBNET_COLOR;
import static org.sagebionetworks.template.Constants.REPO_BEANSTALK_NUMBER;
import static org.sagebionetworks.template.Constants.SHARED_EXPORT_PREFIX;
@@ -111,6 +71,51 @@
import static org.sagebionetworks.template.Constants.VPC_EXPORT_PREFIX;
import static org.sagebionetworks.template.Constants.VPC_SUBNET_COLOR;
+import java.io.StringWriter;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Set;
+import java.util.StringJoiner;
+import java.util.stream.Collectors;
+
+import org.apache.logging.log4j.Logger;
+import org.apache.velocity.Template;
+import org.apache.velocity.VelocityContext;
+import org.apache.velocity.app.VelocityEngine;
+import org.json.JSONObject;
+import org.sagebionetworks.template.CloudFormationClient;
+import org.sagebionetworks.template.ConfigurationPropertyNotFound;
+import org.sagebionetworks.template.Constants;
+import org.sagebionetworks.template.CreateOrUpdateStackRequest;
+import org.sagebionetworks.template.Ec2Client;
+import org.sagebionetworks.template.LoggerFactory;
+import org.sagebionetworks.template.StackTagsProvider;
+import org.sagebionetworks.template.WaitConditionHandler;
+import org.sagebionetworks.template.config.RepoConfiguration;
+import org.sagebionetworks.template.config.TimeToLive;
+import org.sagebionetworks.template.repo.beanstalk.ArtifactCopy;
+import org.sagebionetworks.template.repo.beanstalk.BeanstalkUtils;
+import org.sagebionetworks.template.repo.beanstalk.ElasticBeanstalkSolutionStackNameProvider;
+import org.sagebionetworks.template.repo.beanstalk.EnvironmentDescriptor;
+import org.sagebionetworks.template.repo.beanstalk.EnvironmentType;
+import org.sagebionetworks.template.repo.beanstalk.SecretBuilder;
+import org.sagebionetworks.template.repo.beanstalk.SourceBundle;
+import org.sagebionetworks.template.repo.cloudwatchlogs.CloudwatchLogsVelocityContextProvider;
+
+import com.amazonaws.services.cloudformation.model.Output;
+import com.amazonaws.services.cloudformation.model.Parameter;
+import com.amazonaws.services.cloudformation.model.Stack;
+import com.amazonaws.services.cloudformation.model.Tag;
+import com.amazonaws.services.elasticbeanstalk.AWSElasticBeanstalk;
+import com.amazonaws.services.elasticbeanstalk.model.ListPlatformVersionsRequest;
+import com.amazonaws.services.elasticbeanstalk.model.ListPlatformVersionsResult;
+import com.amazonaws.services.elasticbeanstalk.model.PlatformSummary;
+import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
+import com.amazonaws.services.securitytoken.model.GetCallerIdentityRequest;
+import com.google.inject.Inject;
+
public class RepositoryTemplateBuilderImpl implements RepositoryTemplateBuilder {
public static final List MACHINE_TYPE_LIST = List.of("Workers", "Repository");
public static final List POOL_TYPE_LIST = List.of("Idgen", "Main", "Migration", "Tables");
@@ -128,6 +133,8 @@ public class RepositoryTemplateBuilderImpl implements RepositoryTemplateBuilder
private final CloudwatchLogsVelocityContextProvider cwlContextProvider;
private final AWSElasticBeanstalk beanstalkClient;
private final TimeToLive timeToLive;
+ private final AWSSecurityTokenService stsClient;
+ private final Set waitConditionHandlers;
@Inject
public RepositoryTemplateBuilderImpl(CloudFormationClient cloudFormationClient, VelocityEngine velocityEngine,
@@ -135,7 +142,8 @@ public RepositoryTemplateBuilderImpl(CloudFormationClient cloudFormationClient,
SecretBuilder secretBuilder, Set contextProviders,
ElasticBeanstalkSolutionStackNameProvider elasticBeanstalkDefaultAMIEncrypter,
StackTagsProvider stackTagsProvider, CloudwatchLogsVelocityContextProvider cloudwatchLogsVelocityContextProvider,
- Ec2Client ec2Client, AWSElasticBeanstalk beanstalkClient, TimeToLive ttl) {
+ Ec2Client ec2Client, AWSElasticBeanstalk beanstalkClient, TimeToLive ttl,
+ AWSSecurityTokenService stsClient, Set waitConditionHandlers) {
super();
this.cloudFormationClient = cloudFormationClient;
this.ec2Client = ec2Client;
@@ -150,6 +158,8 @@ public RepositoryTemplateBuilderImpl(CloudFormationClient cloudFormationClient,
this.cwlContextProvider = cloudwatchLogsVelocityContextProvider;
this.beanstalkClient = beanstalkClient;
this.timeToLive = ttl;
+ this.stsClient = stsClient;
+ this.waitConditionHandlers = waitConditionHandlers;
}
public String getActualBeanstalkAmazonLinuxPlatform() {
@@ -185,12 +195,12 @@ public void buildAndDeploy() throws InterruptedException {
buildAndDeployStack(context, sharedResourceStackName, TEMPALTE_SHARED_RESOUCES_MAIN_JSON_VTP, sharedParameters);
// Wait for the shared resources to complete
- Stack sharedStackResults = cloudFormationClient.waitForStackToComplete(sharedResourceStackName).orElseThrow(()->new IllegalStateException("Stack does not exist: "+sharedResourceStackName));
-
+ Stack sharedStackResults = cloudFormationClient.waitForStackToComplete(sharedResourceStackName, waitConditionHandlers).orElseThrow(()->new IllegalStateException("Stack does not exist: "+sharedResourceStackName));
+
// Build each bean stalk environment.
List environmentNames = buildEnvironments(sharedStackResults);
}
-
+
/**
* Build all of the environments
* @param sharedStackResults
@@ -307,7 +317,9 @@ void buildAndDeployStack(VelocityContext context, String stackName, String templ
*/
VelocityContext createSharedContext() {
VelocityContext context = new VelocityContext();
+
String stack = config.getProperty(PROPERTY_KEY_STACK);
+
context.put(STACK, stack);
context.put(INSTANCE, config.getProperty(PROPERTY_KEY_INSTANCE));
context.put(MACHINE_TYPES, MACHINE_TYPE_LIST);
@@ -321,8 +333,7 @@ VelocityContext createSharedContext() {
context.put(CTXT_ENABLE_ENHANCED_RDS_MONITORING, config.getProperty(PROPERTY_KEY_ENABLE_RDS_ENHANCED_MONITORING));
context.put(ADMIN_RULE_ACTION, Constants.isProd(stack) ? "Block:{}" : "Count:{}");
- context.put(DELETION_POLICY,
- Constants.isProd(stack) ? DeletionPolicy.Retain.name() : DeletionPolicy.Delete.name());
+ context.put(DELETION_POLICY, Constants.isProd(stack) ? DeletionPolicy.Retain.name() : DeletionPolicy.Delete.name());
// Create the descriptors for all of the database.
context.put(DATABASE_DESCRIPTORS, createDatabaseDescriptors());
@@ -331,6 +342,8 @@ VelocityContext createSharedContext() {
provider.addToContext(context);
}
+ context.put(IDENTITY_ARN, stsClient.getCallerIdentity(new GetCallerIdentityRequest()).getArn());
+
RegularExpressions.bindRegexToContext(context);
return context;
diff --git a/src/main/java/org/sagebionetworks/template/repo/agent/BedrockAgentContextProvider.java b/src/main/java/org/sagebionetworks/template/repo/agent/BedrockAgentContextProvider.java
index 9a9a85b2..d3d70849 100644
--- a/src/main/java/org/sagebionetworks/template/repo/agent/BedrockAgentContextProvider.java
+++ b/src/main/java/org/sagebionetworks/template/repo/agent/BedrockAgentContextProvider.java
@@ -6,6 +6,7 @@
import java.util.StringJoiner;
import org.apache.velocity.VelocityContext;
+import org.json.JSONArray;
import org.json.JSONObject;
import org.sagebionetworks.template.TemplateUtils;
import org.sagebionetworks.template.config.RepoConfiguration;
@@ -14,7 +15,7 @@
import com.google.inject.Inject;
public class BedrockAgentContextProvider implements VelocityContextProvider {
-
+
private final RepoConfiguration repoConfig;
@Inject
@@ -28,11 +29,37 @@ public void addToContext(VelocityContext context) {
String stack = repoConfig.getProperty(PROPERTY_KEY_STACK);
String instance = repoConfig.getProperty(PROPERTY_KEY_INSTANCE);
String agentName = new StringJoiner("-").add(stack).add(instance).add("agent").toString();
- JSONObject baseTemplate = new JSONObject(
- TemplateUtils.loadContentFromFile("templates/repo/agent/bedrock_agent_template.json"));
+
+ JSONObject baseTemplate = new JSONObject(TemplateUtils.loadContentFromFile("templates/repo/agent/bedrock_agent_template.json"));
JSONObject resources = baseTemplate.getJSONObject("Resources");
+
+ // Since the agent template is shared to external people, we need to hack it to replace parameters that do not exist in our template
+ JSONArray bedrockAgentRoleKbResource = resources
+ .getJSONObject("bedrockAgentRole")
+ .getJSONObject("Properties")
+ .getJSONArray("Policies")
+ .getJSONObject(0)
+ .getJSONObject("PolicyDocument")
+ .getJSONArray("Statement")
+ .getJSONObject(1)
+ .getJSONArray("Fn::If")
+ .getJSONObject(1)
+ .getJSONArray("Resource");
+
+ bedrockAgentRoleKbResource.put(0, new JSONObject("{ \"Fn::GetAtt\": [\"SynapseHelpKnowledgeBase\", \"KnowledgeBaseArn\"] }"));
+
JSONObject bedrockAgentProps = resources.getJSONObject("bedrockAgent").getJSONObject("Properties");
+
+ JSONObject kbProperty = bedrockAgentProps
+ .getJSONObject("KnowledgeBases")
+ .getJSONArray("Fn::If")
+ .getJSONArray(1)
+ .getJSONObject(0);
+
+ kbProperty.getJSONObject("KnowledgeBaseId").put("Ref", "SynapseHelpKnowledgeBase");
+ kbProperty.put("Description", baseTemplate.getJSONObject("Parameters").getJSONObject("knowledgeBaseDescription").getString("Default"));
+
bedrockAgentProps.put("AgentName", agentName);
String json = resources.toString();
context.put("bedrock_agent_resouces", "," + json.substring(1, json.length()-1));
diff --git a/src/main/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpCollectionIndexCreation.java b/src/main/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpCollectionIndexCreation.java
new file mode 100644
index 00000000..009cf514
--- /dev/null
+++ b/src/main/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpCollectionIndexCreation.java
@@ -0,0 +1,104 @@
+package org.sagebionetworks.template.repo.bedrock;
+
+import java.io.IOException;
+import java.util.Optional;
+
+import org.apache.logging.log4j.Logger;
+import org.opensearch.client.opensearch.indices.OpenSearchIndicesClient;
+import org.sagebionetworks.template.Constants;
+import org.sagebionetworks.template.LoggerFactory;
+import org.sagebionetworks.template.OpenSearchClientFactory;
+import org.sagebionetworks.template.WaitConditionHandler;
+import org.sagebionetworks.template.config.RepoConfiguration;
+
+import com.amazonaws.services.cloudformation.model.StackEvent;
+import com.google.inject.Inject;
+
+import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient;
+import software.amazon.awssdk.services.opensearchserverless.model.CollectionDetail;
+import software.amazon.awssdk.services.opensearchserverless.model.CollectionStatus;
+
+/**
+ * A bedrock knowledge base that uses an open search collection requires the index to exists before its creation, since
+ * the index creation is part of the opensearch API operations and there is no cloudformation resource for it we need to
+ * invoke the opensearch API as part of a wait condition in the stack. Note that a wait condition is only processed during
+ * the stack creation, so the index cannot be updated.
+ */
+public class SynapseHelpCollectionIndexCreation implements WaitConditionHandler {
+
+ private static final String IDX_NAME = "vector-idx";
+
+ private Logger logger;
+
+ private RepoConfiguration config;
+
+ private OpenSearchServerlessClient ossManagementClient;
+
+ private OpenSearchClientFactory openSearchClientFactory;
+
+
+ @Inject
+ public SynapseHelpCollectionIndexCreation(LoggerFactory loggerFactory, RepoConfiguration config, OpenSearchServerlessClient ossClient, OpenSearchClientFactory openSearchClientFactory) {
+ this.logger = loggerFactory.getLogger(SynapseHelpCollectionIndexCreation.class);
+ this.config = config;
+ this.ossManagementClient = ossClient;
+ this.openSearchClientFactory = openSearchClientFactory;
+ }
+
+ @Override
+ public String getWaitConditionId() {
+ return "SynapseHelpCollectionCreateIndexWaitCondition";
+ }
+
+ @Override
+ public Optional handle(StackEvent stackEvent) {
+ String collectionName = config.getProperty(Constants.PROPERTY_KEY_STACK) + "-" + config.getProperty(Constants.PROPERTY_KEY_INSTANCE) + "-synhelp";
+
+ CollectionDetail collection = ossManagementClient.batchGetCollection(req -> req
+ .names(collectionName)
+ ).collectionDetails().stream().findFirst().orElseThrow();
+
+ if (!CollectionStatus.ACTIVE.equals(collection.status())) {
+ logger.warn("Collection {} not ready, status: {}", collectionName, collection.status());
+ return Optional.empty();
+ }
+
+ OpenSearchIndicesClient client = openSearchClientFactory.getIndicesClient(collection.collectionEndpoint());
+
+ try {
+ if (client.exists(req -> req.index(IDX_NAME)).value()) {
+ logger.warn("Index {} already exists.", IDX_NAME);
+ return Optional.of("index-already-exists");
+ }
+
+ logger.info("Index {} does not exist, creating...", IDX_NAME);
+
+ client.create(req -> req
+ .index(IDX_NAME)
+ .settings(settings -> settings.knn(true).knnAlgoParamEfSearch(512))
+ .mappings(mappings -> mappings
+ .properties("text_vector", p -> p
+ .knnVector(vector -> vector
+ .dimension(1024)
+ .method(method -> method
+ .name("hnsw")
+ .engine("faiss")
+ .spaceType("l2")
+ )
+ )
+ )
+ .properties("text_raw", p -> p.text(text -> text.index(true)))
+ .properties("text_metadata", p -> p.text(text -> text.index(false)))
+ )
+ );
+
+ logger.info("Index {} creation completed.", IDX_NAME);
+
+ return Optional.of("index-creation-complete");
+
+ } catch (IOException e) {
+ throw new IllegalStateException(e);
+ }
+
+ }
+}
diff --git a/src/main/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpKnowledgeBaseDataSourceSync.java b/src/main/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpKnowledgeBaseDataSourceSync.java
new file mode 100644
index 00000000..a6d83cf9
--- /dev/null
+++ b/src/main/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpKnowledgeBaseDataSourceSync.java
@@ -0,0 +1,127 @@
+package org.sagebionetworks.template.repo.bedrock;
+
+import java.util.Optional;
+
+import org.apache.commons.codec.digest.DigestUtils;
+import org.apache.logging.log4j.Logger;
+import org.sagebionetworks.template.Constants;
+import org.sagebionetworks.template.LoggerFactory;
+import org.sagebionetworks.template.ThreadProvider;
+import org.sagebionetworks.template.WaitConditionHandler;
+import org.sagebionetworks.template.config.RepoConfiguration;
+
+import com.amazonaws.services.cloudformation.model.StackEvent;
+import com.google.inject.Inject;
+
+import software.amazon.awssdk.services.bedrockagent.BedrockAgentClient;
+import software.amazon.awssdk.services.bedrockagent.model.DataSourceSummary;
+import software.amazon.awssdk.services.bedrockagent.model.IngestionJob;
+import software.amazon.awssdk.services.bedrockagent.model.IngestionJobStatistics;
+import software.amazon.awssdk.services.bedrockagent.model.IngestionJobSummary;
+import software.amazon.awssdk.services.bedrockagent.model.KnowledgeBaseSummary;
+
+/**
+ * When a bedrock knowledge base is created its data source needs to be synchronized, we do this after the datasource is created through a
+ * wait condition using the bedrock APIs.
+ */
+public class SynapseHelpKnowledgeBaseDataSourceSync implements WaitConditionHandler {
+
+ private static final long SLEEP_MS = 10_000;
+
+ private Logger logger;
+ private RepoConfiguration config;
+ private ThreadProvider threadProvider;
+ private BedrockAgentClient bedrockAgentClient;
+
+ @Inject
+ public SynapseHelpKnowledgeBaseDataSourceSync(LoggerFactory loggerFactory, RepoConfiguration config, ThreadProvider threadProvider, BedrockAgentClient bedrockAgentClient) {
+ this.logger = loggerFactory.getLogger(SynapseHelpKnowledgeBaseDataSourceSync.class);
+ this.config = config;
+ this.threadProvider = threadProvider;
+ this.bedrockAgentClient = bedrockAgentClient;
+ }
+
+ @Override
+ public String getWaitConditionId() {
+ return "SynapseHelpKnowledgeBaseDataSourceSyncWaitCondition";
+ }
+
+ @Override
+ public Optional handle(StackEvent stackEvent) throws InterruptedException {
+ String stackPrefix = config.getProperty(Constants.PROPERTY_KEY_STACK) + "-" + config.getProperty(Constants.PROPERTY_KEY_INSTANCE);
+
+ String knowledgeBaseName = stackPrefix + "-synhelp-knowledge-base";
+ String knowledgeBaseId = bedrockAgentClient.listKnowledgeBasesPaginator(req -> {})
+ .knowledgeBaseSummaries().stream()
+ .filter(kb -> kb.name().equals(knowledgeBaseName))
+ .findFirst()
+ .map(KnowledgeBaseSummary::knowledgeBaseId)
+ .orElseThrow();
+
+ String dataSourceName = stackPrefix + "-synhelp-datasource";
+ String dataSourceId = bedrockAgentClient.listDataSourcesPaginator(req -> req.knowledgeBaseId(knowledgeBaseId))
+ .dataSourceSummaries().stream()
+ .filter(dataSource -> dataSource.name().equals(dataSourceName))
+ .findFirst()
+ .map(DataSourceSummary::dataSourceId)
+ .orElseThrow();
+
+ Optional existingJob = bedrockAgentClient.listIngestionJobs(req -> req.knowledgeBaseId(knowledgeBaseId).dataSourceId(dataSourceId).maxResults(1))
+ .ingestionJobSummaries()
+ .stream()
+ .findFirst();
+
+ if (existingJob.isPresent()) {
+ logger.warn("Sync job {} already exists (Status: {}).", existingJob.get().ingestionJobId(), existingJob.get().statusAsString());
+ return Optional.of("sync-started");
+ }
+
+ String clientToken = DigestUtils.sha256Hex(knowledgeBaseId + " - " + dataSourceId);
+
+ IngestionJob job = bedrockAgentClient.startIngestionJob(req -> req
+ .clientToken(clientToken)
+ .dataSourceId(dataSourceId)
+ .knowledgeBaseId(knowledgeBaseId)
+ ).ingestionJob();
+
+ String jobId = job.ingestionJobId();
+ boolean done = false;
+
+ do {
+ logger.info("Waiting for sync job {} to complete (Status: {}).", job.ingestionJobId(), job.statusAsString());
+
+ threadProvider.sleep(SLEEP_MS);
+
+ job = bedrockAgentClient.getIngestionJob(req -> req
+ .ingestionJobId(jobId)
+ .knowledgeBaseId(knowledgeBaseId)
+ .dataSourceId(dataSourceId)
+ ).ingestionJob();
+
+ switch (job.status()) {
+ case COMPLETE:
+ IngestionJobStatistics stats = job.statistics();
+
+ logger.info("Sync job {} completed (Documents Scanned: {}, Documents Indexed: {}, Documents Failed: {}).",
+ job.ingestionJobId(),
+ stats.numberOfDocumentsScanned(),
+ stats.numberOfNewDocumentsIndexed(),
+ stats.numberOfDocumentsFailed()
+ );
+
+ done = true;
+ break;
+ case FAILED:
+ case STOPPED:
+ case UNKNOWN_TO_SDK_VERSION:
+ throw new IllegalStateException("Sync job " + jobId + " failed (Status: " + job.statusAsString() + ", Failures: " + job.failureReasons().toString() +")");
+ default:
+ break;
+ }
+
+ } while (!done);
+
+ return Optional.of("sync-completed");
+ }
+
+}
diff --git a/src/main/resources/templates/repo/agent/bedrock_agent_template.json b/src/main/resources/templates/repo/agent/bedrock_agent_template.json
index 01b9d29a..675ea471 100644
--- a/src/main/resources/templates/repo/agent/bedrock_agent_template.json
+++ b/src/main/resources/templates/repo/agent/bedrock_agent_template.json
@@ -1,13 +1,28 @@
{
"AWSTemplateFormatVersion": "2010-09-09",
- "Description": "This tempalte contains all of the resources needed to create the base Synapse Bedrock Agent",
+ "Description": "This template contains all of the resources needed to create the base Synapse Bedrock Agent.",
"Parameters": {
"agentName": {
"Description": "Provide a unique name for this bedrock agent.",
"Type": "String",
"AllowedPattern": "^([0-9a-zA-Z][_-]?){1,100}$"
+ },
+ "knowledgeBaseId": {
+ "Description": "Provide the id of a knowledge base that will be associated to this bedrock agent.",
+ "Type": "String",
+ "Default": "",
+ "AllowedPattern": "^([0-9a-zA-Z]{10})?$"
+ },
+ "knowledgeBaseDescription": {
+ "Description": "Provide the description for the knowledge base that will be associated to this bedrock agent.",
+ "Type": "String",
+ "Default": "This knowledge base contains the Synapse help documentation. You can use it to answer questions around Synapse usage and its features.",
+ "MaxLength": 200
}
},
+ "Conditions": {
+ "AttachKnowledgeBase": {"Fn::Not": [{"Fn::Equals" : [{"Ref" : "knowledgeBaseId"}, ""]}]}
+ },
"Resources": {
"bedrockAgentRole": {
"Type": "AWS::IAM::Role",
@@ -44,12 +59,34 @@
"Statement": [
{
"Effect": "Allow",
- "Action": "bedrock:InvokeModel",
+ "Action": [
+ "bedrock:InvokeModel",
+ "bedrock:InvokeModelWithResponseStream"
+ ],
"Resource": [
{
"Fn::Sub": "arn:aws:bedrock:${AWS::Region}::foundation-model/*"
}
]
+ },
+ {
+ "Fn::If" : [ "AttachKnowledgeBase",
+ {
+ "Effect": "Allow",
+ "Action": [
+ "bedrock:Retrieve",
+ "bedrock:RetrieveAndGenerate"
+ ],
+ "Resource": [
+ {
+ "Fn::Sub": "arn:aws:bedrock:${AWS::Region}:${AWS::AccountId}:knowledge-base/${knowledgeBaseId}"
+ }
+ ]
+ },
+ {
+ "Ref" : "AWS::NoValue"
+ }
+ ]
}
]
}
@@ -136,6 +173,17 @@
"Arn"
]
},
+ "KnowledgeBases": {
+ "Fn::If" : [ "AttachKnowledgeBase",
+ [{
+ "KnowledgeBaseId" : { "Ref" : "knowledgeBaseId" },
+ "Description" : { "Ref" : "knowledgeBaseDescription" }
+ }],
+ {
+ "Ref" : "AWS::NoValue"
+ }
+ ]
+ },
"AutoPrepare": true,
"Description": "Test of the use of actions groups to allow the agent to make Synapse API calls.",
"FoundationModel": "anthropic.claude-3-sonnet-20240229-v1:0",
diff --git a/src/main/resources/templates/repo/bedrock-knowledge-base-template.json.vpt b/src/main/resources/templates/repo/bedrock-knowledge-base-template.json.vpt
new file mode 100644
index 00000000..acdbbfe1
--- /dev/null
+++ b/src/main/resources/templates/repo/bedrock-knowledge-base-template.json.vpt
@@ -0,0 +1,192 @@
+ ,
+ "SynapseHelpKnowledgeBaseExecutionRole": {
+ "Type": "AWS::IAM::Role",
+ "Properties": {
+ "AssumeRolePolicyDocument": {
+ "Version": "2012-10-17",
+ "Statement": [
+ {
+ "Effect": "Allow",
+ "Principal": {
+ "Service": "bedrock.amazonaws.com"
+ },
+ "Action": "sts:AssumeRole",
+ "Condition": {
+ "StringEquals": {
+ "aws:SourceAccount": {
+ "Ref": "AWS::AccountId"
+ }
+ },
+ "ArnLike": {
+ "aws:SourceArn": {
+ "Fn::Sub": "arn:aws:bedrock:#[[${AWS::Region}:${AWS::AccountId}]]#:knowledge-base/*"
+ }
+ }
+ }
+ }
+ ]
+ }
+ }
+ },
+ "SynapseHelpCollectionDeployerDataAccessPolicy": {
+ "Type":"AWS::OpenSearchServerless::AccessPolicy",
+ "Properties": {
+ "Name":"${stack}-${instance}-synhelp-deployer",
+ "Type":"data",
+ "Description":"Access policy for the deployer to create the index for the synapse help collection",
+ "Policy": {
+ "Fn::Sub": ["[{\"Rules\":[{\"ResourceType\":\"index\",\"Resource\":[\"index/${stack}-${instance}-synhelp/*\"],\"Permission\":[\"aoss:CreateIndex\",\"aoss:DescribeIndex\"]}],\"Principal\":[\"#[[${deployerRoleArn}]]#\"]}]",
+ {
+ "deployerRoleArn" : "${identityArn}"
+ }
+ ]
+ }
+ }
+ },
+ "SynapseHelpCollectionKnowledgeBaseDataAccessPolicy": {
+ "Type":"AWS::OpenSearchServerless::AccessPolicy",
+ "Properties": {
+ "Name":"${stack}-${instance}-synhelp-bedrock",
+ "Type":"data",
+ "Description":"Access policy for the knowledge base to sync and query the synapse help collection",
+ "Policy": {
+ "Fn::Sub": ["[{\"Rules\":[{\"ResourceType\":\"index\",\"Resource\":[\"index/${stack}-${instance}-synhelp/*\"],\"Permission\":[\"aoss:CreateIndex\",\"aoss:UpdateIndex\",\"aoss:DescribeIndex\",\"aoss:ReadDocument\",\"aoss:WriteDocument\"]},{\"ResourceType\":\"collection\",\"Resource\":[\"collection/${stack}-${instance}-synhelp\"],\"Permission\":[\"aoss:DescribeCollectionItems\",\"aoss:CreateCollectionItems\",\"aoss:UpdateCollectionItems\"]}],\"Principal\":[\"#[[${executionRoleArn}]]#\"]}]",
+ {
+ "executionRoleArn": { "Fn::GetAtt": ["SynapseHelpKnowledgeBaseExecutionRole", "Arn"] }
+ }
+ ]
+ }
+ }
+ },
+ "SynapseHelpCollectionEncryptionPolicy": {
+ "Type":"AWS::OpenSearchServerless::SecurityPolicy",
+ "Properties": {
+ "Name":"${stack}-${instance}-synhelp",
+ "Type":"encryption",
+ "Description":"Encryption policy for the synapse help collection",
+ "Policy":"{\"Rules\":[{\"ResourceType\":\"collection\",\"Resource\":[\"collection/${stack}-${instance}-synhelp\"]}],\"AWSOwnedKey\":true}"
+ }
+ },
+ "SynapseHelpCollectionNetworkPolicy": {
+ "Type":"AWS::OpenSearchServerless::SecurityPolicy",
+ "Properties": {
+ "Name":"${stack}-${instance}-synhelp",
+ "Type":"network",
+ "Description":"Encryption policy for the synapse help collection",
+ "Policy": {
+ "Fn::Sub": [
+ "[{\"Rules\":[{\"ResourceType\":\"collection\",\"Resource\":[\"collection/${stack}-${instance}-synhelp\"]}],\"AllowFromPublic\":false,\"SourceServices\":[\"bedrock.amazonaws.com\"],\"SourceVPCEs\":[\"#[[${vpcEndpointId}]]#\"]}]",
+ {
+ "vpcEndpointId": { "Fn::ImportValue": "${vpcExportPrefix}-open-search-vpce" }
+ }
+ ]
+ }
+ }
+ },
+ "SynapseHelpCollection": {
+ "Type":"AWS::OpenSearchServerless::Collection",
+ "Properties": {
+ "Name":"${stack}-${instance}-synhelp",
+ "Type":"VECTORSEARCH",
+ "Description":"Vector store collection for synapse help documentation",
+ "StandbyReplicas": "DISABLED"
+ },
+ "DependsOn": ["SynapseHelpCollectionEncryptionPolicy", "SynapseHelpCollectionNetworkPolicy", "SynapseHelpCollectionDeployerDataAccessPolicy", "SynapseHelpCollectionKnowledgeBaseDataAccessPolicy"]
+ },
+ "SynapseHelpKnoweldgeBaseExecutionRolePolicy": {
+ "Type":"AWS::IAM::RolePolicy",
+ "Properties": {
+ "PolicyName": "SynapseHelpKnoweldgeBaseExecutionRolePolicy",
+ "PolicyDocument": {
+ "Version": "2012-10-17",
+ "Statement": [
+ {
+ "Effect": "Allow",
+ "Action": "bedrock:InvokeModel",
+ "Resource": [
+ {
+ "Fn::Sub": "arn:aws:bedrock:#[[${AWS::Region}]]#::foundation-model/*"
+ }
+ ]
+ },
+ {
+ "Effect": "Allow",
+ "Action": "aoss:APIAccessAll",
+ "Resource": [
+ {
+ "Fn::GetAtt" : ["SynapseHelpCollection", "Arn"]
+ }
+ ]
+ }
+ ]
+ },
+ "RoleName": {
+ "Ref": "SynapseHelpKnowledgeBaseExecutionRole"
+ }
+ }
+ },
+ "SynapseHelpCollectionCreateIndexWaitCondition": {
+ "Type": "AWS::CloudFormation::WaitCondition",
+ "DependsOn": [ "SynapseHelpCollection" ],
+ "CreationPolicy" : {
+ "ResourceSignal" : {
+ "Timeout" : "PT3M",
+ "Count" : "1"
+ }
+ }
+ },
+ "SynapseHelpKnowledgeBase": {
+ "Type": "AWS::Bedrock::KnowledgeBase",
+ "DependsOn": [ "SynapseHelpCollectionCreateIndexWaitCondition", "SynapseHelpKnoweldgeBaseExecutionRolePolicy" ],
+ "Properties": {
+ "Name": "${stack}-${instance}-synhelp-knowledge-base",
+ "Description": "Knowledge base for the synpase help documentation",
+ "RoleArn": { "Fn::GetAtt" : ["SynapseHelpKnowledgeBaseExecutionRole", "Arn"] },
+ "KnowledgeBaseConfiguration": {
+ "Type": "VECTOR",
+ "VectorKnowledgeBaseConfiguration": {
+ "EmbeddingModelArn": { "Fn::Sub": "arn:aws:bedrock:#[[${AWS::Region}]]#::foundation-model/amazon.titan-embed-text-v2:0" }
+ }
+ },
+ "StorageConfiguration": {
+ "Type": "OPENSEARCH_SERVERLESS",
+ "OpensearchServerlessConfiguration": {
+ "CollectionArn": { "Fn::GetAtt" : ["SynapseHelpCollection", "Arn"] },
+ "VectorIndexName": "vector-idx",
+ "FieldMapping": {
+ "VectorField": "text_vector",
+ "TextField": "text_raw",
+ "MetadataField": "text_metadata"
+ }
+ }
+ }
+ }
+ },
+ "SynapseHelpKnowledgeBaseDataSource": {
+ "Type": "AWS::Bedrock::DataSource",
+ "Properties": {
+ "KnowledgeBaseId": { "Ref": "SynapseHelpKnowledgeBase" },
+ "Name": "${stack}-${instance}-synhelp-datasource",
+ "Description": "The datasource for the synapse help document, the data is crawled from the synapse help website",
+ "DataSourceConfiguration": {
+ "Type": "WEB",
+ "WebConfiguration": {
+ "SourceConfiguration": {
+ "UrlConfiguration": {
+ "SeedUrls": [ { "Url": "https://help.synapse.org" } ]
+ }
+ }
+ }
+ }
+ }
+ },
+ "SynapseHelpKnowledgeBaseDataSourceSyncWaitCondition": {
+ "Type": "AWS::CloudFormation::WaitCondition",
+ "DependsOn": "SynapseHelpKnowledgeBaseDataSource",
+ "CreationPolicy" : {
+ "ResourceSignal" : {
+ "Timeout" : "PT10M",
+ "Count" : "1"
+ }
+ }
+ }
\ No newline at end of file
diff --git a/src/main/resources/templates/repo/main-repo-shared-resources-template.json.vpt b/src/main/resources/templates/repo/main-repo-shared-resources-template.json.vpt
index fc589266..cbeb598b 100644
--- a/src/main/resources/templates/repo/main-repo-shared-resources-template.json.vpt
+++ b/src/main/resources/templates/repo/main-repo-shared-resources-template.json.vpt
@@ -9,6 +9,9 @@
},
#parse("templates/repo/time-to-live-parameter.vpt")
},
+ "Conditions": {
+ "AttachKnowledgeBase": {"Fn::Equals": ["true", "true"]}
+ },
"Resources": {
"${stack}${instance}DBSubnetGroup": {
"Type": "AWS::RDS::DBSubnetGroup",
@@ -745,6 +748,7 @@
#parse("templates/repo/web/acl/web-acl-template.json.vpt")
#parse("templates/repo/appconfig-template.json.vpt")
#parse("templates/repo/api-gateway-template.json.vpt")
+ #parse("templates/repo/bedrock-knowledge-base-template.json.vpt")
${bedrock_agent_resouces}
},
"Outputs": {
@@ -858,6 +862,27 @@
]
}
}
+ },
+ "SynapseHelpCollectionEndpoint": {
+ "Value": {
+ "Fn::GetAtt" : ["SynapseHelpCollection", "CollectionEndpoint"]
+ },
+ "Export": {
+ "Name": {
+ "Fn::Join": [
+ "-",
+ [
+ {
+ "Ref": "AWS::Region"
+ },
+ {
+ "Ref": "AWS::StackName"
+ },
+ "Synapse-Help-Collection-Endpoint"
+ ]
+ ]
+ }
+ }
}
#foreach( $descriptor in ${databaseDescriptors} )
,
diff --git a/src/test/java/org/sagebionetworks/template/CloudFormationClientImplTest.java b/src/test/java/org/sagebionetworks/template/CloudFormationClientImplTest.java
index 2ead4c5f..d90c603b 100644
--- a/src/test/java/org/sagebionetworks/template/CloudFormationClientImplTest.java
+++ b/src/test/java/org/sagebionetworks/template/CloudFormationClientImplTest.java
@@ -1,6 +1,7 @@
package org.sagebionetworks.template;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
@@ -13,8 +14,10 @@
import java.net.MalformedURLException;
import java.util.ArrayList;
import java.util.Collection;
+import java.util.Collections;
import java.util.List;
import java.util.Optional;
+import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
@@ -35,11 +38,17 @@
import com.amazonaws.services.cloudformation.model.CreateStackRequest;
import com.amazonaws.services.cloudformation.model.CreateStackResult;
import com.amazonaws.services.cloudformation.model.DeleteStackRequest;
+import com.amazonaws.services.cloudformation.model.DescribeStackEventsRequest;
+import com.amazonaws.services.cloudformation.model.DescribeStackEventsResult;
import com.amazonaws.services.cloudformation.model.DescribeStacksRequest;
import com.amazonaws.services.cloudformation.model.DescribeStacksResult;
import com.amazonaws.services.cloudformation.model.Output;
import com.amazonaws.services.cloudformation.model.Parameter;
+import com.amazonaws.services.cloudformation.model.ResourceSignalStatus;
+import com.amazonaws.services.cloudformation.model.ResourceStatus;
+import com.amazonaws.services.cloudformation.model.SignalResourceRequest;
import com.amazonaws.services.cloudformation.model.Stack;
+import com.amazonaws.services.cloudformation.model.StackEvent;
import com.amazonaws.services.cloudformation.model.StackStatus;
import com.amazonaws.services.cloudformation.model.UpdateStackRequest;
import com.amazonaws.services.cloudformation.model.UpdateStackResult;
@@ -68,6 +77,8 @@ public class CloudFormationClientImplTest {
Logger mockLogger;
@Mock
ThreadProvider mockThreadProvider;
+ @Mock
+ WaitConditionHandler mockWaitConditionHandler;
@Captor
ArgumentCaptor describeStackRequestCapture;
@@ -475,6 +486,263 @@ public void testWaitForStackToCompleteUpdateRollbackCompleteToUpdateRollBackComp
Assertions.assertNotNull(resultStack.getStackStatus());
Assertions.assertEquals(StackStatus.UPDATE_ROLLBACK_COMPLETE, StackStatus.fromValue(resultStack.getStackStatus()));
}
+
+ @Test
+ public void testWaitForStackToCompleteWithWaitConditionHandlers() throws InterruptedException {
+ initStack.setStackStatus(StackStatus.CREATE_IN_PROGRESS);
+ stack.setStackStatus(StackStatus.CREATE_IN_PROGRESS);
+
+ when(mockCloudFormationClient.describeStacks(any(DescribeStacksRequest.class))).thenReturn(
+ initDescribeResult,
+ describeResult,
+ new DescribeStacksResult().withStacks(new Stack().withStackStatus(StackStatus.CREATE_COMPLETE))
+ );
+
+ String waitConditionId = "waitConditionId";
+ StackEvent waitConditionEvent = new StackEvent()
+ .withResourceType("AWS::CloudFormation::WaitCondition")
+ .withLogicalResourceId(waitConditionId)
+ .withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS);
+
+ when(mockWaitConditionHandler.getWaitConditionId()).thenReturn(waitConditionId);
+ when(mockWaitConditionHandler.handle(waitConditionEvent)).thenReturn(Optional.of("done"));
+ when(mockCloudFormationClient.describeStackEvents(new DescribeStackEventsRequest().withStackName(stackName))).thenReturn(
+ new DescribeStackEventsResult().withStackEvents(waitConditionEvent)
+ );
+
+ // call under test
+ Stack resultStack = client.waitForStackToComplete(stackName, Set.of(mockWaitConditionHandler)).get();
+
+ verify(mockCloudFormationClient).signalResource(new SignalResourceRequest()
+ .withLogicalResourceId(waitConditionId)
+ .withStackName(stackName)
+ .withStatus(ResourceSignalStatus.SUCCESS)
+ .withUniqueId("done")
+ );
+ }
+
+ @Test
+ public void testWaitForStackToCompleteWithWaitConditionHandlersAndMultipleEvents() throws InterruptedException {
+ initStack.setStackStatus(StackStatus.CREATE_IN_PROGRESS);
+ stack.setStackStatus(StackStatus.CREATE_IN_PROGRESS);
+
+ when(mockCloudFormationClient.describeStacks(any(DescribeStacksRequest.class))).thenReturn(
+ initDescribeResult,
+ describeResult,
+ new DescribeStacksResult().withStacks(new Stack().withStackStatus(StackStatus.CREATE_COMPLETE))
+ );
+
+ String waitConditionId = "waitConditionId";
+
+ StackEvent waitConditionEvent = new StackEvent()
+ .withResourceType("AWS::CloudFormation::WaitCondition")
+ .withLogicalResourceId(waitConditionId)
+ .withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS)
+ .withEventId("last");
+
+ when(mockWaitConditionHandler.getWaitConditionId()).thenReturn(waitConditionId);
+ when(mockWaitConditionHandler.handle(waitConditionEvent)).thenReturn(Optional.of("done"));
+ when(mockCloudFormationClient.describeStackEvents(new DescribeStackEventsRequest().withStackName(stackName))).thenReturn(
+ new DescribeStackEventsResult().withStackEvents(
+ waitConditionEvent,
+ new StackEvent().withResourceType("AWS::CloudFormation::WaitCondition").withLogicalResourceId("anotherWaitConditionId").withResourceStatus(ResourceStatus.CREATE_COMPLETE),
+ new StackEvent().withResourceType("anotherType").withLogicalResourceId("anotherResourceId").withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS),
+ // Another event for the same condition id, should be discarded
+ new StackEvent().withResourceType("AWS::CloudFormation::WaitCondition").withLogicalResourceId(waitConditionId).withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS).withEventId("previous")
+ )
+ );
+
+ // call under test
+ Stack resultStack = client.waitForStackToComplete(stackName, Set.of(mockWaitConditionHandler)).get();
+
+ verify(mockCloudFormationClient).signalResource(new SignalResourceRequest()
+ .withLogicalResourceId(waitConditionId)
+ .withStackName(stackName)
+ .withStatus(ResourceSignalStatus.SUCCESS)
+ .withUniqueId("done")
+ );
+ }
+
+ @Test
+ public void testWaitForStackToCompleteWithWaitConditionHandlersAndAlreadyProcessed() throws InterruptedException {
+ initStack.setStackStatus(StackStatus.CREATE_IN_PROGRESS);
+ stack.setStackStatus(StackStatus.CREATE_IN_PROGRESS);
+
+ when(mockCloudFormationClient.describeStacks(any(DescribeStacksRequest.class))).thenReturn(
+ initDescribeResult,
+ describeResult,
+ describeResult,
+ new DescribeStacksResult().withStacks(new Stack().withStackStatus(StackStatus.CREATE_COMPLETE))
+ );
+
+ String waitConditionId = "waitConditionId";
+
+ StackEvent waitConditionEvent = new StackEvent()
+ .withResourceType("AWS::CloudFormation::WaitCondition")
+ .withLogicalResourceId(waitConditionId)
+ .withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS);
+
+ when(mockWaitConditionHandler.getWaitConditionId()).thenReturn(waitConditionId);
+ when(mockWaitConditionHandler.handle(waitConditionEvent)).thenReturn(Optional.of("done"));
+ when(mockCloudFormationClient.describeStackEvents(new DescribeStackEventsRequest().withStackName(stackName))).thenReturn(
+ new DescribeStackEventsResult().withStackEvents(waitConditionEvent)
+ );
+
+ // call under test
+ Stack resultStack = client.waitForStackToComplete(stackName, Set.of(mockWaitConditionHandler)).get();
+
+ verify(mockCloudFormationClient, times(2)).describeStackEvents(any());
+
+ // Should be invoked only once
+ verify(mockWaitConditionHandler).handle(waitConditionEvent);
+ verify(mockCloudFormationClient).signalResource(new SignalResourceRequest()
+ .withLogicalResourceId(waitConditionId)
+ .withStackName(stackName)
+ .withStatus(ResourceSignalStatus.SUCCESS)
+ .withUniqueId("done")
+ );
+ }
+
+ @Test
+ public void testWaitForStackToCompleteWithWaitConditionHandlersAndNoMatchingHandler() throws InterruptedException {
+ initStack.setStackStatus(StackStatus.CREATE_IN_PROGRESS);
+ stack.setStackStatus(StackStatus.CREATE_IN_PROGRESS);
+
+ when(mockCloudFormationClient.describeStacks(any(DescribeStacksRequest.class))).thenReturn(
+ initDescribeResult,
+ describeResult,
+ new DescribeStacksResult().withStacks(new Stack().withStackStatus(StackStatus.CREATE_COMPLETE))
+ );
+
+ String waitConditionId = "waitConditionId";
+
+ StackEvent waitConditionEvent = new StackEvent()
+ .withResourceType("AWS::CloudFormation::WaitCondition")
+ .withLogicalResourceId(waitConditionId)
+ .withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS)
+ .withEventId("last");
+
+ when(mockWaitConditionHandler.getWaitConditionId()).thenReturn(waitConditionId + "-mistmatching");
+ when(mockCloudFormationClient.describeStackEvents(new DescribeStackEventsRequest().withStackName(stackName))).thenReturn(
+ new DescribeStackEventsResult().withStackEvents(
+ waitConditionEvent
+ )
+ );
+
+ IllegalStateException result = assertThrows(IllegalStateException.class, () -> {
+ // call under test
+ Stack resultStack = client.waitForStackToComplete(stackName, Set.of(mockWaitConditionHandler)).get();
+ });
+
+ assertEquals("Processing wait condition waitConditionId failed: could not find an handler.", result.getMessage());
+
+ verify(mockCloudFormationClient).signalResource(new SignalResourceRequest()
+ .withLogicalResourceId(waitConditionId)
+ .withStackName(stackName)
+ .withStatus(ResourceSignalStatus.FAILURE)
+ .withUniqueId("handler-not-found")
+ );
+
+ verifyNoMoreInteractions(mockWaitConditionHandler);
+ }
+
+ @Test
+ public void testWaitForStackToCompleteWithWaitConditionHandlersAndException() throws InterruptedException {
+ initStack.setStackStatus(StackStatus.CREATE_IN_PROGRESS);
+ stack.setStackStatus(StackStatus.CREATE_IN_PROGRESS);
+
+ when(mockCloudFormationClient.describeStacks(any(DescribeStacksRequest.class))).thenReturn(
+ initDescribeResult,
+ describeResult,
+ new DescribeStacksResult().withStacks(new Stack().withStackStatus(StackStatus.CREATE_COMPLETE))
+ );
+
+ String waitConditionId = "waitConditionId";
+
+ StackEvent waitConditionEvent = new StackEvent()
+ .withResourceType("AWS::CloudFormation::WaitCondition")
+ .withLogicalResourceId(waitConditionId)
+ .withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS)
+ .withEventId("last");
+
+ when(mockWaitConditionHandler.getWaitConditionId()).thenReturn(waitConditionId);
+
+ RuntimeException cause = new RuntimeException("processing error");
+
+ when(mockWaitConditionHandler.handle(waitConditionEvent)).thenThrow(cause);
+ when(mockCloudFormationClient.describeStackEvents(new DescribeStackEventsRequest().withStackName(stackName))).thenReturn(
+ new DescribeStackEventsResult().withStackEvents(
+ waitConditionEvent
+ )
+ );
+
+ IllegalStateException result = assertThrows(IllegalStateException.class, () -> {
+ // call under test
+ client.waitForStackToComplete(stackName, Set.of(mockWaitConditionHandler)).get();
+ });
+
+ assertEquals("Processing wait condition waitConditionId failed.", result.getMessage());
+ assertEquals(cause, result.getCause());
+
+ verify(mockCloudFormationClient).signalResource(new SignalResourceRequest()
+ .withLogicalResourceId(waitConditionId)
+ .withStackName(stackName)
+ .withStatus(ResourceSignalStatus.FAILURE)
+ .withUniqueId("handler-failed")
+ );
+
+ verifyNoMoreInteractions(mockWaitConditionHandler);
+ }
+
+ @Test
+ public void testWaitForStackToCompleteWithWaitConditionHandlersAndNoSignal() throws InterruptedException {
+ initStack.setStackStatus(StackStatus.CREATE_IN_PROGRESS);
+ stack.setStackStatus(StackStatus.CREATE_IN_PROGRESS);
+
+ when(mockCloudFormationClient.describeStacks(any(DescribeStacksRequest.class))).thenReturn(
+ initDescribeResult,
+ describeResult,
+ new DescribeStacksResult().withStacks(new Stack().withStackStatus(StackStatus.CREATE_COMPLETE))
+ );
+
+ String waitConditionId = "waitConditionId";
+
+ StackEvent waitConditionEvent = new StackEvent()
+ .withResourceType("AWS::CloudFormation::WaitCondition")
+ .withLogicalResourceId(waitConditionId)
+ .withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS)
+ .withEventId("last");
+
+ when(mockWaitConditionHandler.getWaitConditionId()).thenReturn(waitConditionId);
+ when(mockWaitConditionHandler.handle(waitConditionEvent)).thenReturn(Optional.empty());
+ when(mockCloudFormationClient.describeStackEvents(new DescribeStackEventsRequest().withStackName(stackName))).thenReturn(
+ new DescribeStackEventsResult().withStackEvents(
+ waitConditionEvent
+ )
+ );
+
+ // call under test
+ Stack resultStack = client.waitForStackToComplete(stackName, Set.of(mockWaitConditionHandler)).get();
+
+ verifyNoMoreInteractions(mockCloudFormationClient, mockWaitConditionHandler);
+ }
+
+ @Test
+ public void testWaitForStackToCompleteWithEmptyWaitConditionHandlers() throws InterruptedException {
+ initStack.setStackStatus(StackStatus.CREATE_IN_PROGRESS);
+ stack.setStackStatus(StackStatus.CREATE_IN_PROGRESS);
+
+ when(mockCloudFormationClient.describeStacks(any(DescribeStacksRequest.class))).thenReturn(
+ initDescribeResult,
+ describeResult,
+ new DescribeStacksResult().withStacks(new Stack().withStackStatus(StackStatus.CREATE_COMPLETE))
+ );
+
+ // call under test
+ Stack resultStack = client.waitForStackToComplete(stackName, Collections.emptySet()).get();
+
+ verifyNoMoreInteractions(mockCloudFormationClient, mockWaitConditionHandler);
+ }
@Test
public void testGetOutput() {
diff --git a/src/test/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImplTest.java b/src/test/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImplTest.java
index e41481c3..cb4f6ffb 100644
--- a/src/test/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImplTest.java
+++ b/src/test/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImplTest.java
@@ -21,9 +21,10 @@
import static org.sagebionetworks.template.Constants.DATABASE_DESCRIPTORS;
import static org.sagebionetworks.template.Constants.DB_ENDPOINT_SUFFIX;
import static org.sagebionetworks.template.Constants.DELETION_POLICY;
-import static org.sagebionetworks.template.Constants.EC2_INSTANCE_TYPE;
import static org.sagebionetworks.template.Constants.EC2_INSTANCE_MEMORY;
+import static org.sagebionetworks.template.Constants.EC2_INSTANCE_TYPE;
import static org.sagebionetworks.template.Constants.ENVIRONMENT;
+import static org.sagebionetworks.template.Constants.IDENTITY_ARN;
import static org.sagebionetworks.template.Constants.INSTANCE;
import static org.sagebionetworks.template.Constants.NOSNAPSHOT;
import static org.sagebionetworks.template.Constants.OUTPUT_NAME_SUFFIX_REPOSITORY_DB_ENDPOINT;
@@ -36,8 +37,8 @@
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_BEANSTALK_SSL_ARN;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_BEANSTALK_VERSION;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_DATA_CDN_KEYPAIR_ID;
-import static org.sagebionetworks.template.Constants.PROPERTY_KEY_EC2_INSTANCE_TYPE;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_EC2_INSTANCE_MEMORY;
+import static org.sagebionetworks.template.Constants.PROPERTY_KEY_EC2_INSTANCE_TYPE;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ELASTICBEANSTALK_IMAGE_VERSION_AMAZONLINUX;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ELASTICBEANSTALK_IMAGE_VERSION_JAVA;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ELASTICBEANSTALK_IMAGE_VERSION_TOMCAT;
@@ -49,19 +50,19 @@
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_ALLOCATED_STORAGE;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_INSTANCE_CLASS;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_IOPS;
-import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_THROUGHPUT;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_MAX_ALLOCATED_STORAGE;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_MULTI_AZ;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_STORAGE_TYPE;
+import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_THROUGHPUT;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ROUTE_53_HOSTED_ZONE;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_STACK;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_INSTANCE_COUNT;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_ALLOCATED_STORAGE;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_INSTANCE_CLASS;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_IOPS;
-import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_THROUGHPUT;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_MAX_ALLOCATED_STORAGE;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_STORAGE_TYPE;
+import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_THROUGHPUT;
import static org.sagebionetworks.template.Constants.PROPERTY_KEY_VPC_SUBNET_COLOR;
import static org.sagebionetworks.template.Constants.REPO_BEANSTALK_NUMBER;
import static org.sagebionetworks.template.Constants.SHARED_EXPORT_PREFIX;
@@ -100,6 +101,7 @@
import org.sagebionetworks.template.LoggerFactory;
import org.sagebionetworks.template.StackTagsProvider;
import org.sagebionetworks.template.TemplateGuiceModule;
+import org.sagebionetworks.template.WaitConditionHandler;
import org.sagebionetworks.template.config.RepoConfiguration;
import org.sagebionetworks.template.config.TimeToLive;
import org.sagebionetworks.template.repo.agent.BedrockAgentContextProvider;
@@ -123,6 +125,8 @@
import com.amazonaws.services.elasticbeanstalk.model.ListPlatformVersionsResult;
import com.amazonaws.services.elasticbeanstalk.model.PlatformFilter;
import com.amazonaws.services.elasticbeanstalk.model.PlatformSummary;
+import com.amazonaws.services.securitytoken.AWSSecurityTokenService;
+import com.amazonaws.services.securitytoken.model.GetCallerIdentityResult;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
@@ -157,6 +161,10 @@ public class RepositoryTemplateBuilderImplTest {
private CloudwatchLogsVelocityContextProvider mockCwlContextProvider;
@Mock
private TimeToLive mockTimeToLive;
+ @Mock
+ private AWSSecurityTokenService mockStsClient;
+ @Mock
+ private WaitConditionHandler mockWaitConditionHandler;
@Captor
private ArgumentCaptor requestCaptor;
@@ -188,10 +196,12 @@ public void before() throws InterruptedException {
expectedTags.add(t);
when(mockLoggerFactory.getLogger(any())).thenReturn(mockLogger);
+
builder = new RepositoryTemplateBuilderImpl(mockCloudFormationClient, velocityEngine, config, mockLoggerFactory,
mockArtifactCopy, mockSecretBuilder, Sets.newHashSet(mockContextProvider1, mockContextProvider2, new BedrockAgentContextProvider(config)),
mockElasticBeanstalkSolutionStackNameProvider, mockStackTagsProvider, mockCwlContextProvider,
- mockEc2Client, mockBeanstalkClient, mockTimeToLive);
+ mockEc2Client, mockBeanstalkClient, mockTimeToLive, mockStsClient, Set.of(mockWaitConditionHandler));
+
builderSpy = Mockito.spy(builder);
stack = "dev";
@@ -210,7 +220,7 @@ public void before() throws InterruptedException {
Output tableDBOutput2 = new Output();
tableDBOutput2.withOutputKey(stack + instance + "Table1" + OUTPUT_NAME_SUFFIX_REPOSITORY_DB_ENDPOINT);
tableDBOutput2.withOutputValue(stack + "-" + instance + "-table-1." + databaseEndpointSuffix);
-
+
sharedResouces.withOutputs(dbOut, tableDBOutput1, tableDBOutput2);
secretsSouce = new SourceBundle("secretBucket", "secretKey");
@@ -221,17 +231,27 @@ public void before() throws InterruptedException {
}
private void configureStack(String inputStack) throws InterruptedException {
+ when(mockStsClient.getCallerIdentity(any())).thenReturn(new GetCallerIdentityResult().withArn("currentIdentityArn"));
stack = inputStack;
+
when(config.getProperty(PROPERTY_KEY_STACK)).thenReturn(stack);
+
+
sharedResouces = new Stack();
- Output dbOut = new Output();
- dbOut.withOutputKey(stack + instance + OUTPUT_NAME_SUFFIX_REPOSITORY_DB_ENDPOINT);
+
databaseEndpointSuffix = "something.amazon.com";
- dbOut.withOutputValue(stack + "-" + instance + "-db." + databaseEndpointSuffix);
- sharedResouces.withOutputs(dbOut);
-
- when(mockCloudFormationClient.waitForStackToComplete(any(String.class)))
- .thenReturn(Optional.of(sharedResouces));
+
+ sharedResouces.withOutputs(
+ new Output()
+ .withOutputKey(stack + instance + OUTPUT_NAME_SUFFIX_REPOSITORY_DB_ENDPOINT)
+ .withOutputValue(stack + "-" + instance + "-db." + databaseEndpointSuffix),
+ new Output()
+ .withOutputKey("SynapseHelpCollectionEndpoint")
+ .withOutputValue("synhelp-endpoint")
+ );
+
+ when(mockCloudFormationClient.waitForStackToComplete(any(String.class), any())).thenReturn(Optional.of(sharedResouces));
+
}
@Test
@@ -300,6 +320,8 @@ public void testBuildAndDeployProd() throws InterruptedException {
builder.buildAndDeploy();
verify(mockCloudFormationClient, times(4)).createOrUpdateStack(requestCaptor.capture());
+ verify(mockCloudFormationClient).waitForStackToComplete("prod-101-shared-resources", Set.of(mockWaitConditionHandler));
+
List list = requestCaptor.getAllValues();
CreateOrUpdateStackRequest request = list.get(0);
assertEquals("prod-101-shared-resources", request.getStackName());
@@ -354,7 +376,9 @@ public void testBuildAndDeployProd() throws InterruptedException {
assertEquals(15000, tDbProps.getInt("StorageThroughput"));
assertFalse(resources.has("WebhookTestApi"));
-
+ assertTrue(resources.has("SynapseHelpCollection"));
+ assertTrue(resources.has("SynapseHelpKnowledgeBaseExecutionRole"));
+ assertTrue(resources.has("SynapseHelpKnowledgeBase"));
assertTrue(resources.has("bedrockAgentRole"));
assertTrue(resources.has("bedrockAgent"));
assertEquals("prod-101-agent", resources.getJSONObject("bedrockAgent").getJSONObject("Properties").get("AgentName"));
@@ -546,6 +570,8 @@ public void testBuildAndDeployDev() throws InterruptedException {
builder.buildAndDeploy();
verify(mockCloudFormationClient, times(4)).createOrUpdateStack(requestCaptor.capture());
+ verify(mockCloudFormationClient).waitForStackToComplete("dev-101-shared-resources", Set.of(mockWaitConditionHandler));
+
List list = requestCaptor.getAllValues();
CreateOrUpdateStackRequest request = list.get(0);
assertEquals("dev-101-shared-resources", request.getStackName());
@@ -598,7 +624,9 @@ public void testBuildAndDeployDev() throws InterruptedException {
assertEquals(15000, tDbProps.getInt("StorageThroughput"));
assertTrue(resources.has("WebhookTestApi"));
-
+ assertTrue(resources.has("SynapseHelpCollection"));
+ assertTrue(resources.has("SynapseHelpKnowledgeBaseExecutionRole"));
+ assertTrue(resources.has("SynapseHelpKnowledgeBase"));
assertTrue(resources.has("bedrockAgentRole"));
assertTrue(resources.has("bedrockAgent"));
assertEquals("dev-101-agent", resources.getJSONObject("bedrockAgent").getJSONObject("Properties").get("AgentName"));
@@ -905,7 +933,7 @@ public void testCreateContext() {
when(config.getProperty(PROPERTY_KEY_RDS_REPO_SNAPSHOT_IDENTIFIER)).thenReturn(NOSNAPSHOT);
String[] noSnapshots = new String[] { NOSNAPSHOT };
when(config.getComaSeparatedProperty(PROPERTY_KEY_RDS_TABLES_SNAPSHOT_IDENTIFIERS)).thenReturn(noSnapshots);
-
+ when(mockStsClient.getCallerIdentity(any())).thenReturn(new GetCallerIdentityResult().withArn("currentIdentityArn"));
// call under test
VelocityContext context = builder.createSharedContext();
@@ -988,7 +1016,8 @@ public void testCreateContextProd() {
when(config.getProperty(PROPERTY_KEY_RDS_REPO_SNAPSHOT_IDENTIFIER)).thenReturn(NOSNAPSHOT);
String[] noSnapshots = new String[] { NOSNAPSHOT };
when(config.getComaSeparatedProperty(PROPERTY_KEY_RDS_TABLES_SNAPSHOT_IDENTIFIERS)).thenReturn(noSnapshots);
-
+ when(mockStsClient.getCallerIdentity(any())).thenReturn(new GetCallerIdentityResult().withArn("currentIdentityArn"));
+
// call under test
VelocityContext context = builder.createSharedContext();
@@ -1001,6 +1030,7 @@ public void testCreateContextProd() {
assertEquals("Block:{}", context.get(ADMIN_RULE_ACTION));
assertEquals("Retain", context.get(DELETION_POLICY));
+ assertEquals("currentIdentityArn", context.get(IDENTITY_ARN));
}
diff --git a/src/test/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpCollectionIndexCreationTest.java b/src/test/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpCollectionIndexCreationTest.java
new file mode 100644
index 00000000..89aff5f0
--- /dev/null
+++ b/src/test/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpCollectionIndexCreationTest.java
@@ -0,0 +1,248 @@
+package org.sagebionetworks.template.repo.bedrock;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.when;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+import java.util.NoSuchElementException;
+import java.util.Optional;
+import java.util.function.Consumer;
+
+import org.apache.logging.log4j.Logger;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.opensearch.client.opensearch._types.mapping.Property;
+import org.opensearch.client.opensearch.indices.CreateIndexRequest;
+import org.opensearch.client.opensearch.indices.CreateIndexResponse;
+import org.opensearch.client.opensearch.indices.ExistsRequest;
+import org.opensearch.client.opensearch.indices.OpenSearchIndicesClient;
+import org.opensearch.client.transport.endpoints.BooleanResponse;
+import org.sagebionetworks.template.Constants;
+import org.sagebionetworks.template.LoggerFactory;
+import org.sagebionetworks.template.OpenSearchClientFactory;
+import org.sagebionetworks.template.WaitConditionHandler;
+import org.sagebionetworks.template.config.RepoConfiguration;
+
+import com.amazonaws.services.cloudformation.model.StackEvent;
+
+import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient;
+import software.amazon.awssdk.services.opensearchserverless.model.BatchGetCollectionRequest;
+import software.amazon.awssdk.services.opensearchserverless.model.BatchGetCollectionResponse;
+import software.amazon.awssdk.services.opensearchserverless.model.CollectionDetail;
+import software.amazon.awssdk.services.opensearchserverless.model.CollectionStatus;
+
+@ExtendWith(MockitoExtension.class)
+public class SynapseHelpCollectionIndexCreationTest {
+
+ private final static String COLLECTION_ENDPOINT = "endpoint";
+
+ @Mock
+ private LoggerFactory mockLoggerFactory;
+
+ @Mock
+ private OpenSearchServerlessClient mockOssManagementClient;
+
+ @Mock
+ private OpenSearchClientFactory mockOpenSearchClientFactory;
+
+ @Mock
+ private RepoConfiguration mockConfig;
+
+ private WaitConditionHandler handler;
+
+ @Mock
+ private Logger mockLogger;
+
+ @Mock
+ private StackEvent mockStackEvent;
+
+ @Mock
+ private OpenSearchIndicesClient mockOpenSearchIndicesClient;
+
+ @Captor
+ private ArgumentCaptor> getCollectionRequestCaptor;
+
+ @Captor
+ private ArgumentCaptor existRequestCaptor;
+
+ @Captor
+ private ArgumentCaptor createRequestCaptor;
+
+ @BeforeEach
+ public void before() {
+ when(mockLoggerFactory.getLogger(any())).thenReturn(mockLogger);
+ handler = new SynapseHelpCollectionIndexCreation(mockLoggerFactory, mockConfig, mockOssManagementClient, mockOpenSearchClientFactory);
+ }
+
+ @Test
+ public void testGetWaitConditionId() {
+ // Call under test
+ assertEquals("SynapseHelpCollectionCreateIndexWaitCondition", handler.getWaitConditionId());
+ }
+
+ @Test
+ public void testHandleWithActiveCollection() throws IOException, InterruptedException {
+
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev");
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101");
+
+ when(mockOssManagementClient.batchGetCollection(getCollectionRequestCaptor.capture())).thenReturn(BatchGetCollectionResponse.builder()
+ .collectionDetails(CollectionDetail.builder().status(CollectionStatus.ACTIVE).collectionEndpoint(COLLECTION_ENDPOINT).build()).build()
+ );
+
+ when(mockOpenSearchClientFactory.getIndicesClient(COLLECTION_ENDPOINT)).thenReturn(mockOpenSearchIndicesClient);
+
+ when(mockOpenSearchIndicesClient.exists(existRequestCaptor.capture())).thenReturn(new BooleanResponse(false));
+ when(mockOpenSearchIndicesClient.create(createRequestCaptor.capture())).thenReturn(
+ CreateIndexResponse.of(resp -> resp
+ .index("vector-idx")
+ .acknowledged(true)
+ .shardsAcknowledged(true)
+ )
+ );
+
+ // Call under test
+ assertEquals(Optional.of("index-creation-complete"), handler.handle(mockStackEvent));
+
+ assertEquals(
+ BatchGetCollectionRequest.builder().names("dev-101-synhelp").build(),
+ BatchGetCollectionRequest.builder().applyMutation(getCollectionRequestCaptor.getValue()).build()
+ );
+
+ ExistsRequest existRequest = existRequestCaptor.getValue();
+
+ assertEquals(List.of("vector-idx"), existRequest.index());
+
+ CreateIndexRequest createRequest = createRequestCaptor.getValue();
+
+ assertEquals("vector-idx", createRequest.index());
+ assertTrue(createRequest.settings().knn());
+ assertEquals(512, createRequest.settings().knnAlgoParamEfSearch());
+
+ Property textVectorProp = createRequest.mappings().properties().get("text_vector");
+
+ assertEquals(1024, textVectorProp.knnVector().dimension());
+ assertEquals("hnsw", textVectorProp.knnVector().method().name());
+ assertEquals("faiss", textVectorProp.knnVector().method().engine());
+ assertEquals("l2", textVectorProp.knnVector().method().spaceType());
+ assertTrue(createRequest.mappings().properties().get("text_raw").text().index());
+ assertFalse(createRequest.mappings().properties().get("text_metadata").text().index());
+
+ }
+
+ @Test
+ public void testHandleWithInactiveCollection() throws IOException, InterruptedException {
+
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev");
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101");
+
+ when(mockOssManagementClient.batchGetCollection(getCollectionRequestCaptor.capture())).thenReturn(BatchGetCollectionResponse.builder()
+ .collectionDetails(CollectionDetail.builder().status(CollectionStatus.CREATING).collectionEndpoint(COLLECTION_ENDPOINT).build()).build()
+ );
+
+ // Call under test
+ assertEquals(Optional.empty(), handler.handle(mockStackEvent));
+
+ verifyNoMoreInteractions(mockOpenSearchIndicesClient);
+
+ }
+
+ @Test
+ public void testHandleWithCollectionNotFound() throws IOException, InterruptedException {
+
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev");
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101");
+
+ when(mockOssManagementClient.batchGetCollection(getCollectionRequestCaptor.capture())).thenReturn(BatchGetCollectionResponse.builder()
+ .collectionDetails(Collections.emptyList()).build()
+ );
+
+ assertThrows(NoSuchElementException.class, () -> {
+ // Call under test
+ handler.handle(mockStackEvent);
+ });
+
+ verifyNoMoreInteractions(mockOpenSearchIndicesClient);
+
+ }
+
+ @Test
+ public void testHandleWithExistingIndex() throws IOException, InterruptedException {
+
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev");
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101");
+
+ when(mockOssManagementClient.batchGetCollection(getCollectionRequestCaptor.capture())).thenReturn(BatchGetCollectionResponse.builder()
+ .collectionDetails(CollectionDetail.builder().status(CollectionStatus.ACTIVE).collectionEndpoint(COLLECTION_ENDPOINT).build()).build()
+ );
+
+ when(mockOpenSearchClientFactory.getIndicesClient(COLLECTION_ENDPOINT)).thenReturn(mockOpenSearchIndicesClient);
+
+ when(mockOpenSearchIndicesClient.exists(existRequestCaptor.capture())).thenReturn(new BooleanResponse(true));
+
+ // Call under test
+ assertEquals(Optional.of("index-already-exists"), handler.handle(mockStackEvent));
+
+ assertEquals(
+ BatchGetCollectionRequest.builder().names("dev-101-synhelp").build(),
+ BatchGetCollectionRequest.builder().applyMutation(getCollectionRequestCaptor.getValue()).build()
+ );
+
+ ExistsRequest existRequest = existRequestCaptor.getValue();
+
+ assertEquals(List.of("vector-idx"), existRequest.index());
+
+ verifyNoMoreInteractions(mockOpenSearchIndicesClient);
+
+ }
+
+ @Test
+ public void testHandleWithIOException() throws IOException {
+
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev");
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101");
+
+ when(mockOssManagementClient.batchGetCollection(getCollectionRequestCaptor.capture())).thenReturn(BatchGetCollectionResponse.builder()
+ .collectionDetails(CollectionDetail.builder().status(CollectionStatus.ACTIVE).collectionEndpoint(COLLECTION_ENDPOINT).build()).build()
+ );
+
+ when(mockOpenSearchClientFactory.getIndicesClient(COLLECTION_ENDPOINT)).thenReturn(mockOpenSearchIndicesClient);
+
+ IOException ex = new IOException("nope");
+
+ when(mockOpenSearchIndicesClient.exists(existRequestCaptor.capture())).thenThrow(ex);
+
+ IllegalStateException result = assertThrows(IllegalStateException.class, () -> {
+ // Call under test
+ handler.handle(mockStackEvent);
+ });
+
+ assertEquals(ex, result.getCause());
+
+ assertEquals(
+ BatchGetCollectionRequest.builder().names("dev-101-synhelp").build(),
+ BatchGetCollectionRequest.builder().applyMutation(getCollectionRequestCaptor.getValue()).build()
+ );
+
+ ExistsRequest existRequest = existRequestCaptor.getValue();
+
+ assertEquals(List.of("vector-idx"), existRequest.index());
+
+ verifyNoMoreInteractions(mockOpenSearchIndicesClient);
+
+ }
+
+
+}
diff --git a/src/test/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpKnowledgeBaseDataSourceSyncTest.java b/src/test/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpKnowledgeBaseDataSourceSyncTest.java
new file mode 100644
index 00000000..26aa28b8
--- /dev/null
+++ b/src/test/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpKnowledgeBaseDataSourceSyncTest.java
@@ -0,0 +1,388 @@
+package org.sagebionetworks.template.repo.bedrock;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.when;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.NoSuchElementException;
+import java.util.Optional;
+import java.util.function.Consumer;
+
+import org.apache.logging.log4j.Logger;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+import org.sagebionetworks.template.Constants;
+import org.sagebionetworks.template.LoggerFactory;
+import org.sagebionetworks.template.ThreadProvider;
+import org.sagebionetworks.template.WaitConditionHandler;
+import org.sagebionetworks.template.config.RepoConfiguration;
+
+import com.amazonaws.services.cloudformation.model.StackEvent;
+
+import software.amazon.awssdk.services.bedrockagent.BedrockAgentClient;
+import software.amazon.awssdk.services.bedrockagent.model.DataSourceSummary;
+import software.amazon.awssdk.services.bedrockagent.model.GetIngestionJobRequest;
+import software.amazon.awssdk.services.bedrockagent.model.GetIngestionJobResponse;
+import software.amazon.awssdk.services.bedrockagent.model.IngestionJob;
+import software.amazon.awssdk.services.bedrockagent.model.IngestionJobStatus;
+import software.amazon.awssdk.services.bedrockagent.model.IngestionJobSummary;
+import software.amazon.awssdk.services.bedrockagent.model.KnowledgeBaseSummary;
+import software.amazon.awssdk.services.bedrockagent.model.ListDataSourcesRequest;
+import software.amazon.awssdk.services.bedrockagent.model.ListDataSourcesResponse;
+import software.amazon.awssdk.services.bedrockagent.model.ListIngestionJobsRequest;
+import software.amazon.awssdk.services.bedrockagent.model.ListIngestionJobsResponse;
+import software.amazon.awssdk.services.bedrockagent.model.ListKnowledgeBasesRequest;
+import software.amazon.awssdk.services.bedrockagent.model.ListKnowledgeBasesResponse;
+import software.amazon.awssdk.services.bedrockagent.model.StartIngestionJobRequest;
+import software.amazon.awssdk.services.bedrockagent.model.StartIngestionJobResponse;
+import software.amazon.awssdk.services.bedrockagent.paginators.ListDataSourcesIterable;
+import software.amazon.awssdk.services.bedrockagent.paginators.ListKnowledgeBasesIterable;
+
+@ExtendWith(MockitoExtension.class)
+public class SynapseHelpKnowledgeBaseDataSourceSyncTest {
+
+ private static final String KNOWLEDGE_BASE_ID = "123";
+ private static final String DATA_SOURCE_ID = "456";
+ private static final String JOB_HASH = "027161e3d3c63c5680c1e4d38da31892f5a32797600fb68bd1e9ab4bc95ceb8a";
+ private static final String JOB_ID = "job-id";
+
+ @Mock
+ private LoggerFactory mockLoggerFactory;
+ @Mock
+ private BedrockAgentClient mockBedrockAgentClient;
+ @Mock
+ private RepoConfiguration mockConfig;
+ @Mock
+ private ThreadProvider mockThreadProvider;
+
+ private WaitConditionHandler handler;
+
+ @Mock
+ private Logger mockLogger;
+
+ @Mock
+ private StackEvent mockStackEvent;
+
+ @Captor
+ private ArgumentCaptor> listKnowledgeBasesRequestCaptor;
+ @Captor
+ private ArgumentCaptor> listDataSourceRequestCaptor;
+ @Captor
+ private ArgumentCaptor> listIngestionRequestCaptor;
+ @Captor
+ private ArgumentCaptor> startIngestionJobRequestCaptor;
+ @Captor
+ private ArgumentCaptor> getIngestionJobRequestCaptor;
+
+ @BeforeEach
+ public void before() {
+ when(mockLoggerFactory.getLogger(any())).thenReturn(mockLogger);
+ handler = new SynapseHelpKnowledgeBaseDataSourceSync(mockLoggerFactory, mockConfig, mockThreadProvider, mockBedrockAgentClient);
+ }
+
+ @Test
+ public void testGetWaitConditionId() {
+ assertEquals("SynapseHelpKnowledgeBaseDataSourceSyncWaitCondition", handler.getWaitConditionId());
+ }
+
+ @Test
+ public void testHandle() throws InterruptedException {
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev");
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101");
+
+ String knowledgeBaseName = "dev-101-synhelp-knowledge-base";
+
+ when(mockBedrockAgentClient.listKnowledgeBases(any(ListKnowledgeBasesRequest.class))).thenReturn(
+ ListKnowledgeBasesResponse.builder().knowledgeBaseSummaries(List.of(
+ KnowledgeBaseSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).name(knowledgeBaseName).build()
+ )).build()
+ );
+
+ when(mockBedrockAgentClient.listKnowledgeBasesPaginator(listKnowledgeBasesRequestCaptor.capture())).thenReturn(
+ new ListKnowledgeBasesIterable(mockBedrockAgentClient, ListKnowledgeBasesRequest.builder().build())
+ );
+
+ String dataSourceName = "dev-101-synhelp-datasource";
+
+ when(mockBedrockAgentClient.listDataSources(any(ListDataSourcesRequest.class))).thenReturn(
+ ListDataSourcesResponse.builder().dataSourceSummaries(List.of(
+ DataSourceSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).name(dataSourceName).build()
+ )).build()
+ );
+
+ when(mockBedrockAgentClient.listDataSourcesPaginator(listDataSourceRequestCaptor.capture())).thenReturn(
+ new ListDataSourcesIterable(mockBedrockAgentClient, ListDataSourcesRequest.builder().build())
+ );
+
+ when(mockBedrockAgentClient.listIngestionJobs(listIngestionRequestCaptor.capture())).thenReturn(ListIngestionJobsResponse.builder().build());
+
+ IngestionJob job = IngestionJob.builder()
+ .knowledgeBaseId(KNOWLEDGE_BASE_ID)
+ .dataSourceId(DATA_SOURCE_ID)
+ .ingestionJobId(JOB_ID)
+ .status(IngestionJobStatus.STARTING)
+ .build();
+
+ when(mockBedrockAgentClient.startIngestionJob(startIngestionJobRequestCaptor.capture())).thenReturn(
+ StartIngestionJobResponse.builder().ingestionJob(job).build()
+ );
+
+ when(mockBedrockAgentClient.getIngestionJob(getIngestionJobRequestCaptor.capture())).thenReturn(
+ GetIngestionJobResponse.builder().ingestionJob(job.copy(b -> b
+ .status(IngestionJobStatus.COMPLETE)
+ .statistics(stats -> stats
+ .numberOfDocumentsScanned(10L)
+ .numberOfNewDocumentsIndexed(5L)
+ .numberOfDocumentsFailed(5L)
+ )
+ )).build()
+ );
+
+ // Call under test
+ assertEquals(Optional.of("sync-completed"), handler.handle(mockStackEvent));
+
+ assertEquals(
+ ListKnowledgeBasesRequest.builder().build(),
+ ListKnowledgeBasesRequest.builder().applyMutation(listKnowledgeBasesRequestCaptor.getValue()).build()
+ );
+
+ assertEquals(
+ ListDataSourcesRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).build(),
+ ListDataSourcesRequest.builder().applyMutation(listDataSourceRequestCaptor.getValue()).build()
+ );
+
+ assertEquals(
+ ListIngestionJobsRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).maxResults(1).build(),
+ ListIngestionJobsRequest.builder().applyMutation(listIngestionRequestCaptor.getValue()).build()
+ );
+
+ assertEquals(
+ StartIngestionJobRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).clientToken(JOB_HASH).build(),
+ StartIngestionJobRequest.builder().applyMutation(startIngestionJobRequestCaptor.getValue()).build()
+ );
+
+ assertEquals(
+ GetIngestionJobRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).ingestionJobId(JOB_ID).build(),
+ GetIngestionJobRequest.builder().applyMutation(getIngestionJobRequestCaptor.getValue()).build()
+ );
+ }
+
+ @Test
+ public void testHandleWithKnowledgeBaseNotFound() throws InterruptedException {
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev");
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101");
+
+ when(mockBedrockAgentClient.listKnowledgeBases(any(ListKnowledgeBasesRequest.class))).thenReturn(
+ ListKnowledgeBasesResponse.builder().knowledgeBaseSummaries(Collections.emptyList()).build()
+ );
+
+ when(mockBedrockAgentClient.listKnowledgeBasesPaginator(listKnowledgeBasesRequestCaptor.capture())).thenReturn(
+ new ListKnowledgeBasesIterable(mockBedrockAgentClient, ListKnowledgeBasesRequest.builder().build())
+ );
+
+ assertThrows(NoSuchElementException.class, () -> {
+ // Call under test
+ handler.handle(mockStackEvent);
+ });
+
+ assertEquals(
+ ListKnowledgeBasesRequest.builder().build(),
+ ListKnowledgeBasesRequest.builder().applyMutation(listKnowledgeBasesRequestCaptor.getValue()).build()
+ );
+
+ verifyNoMoreInteractions(mockBedrockAgentClient);
+ }
+
+ @Test
+ public void testHandleWithDataSourceNotFound() throws InterruptedException {
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev");
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101");
+
+ String knowledgeBaseName = "dev-101-synhelp-knowledge-base";
+
+ when(mockBedrockAgentClient.listKnowledgeBases(any(ListKnowledgeBasesRequest.class))).thenReturn(
+ ListKnowledgeBasesResponse.builder().knowledgeBaseSummaries(List.of(
+ KnowledgeBaseSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).name(knowledgeBaseName).build()
+ )).build()
+ );
+
+ when(mockBedrockAgentClient.listKnowledgeBasesPaginator(listKnowledgeBasesRequestCaptor.capture())).thenReturn(
+ new ListKnowledgeBasesIterable(mockBedrockAgentClient, ListKnowledgeBasesRequest.builder().build())
+ );
+
+ when(mockBedrockAgentClient.listDataSources(any(ListDataSourcesRequest.class))).thenReturn(
+ ListDataSourcesResponse.builder().dataSourceSummaries(Collections.emptyList()).build()
+ );
+
+ when(mockBedrockAgentClient.listDataSourcesPaginator(listDataSourceRequestCaptor.capture())).thenReturn(
+ new ListDataSourcesIterable(mockBedrockAgentClient, ListDataSourcesRequest.builder().build())
+ );
+
+ assertThrows(NoSuchElementException.class, () -> {
+ // Call under test
+ handler.handle(mockStackEvent);
+ });
+
+ assertEquals(
+ ListKnowledgeBasesRequest.builder().build(),
+ ListKnowledgeBasesRequest.builder().applyMutation(listKnowledgeBasesRequestCaptor.getValue()).build()
+ );
+
+ assertEquals(
+ ListDataSourcesRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).build(),
+ ListDataSourcesRequest.builder().applyMutation(listDataSourceRequestCaptor.getValue()).build()
+ );
+
+ verifyNoMoreInteractions(mockBedrockAgentClient);
+ }
+
+ @Test
+ public void testHandleWithJobAlreadyStarted() throws InterruptedException {
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev");
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101");
+
+ String knowledgeBaseName = "dev-101-synhelp-knowledge-base";
+
+ when(mockBedrockAgentClient.listKnowledgeBases(any(ListKnowledgeBasesRequest.class))).thenReturn(
+ ListKnowledgeBasesResponse.builder().knowledgeBaseSummaries(List.of(
+ KnowledgeBaseSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).name(knowledgeBaseName).build()
+ )).build()
+ );
+
+ when(mockBedrockAgentClient.listKnowledgeBasesPaginator(listKnowledgeBasesRequestCaptor.capture())).thenReturn(
+ new ListKnowledgeBasesIterable(mockBedrockAgentClient, ListKnowledgeBasesRequest.builder().build())
+ );
+
+ String dataSourceName = "dev-101-synhelp-datasource";
+
+ when(mockBedrockAgentClient.listDataSources(any(ListDataSourcesRequest.class))).thenReturn(
+ ListDataSourcesResponse.builder().dataSourceSummaries(List.of(
+ DataSourceSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).name(dataSourceName).build()
+ )).build()
+ );
+
+ when(mockBedrockAgentClient.listDataSourcesPaginator(listDataSourceRequestCaptor.capture())).thenReturn(
+ new ListDataSourcesIterable(mockBedrockAgentClient, ListDataSourcesRequest.builder().build())
+ );
+
+ when(mockBedrockAgentClient.listIngestionJobs(listIngestionRequestCaptor.capture())).thenReturn(
+ ListIngestionJobsResponse.builder().ingestionJobSummaries(List.of(
+ IngestionJobSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).ingestionJobId(JOB_ID).build()
+ )).build()
+ );
+
+ // Call under test
+ assertEquals(Optional.of("sync-started"), handler.handle(mockStackEvent));
+
+ assertEquals(
+ ListKnowledgeBasesRequest.builder().build(),
+ ListKnowledgeBasesRequest.builder().applyMutation(listKnowledgeBasesRequestCaptor.getValue()).build()
+ );
+
+ assertEquals(
+ ListDataSourcesRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).build(),
+ ListDataSourcesRequest.builder().applyMutation(listDataSourceRequestCaptor.getValue()).build()
+ );
+
+ assertEquals(
+ ListIngestionJobsRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).maxResults(1).build(),
+ ListIngestionJobsRequest.builder().applyMutation(listIngestionRequestCaptor.getValue()).build()
+ );
+
+ verifyNoMoreInteractions(mockBedrockAgentClient);
+ }
+
+ @Test
+ public void testHandleWithJobFailure() throws InterruptedException {
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev");
+ when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101");
+
+ String knowledgeBaseName = "dev-101-synhelp-knowledge-base";
+
+ when(mockBedrockAgentClient.listKnowledgeBases(any(ListKnowledgeBasesRequest.class))).thenReturn(
+ ListKnowledgeBasesResponse.builder().knowledgeBaseSummaries(List.of(
+ KnowledgeBaseSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).name(knowledgeBaseName).build()
+ )).build()
+ );
+
+ when(mockBedrockAgentClient.listKnowledgeBasesPaginator(listKnowledgeBasesRequestCaptor.capture())).thenReturn(
+ new ListKnowledgeBasesIterable(mockBedrockAgentClient, ListKnowledgeBasesRequest.builder().build())
+ );
+
+ String dataSourceName = "dev-101-synhelp-datasource";
+
+ when(mockBedrockAgentClient.listDataSources(any(ListDataSourcesRequest.class))).thenReturn(
+ ListDataSourcesResponse.builder().dataSourceSummaries(List.of(
+ DataSourceSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).name(dataSourceName).build()
+ )).build()
+ );
+
+ when(mockBedrockAgentClient.listDataSourcesPaginator(listDataSourceRequestCaptor.capture())).thenReturn(
+ new ListDataSourcesIterable(mockBedrockAgentClient, ListDataSourcesRequest.builder().build())
+ );
+
+ when(mockBedrockAgentClient.listIngestionJobs(listIngestionRequestCaptor.capture())).thenReturn(ListIngestionJobsResponse.builder().build());
+
+ IngestionJob job = IngestionJob.builder()
+ .knowledgeBaseId(KNOWLEDGE_BASE_ID)
+ .dataSourceId(DATA_SOURCE_ID)
+ .ingestionJobId(JOB_ID)
+ .status(IngestionJobStatus.STARTING)
+ .build();
+
+ when(mockBedrockAgentClient.startIngestionJob(startIngestionJobRequestCaptor.capture())).thenReturn(
+ StartIngestionJobResponse.builder().ingestionJob(job).build()
+ );
+
+ when(mockBedrockAgentClient.getIngestionJob(getIngestionJobRequestCaptor.capture())).thenReturn(
+ GetIngestionJobResponse.builder().ingestionJob(job).build(),
+ GetIngestionJobResponse.builder().ingestionJob(job.copy(b -> b
+ .status(IngestionJobStatus.FAILED)
+ .failureReasons("Some failure")
+ )).build()
+ );
+
+ assertEquals("Sync job job-id failed (Status: FAILED, Failures: [Some failure])", assertThrows(IllegalStateException.class, () -> {
+ // Call under test
+ handler.handle(mockStackEvent);
+ }).getMessage());
+
+ verify(mockBedrockAgentClient, times(2)).getIngestionJob(getIngestionJobRequestCaptor.capture());
+
+ assertEquals(
+ ListKnowledgeBasesRequest.builder().build(),
+ ListKnowledgeBasesRequest.builder().applyMutation(listKnowledgeBasesRequestCaptor.getValue()).build()
+ );
+
+ assertEquals(
+ ListDataSourcesRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).build(),
+ ListDataSourcesRequest.builder().applyMutation(listDataSourceRequestCaptor.getValue()).build()
+ );
+
+ assertEquals(
+ ListIngestionJobsRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).maxResults(1).build(),
+ ListIngestionJobsRequest.builder().applyMutation(listIngestionRequestCaptor.getValue()).build()
+ );
+
+ assertEquals(
+ StartIngestionJobRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).clientToken(JOB_HASH).build(),
+ StartIngestionJobRequest.builder().applyMutation(startIngestionJobRequestCaptor.getValue()).build()
+ );
+
+ assertEquals(
+ GetIngestionJobRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).ingestionJobId(JOB_ID).build(),
+ GetIngestionJobRequest.builder().applyMutation(getIngestionJobRequestCaptor.getValue()).build()
+ );
+ }
+}