diff --git a/pom.xml b/pom.xml index 381af0fa..fc65fe7a 100644 --- a/pom.xml +++ b/pom.xml @@ -23,6 +23,18 @@ + + + + + software.amazon.awssdk + bom + ${amazon.sdk.v2.version} + pom + import + + + @@ -225,6 +237,23 @@ guava 31.1-jre + + org.opensearch.client + opensearch-java + 2.8.1 + + + software.amazon.awssdk + opensearchserverless + + + software.amazon.awssdk + apache-client + + + software.amazon.awssdk + bedrockagent + @@ -275,6 +304,7 @@ 1.78.1 1.12.296 + 2.29.34 2.17.1 5.4.1 5.4.1 diff --git a/src/main/java/org/sagebionetworks/template/CloudFormationClient.java b/src/main/java/org/sagebionetworks/template/CloudFormationClient.java index d4ddde1a..56cd2021 100644 --- a/src/main/java/org/sagebionetworks/template/CloudFormationClient.java +++ b/src/main/java/org/sagebionetworks/template/CloudFormationClient.java @@ -1,6 +1,7 @@ package org.sagebionetworks.template; import java.util.Optional; +import java.util.Set; import java.util.stream.Stream; import com.amazonaws.services.cloudformation.model.AmazonCloudFormationException; @@ -20,7 +21,7 @@ public interface CloudFormationClient { * @param stackName * @return */ - public boolean doesStackNameExist(String stackName); + boolean doesStackNameExist(String stackName); /** * Describe a stack given its name. @@ -29,7 +30,7 @@ public interface CloudFormationClient { * @return * @throws AmazonCloudFormationException When the stack does not exist. */ - public Optional describeStack(String stackName); + Optional describeStack(String stackName); /** * Update a stack with the given name using the provided template body. @@ -38,7 +39,7 @@ public interface CloudFormationClient { * @param templateBody * @return StackId */ - public void updateStack(CreateOrUpdateStackRequest request); + void updateStack(CreateOrUpdateStackRequest request); /** * Create a stack with the given name using the provided template body. @@ -47,7 +48,7 @@ public interface CloudFormationClient { * @param templateBody * @return StackId */ - public void createStack(CreateOrUpdateStackRequest request); + void createStack(CreateOrUpdateStackRequest request); /** * If a stack does not exist the stack will be created else the stack will be @@ -57,7 +58,7 @@ public interface CloudFormationClient { * @param templateBody * @return StackId */ - public void createOrUpdateStack(CreateOrUpdateStackRequest request); + void createOrUpdateStack(CreateOrUpdateStackRequest request); /** * Wait for the given stack to complete. @@ -65,25 +66,34 @@ public interface CloudFormationClient { * @return * @throws InterruptedException */ - public Optional waitForStackToComplete(String stackName) throws InterruptedException; + Optional waitForStackToComplete(String stackName) throws InterruptedException; + + /** + * Wait for the given stack to complete and handles any wait condition that is provided in the map (where the key is the logical id of the wait condition) + * @param stackName + * @param waitConditionHandlers + * @return + * @throws InterruptedException + */ + Optional waitForStackToComplete(String stackName, Set waitConditionHandlers) throws InterruptedException; /** * * @param stackName * @return */ - public String getOutput(String stackName, String outputKey); + String getOutput(String stackName, String outputKey); /** * Stream over all stacks. * @return */ - public Stream streamOverAllStacks(); + Stream streamOverAllStacks(); /** * Delete a stack by name * @param stackName */ - public void deleteStack(String stackName); + void deleteStack(String stackName); } diff --git a/src/main/java/org/sagebionetworks/template/CloudFormationClientImpl.java b/src/main/java/org/sagebionetworks/template/CloudFormationClientImpl.java index b1ac6f01..a82cb9ca 100644 --- a/src/main/java/org/sagebionetworks/template/CloudFormationClientImpl.java +++ b/src/main/java/org/sagebionetworks/template/CloudFormationClientImpl.java @@ -3,12 +3,16 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.util.Collections; +import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.Spliterator; import java.util.Spliterators; import java.util.UUID; import java.util.function.Function; +import java.util.stream.Collectors; import java.util.stream.Stream; import java.util.stream.StreamSupport; @@ -21,11 +25,15 @@ import com.amazonaws.services.cloudformation.model.CreateStackRequest; import com.amazonaws.services.cloudformation.model.CreateStackResult; import com.amazonaws.services.cloudformation.model.DeleteStackRequest; -import com.amazonaws.services.cloudformation.model.DeleteStackResult; +import com.amazonaws.services.cloudformation.model.DescribeStackEventsRequest; import com.amazonaws.services.cloudformation.model.DescribeStacksRequest; import com.amazonaws.services.cloudformation.model.DescribeStacksResult; import com.amazonaws.services.cloudformation.model.Output; +import com.amazonaws.services.cloudformation.model.ResourceSignalStatus; +import com.amazonaws.services.cloudformation.model.ResourceStatus; +import com.amazonaws.services.cloudformation.model.SignalResourceRequest; import com.amazonaws.services.cloudformation.model.Stack; +import com.amazonaws.services.cloudformation.model.StackEvent; import com.amazonaws.services.cloudformation.model.StackStatus; import com.amazonaws.services.cloudformation.model.UpdateStackRequest; import com.amazonaws.services.cloudformation.model.UpdateStackResult; @@ -225,7 +233,19 @@ public boolean isStartedInUpdateRollbackComplete(String stackName) { @Override public Optional waitForStackToComplete(String stackName) throws InterruptedException { + return waitForStackToComplete(stackName, Collections.emptySet()); + } + + @Override + public Optional waitForStackToComplete(String stackName, Set waitConditionHandlers) throws InterruptedException { boolean startedInUpdateRollbackComplete = isStartedInUpdateRollbackComplete(stackName); // Initial state + + Map waitConditionHandlerMap = waitConditionHandlers.stream() + .collect(Collectors.toMap(WaitConditionHandler::getWaitConditionId, Function.identity())); + + // To avoid re-processing the same wait condition multiple times we need to keep track of them + Set processedWaitConditionSet = new HashSet<>(); + long start = threadProvider.currentTimeMillis(); while (true) { long elapse = threadProvider.currentTimeMillis() - start; @@ -246,10 +266,10 @@ public Optional waitForStackToComplete(String stackName) throws Interrupt return optional; case CREATE_IN_PROGRESS: case UPDATE_IN_PROGRESS: + handleWaitConditions(stackName, waitConditionHandlerMap, processedWaitConditionSet); case DELETE_IN_PROGRESS: case UPDATE_COMPLETE_CLEANUP_IN_PROGRESS: - logger.info("Waiting for stack: '" + stackName + "' to complete. Current status: " + status.name() - + "..."); + logger.info("Waiting for stack: '" + stackName + "' to complete. Current status: " + status.name() + "..."); threadProvider.sleep(SLEEP_TIME); break; case UPDATE_ROLLBACK_COMPLETE: @@ -262,6 +282,84 @@ public Optional waitForStackToComplete(String stackName) throws Interrupt } } } + + void handleWaitConditions(String stackName, Map waitConditionHandlers, Set processedWaitConditionSet) { + if (waitConditionHandlers.isEmpty()) { + return; + } + + Set waitConditionEventIds = new HashSet<>(); + + List waitConditionEvents = cloudFormationClient.describeStackEvents( + new DescribeStackEventsRequest().withStackName(stackName) + ) + .getStackEvents() + .stream() + .filter(event -> "AWS::CloudFormation::WaitCondition".equals(event.getResourceType())) + // We only need the latest event for each wait condition + .filter(event -> waitConditionEventIds.add(event.getLogicalResourceId())) + .filter(event -> ResourceStatus.CREATE_IN_PROGRESS.equals(ResourceStatus.fromValue(event.getResourceStatus()))) + .collect(Collectors.toList()); + + for (StackEvent waitConditionEvent : waitConditionEvents) { + String waitConditionId = waitConditionEvent.getLogicalResourceId(); + + if (processedWaitConditionSet.contains(waitConditionId)) { + logger.warn("Wait condition {} already processed, skipping.", waitConditionId); + continue; + } + + logger.info("Processing wait condition {} (Status: {}, Reason: {})...", waitConditionId, waitConditionEvent.getResourceStatus(), waitConditionEvent.getResourceStatusReason()); + + WaitConditionHandler waitConditionHandler = waitConditionHandlers.get(waitConditionId); + + if (waitConditionHandler == null) { + + cloudFormationClient.signalResource(new SignalResourceRequest() + .withStackName(stackName) + .withLogicalResourceId(waitConditionId) + .withStatus(ResourceSignalStatus.FAILURE) + .withUniqueId("handler-not-found") + ); + + throw new IllegalStateException("Processing wait condition " + waitConditionId + " failed: could not find an handler."); + + } else { + logger.info("Processing wait condition {} started...", waitConditionId); + + try { + waitConditionHandler.handle(waitConditionEvent).ifPresentOrElse(signalId -> { + logger.info("Processing wait condition {} completed with signal {}.", waitConditionId, signalId); + + cloudFormationClient.signalResource(new SignalResourceRequest() + .withStackName(stackName) + .withLogicalResourceId(waitConditionId) + .withStatus(ResourceSignalStatus.SUCCESS) + .withUniqueId(signalId) + ); + + processedWaitConditionSet.add(waitConditionId); + }, () -> { + logger.info("Processing wait condition {} didn't return a signal, will process later.", waitConditionId); + }); + + } catch (Exception e) { + logger.error("Processing wait condition {} failed exceptionally: ", waitConditionId, e); + + cloudFormationClient.signalResource(new SignalResourceRequest() + .withStackName(stackName) + .withLogicalResourceId(waitConditionId) + .withStatus(ResourceSignalStatus.FAILURE) + .withUniqueId("handler-failed") + ); + + throw new IllegalStateException("Processing wait condition " + waitConditionId + " failed.", e); + } + } + + } + + } @Override public String getOutput(String stackName, String outputKey) { diff --git a/src/main/java/org/sagebionetworks/template/Constants.java b/src/main/java/org/sagebionetworks/template/Constants.java index f261ff88..aab02c19 100644 --- a/src/main/java/org/sagebionetworks/template/Constants.java +++ b/src/main/java/org/sagebionetworks/template/Constants.java @@ -221,6 +221,7 @@ public class Constants { public static final String HOSTED_ZONE = "hostedZone"; public static final String VPC_ENDPOINTS_COLOR = "vpcEndpointsColor"; public static final String VPC_ENDPOINTS_AZ = "vpcEndpointsAz"; + public static final String IDENTITY_ARN = "identityArn"; public static final String CAPABILITY_NAMED_IAM = "CAPABILITY_NAMED_IAM"; public static final String OUTPUT_NAME_SUFFIX_REPOSITORY_DB_ENDPOINT = "RepositoryDBEndpoint"; diff --git a/src/main/java/org/sagebionetworks/template/OpenSearchClientFactory.java b/src/main/java/org/sagebionetworks/template/OpenSearchClientFactory.java new file mode 100644 index 00000000..a8763190 --- /dev/null +++ b/src/main/java/org/sagebionetworks/template/OpenSearchClientFactory.java @@ -0,0 +1,9 @@ +package org.sagebionetworks.template; + +import org.opensearch.client.opensearch.indices.OpenSearchIndicesClient; + +public interface OpenSearchClientFactory { + + OpenSearchIndicesClient getIndicesClient(String collectionEndpoint); + +} diff --git a/src/main/java/org/sagebionetworks/template/OpenSearchClientFactoryImpl.java b/src/main/java/org/sagebionetworks/template/OpenSearchClientFactoryImpl.java new file mode 100644 index 00000000..b85743ce --- /dev/null +++ b/src/main/java/org/sagebionetworks/template/OpenSearchClientFactoryImpl.java @@ -0,0 +1,33 @@ +package org.sagebionetworks.template; + +import org.opensearch.client.opensearch.indices.OpenSearchIndicesClient; +import org.opensearch.client.transport.aws.AwsSdk2Transport; +import org.opensearch.client.transport.aws.AwsSdk2TransportOptions; + +import com.google.inject.Inject; + +import software.amazon.awssdk.http.SdkHttpClient; +import software.amazon.awssdk.regions.Region; + +public class OpenSearchClientFactoryImpl implements OpenSearchClientFactory { + + private SdkHttpClient httpClient; + + @Inject + public OpenSearchClientFactoryImpl(SdkHttpClient httpClient) { + this.httpClient = httpClient; + } + + public OpenSearchIndicesClient getIndicesClient(String collectionEndpoint) { + return new OpenSearchIndicesClient( + new AwsSdk2Transport( + httpClient, + collectionEndpoint.replace("https://", ""), + "aoss", + Region.US_EAST_1, + AwsSdk2TransportOptions.builder().build() + ) + ); + } + +} diff --git a/src/main/java/org/sagebionetworks/template/TemplateGuiceModule.java b/src/main/java/org/sagebionetworks/template/TemplateGuiceModule.java index 366c2ad4..d081361a 100644 --- a/src/main/java/org/sagebionetworks/template/TemplateGuiceModule.java +++ b/src/main/java/org/sagebionetworks/template/TemplateGuiceModule.java @@ -1,35 +1,17 @@ package org.sagebionetworks.template; -import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; -import com.amazonaws.regions.Regions; -import com.amazonaws.services.athena.AmazonAthena; -import com.amazonaws.services.athena.AmazonAthenaClientBuilder; -import com.amazonaws.services.cloudformation.AmazonCloudFormation; -import com.amazonaws.services.cloudformation.AmazonCloudFormationClientBuilder; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.AmazonEC2ClientBuilder; -import com.amazonaws.services.elasticbeanstalk.AWSElasticBeanstalk; -import com.amazonaws.services.elasticbeanstalk.AWSElasticBeanstalkClientBuilder; -import com.amazonaws.services.elasticloadbalancingv2.AmazonElasticLoadBalancing; -import com.amazonaws.services.elasticloadbalancingv2.AmazonElasticLoadBalancingClientBuilder; -import com.amazonaws.services.glue.AWSGlue; -import com.amazonaws.services.glue.AWSGlueClientBuilder; -import com.amazonaws.services.kms.AWSKMS; -import com.amazonaws.services.kms.AWSKMSAsyncClientBuilder; -import com.amazonaws.services.lambda.AWSLambda; -import com.amazonaws.services.lambda.AWSLambdaClientBuilder; -import com.amazonaws.services.route53.AmazonRoute53; -import com.amazonaws.services.route53.AmazonRoute53ClientBuilder; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.secretsmanager.AWSSecretsManager; -import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; -import com.amazonaws.services.securitytoken.AWSSecurityTokenService; -import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder; -import com.amazonaws.services.simpleemail.AmazonSimpleEmailService; -import com.amazonaws.services.simpleemail.AmazonSimpleEmailServiceClientBuilder; -import com.google.inject.Provides; -import com.google.inject.multibindings.Multibinder; +import static org.sagebionetworks.template.Constants.APPCONFIG_CONFIG_FILE; +import static org.sagebionetworks.template.Constants.ATHENA_QUERIES_CONFIG_FILE; +import static org.sagebionetworks.template.Constants.CLOUDWATCH_LOGS_CONFIG_FILE; +import static org.sagebionetworks.template.Constants.DATAWAREHOUSE_CONFIG_FILE; +import static org.sagebionetworks.template.Constants.KINESIS_CONFIG_FILE; +import static org.sagebionetworks.template.Constants.LOAD_BALANCER_ALARM_CONFIG_FILE; +import static org.sagebionetworks.template.Constants.S3_CONFIG_FILE; +import static org.sagebionetworks.template.Constants.SNS_AND_SQS_CONFIG_FILE; +import static org.sagebionetworks.template.TemplateUtils.loadFromJsonFile; + +import java.io.IOException; + import org.apache.http.client.HttpClient; import org.apache.http.impl.client.HttpClientBuilder; import org.apache.velocity.app.VelocityEngine; @@ -97,6 +79,8 @@ import org.sagebionetworks.template.repo.beanstalk.ssl.CertificateBuilderImpl; import org.sagebionetworks.template.repo.beanstalk.ssl.ElasticBeanstalkExtentionBuilder; import org.sagebionetworks.template.repo.beanstalk.ssl.ElasticBeanstalkExtentionBuilderImpl; +import org.sagebionetworks.template.repo.bedrock.SynapseHelpCollectionIndexCreation; +import org.sagebionetworks.template.repo.bedrock.SynapseHelpKnowledgeBaseDataSourceSync; import org.sagebionetworks.template.repo.cloudwatchlogs.CloudwatchLogsConfig; import org.sagebionetworks.template.repo.cloudwatchlogs.CloudwatchLogsConfigValidator; import org.sagebionetworks.template.repo.cloudwatchlogs.CloudwatchLogsVelocityContextProvider; @@ -123,17 +107,41 @@ import org.sagebionetworks.war.WarAppender; import org.sagebionetworks.war.WarAppenderImpl; -import java.io.IOException; -import static org.sagebionetworks.template.Constants.ATHENA_QUERIES_CONFIG_FILE; -import static org.sagebionetworks.template.Constants.CLOUDWATCH_LOGS_CONFIG_FILE; -import static org.sagebionetworks.template.Constants.DATAWAREHOUSE_CONFIG_FILE; -import static org.sagebionetworks.template.Constants.KINESIS_CONFIG_FILE; -import static org.sagebionetworks.template.Constants.LOAD_BALANCER_ALARM_CONFIG_FILE; -import static org.sagebionetworks.template.Constants.S3_CONFIG_FILE; -import static org.sagebionetworks.template.Constants.SNS_AND_SQS_CONFIG_FILE; -import static org.sagebionetworks.template.Constants.APPCONFIG_CONFIG_FILE; +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.regions.Regions; +import com.amazonaws.services.athena.AmazonAthena; +import com.amazonaws.services.athena.AmazonAthenaClientBuilder; +import com.amazonaws.services.cloudformation.AmazonCloudFormation; +import com.amazonaws.services.cloudformation.AmazonCloudFormationClientBuilder; +import com.amazonaws.services.ec2.AmazonEC2; +import com.amazonaws.services.ec2.AmazonEC2ClientBuilder; +import com.amazonaws.services.elasticbeanstalk.AWSElasticBeanstalk; +import com.amazonaws.services.elasticbeanstalk.AWSElasticBeanstalkClientBuilder; +import com.amazonaws.services.elasticloadbalancingv2.AmazonElasticLoadBalancing; +import com.amazonaws.services.elasticloadbalancingv2.AmazonElasticLoadBalancingClientBuilder; +import com.amazonaws.services.glue.AWSGlue; +import com.amazonaws.services.glue.AWSGlueClientBuilder; +import com.amazonaws.services.kms.AWSKMS; +import com.amazonaws.services.kms.AWSKMSAsyncClientBuilder; +import com.amazonaws.services.lambda.AWSLambda; +import com.amazonaws.services.lambda.AWSLambdaClientBuilder; +import com.amazonaws.services.route53.AmazonRoute53; +import com.amazonaws.services.route53.AmazonRoute53ClientBuilder; +import com.amazonaws.services.s3.AmazonS3; +import com.amazonaws.services.s3.AmazonS3ClientBuilder; +import com.amazonaws.services.secretsmanager.AWSSecretsManager; +import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder; +import com.amazonaws.services.securitytoken.AWSSecurityTokenService; +import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder; +import com.amazonaws.services.simpleemail.AmazonSimpleEmailService; +import com.amazonaws.services.simpleemail.AmazonSimpleEmailServiceClientBuilder; +import com.google.inject.Provides; +import com.google.inject.multibindings.Multibinder; -import static org.sagebionetworks.template.TemplateUtils.loadFromJsonFile; +import software.amazon.awssdk.http.apache.ApacheHttpClient; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockagent.BedrockAgentClient; +import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient; public class TemplateGuiceModule extends com.google.inject.AbstractModule { @@ -191,6 +199,11 @@ protected void configure() { velocityContextProviderMultibinder.addBinding().to(KinesisFirehoseVelocityContextProvider.class); velocityContextProviderMultibinder.addBinding().to(RecurrentAthenaQueryContextProvider.class); velocityContextProviderMultibinder.addBinding().to(BedrockAgentContextProvider.class); + + Multibinder waitConditionHandlerBinder = Multibinder.newSetBinder(binder(), WaitConditionHandler.class); + + waitConditionHandlerBinder.addBinding().to(SynapseHelpCollectionIndexCreation.class); + waitConditionHandlerBinder.addBinding().to(SynapseHelpKnowledgeBaseDataSourceSync.class); } /** @@ -366,5 +379,20 @@ public S3TransferManagerFactory provideS3TransferManagerFactory(AmazonS3 s3Clien public DataWarehouseConfig dataWarehouseConfigProvider() throws IOException { return new DataWarehouseConfigValidator(loadFromJsonFile(DATAWAREHOUSE_CONFIG_FILE, DataWarehouseConfig.class)).validate(); } - + + @Provides + public OpenSearchServerlessClient ossManagementClientProvider() { + return OpenSearchServerlessClient.builder().region(Region.US_EAST_1).build(); + } + + @Provides + public BedrockAgentClient bedrockAgentClientProvider() { + return BedrockAgentClient.builder().region(Region.US_EAST_1).build(); + } + + @Provides + public OpenSearchClientFactory openSearchClientFactoryProvider() { + return new OpenSearchClientFactoryImpl(ApacheHttpClient.builder().build()); + } + } diff --git a/src/main/java/org/sagebionetworks/template/WaitConditionHandler.java b/src/main/java/org/sagebionetworks/template/WaitConditionHandler.java new file mode 100644 index 00000000..d560858d --- /dev/null +++ b/src/main/java/org/sagebionetworks/template/WaitConditionHandler.java @@ -0,0 +1,26 @@ +package org.sagebionetworks.template; + +import java.util.Optional; + +import com.amazonaws.services.cloudformation.model.StackEvent; + +/** + * Interface for an handler of a wait condition defined in a cloud formation template. Note that wait conditions updates are not supported and will be invoked only when the stack is created. + */ +public interface WaitConditionHandler { + + /** + * @return The logical resource id for the wait condition defined in the cloud formation template, this handler will be mapped to the returned id + */ + String getWaitConditionId(); + + /** + * When the last wait condition event matches the {@link #getWaitConditionId()} and the event is in the CREATE_IN_PROGRESS status this handler will + * be invoked. The handle implementation should be idempotent since it is possible that it is invoked multiple times during the stack creation. + * + * @param stackEvent + * @return An optional signal id to send back to cloud formation if the condition could be processed + */ + Optional handle(StackEvent stackEvent) throws InterruptedException; + +} diff --git a/src/main/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImpl.java b/src/main/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImpl.java index d70c86a9..aba02ee5 100644 --- a/src/main/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImpl.java +++ b/src/main/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImpl.java @@ -1,47 +1,6 @@ package org.sagebionetworks.template.repo; -import com.amazonaws.services.cloudformation.model.Output; -import com.amazonaws.services.cloudformation.model.Parameter; -import com.amazonaws.services.cloudformation.model.Stack; -import com.amazonaws.services.cloudformation.model.Tag; -import com.amazonaws.services.elasticbeanstalk.AWSElasticBeanstalk; -import com.amazonaws.services.elasticbeanstalk.model.ListPlatformVersionsRequest; -import com.amazonaws.services.elasticbeanstalk.model.ListPlatformVersionsResult; -import com.amazonaws.services.elasticbeanstalk.model.PlatformSummary; -import com.google.inject.Inject; -import org.apache.logging.log4j.Logger; -import org.apache.velocity.Template; -import org.apache.velocity.VelocityContext; -import org.apache.velocity.app.VelocityEngine; -import org.json.JSONObject; -import org.sagebionetworks.template.CloudFormationClient; -import org.sagebionetworks.template.ConfigurationPropertyNotFound; -import org.sagebionetworks.template.Constants; -import org.sagebionetworks.template.CreateOrUpdateStackRequest; -import org.sagebionetworks.template.Ec2Client; -import org.sagebionetworks.template.LoggerFactory; -import org.sagebionetworks.template.StackTagsProvider; -import org.sagebionetworks.template.config.RepoConfiguration; -import org.sagebionetworks.template.config.TimeToLive; -import org.sagebionetworks.template.repo.beanstalk.ArtifactCopy; -import org.sagebionetworks.template.repo.beanstalk.BeanstalkUtils; -import org.sagebionetworks.template.repo.beanstalk.ElasticBeanstalkSolutionStackNameProvider; -import org.sagebionetworks.template.repo.beanstalk.EnvironmentDescriptor; -import org.sagebionetworks.template.repo.beanstalk.EnvironmentType; -import org.sagebionetworks.template.repo.beanstalk.SecretBuilder; -import org.sagebionetworks.template.repo.beanstalk.SourceBundle; -import org.sagebionetworks.template.repo.cloudwatchlogs.CloudwatchLogsVelocityContextProvider; - -import java.io.StringWriter; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.LinkedList; -import java.util.List; -import java.util.Set; -import java.util.StringJoiner; -import java.util.stream.Collectors; - import static org.sagebionetworks.template.Constants.ADMIN_RULE_ACTION; import static org.sagebionetworks.template.Constants.BEANSTALK_INSTANCES_SUBNETS; import static org.sagebionetworks.template.Constants.CAPABILITY_NAMED_IAM; @@ -53,11 +12,12 @@ import static org.sagebionetworks.template.Constants.DATA_CDN_DOMAIN_NAME_FMT; import static org.sagebionetworks.template.Constants.DB_ENDPOINT_SUFFIX; import static org.sagebionetworks.template.Constants.DELETION_POLICY; -import static org.sagebionetworks.template.Constants.EC2_INSTANCE_TYPE; import static org.sagebionetworks.template.Constants.EC2_INSTANCE_MEMORY; +import static org.sagebionetworks.template.Constants.EC2_INSTANCE_TYPE; import static org.sagebionetworks.template.Constants.ENVIRONMENT; import static org.sagebionetworks.template.Constants.EXCEPTION_THROWER; import static org.sagebionetworks.template.Constants.GLOBAL_RESOURCES_EXPORT_PREFIX; +import static org.sagebionetworks.template.Constants.IDENTITY_ARN; import static org.sagebionetworks.template.Constants.INSTANCE; import static org.sagebionetworks.template.Constants.JSON_INDENT; import static org.sagebionetworks.template.Constants.MACHINE_TYPES; @@ -73,8 +33,8 @@ import static org.sagebionetworks.template.Constants.PROPERTY_KEY_BEANSTALK_SSL_ARN; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_BEANSTALK_VERSION; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_DATA_CDN_KEYPAIR_ID; -import static org.sagebionetworks.template.Constants.PROPERTY_KEY_EC2_INSTANCE_TYPE; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_EC2_INSTANCE_MEMORY; +import static org.sagebionetworks.template.Constants.PROPERTY_KEY_EC2_INSTANCE_TYPE; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ELASTICBEANSTALK_IMAGE_VERSION_AMAZONLINUX; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ELASTICBEANSTALK_IMAGE_VERSION_JAVA; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ELASTICBEANSTALK_IMAGE_VERSION_TOMCAT; @@ -86,19 +46,19 @@ import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_ALLOCATED_STORAGE; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_INSTANCE_CLASS; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_IOPS; -import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_THROUGHPUT; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_MAX_ALLOCATED_STORAGE; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_MULTI_AZ; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_STORAGE_TYPE; +import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_THROUGHPUT; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ROUTE_53_HOSTED_ZONE; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_STACK; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_INSTANCE_COUNT; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_ALLOCATED_STORAGE; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_INSTANCE_CLASS; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_IOPS; -import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_THROUGHPUT; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_MAX_ALLOCATED_STORAGE; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_STORAGE_TYPE; +import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_THROUGHPUT; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_VPC_SUBNET_COLOR; import static org.sagebionetworks.template.Constants.REPO_BEANSTALK_NUMBER; import static org.sagebionetworks.template.Constants.SHARED_EXPORT_PREFIX; @@ -111,6 +71,51 @@ import static org.sagebionetworks.template.Constants.VPC_EXPORT_PREFIX; import static org.sagebionetworks.template.Constants.VPC_SUBNET_COLOR; +import java.io.StringWriter; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; +import java.util.StringJoiner; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.Logger; +import org.apache.velocity.Template; +import org.apache.velocity.VelocityContext; +import org.apache.velocity.app.VelocityEngine; +import org.json.JSONObject; +import org.sagebionetworks.template.CloudFormationClient; +import org.sagebionetworks.template.ConfigurationPropertyNotFound; +import org.sagebionetworks.template.Constants; +import org.sagebionetworks.template.CreateOrUpdateStackRequest; +import org.sagebionetworks.template.Ec2Client; +import org.sagebionetworks.template.LoggerFactory; +import org.sagebionetworks.template.StackTagsProvider; +import org.sagebionetworks.template.WaitConditionHandler; +import org.sagebionetworks.template.config.RepoConfiguration; +import org.sagebionetworks.template.config.TimeToLive; +import org.sagebionetworks.template.repo.beanstalk.ArtifactCopy; +import org.sagebionetworks.template.repo.beanstalk.BeanstalkUtils; +import org.sagebionetworks.template.repo.beanstalk.ElasticBeanstalkSolutionStackNameProvider; +import org.sagebionetworks.template.repo.beanstalk.EnvironmentDescriptor; +import org.sagebionetworks.template.repo.beanstalk.EnvironmentType; +import org.sagebionetworks.template.repo.beanstalk.SecretBuilder; +import org.sagebionetworks.template.repo.beanstalk.SourceBundle; +import org.sagebionetworks.template.repo.cloudwatchlogs.CloudwatchLogsVelocityContextProvider; + +import com.amazonaws.services.cloudformation.model.Output; +import com.amazonaws.services.cloudformation.model.Parameter; +import com.amazonaws.services.cloudformation.model.Stack; +import com.amazonaws.services.cloudformation.model.Tag; +import com.amazonaws.services.elasticbeanstalk.AWSElasticBeanstalk; +import com.amazonaws.services.elasticbeanstalk.model.ListPlatformVersionsRequest; +import com.amazonaws.services.elasticbeanstalk.model.ListPlatformVersionsResult; +import com.amazonaws.services.elasticbeanstalk.model.PlatformSummary; +import com.amazonaws.services.securitytoken.AWSSecurityTokenService; +import com.amazonaws.services.securitytoken.model.GetCallerIdentityRequest; +import com.google.inject.Inject; + public class RepositoryTemplateBuilderImpl implements RepositoryTemplateBuilder { public static final List MACHINE_TYPE_LIST = List.of("Workers", "Repository"); public static final List POOL_TYPE_LIST = List.of("Idgen", "Main", "Migration", "Tables"); @@ -128,6 +133,8 @@ public class RepositoryTemplateBuilderImpl implements RepositoryTemplateBuilder private final CloudwatchLogsVelocityContextProvider cwlContextProvider; private final AWSElasticBeanstalk beanstalkClient; private final TimeToLive timeToLive; + private final AWSSecurityTokenService stsClient; + private final Set waitConditionHandlers; @Inject public RepositoryTemplateBuilderImpl(CloudFormationClient cloudFormationClient, VelocityEngine velocityEngine, @@ -135,7 +142,8 @@ public RepositoryTemplateBuilderImpl(CloudFormationClient cloudFormationClient, SecretBuilder secretBuilder, Set contextProviders, ElasticBeanstalkSolutionStackNameProvider elasticBeanstalkDefaultAMIEncrypter, StackTagsProvider stackTagsProvider, CloudwatchLogsVelocityContextProvider cloudwatchLogsVelocityContextProvider, - Ec2Client ec2Client, AWSElasticBeanstalk beanstalkClient, TimeToLive ttl) { + Ec2Client ec2Client, AWSElasticBeanstalk beanstalkClient, TimeToLive ttl, + AWSSecurityTokenService stsClient, Set waitConditionHandlers) { super(); this.cloudFormationClient = cloudFormationClient; this.ec2Client = ec2Client; @@ -150,6 +158,8 @@ public RepositoryTemplateBuilderImpl(CloudFormationClient cloudFormationClient, this.cwlContextProvider = cloudwatchLogsVelocityContextProvider; this.beanstalkClient = beanstalkClient; this.timeToLive = ttl; + this.stsClient = stsClient; + this.waitConditionHandlers = waitConditionHandlers; } public String getActualBeanstalkAmazonLinuxPlatform() { @@ -185,12 +195,12 @@ public void buildAndDeploy() throws InterruptedException { buildAndDeployStack(context, sharedResourceStackName, TEMPALTE_SHARED_RESOUCES_MAIN_JSON_VTP, sharedParameters); // Wait for the shared resources to complete - Stack sharedStackResults = cloudFormationClient.waitForStackToComplete(sharedResourceStackName).orElseThrow(()->new IllegalStateException("Stack does not exist: "+sharedResourceStackName)); - + Stack sharedStackResults = cloudFormationClient.waitForStackToComplete(sharedResourceStackName, waitConditionHandlers).orElseThrow(()->new IllegalStateException("Stack does not exist: "+sharedResourceStackName)); + // Build each bean stalk environment. List environmentNames = buildEnvironments(sharedStackResults); } - + /** * Build all of the environments * @param sharedStackResults @@ -307,7 +317,9 @@ void buildAndDeployStack(VelocityContext context, String stackName, String templ */ VelocityContext createSharedContext() { VelocityContext context = new VelocityContext(); + String stack = config.getProperty(PROPERTY_KEY_STACK); + context.put(STACK, stack); context.put(INSTANCE, config.getProperty(PROPERTY_KEY_INSTANCE)); context.put(MACHINE_TYPES, MACHINE_TYPE_LIST); @@ -321,8 +333,7 @@ VelocityContext createSharedContext() { context.put(CTXT_ENABLE_ENHANCED_RDS_MONITORING, config.getProperty(PROPERTY_KEY_ENABLE_RDS_ENHANCED_MONITORING)); context.put(ADMIN_RULE_ACTION, Constants.isProd(stack) ? "Block:{}" : "Count:{}"); - context.put(DELETION_POLICY, - Constants.isProd(stack) ? DeletionPolicy.Retain.name() : DeletionPolicy.Delete.name()); + context.put(DELETION_POLICY, Constants.isProd(stack) ? DeletionPolicy.Retain.name() : DeletionPolicy.Delete.name()); // Create the descriptors for all of the database. context.put(DATABASE_DESCRIPTORS, createDatabaseDescriptors()); @@ -331,6 +342,8 @@ VelocityContext createSharedContext() { provider.addToContext(context); } + context.put(IDENTITY_ARN, stsClient.getCallerIdentity(new GetCallerIdentityRequest()).getArn()); + RegularExpressions.bindRegexToContext(context); return context; diff --git a/src/main/java/org/sagebionetworks/template/repo/agent/BedrockAgentContextProvider.java b/src/main/java/org/sagebionetworks/template/repo/agent/BedrockAgentContextProvider.java index 9a9a85b2..d3d70849 100644 --- a/src/main/java/org/sagebionetworks/template/repo/agent/BedrockAgentContextProvider.java +++ b/src/main/java/org/sagebionetworks/template/repo/agent/BedrockAgentContextProvider.java @@ -6,6 +6,7 @@ import java.util.StringJoiner; import org.apache.velocity.VelocityContext; +import org.json.JSONArray; import org.json.JSONObject; import org.sagebionetworks.template.TemplateUtils; import org.sagebionetworks.template.config.RepoConfiguration; @@ -14,7 +15,7 @@ import com.google.inject.Inject; public class BedrockAgentContextProvider implements VelocityContextProvider { - + private final RepoConfiguration repoConfig; @Inject @@ -28,11 +29,37 @@ public void addToContext(VelocityContext context) { String stack = repoConfig.getProperty(PROPERTY_KEY_STACK); String instance = repoConfig.getProperty(PROPERTY_KEY_INSTANCE); String agentName = new StringJoiner("-").add(stack).add(instance).add("agent").toString(); - JSONObject baseTemplate = new JSONObject( - TemplateUtils.loadContentFromFile("templates/repo/agent/bedrock_agent_template.json")); + + JSONObject baseTemplate = new JSONObject(TemplateUtils.loadContentFromFile("templates/repo/agent/bedrock_agent_template.json")); JSONObject resources = baseTemplate.getJSONObject("Resources"); + + // Since the agent template is shared to external people, we need to hack it to replace parameters that do not exist in our template + JSONArray bedrockAgentRoleKbResource = resources + .getJSONObject("bedrockAgentRole") + .getJSONObject("Properties") + .getJSONArray("Policies") + .getJSONObject(0) + .getJSONObject("PolicyDocument") + .getJSONArray("Statement") + .getJSONObject(1) + .getJSONArray("Fn::If") + .getJSONObject(1) + .getJSONArray("Resource"); + + bedrockAgentRoleKbResource.put(0, new JSONObject("{ \"Fn::GetAtt\": [\"SynapseHelpKnowledgeBase\", \"KnowledgeBaseArn\"] }")); + JSONObject bedrockAgentProps = resources.getJSONObject("bedrockAgent").getJSONObject("Properties"); + + JSONObject kbProperty = bedrockAgentProps + .getJSONObject("KnowledgeBases") + .getJSONArray("Fn::If") + .getJSONArray(1) + .getJSONObject(0); + + kbProperty.getJSONObject("KnowledgeBaseId").put("Ref", "SynapseHelpKnowledgeBase"); + kbProperty.put("Description", baseTemplate.getJSONObject("Parameters").getJSONObject("knowledgeBaseDescription").getString("Default")); + bedrockAgentProps.put("AgentName", agentName); String json = resources.toString(); context.put("bedrock_agent_resouces", "," + json.substring(1, json.length()-1)); diff --git a/src/main/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpCollectionIndexCreation.java b/src/main/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpCollectionIndexCreation.java new file mode 100644 index 00000000..009cf514 --- /dev/null +++ b/src/main/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpCollectionIndexCreation.java @@ -0,0 +1,104 @@ +package org.sagebionetworks.template.repo.bedrock; + +import java.io.IOException; +import java.util.Optional; + +import org.apache.logging.log4j.Logger; +import org.opensearch.client.opensearch.indices.OpenSearchIndicesClient; +import org.sagebionetworks.template.Constants; +import org.sagebionetworks.template.LoggerFactory; +import org.sagebionetworks.template.OpenSearchClientFactory; +import org.sagebionetworks.template.WaitConditionHandler; +import org.sagebionetworks.template.config.RepoConfiguration; + +import com.amazonaws.services.cloudformation.model.StackEvent; +import com.google.inject.Inject; + +import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient; +import software.amazon.awssdk.services.opensearchserverless.model.CollectionDetail; +import software.amazon.awssdk.services.opensearchserverless.model.CollectionStatus; + +/** + * A bedrock knowledge base that uses an open search collection requires the index to exists before its creation, since + * the index creation is part of the opensearch API operations and there is no cloudformation resource for it we need to + * invoke the opensearch API as part of a wait condition in the stack. Note that a wait condition is only processed during + * the stack creation, so the index cannot be updated. + */ +public class SynapseHelpCollectionIndexCreation implements WaitConditionHandler { + + private static final String IDX_NAME = "vector-idx"; + + private Logger logger; + + private RepoConfiguration config; + + private OpenSearchServerlessClient ossManagementClient; + + private OpenSearchClientFactory openSearchClientFactory; + + + @Inject + public SynapseHelpCollectionIndexCreation(LoggerFactory loggerFactory, RepoConfiguration config, OpenSearchServerlessClient ossClient, OpenSearchClientFactory openSearchClientFactory) { + this.logger = loggerFactory.getLogger(SynapseHelpCollectionIndexCreation.class); + this.config = config; + this.ossManagementClient = ossClient; + this.openSearchClientFactory = openSearchClientFactory; + } + + @Override + public String getWaitConditionId() { + return "SynapseHelpCollectionCreateIndexWaitCondition"; + } + + @Override + public Optional handle(StackEvent stackEvent) { + String collectionName = config.getProperty(Constants.PROPERTY_KEY_STACK) + "-" + config.getProperty(Constants.PROPERTY_KEY_INSTANCE) + "-synhelp"; + + CollectionDetail collection = ossManagementClient.batchGetCollection(req -> req + .names(collectionName) + ).collectionDetails().stream().findFirst().orElseThrow(); + + if (!CollectionStatus.ACTIVE.equals(collection.status())) { + logger.warn("Collection {} not ready, status: {}", collectionName, collection.status()); + return Optional.empty(); + } + + OpenSearchIndicesClient client = openSearchClientFactory.getIndicesClient(collection.collectionEndpoint()); + + try { + if (client.exists(req -> req.index(IDX_NAME)).value()) { + logger.warn("Index {} already exists.", IDX_NAME); + return Optional.of("index-already-exists"); + } + + logger.info("Index {} does not exist, creating...", IDX_NAME); + + client.create(req -> req + .index(IDX_NAME) + .settings(settings -> settings.knn(true).knnAlgoParamEfSearch(512)) + .mappings(mappings -> mappings + .properties("text_vector", p -> p + .knnVector(vector -> vector + .dimension(1024) + .method(method -> method + .name("hnsw") + .engine("faiss") + .spaceType("l2") + ) + ) + ) + .properties("text_raw", p -> p.text(text -> text.index(true))) + .properties("text_metadata", p -> p.text(text -> text.index(false))) + ) + ); + + logger.info("Index {} creation completed.", IDX_NAME); + + return Optional.of("index-creation-complete"); + + } catch (IOException e) { + throw new IllegalStateException(e); + } + + } +} diff --git a/src/main/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpKnowledgeBaseDataSourceSync.java b/src/main/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpKnowledgeBaseDataSourceSync.java new file mode 100644 index 00000000..a6d83cf9 --- /dev/null +++ b/src/main/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpKnowledgeBaseDataSourceSync.java @@ -0,0 +1,127 @@ +package org.sagebionetworks.template.repo.bedrock; + +import java.util.Optional; + +import org.apache.commons.codec.digest.DigestUtils; +import org.apache.logging.log4j.Logger; +import org.sagebionetworks.template.Constants; +import org.sagebionetworks.template.LoggerFactory; +import org.sagebionetworks.template.ThreadProvider; +import org.sagebionetworks.template.WaitConditionHandler; +import org.sagebionetworks.template.config.RepoConfiguration; + +import com.amazonaws.services.cloudformation.model.StackEvent; +import com.google.inject.Inject; + +import software.amazon.awssdk.services.bedrockagent.BedrockAgentClient; +import software.amazon.awssdk.services.bedrockagent.model.DataSourceSummary; +import software.amazon.awssdk.services.bedrockagent.model.IngestionJob; +import software.amazon.awssdk.services.bedrockagent.model.IngestionJobStatistics; +import software.amazon.awssdk.services.bedrockagent.model.IngestionJobSummary; +import software.amazon.awssdk.services.bedrockagent.model.KnowledgeBaseSummary; + +/** + * When a bedrock knowledge base is created its data source needs to be synchronized, we do this after the datasource is created through a + * wait condition using the bedrock APIs. + */ +public class SynapseHelpKnowledgeBaseDataSourceSync implements WaitConditionHandler { + + private static final long SLEEP_MS = 10_000; + + private Logger logger; + private RepoConfiguration config; + private ThreadProvider threadProvider; + private BedrockAgentClient bedrockAgentClient; + + @Inject + public SynapseHelpKnowledgeBaseDataSourceSync(LoggerFactory loggerFactory, RepoConfiguration config, ThreadProvider threadProvider, BedrockAgentClient bedrockAgentClient) { + this.logger = loggerFactory.getLogger(SynapseHelpKnowledgeBaseDataSourceSync.class); + this.config = config; + this.threadProvider = threadProvider; + this.bedrockAgentClient = bedrockAgentClient; + } + + @Override + public String getWaitConditionId() { + return "SynapseHelpKnowledgeBaseDataSourceSyncWaitCondition"; + } + + @Override + public Optional handle(StackEvent stackEvent) throws InterruptedException { + String stackPrefix = config.getProperty(Constants.PROPERTY_KEY_STACK) + "-" + config.getProperty(Constants.PROPERTY_KEY_INSTANCE); + + String knowledgeBaseName = stackPrefix + "-synhelp-knowledge-base"; + String knowledgeBaseId = bedrockAgentClient.listKnowledgeBasesPaginator(req -> {}) + .knowledgeBaseSummaries().stream() + .filter(kb -> kb.name().equals(knowledgeBaseName)) + .findFirst() + .map(KnowledgeBaseSummary::knowledgeBaseId) + .orElseThrow(); + + String dataSourceName = stackPrefix + "-synhelp-datasource"; + String dataSourceId = bedrockAgentClient.listDataSourcesPaginator(req -> req.knowledgeBaseId(knowledgeBaseId)) + .dataSourceSummaries().stream() + .filter(dataSource -> dataSource.name().equals(dataSourceName)) + .findFirst() + .map(DataSourceSummary::dataSourceId) + .orElseThrow(); + + Optional existingJob = bedrockAgentClient.listIngestionJobs(req -> req.knowledgeBaseId(knowledgeBaseId).dataSourceId(dataSourceId).maxResults(1)) + .ingestionJobSummaries() + .stream() + .findFirst(); + + if (existingJob.isPresent()) { + logger.warn("Sync job {} already exists (Status: {}).", existingJob.get().ingestionJobId(), existingJob.get().statusAsString()); + return Optional.of("sync-started"); + } + + String clientToken = DigestUtils.sha256Hex(knowledgeBaseId + " - " + dataSourceId); + + IngestionJob job = bedrockAgentClient.startIngestionJob(req -> req + .clientToken(clientToken) + .dataSourceId(dataSourceId) + .knowledgeBaseId(knowledgeBaseId) + ).ingestionJob(); + + String jobId = job.ingestionJobId(); + boolean done = false; + + do { + logger.info("Waiting for sync job {} to complete (Status: {}).", job.ingestionJobId(), job.statusAsString()); + + threadProvider.sleep(SLEEP_MS); + + job = bedrockAgentClient.getIngestionJob(req -> req + .ingestionJobId(jobId) + .knowledgeBaseId(knowledgeBaseId) + .dataSourceId(dataSourceId) + ).ingestionJob(); + + switch (job.status()) { + case COMPLETE: + IngestionJobStatistics stats = job.statistics(); + + logger.info("Sync job {} completed (Documents Scanned: {}, Documents Indexed: {}, Documents Failed: {}).", + job.ingestionJobId(), + stats.numberOfDocumentsScanned(), + stats.numberOfNewDocumentsIndexed(), + stats.numberOfDocumentsFailed() + ); + + done = true; + break; + case FAILED: + case STOPPED: + case UNKNOWN_TO_SDK_VERSION: + throw new IllegalStateException("Sync job " + jobId + " failed (Status: " + job.statusAsString() + ", Failures: " + job.failureReasons().toString() +")"); + default: + break; + } + + } while (!done); + + return Optional.of("sync-completed"); + } + +} diff --git a/src/main/resources/templates/repo/agent/bedrock_agent_template.json b/src/main/resources/templates/repo/agent/bedrock_agent_template.json index 01b9d29a..675ea471 100644 --- a/src/main/resources/templates/repo/agent/bedrock_agent_template.json +++ b/src/main/resources/templates/repo/agent/bedrock_agent_template.json @@ -1,13 +1,28 @@ { "AWSTemplateFormatVersion": "2010-09-09", - "Description": "This tempalte contains all of the resources needed to create the base Synapse Bedrock Agent", + "Description": "This template contains all of the resources needed to create the base Synapse Bedrock Agent.", "Parameters": { "agentName": { "Description": "Provide a unique name for this bedrock agent.", "Type": "String", "AllowedPattern": "^([0-9a-zA-Z][_-]?){1,100}$" + }, + "knowledgeBaseId": { + "Description": "Provide the id of a knowledge base that will be associated to this bedrock agent.", + "Type": "String", + "Default": "", + "AllowedPattern": "^([0-9a-zA-Z]{10})?$" + }, + "knowledgeBaseDescription": { + "Description": "Provide the description for the knowledge base that will be associated to this bedrock agent.", + "Type": "String", + "Default": "This knowledge base contains the Synapse help documentation. You can use it to answer questions around Synapse usage and its features.", + "MaxLength": 200 } }, + "Conditions": { + "AttachKnowledgeBase": {"Fn::Not": [{"Fn::Equals" : [{"Ref" : "knowledgeBaseId"}, ""]}]} + }, "Resources": { "bedrockAgentRole": { "Type": "AWS::IAM::Role", @@ -44,12 +59,34 @@ "Statement": [ { "Effect": "Allow", - "Action": "bedrock:InvokeModel", + "Action": [ + "bedrock:InvokeModel", + "bedrock:InvokeModelWithResponseStream" + ], "Resource": [ { "Fn::Sub": "arn:aws:bedrock:${AWS::Region}::foundation-model/*" } ] + }, + { + "Fn::If" : [ "AttachKnowledgeBase", + { + "Effect": "Allow", + "Action": [ + "bedrock:Retrieve", + "bedrock:RetrieveAndGenerate" + ], + "Resource": [ + { + "Fn::Sub": "arn:aws:bedrock:${AWS::Region}:${AWS::AccountId}:knowledge-base/${knowledgeBaseId}" + } + ] + }, + { + "Ref" : "AWS::NoValue" + } + ] } ] } @@ -136,6 +173,17 @@ "Arn" ] }, + "KnowledgeBases": { + "Fn::If" : [ "AttachKnowledgeBase", + [{ + "KnowledgeBaseId" : { "Ref" : "knowledgeBaseId" }, + "Description" : { "Ref" : "knowledgeBaseDescription" } + }], + { + "Ref" : "AWS::NoValue" + } + ] + }, "AutoPrepare": true, "Description": "Test of the use of actions groups to allow the agent to make Synapse API calls.", "FoundationModel": "anthropic.claude-3-sonnet-20240229-v1:0", diff --git a/src/main/resources/templates/repo/bedrock-knowledge-base-template.json.vpt b/src/main/resources/templates/repo/bedrock-knowledge-base-template.json.vpt new file mode 100644 index 00000000..acdbbfe1 --- /dev/null +++ b/src/main/resources/templates/repo/bedrock-knowledge-base-template.json.vpt @@ -0,0 +1,192 @@ + , + "SynapseHelpKnowledgeBaseExecutionRole": { + "Type": "AWS::IAM::Role", + "Properties": { + "AssumeRolePolicyDocument": { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Principal": { + "Service": "bedrock.amazonaws.com" + }, + "Action": "sts:AssumeRole", + "Condition": { + "StringEquals": { + "aws:SourceAccount": { + "Ref": "AWS::AccountId" + } + }, + "ArnLike": { + "aws:SourceArn": { + "Fn::Sub": "arn:aws:bedrock:#[[${AWS::Region}:${AWS::AccountId}]]#:knowledge-base/*" + } + } + } + } + ] + } + } + }, + "SynapseHelpCollectionDeployerDataAccessPolicy": { + "Type":"AWS::OpenSearchServerless::AccessPolicy", + "Properties": { + "Name":"${stack}-${instance}-synhelp-deployer", + "Type":"data", + "Description":"Access policy for the deployer to create the index for the synapse help collection", + "Policy": { + "Fn::Sub": ["[{\"Rules\":[{\"ResourceType\":\"index\",\"Resource\":[\"index/${stack}-${instance}-synhelp/*\"],\"Permission\":[\"aoss:CreateIndex\",\"aoss:DescribeIndex\"]}],\"Principal\":[\"#[[${deployerRoleArn}]]#\"]}]", + { + "deployerRoleArn" : "${identityArn}" + } + ] + } + } + }, + "SynapseHelpCollectionKnowledgeBaseDataAccessPolicy": { + "Type":"AWS::OpenSearchServerless::AccessPolicy", + "Properties": { + "Name":"${stack}-${instance}-synhelp-bedrock", + "Type":"data", + "Description":"Access policy for the knowledge base to sync and query the synapse help collection", + "Policy": { + "Fn::Sub": ["[{\"Rules\":[{\"ResourceType\":\"index\",\"Resource\":[\"index/${stack}-${instance}-synhelp/*\"],\"Permission\":[\"aoss:CreateIndex\",\"aoss:UpdateIndex\",\"aoss:DescribeIndex\",\"aoss:ReadDocument\",\"aoss:WriteDocument\"]},{\"ResourceType\":\"collection\",\"Resource\":[\"collection/${stack}-${instance}-synhelp\"],\"Permission\":[\"aoss:DescribeCollectionItems\",\"aoss:CreateCollectionItems\",\"aoss:UpdateCollectionItems\"]}],\"Principal\":[\"#[[${executionRoleArn}]]#\"]}]", + { + "executionRoleArn": { "Fn::GetAtt": ["SynapseHelpKnowledgeBaseExecutionRole", "Arn"] } + } + ] + } + } + }, + "SynapseHelpCollectionEncryptionPolicy": { + "Type":"AWS::OpenSearchServerless::SecurityPolicy", + "Properties": { + "Name":"${stack}-${instance}-synhelp", + "Type":"encryption", + "Description":"Encryption policy for the synapse help collection", + "Policy":"{\"Rules\":[{\"ResourceType\":\"collection\",\"Resource\":[\"collection/${stack}-${instance}-synhelp\"]}],\"AWSOwnedKey\":true}" + } + }, + "SynapseHelpCollectionNetworkPolicy": { + "Type":"AWS::OpenSearchServerless::SecurityPolicy", + "Properties": { + "Name":"${stack}-${instance}-synhelp", + "Type":"network", + "Description":"Encryption policy for the synapse help collection", + "Policy": { + "Fn::Sub": [ + "[{\"Rules\":[{\"ResourceType\":\"collection\",\"Resource\":[\"collection/${stack}-${instance}-synhelp\"]}],\"AllowFromPublic\":false,\"SourceServices\":[\"bedrock.amazonaws.com\"],\"SourceVPCEs\":[\"#[[${vpcEndpointId}]]#\"]}]", + { + "vpcEndpointId": { "Fn::ImportValue": "${vpcExportPrefix}-open-search-vpce" } + } + ] + } + } + }, + "SynapseHelpCollection": { + "Type":"AWS::OpenSearchServerless::Collection", + "Properties": { + "Name":"${stack}-${instance}-synhelp", + "Type":"VECTORSEARCH", + "Description":"Vector store collection for synapse help documentation", + "StandbyReplicas": "DISABLED" + }, + "DependsOn": ["SynapseHelpCollectionEncryptionPolicy", "SynapseHelpCollectionNetworkPolicy", "SynapseHelpCollectionDeployerDataAccessPolicy", "SynapseHelpCollectionKnowledgeBaseDataAccessPolicy"] + }, + "SynapseHelpKnoweldgeBaseExecutionRolePolicy": { + "Type":"AWS::IAM::RolePolicy", + "Properties": { + "PolicyName": "SynapseHelpKnoweldgeBaseExecutionRolePolicy", + "PolicyDocument": { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": "bedrock:InvokeModel", + "Resource": [ + { + "Fn::Sub": "arn:aws:bedrock:#[[${AWS::Region}]]#::foundation-model/*" + } + ] + }, + { + "Effect": "Allow", + "Action": "aoss:APIAccessAll", + "Resource": [ + { + "Fn::GetAtt" : ["SynapseHelpCollection", "Arn"] + } + ] + } + ] + }, + "RoleName": { + "Ref": "SynapseHelpKnowledgeBaseExecutionRole" + } + } + }, + "SynapseHelpCollectionCreateIndexWaitCondition": { + "Type": "AWS::CloudFormation::WaitCondition", + "DependsOn": [ "SynapseHelpCollection" ], + "CreationPolicy" : { + "ResourceSignal" : { + "Timeout" : "PT3M", + "Count" : "1" + } + } + }, + "SynapseHelpKnowledgeBase": { + "Type": "AWS::Bedrock::KnowledgeBase", + "DependsOn": [ "SynapseHelpCollectionCreateIndexWaitCondition", "SynapseHelpKnoweldgeBaseExecutionRolePolicy" ], + "Properties": { + "Name": "${stack}-${instance}-synhelp-knowledge-base", + "Description": "Knowledge base for the synpase help documentation", + "RoleArn": { "Fn::GetAtt" : ["SynapseHelpKnowledgeBaseExecutionRole", "Arn"] }, + "KnowledgeBaseConfiguration": { + "Type": "VECTOR", + "VectorKnowledgeBaseConfiguration": { + "EmbeddingModelArn": { "Fn::Sub": "arn:aws:bedrock:#[[${AWS::Region}]]#::foundation-model/amazon.titan-embed-text-v2:0" } + } + }, + "StorageConfiguration": { + "Type": "OPENSEARCH_SERVERLESS", + "OpensearchServerlessConfiguration": { + "CollectionArn": { "Fn::GetAtt" : ["SynapseHelpCollection", "Arn"] }, + "VectorIndexName": "vector-idx", + "FieldMapping": { + "VectorField": "text_vector", + "TextField": "text_raw", + "MetadataField": "text_metadata" + } + } + } + } + }, + "SynapseHelpKnowledgeBaseDataSource": { + "Type": "AWS::Bedrock::DataSource", + "Properties": { + "KnowledgeBaseId": { "Ref": "SynapseHelpKnowledgeBase" }, + "Name": "${stack}-${instance}-synhelp-datasource", + "Description": "The datasource for the synapse help document, the data is crawled from the synapse help website", + "DataSourceConfiguration": { + "Type": "WEB", + "WebConfiguration": { + "SourceConfiguration": { + "UrlConfiguration": { + "SeedUrls": [ { "Url": "https://help.synapse.org" } ] + } + } + } + } + } + }, + "SynapseHelpKnowledgeBaseDataSourceSyncWaitCondition": { + "Type": "AWS::CloudFormation::WaitCondition", + "DependsOn": "SynapseHelpKnowledgeBaseDataSource", + "CreationPolicy" : { + "ResourceSignal" : { + "Timeout" : "PT10M", + "Count" : "1" + } + } + } \ No newline at end of file diff --git a/src/main/resources/templates/repo/main-repo-shared-resources-template.json.vpt b/src/main/resources/templates/repo/main-repo-shared-resources-template.json.vpt index fc589266..cbeb598b 100644 --- a/src/main/resources/templates/repo/main-repo-shared-resources-template.json.vpt +++ b/src/main/resources/templates/repo/main-repo-shared-resources-template.json.vpt @@ -9,6 +9,9 @@ }, #parse("templates/repo/time-to-live-parameter.vpt") }, + "Conditions": { + "AttachKnowledgeBase": {"Fn::Equals": ["true", "true"]} + }, "Resources": { "${stack}${instance}DBSubnetGroup": { "Type": "AWS::RDS::DBSubnetGroup", @@ -745,6 +748,7 @@ #parse("templates/repo/web/acl/web-acl-template.json.vpt") #parse("templates/repo/appconfig-template.json.vpt") #parse("templates/repo/api-gateway-template.json.vpt") + #parse("templates/repo/bedrock-knowledge-base-template.json.vpt") ${bedrock_agent_resouces} }, "Outputs": { @@ -858,6 +862,27 @@ ] } } + }, + "SynapseHelpCollectionEndpoint": { + "Value": { + "Fn::GetAtt" : ["SynapseHelpCollection", "CollectionEndpoint"] + }, + "Export": { + "Name": { + "Fn::Join": [ + "-", + [ + { + "Ref": "AWS::Region" + }, + { + "Ref": "AWS::StackName" + }, + "Synapse-Help-Collection-Endpoint" + ] + ] + } + } } #foreach( $descriptor in ${databaseDescriptors} ) , diff --git a/src/test/java/org/sagebionetworks/template/CloudFormationClientImplTest.java b/src/test/java/org/sagebionetworks/template/CloudFormationClientImplTest.java index 2ead4c5f..d90c603b 100644 --- a/src/test/java/org/sagebionetworks/template/CloudFormationClientImplTest.java +++ b/src/test/java/org/sagebionetworks/template/CloudFormationClientImplTest.java @@ -1,6 +1,7 @@ package org.sagebionetworks.template; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; @@ -13,8 +14,10 @@ import java.net.MalformedURLException; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.function.Function; import java.util.stream.Collectors; @@ -35,11 +38,17 @@ import com.amazonaws.services.cloudformation.model.CreateStackRequest; import com.amazonaws.services.cloudformation.model.CreateStackResult; import com.amazonaws.services.cloudformation.model.DeleteStackRequest; +import com.amazonaws.services.cloudformation.model.DescribeStackEventsRequest; +import com.amazonaws.services.cloudformation.model.DescribeStackEventsResult; import com.amazonaws.services.cloudformation.model.DescribeStacksRequest; import com.amazonaws.services.cloudformation.model.DescribeStacksResult; import com.amazonaws.services.cloudformation.model.Output; import com.amazonaws.services.cloudformation.model.Parameter; +import com.amazonaws.services.cloudformation.model.ResourceSignalStatus; +import com.amazonaws.services.cloudformation.model.ResourceStatus; +import com.amazonaws.services.cloudformation.model.SignalResourceRequest; import com.amazonaws.services.cloudformation.model.Stack; +import com.amazonaws.services.cloudformation.model.StackEvent; import com.amazonaws.services.cloudformation.model.StackStatus; import com.amazonaws.services.cloudformation.model.UpdateStackRequest; import com.amazonaws.services.cloudformation.model.UpdateStackResult; @@ -68,6 +77,8 @@ public class CloudFormationClientImplTest { Logger mockLogger; @Mock ThreadProvider mockThreadProvider; + @Mock + WaitConditionHandler mockWaitConditionHandler; @Captor ArgumentCaptor describeStackRequestCapture; @@ -475,6 +486,263 @@ public void testWaitForStackToCompleteUpdateRollbackCompleteToUpdateRollBackComp Assertions.assertNotNull(resultStack.getStackStatus()); Assertions.assertEquals(StackStatus.UPDATE_ROLLBACK_COMPLETE, StackStatus.fromValue(resultStack.getStackStatus())); } + + @Test + public void testWaitForStackToCompleteWithWaitConditionHandlers() throws InterruptedException { + initStack.setStackStatus(StackStatus.CREATE_IN_PROGRESS); + stack.setStackStatus(StackStatus.CREATE_IN_PROGRESS); + + when(mockCloudFormationClient.describeStacks(any(DescribeStacksRequest.class))).thenReturn( + initDescribeResult, + describeResult, + new DescribeStacksResult().withStacks(new Stack().withStackStatus(StackStatus.CREATE_COMPLETE)) + ); + + String waitConditionId = "waitConditionId"; + StackEvent waitConditionEvent = new StackEvent() + .withResourceType("AWS::CloudFormation::WaitCondition") + .withLogicalResourceId(waitConditionId) + .withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS); + + when(mockWaitConditionHandler.getWaitConditionId()).thenReturn(waitConditionId); + when(mockWaitConditionHandler.handle(waitConditionEvent)).thenReturn(Optional.of("done")); + when(mockCloudFormationClient.describeStackEvents(new DescribeStackEventsRequest().withStackName(stackName))).thenReturn( + new DescribeStackEventsResult().withStackEvents(waitConditionEvent) + ); + + // call under test + Stack resultStack = client.waitForStackToComplete(stackName, Set.of(mockWaitConditionHandler)).get(); + + verify(mockCloudFormationClient).signalResource(new SignalResourceRequest() + .withLogicalResourceId(waitConditionId) + .withStackName(stackName) + .withStatus(ResourceSignalStatus.SUCCESS) + .withUniqueId("done") + ); + } + + @Test + public void testWaitForStackToCompleteWithWaitConditionHandlersAndMultipleEvents() throws InterruptedException { + initStack.setStackStatus(StackStatus.CREATE_IN_PROGRESS); + stack.setStackStatus(StackStatus.CREATE_IN_PROGRESS); + + when(mockCloudFormationClient.describeStacks(any(DescribeStacksRequest.class))).thenReturn( + initDescribeResult, + describeResult, + new DescribeStacksResult().withStacks(new Stack().withStackStatus(StackStatus.CREATE_COMPLETE)) + ); + + String waitConditionId = "waitConditionId"; + + StackEvent waitConditionEvent = new StackEvent() + .withResourceType("AWS::CloudFormation::WaitCondition") + .withLogicalResourceId(waitConditionId) + .withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS) + .withEventId("last"); + + when(mockWaitConditionHandler.getWaitConditionId()).thenReturn(waitConditionId); + when(mockWaitConditionHandler.handle(waitConditionEvent)).thenReturn(Optional.of("done")); + when(mockCloudFormationClient.describeStackEvents(new DescribeStackEventsRequest().withStackName(stackName))).thenReturn( + new DescribeStackEventsResult().withStackEvents( + waitConditionEvent, + new StackEvent().withResourceType("AWS::CloudFormation::WaitCondition").withLogicalResourceId("anotherWaitConditionId").withResourceStatus(ResourceStatus.CREATE_COMPLETE), + new StackEvent().withResourceType("anotherType").withLogicalResourceId("anotherResourceId").withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS), + // Another event for the same condition id, should be discarded + new StackEvent().withResourceType("AWS::CloudFormation::WaitCondition").withLogicalResourceId(waitConditionId).withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS).withEventId("previous") + ) + ); + + // call under test + Stack resultStack = client.waitForStackToComplete(stackName, Set.of(mockWaitConditionHandler)).get(); + + verify(mockCloudFormationClient).signalResource(new SignalResourceRequest() + .withLogicalResourceId(waitConditionId) + .withStackName(stackName) + .withStatus(ResourceSignalStatus.SUCCESS) + .withUniqueId("done") + ); + } + + @Test + public void testWaitForStackToCompleteWithWaitConditionHandlersAndAlreadyProcessed() throws InterruptedException { + initStack.setStackStatus(StackStatus.CREATE_IN_PROGRESS); + stack.setStackStatus(StackStatus.CREATE_IN_PROGRESS); + + when(mockCloudFormationClient.describeStacks(any(DescribeStacksRequest.class))).thenReturn( + initDescribeResult, + describeResult, + describeResult, + new DescribeStacksResult().withStacks(new Stack().withStackStatus(StackStatus.CREATE_COMPLETE)) + ); + + String waitConditionId = "waitConditionId"; + + StackEvent waitConditionEvent = new StackEvent() + .withResourceType("AWS::CloudFormation::WaitCondition") + .withLogicalResourceId(waitConditionId) + .withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS); + + when(mockWaitConditionHandler.getWaitConditionId()).thenReturn(waitConditionId); + when(mockWaitConditionHandler.handle(waitConditionEvent)).thenReturn(Optional.of("done")); + when(mockCloudFormationClient.describeStackEvents(new DescribeStackEventsRequest().withStackName(stackName))).thenReturn( + new DescribeStackEventsResult().withStackEvents(waitConditionEvent) + ); + + // call under test + Stack resultStack = client.waitForStackToComplete(stackName, Set.of(mockWaitConditionHandler)).get(); + + verify(mockCloudFormationClient, times(2)).describeStackEvents(any()); + + // Should be invoked only once + verify(mockWaitConditionHandler).handle(waitConditionEvent); + verify(mockCloudFormationClient).signalResource(new SignalResourceRequest() + .withLogicalResourceId(waitConditionId) + .withStackName(stackName) + .withStatus(ResourceSignalStatus.SUCCESS) + .withUniqueId("done") + ); + } + + @Test + public void testWaitForStackToCompleteWithWaitConditionHandlersAndNoMatchingHandler() throws InterruptedException { + initStack.setStackStatus(StackStatus.CREATE_IN_PROGRESS); + stack.setStackStatus(StackStatus.CREATE_IN_PROGRESS); + + when(mockCloudFormationClient.describeStacks(any(DescribeStacksRequest.class))).thenReturn( + initDescribeResult, + describeResult, + new DescribeStacksResult().withStacks(new Stack().withStackStatus(StackStatus.CREATE_COMPLETE)) + ); + + String waitConditionId = "waitConditionId"; + + StackEvent waitConditionEvent = new StackEvent() + .withResourceType("AWS::CloudFormation::WaitCondition") + .withLogicalResourceId(waitConditionId) + .withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS) + .withEventId("last"); + + when(mockWaitConditionHandler.getWaitConditionId()).thenReturn(waitConditionId + "-mistmatching"); + when(mockCloudFormationClient.describeStackEvents(new DescribeStackEventsRequest().withStackName(stackName))).thenReturn( + new DescribeStackEventsResult().withStackEvents( + waitConditionEvent + ) + ); + + IllegalStateException result = assertThrows(IllegalStateException.class, () -> { + // call under test + Stack resultStack = client.waitForStackToComplete(stackName, Set.of(mockWaitConditionHandler)).get(); + }); + + assertEquals("Processing wait condition waitConditionId failed: could not find an handler.", result.getMessage()); + + verify(mockCloudFormationClient).signalResource(new SignalResourceRequest() + .withLogicalResourceId(waitConditionId) + .withStackName(stackName) + .withStatus(ResourceSignalStatus.FAILURE) + .withUniqueId("handler-not-found") + ); + + verifyNoMoreInteractions(mockWaitConditionHandler); + } + + @Test + public void testWaitForStackToCompleteWithWaitConditionHandlersAndException() throws InterruptedException { + initStack.setStackStatus(StackStatus.CREATE_IN_PROGRESS); + stack.setStackStatus(StackStatus.CREATE_IN_PROGRESS); + + when(mockCloudFormationClient.describeStacks(any(DescribeStacksRequest.class))).thenReturn( + initDescribeResult, + describeResult, + new DescribeStacksResult().withStacks(new Stack().withStackStatus(StackStatus.CREATE_COMPLETE)) + ); + + String waitConditionId = "waitConditionId"; + + StackEvent waitConditionEvent = new StackEvent() + .withResourceType("AWS::CloudFormation::WaitCondition") + .withLogicalResourceId(waitConditionId) + .withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS) + .withEventId("last"); + + when(mockWaitConditionHandler.getWaitConditionId()).thenReturn(waitConditionId); + + RuntimeException cause = new RuntimeException("processing error"); + + when(mockWaitConditionHandler.handle(waitConditionEvent)).thenThrow(cause); + when(mockCloudFormationClient.describeStackEvents(new DescribeStackEventsRequest().withStackName(stackName))).thenReturn( + new DescribeStackEventsResult().withStackEvents( + waitConditionEvent + ) + ); + + IllegalStateException result = assertThrows(IllegalStateException.class, () -> { + // call under test + client.waitForStackToComplete(stackName, Set.of(mockWaitConditionHandler)).get(); + }); + + assertEquals("Processing wait condition waitConditionId failed.", result.getMessage()); + assertEquals(cause, result.getCause()); + + verify(mockCloudFormationClient).signalResource(new SignalResourceRequest() + .withLogicalResourceId(waitConditionId) + .withStackName(stackName) + .withStatus(ResourceSignalStatus.FAILURE) + .withUniqueId("handler-failed") + ); + + verifyNoMoreInteractions(mockWaitConditionHandler); + } + + @Test + public void testWaitForStackToCompleteWithWaitConditionHandlersAndNoSignal() throws InterruptedException { + initStack.setStackStatus(StackStatus.CREATE_IN_PROGRESS); + stack.setStackStatus(StackStatus.CREATE_IN_PROGRESS); + + when(mockCloudFormationClient.describeStacks(any(DescribeStacksRequest.class))).thenReturn( + initDescribeResult, + describeResult, + new DescribeStacksResult().withStacks(new Stack().withStackStatus(StackStatus.CREATE_COMPLETE)) + ); + + String waitConditionId = "waitConditionId"; + + StackEvent waitConditionEvent = new StackEvent() + .withResourceType("AWS::CloudFormation::WaitCondition") + .withLogicalResourceId(waitConditionId) + .withResourceStatus(ResourceStatus.CREATE_IN_PROGRESS) + .withEventId("last"); + + when(mockWaitConditionHandler.getWaitConditionId()).thenReturn(waitConditionId); + when(mockWaitConditionHandler.handle(waitConditionEvent)).thenReturn(Optional.empty()); + when(mockCloudFormationClient.describeStackEvents(new DescribeStackEventsRequest().withStackName(stackName))).thenReturn( + new DescribeStackEventsResult().withStackEvents( + waitConditionEvent + ) + ); + + // call under test + Stack resultStack = client.waitForStackToComplete(stackName, Set.of(mockWaitConditionHandler)).get(); + + verifyNoMoreInteractions(mockCloudFormationClient, mockWaitConditionHandler); + } + + @Test + public void testWaitForStackToCompleteWithEmptyWaitConditionHandlers() throws InterruptedException { + initStack.setStackStatus(StackStatus.CREATE_IN_PROGRESS); + stack.setStackStatus(StackStatus.CREATE_IN_PROGRESS); + + when(mockCloudFormationClient.describeStacks(any(DescribeStacksRequest.class))).thenReturn( + initDescribeResult, + describeResult, + new DescribeStacksResult().withStacks(new Stack().withStackStatus(StackStatus.CREATE_COMPLETE)) + ); + + // call under test + Stack resultStack = client.waitForStackToComplete(stackName, Collections.emptySet()).get(); + + verifyNoMoreInteractions(mockCloudFormationClient, mockWaitConditionHandler); + } @Test public void testGetOutput() { diff --git a/src/test/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImplTest.java b/src/test/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImplTest.java index e41481c3..cb4f6ffb 100644 --- a/src/test/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImplTest.java +++ b/src/test/java/org/sagebionetworks/template/repo/RepositoryTemplateBuilderImplTest.java @@ -21,9 +21,10 @@ import static org.sagebionetworks.template.Constants.DATABASE_DESCRIPTORS; import static org.sagebionetworks.template.Constants.DB_ENDPOINT_SUFFIX; import static org.sagebionetworks.template.Constants.DELETION_POLICY; -import static org.sagebionetworks.template.Constants.EC2_INSTANCE_TYPE; import static org.sagebionetworks.template.Constants.EC2_INSTANCE_MEMORY; +import static org.sagebionetworks.template.Constants.EC2_INSTANCE_TYPE; import static org.sagebionetworks.template.Constants.ENVIRONMENT; +import static org.sagebionetworks.template.Constants.IDENTITY_ARN; import static org.sagebionetworks.template.Constants.INSTANCE; import static org.sagebionetworks.template.Constants.NOSNAPSHOT; import static org.sagebionetworks.template.Constants.OUTPUT_NAME_SUFFIX_REPOSITORY_DB_ENDPOINT; @@ -36,8 +37,8 @@ import static org.sagebionetworks.template.Constants.PROPERTY_KEY_BEANSTALK_SSL_ARN; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_BEANSTALK_VERSION; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_DATA_CDN_KEYPAIR_ID; -import static org.sagebionetworks.template.Constants.PROPERTY_KEY_EC2_INSTANCE_TYPE; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_EC2_INSTANCE_MEMORY; +import static org.sagebionetworks.template.Constants.PROPERTY_KEY_EC2_INSTANCE_TYPE; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ELASTICBEANSTALK_IMAGE_VERSION_AMAZONLINUX; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ELASTICBEANSTALK_IMAGE_VERSION_JAVA; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ELASTICBEANSTALK_IMAGE_VERSION_TOMCAT; @@ -49,19 +50,19 @@ import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_ALLOCATED_STORAGE; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_INSTANCE_CLASS; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_IOPS; -import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_THROUGHPUT; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_MAX_ALLOCATED_STORAGE; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_MULTI_AZ; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_STORAGE_TYPE; +import static org.sagebionetworks.template.Constants.PROPERTY_KEY_REPO_RDS_THROUGHPUT; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_ROUTE_53_HOSTED_ZONE; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_STACK; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_INSTANCE_COUNT; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_ALLOCATED_STORAGE; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_INSTANCE_CLASS; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_IOPS; -import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_THROUGHPUT; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_MAX_ALLOCATED_STORAGE; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_STORAGE_TYPE; +import static org.sagebionetworks.template.Constants.PROPERTY_KEY_TABLES_RDS_THROUGHPUT; import static org.sagebionetworks.template.Constants.PROPERTY_KEY_VPC_SUBNET_COLOR; import static org.sagebionetworks.template.Constants.REPO_BEANSTALK_NUMBER; import static org.sagebionetworks.template.Constants.SHARED_EXPORT_PREFIX; @@ -100,6 +101,7 @@ import org.sagebionetworks.template.LoggerFactory; import org.sagebionetworks.template.StackTagsProvider; import org.sagebionetworks.template.TemplateGuiceModule; +import org.sagebionetworks.template.WaitConditionHandler; import org.sagebionetworks.template.config.RepoConfiguration; import org.sagebionetworks.template.config.TimeToLive; import org.sagebionetworks.template.repo.agent.BedrockAgentContextProvider; @@ -123,6 +125,8 @@ import com.amazonaws.services.elasticbeanstalk.model.ListPlatformVersionsResult; import com.amazonaws.services.elasticbeanstalk.model.PlatformFilter; import com.amazonaws.services.elasticbeanstalk.model.PlatformSummary; +import com.amazonaws.services.securitytoken.AWSSecurityTokenService; +import com.amazonaws.services.securitytoken.model.GetCallerIdentityResult; import com.google.common.collect.Lists; import com.google.common.collect.Sets; @@ -157,6 +161,10 @@ public class RepositoryTemplateBuilderImplTest { private CloudwatchLogsVelocityContextProvider mockCwlContextProvider; @Mock private TimeToLive mockTimeToLive; + @Mock + private AWSSecurityTokenService mockStsClient; + @Mock + private WaitConditionHandler mockWaitConditionHandler; @Captor private ArgumentCaptor requestCaptor; @@ -188,10 +196,12 @@ public void before() throws InterruptedException { expectedTags.add(t); when(mockLoggerFactory.getLogger(any())).thenReturn(mockLogger); + builder = new RepositoryTemplateBuilderImpl(mockCloudFormationClient, velocityEngine, config, mockLoggerFactory, mockArtifactCopy, mockSecretBuilder, Sets.newHashSet(mockContextProvider1, mockContextProvider2, new BedrockAgentContextProvider(config)), mockElasticBeanstalkSolutionStackNameProvider, mockStackTagsProvider, mockCwlContextProvider, - mockEc2Client, mockBeanstalkClient, mockTimeToLive); + mockEc2Client, mockBeanstalkClient, mockTimeToLive, mockStsClient, Set.of(mockWaitConditionHandler)); + builderSpy = Mockito.spy(builder); stack = "dev"; @@ -210,7 +220,7 @@ public void before() throws InterruptedException { Output tableDBOutput2 = new Output(); tableDBOutput2.withOutputKey(stack + instance + "Table1" + OUTPUT_NAME_SUFFIX_REPOSITORY_DB_ENDPOINT); tableDBOutput2.withOutputValue(stack + "-" + instance + "-table-1." + databaseEndpointSuffix); - + sharedResouces.withOutputs(dbOut, tableDBOutput1, tableDBOutput2); secretsSouce = new SourceBundle("secretBucket", "secretKey"); @@ -221,17 +231,27 @@ public void before() throws InterruptedException { } private void configureStack(String inputStack) throws InterruptedException { + when(mockStsClient.getCallerIdentity(any())).thenReturn(new GetCallerIdentityResult().withArn("currentIdentityArn")); stack = inputStack; + when(config.getProperty(PROPERTY_KEY_STACK)).thenReturn(stack); + + sharedResouces = new Stack(); - Output dbOut = new Output(); - dbOut.withOutputKey(stack + instance + OUTPUT_NAME_SUFFIX_REPOSITORY_DB_ENDPOINT); + databaseEndpointSuffix = "something.amazon.com"; - dbOut.withOutputValue(stack + "-" + instance + "-db." + databaseEndpointSuffix); - sharedResouces.withOutputs(dbOut); - - when(mockCloudFormationClient.waitForStackToComplete(any(String.class))) - .thenReturn(Optional.of(sharedResouces)); + + sharedResouces.withOutputs( + new Output() + .withOutputKey(stack + instance + OUTPUT_NAME_SUFFIX_REPOSITORY_DB_ENDPOINT) + .withOutputValue(stack + "-" + instance + "-db." + databaseEndpointSuffix), + new Output() + .withOutputKey("SynapseHelpCollectionEndpoint") + .withOutputValue("synhelp-endpoint") + ); + + when(mockCloudFormationClient.waitForStackToComplete(any(String.class), any())).thenReturn(Optional.of(sharedResouces)); + } @Test @@ -300,6 +320,8 @@ public void testBuildAndDeployProd() throws InterruptedException { builder.buildAndDeploy(); verify(mockCloudFormationClient, times(4)).createOrUpdateStack(requestCaptor.capture()); + verify(mockCloudFormationClient).waitForStackToComplete("prod-101-shared-resources", Set.of(mockWaitConditionHandler)); + List list = requestCaptor.getAllValues(); CreateOrUpdateStackRequest request = list.get(0); assertEquals("prod-101-shared-resources", request.getStackName()); @@ -354,7 +376,9 @@ public void testBuildAndDeployProd() throws InterruptedException { assertEquals(15000, tDbProps.getInt("StorageThroughput")); assertFalse(resources.has("WebhookTestApi")); - + assertTrue(resources.has("SynapseHelpCollection")); + assertTrue(resources.has("SynapseHelpKnowledgeBaseExecutionRole")); + assertTrue(resources.has("SynapseHelpKnowledgeBase")); assertTrue(resources.has("bedrockAgentRole")); assertTrue(resources.has("bedrockAgent")); assertEquals("prod-101-agent", resources.getJSONObject("bedrockAgent").getJSONObject("Properties").get("AgentName")); @@ -546,6 +570,8 @@ public void testBuildAndDeployDev() throws InterruptedException { builder.buildAndDeploy(); verify(mockCloudFormationClient, times(4)).createOrUpdateStack(requestCaptor.capture()); + verify(mockCloudFormationClient).waitForStackToComplete("dev-101-shared-resources", Set.of(mockWaitConditionHandler)); + List list = requestCaptor.getAllValues(); CreateOrUpdateStackRequest request = list.get(0); assertEquals("dev-101-shared-resources", request.getStackName()); @@ -598,7 +624,9 @@ public void testBuildAndDeployDev() throws InterruptedException { assertEquals(15000, tDbProps.getInt("StorageThroughput")); assertTrue(resources.has("WebhookTestApi")); - + assertTrue(resources.has("SynapseHelpCollection")); + assertTrue(resources.has("SynapseHelpKnowledgeBaseExecutionRole")); + assertTrue(resources.has("SynapseHelpKnowledgeBase")); assertTrue(resources.has("bedrockAgentRole")); assertTrue(resources.has("bedrockAgent")); assertEquals("dev-101-agent", resources.getJSONObject("bedrockAgent").getJSONObject("Properties").get("AgentName")); @@ -905,7 +933,7 @@ public void testCreateContext() { when(config.getProperty(PROPERTY_KEY_RDS_REPO_SNAPSHOT_IDENTIFIER)).thenReturn(NOSNAPSHOT); String[] noSnapshots = new String[] { NOSNAPSHOT }; when(config.getComaSeparatedProperty(PROPERTY_KEY_RDS_TABLES_SNAPSHOT_IDENTIFIERS)).thenReturn(noSnapshots); - + when(mockStsClient.getCallerIdentity(any())).thenReturn(new GetCallerIdentityResult().withArn("currentIdentityArn")); // call under test VelocityContext context = builder.createSharedContext(); @@ -988,7 +1016,8 @@ public void testCreateContextProd() { when(config.getProperty(PROPERTY_KEY_RDS_REPO_SNAPSHOT_IDENTIFIER)).thenReturn(NOSNAPSHOT); String[] noSnapshots = new String[] { NOSNAPSHOT }; when(config.getComaSeparatedProperty(PROPERTY_KEY_RDS_TABLES_SNAPSHOT_IDENTIFIERS)).thenReturn(noSnapshots); - + when(mockStsClient.getCallerIdentity(any())).thenReturn(new GetCallerIdentityResult().withArn("currentIdentityArn")); + // call under test VelocityContext context = builder.createSharedContext(); @@ -1001,6 +1030,7 @@ public void testCreateContextProd() { assertEquals("Block:{}", context.get(ADMIN_RULE_ACTION)); assertEquals("Retain", context.get(DELETION_POLICY)); + assertEquals("currentIdentityArn", context.get(IDENTITY_ARN)); } diff --git a/src/test/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpCollectionIndexCreationTest.java b/src/test/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpCollectionIndexCreationTest.java new file mode 100644 index 00000000..89aff5f0 --- /dev/null +++ b/src/test/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpCollectionIndexCreationTest.java @@ -0,0 +1,248 @@ +package org.sagebionetworks.template.repo.bedrock; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.function.Consumer; + +import org.apache.logging.log4j.Logger; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.client.opensearch._types.mapping.Property; +import org.opensearch.client.opensearch.indices.CreateIndexRequest; +import org.opensearch.client.opensearch.indices.CreateIndexResponse; +import org.opensearch.client.opensearch.indices.ExistsRequest; +import org.opensearch.client.opensearch.indices.OpenSearchIndicesClient; +import org.opensearch.client.transport.endpoints.BooleanResponse; +import org.sagebionetworks.template.Constants; +import org.sagebionetworks.template.LoggerFactory; +import org.sagebionetworks.template.OpenSearchClientFactory; +import org.sagebionetworks.template.WaitConditionHandler; +import org.sagebionetworks.template.config.RepoConfiguration; + +import com.amazonaws.services.cloudformation.model.StackEvent; + +import software.amazon.awssdk.services.opensearchserverless.OpenSearchServerlessClient; +import software.amazon.awssdk.services.opensearchserverless.model.BatchGetCollectionRequest; +import software.amazon.awssdk.services.opensearchserverless.model.BatchGetCollectionResponse; +import software.amazon.awssdk.services.opensearchserverless.model.CollectionDetail; +import software.amazon.awssdk.services.opensearchserverless.model.CollectionStatus; + +@ExtendWith(MockitoExtension.class) +public class SynapseHelpCollectionIndexCreationTest { + + private final static String COLLECTION_ENDPOINT = "endpoint"; + + @Mock + private LoggerFactory mockLoggerFactory; + + @Mock + private OpenSearchServerlessClient mockOssManagementClient; + + @Mock + private OpenSearchClientFactory mockOpenSearchClientFactory; + + @Mock + private RepoConfiguration mockConfig; + + private WaitConditionHandler handler; + + @Mock + private Logger mockLogger; + + @Mock + private StackEvent mockStackEvent; + + @Mock + private OpenSearchIndicesClient mockOpenSearchIndicesClient; + + @Captor + private ArgumentCaptor> getCollectionRequestCaptor; + + @Captor + private ArgumentCaptor existRequestCaptor; + + @Captor + private ArgumentCaptor createRequestCaptor; + + @BeforeEach + public void before() { + when(mockLoggerFactory.getLogger(any())).thenReturn(mockLogger); + handler = new SynapseHelpCollectionIndexCreation(mockLoggerFactory, mockConfig, mockOssManagementClient, mockOpenSearchClientFactory); + } + + @Test + public void testGetWaitConditionId() { + // Call under test + assertEquals("SynapseHelpCollectionCreateIndexWaitCondition", handler.getWaitConditionId()); + } + + @Test + public void testHandleWithActiveCollection() throws IOException, InterruptedException { + + when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev"); + when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101"); + + when(mockOssManagementClient.batchGetCollection(getCollectionRequestCaptor.capture())).thenReturn(BatchGetCollectionResponse.builder() + .collectionDetails(CollectionDetail.builder().status(CollectionStatus.ACTIVE).collectionEndpoint(COLLECTION_ENDPOINT).build()).build() + ); + + when(mockOpenSearchClientFactory.getIndicesClient(COLLECTION_ENDPOINT)).thenReturn(mockOpenSearchIndicesClient); + + when(mockOpenSearchIndicesClient.exists(existRequestCaptor.capture())).thenReturn(new BooleanResponse(false)); + when(mockOpenSearchIndicesClient.create(createRequestCaptor.capture())).thenReturn( + CreateIndexResponse.of(resp -> resp + .index("vector-idx") + .acknowledged(true) + .shardsAcknowledged(true) + ) + ); + + // Call under test + assertEquals(Optional.of("index-creation-complete"), handler.handle(mockStackEvent)); + + assertEquals( + BatchGetCollectionRequest.builder().names("dev-101-synhelp").build(), + BatchGetCollectionRequest.builder().applyMutation(getCollectionRequestCaptor.getValue()).build() + ); + + ExistsRequest existRequest = existRequestCaptor.getValue(); + + assertEquals(List.of("vector-idx"), existRequest.index()); + + CreateIndexRequest createRequest = createRequestCaptor.getValue(); + + assertEquals("vector-idx", createRequest.index()); + assertTrue(createRequest.settings().knn()); + assertEquals(512, createRequest.settings().knnAlgoParamEfSearch()); + + Property textVectorProp = createRequest.mappings().properties().get("text_vector"); + + assertEquals(1024, textVectorProp.knnVector().dimension()); + assertEquals("hnsw", textVectorProp.knnVector().method().name()); + assertEquals("faiss", textVectorProp.knnVector().method().engine()); + assertEquals("l2", textVectorProp.knnVector().method().spaceType()); + assertTrue(createRequest.mappings().properties().get("text_raw").text().index()); + assertFalse(createRequest.mappings().properties().get("text_metadata").text().index()); + + } + + @Test + public void testHandleWithInactiveCollection() throws IOException, InterruptedException { + + when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev"); + when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101"); + + when(mockOssManagementClient.batchGetCollection(getCollectionRequestCaptor.capture())).thenReturn(BatchGetCollectionResponse.builder() + .collectionDetails(CollectionDetail.builder().status(CollectionStatus.CREATING).collectionEndpoint(COLLECTION_ENDPOINT).build()).build() + ); + + // Call under test + assertEquals(Optional.empty(), handler.handle(mockStackEvent)); + + verifyNoMoreInteractions(mockOpenSearchIndicesClient); + + } + + @Test + public void testHandleWithCollectionNotFound() throws IOException, InterruptedException { + + when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev"); + when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101"); + + when(mockOssManagementClient.batchGetCollection(getCollectionRequestCaptor.capture())).thenReturn(BatchGetCollectionResponse.builder() + .collectionDetails(Collections.emptyList()).build() + ); + + assertThrows(NoSuchElementException.class, () -> { + // Call under test + handler.handle(mockStackEvent); + }); + + verifyNoMoreInteractions(mockOpenSearchIndicesClient); + + } + + @Test + public void testHandleWithExistingIndex() throws IOException, InterruptedException { + + when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev"); + when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101"); + + when(mockOssManagementClient.batchGetCollection(getCollectionRequestCaptor.capture())).thenReturn(BatchGetCollectionResponse.builder() + .collectionDetails(CollectionDetail.builder().status(CollectionStatus.ACTIVE).collectionEndpoint(COLLECTION_ENDPOINT).build()).build() + ); + + when(mockOpenSearchClientFactory.getIndicesClient(COLLECTION_ENDPOINT)).thenReturn(mockOpenSearchIndicesClient); + + when(mockOpenSearchIndicesClient.exists(existRequestCaptor.capture())).thenReturn(new BooleanResponse(true)); + + // Call under test + assertEquals(Optional.of("index-already-exists"), handler.handle(mockStackEvent)); + + assertEquals( + BatchGetCollectionRequest.builder().names("dev-101-synhelp").build(), + BatchGetCollectionRequest.builder().applyMutation(getCollectionRequestCaptor.getValue()).build() + ); + + ExistsRequest existRequest = existRequestCaptor.getValue(); + + assertEquals(List.of("vector-idx"), existRequest.index()); + + verifyNoMoreInteractions(mockOpenSearchIndicesClient); + + } + + @Test + public void testHandleWithIOException() throws IOException { + + when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev"); + when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101"); + + when(mockOssManagementClient.batchGetCollection(getCollectionRequestCaptor.capture())).thenReturn(BatchGetCollectionResponse.builder() + .collectionDetails(CollectionDetail.builder().status(CollectionStatus.ACTIVE).collectionEndpoint(COLLECTION_ENDPOINT).build()).build() + ); + + when(mockOpenSearchClientFactory.getIndicesClient(COLLECTION_ENDPOINT)).thenReturn(mockOpenSearchIndicesClient); + + IOException ex = new IOException("nope"); + + when(mockOpenSearchIndicesClient.exists(existRequestCaptor.capture())).thenThrow(ex); + + IllegalStateException result = assertThrows(IllegalStateException.class, () -> { + // Call under test + handler.handle(mockStackEvent); + }); + + assertEquals(ex, result.getCause()); + + assertEquals( + BatchGetCollectionRequest.builder().names("dev-101-synhelp").build(), + BatchGetCollectionRequest.builder().applyMutation(getCollectionRequestCaptor.getValue()).build() + ); + + ExistsRequest existRequest = existRequestCaptor.getValue(); + + assertEquals(List.of("vector-idx"), existRequest.index()); + + verifyNoMoreInteractions(mockOpenSearchIndicesClient); + + } + + +} diff --git a/src/test/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpKnowledgeBaseDataSourceSyncTest.java b/src/test/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpKnowledgeBaseDataSourceSyncTest.java new file mode 100644 index 00000000..26aa28b8 --- /dev/null +++ b/src/test/java/org/sagebionetworks/template/repo/bedrock/SynapseHelpKnowledgeBaseDataSourceSyncTest.java @@ -0,0 +1,388 @@ +package org.sagebionetworks.template.repo.bedrock; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import java.util.Collections; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Optional; +import java.util.function.Consumer; + +import org.apache.logging.log4j.Logger; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.sagebionetworks.template.Constants; +import org.sagebionetworks.template.LoggerFactory; +import org.sagebionetworks.template.ThreadProvider; +import org.sagebionetworks.template.WaitConditionHandler; +import org.sagebionetworks.template.config.RepoConfiguration; + +import com.amazonaws.services.cloudformation.model.StackEvent; + +import software.amazon.awssdk.services.bedrockagent.BedrockAgentClient; +import software.amazon.awssdk.services.bedrockagent.model.DataSourceSummary; +import software.amazon.awssdk.services.bedrockagent.model.GetIngestionJobRequest; +import software.amazon.awssdk.services.bedrockagent.model.GetIngestionJobResponse; +import software.amazon.awssdk.services.bedrockagent.model.IngestionJob; +import software.amazon.awssdk.services.bedrockagent.model.IngestionJobStatus; +import software.amazon.awssdk.services.bedrockagent.model.IngestionJobSummary; +import software.amazon.awssdk.services.bedrockagent.model.KnowledgeBaseSummary; +import software.amazon.awssdk.services.bedrockagent.model.ListDataSourcesRequest; +import software.amazon.awssdk.services.bedrockagent.model.ListDataSourcesResponse; +import software.amazon.awssdk.services.bedrockagent.model.ListIngestionJobsRequest; +import software.amazon.awssdk.services.bedrockagent.model.ListIngestionJobsResponse; +import software.amazon.awssdk.services.bedrockagent.model.ListKnowledgeBasesRequest; +import software.amazon.awssdk.services.bedrockagent.model.ListKnowledgeBasesResponse; +import software.amazon.awssdk.services.bedrockagent.model.StartIngestionJobRequest; +import software.amazon.awssdk.services.bedrockagent.model.StartIngestionJobResponse; +import software.amazon.awssdk.services.bedrockagent.paginators.ListDataSourcesIterable; +import software.amazon.awssdk.services.bedrockagent.paginators.ListKnowledgeBasesIterable; + +@ExtendWith(MockitoExtension.class) +public class SynapseHelpKnowledgeBaseDataSourceSyncTest { + + private static final String KNOWLEDGE_BASE_ID = "123"; + private static final String DATA_SOURCE_ID = "456"; + private static final String JOB_HASH = "027161e3d3c63c5680c1e4d38da31892f5a32797600fb68bd1e9ab4bc95ceb8a"; + private static final String JOB_ID = "job-id"; + + @Mock + private LoggerFactory mockLoggerFactory; + @Mock + private BedrockAgentClient mockBedrockAgentClient; + @Mock + private RepoConfiguration mockConfig; + @Mock + private ThreadProvider mockThreadProvider; + + private WaitConditionHandler handler; + + @Mock + private Logger mockLogger; + + @Mock + private StackEvent mockStackEvent; + + @Captor + private ArgumentCaptor> listKnowledgeBasesRequestCaptor; + @Captor + private ArgumentCaptor> listDataSourceRequestCaptor; + @Captor + private ArgumentCaptor> listIngestionRequestCaptor; + @Captor + private ArgumentCaptor> startIngestionJobRequestCaptor; + @Captor + private ArgumentCaptor> getIngestionJobRequestCaptor; + + @BeforeEach + public void before() { + when(mockLoggerFactory.getLogger(any())).thenReturn(mockLogger); + handler = new SynapseHelpKnowledgeBaseDataSourceSync(mockLoggerFactory, mockConfig, mockThreadProvider, mockBedrockAgentClient); + } + + @Test + public void testGetWaitConditionId() { + assertEquals("SynapseHelpKnowledgeBaseDataSourceSyncWaitCondition", handler.getWaitConditionId()); + } + + @Test + public void testHandle() throws InterruptedException { + when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev"); + when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101"); + + String knowledgeBaseName = "dev-101-synhelp-knowledge-base"; + + when(mockBedrockAgentClient.listKnowledgeBases(any(ListKnowledgeBasesRequest.class))).thenReturn( + ListKnowledgeBasesResponse.builder().knowledgeBaseSummaries(List.of( + KnowledgeBaseSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).name(knowledgeBaseName).build() + )).build() + ); + + when(mockBedrockAgentClient.listKnowledgeBasesPaginator(listKnowledgeBasesRequestCaptor.capture())).thenReturn( + new ListKnowledgeBasesIterable(mockBedrockAgentClient, ListKnowledgeBasesRequest.builder().build()) + ); + + String dataSourceName = "dev-101-synhelp-datasource"; + + when(mockBedrockAgentClient.listDataSources(any(ListDataSourcesRequest.class))).thenReturn( + ListDataSourcesResponse.builder().dataSourceSummaries(List.of( + DataSourceSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).name(dataSourceName).build() + )).build() + ); + + when(mockBedrockAgentClient.listDataSourcesPaginator(listDataSourceRequestCaptor.capture())).thenReturn( + new ListDataSourcesIterable(mockBedrockAgentClient, ListDataSourcesRequest.builder().build()) + ); + + when(mockBedrockAgentClient.listIngestionJobs(listIngestionRequestCaptor.capture())).thenReturn(ListIngestionJobsResponse.builder().build()); + + IngestionJob job = IngestionJob.builder() + .knowledgeBaseId(KNOWLEDGE_BASE_ID) + .dataSourceId(DATA_SOURCE_ID) + .ingestionJobId(JOB_ID) + .status(IngestionJobStatus.STARTING) + .build(); + + when(mockBedrockAgentClient.startIngestionJob(startIngestionJobRequestCaptor.capture())).thenReturn( + StartIngestionJobResponse.builder().ingestionJob(job).build() + ); + + when(mockBedrockAgentClient.getIngestionJob(getIngestionJobRequestCaptor.capture())).thenReturn( + GetIngestionJobResponse.builder().ingestionJob(job.copy(b -> b + .status(IngestionJobStatus.COMPLETE) + .statistics(stats -> stats + .numberOfDocumentsScanned(10L) + .numberOfNewDocumentsIndexed(5L) + .numberOfDocumentsFailed(5L) + ) + )).build() + ); + + // Call under test + assertEquals(Optional.of("sync-completed"), handler.handle(mockStackEvent)); + + assertEquals( + ListKnowledgeBasesRequest.builder().build(), + ListKnowledgeBasesRequest.builder().applyMutation(listKnowledgeBasesRequestCaptor.getValue()).build() + ); + + assertEquals( + ListDataSourcesRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).build(), + ListDataSourcesRequest.builder().applyMutation(listDataSourceRequestCaptor.getValue()).build() + ); + + assertEquals( + ListIngestionJobsRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).maxResults(1).build(), + ListIngestionJobsRequest.builder().applyMutation(listIngestionRequestCaptor.getValue()).build() + ); + + assertEquals( + StartIngestionJobRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).clientToken(JOB_HASH).build(), + StartIngestionJobRequest.builder().applyMutation(startIngestionJobRequestCaptor.getValue()).build() + ); + + assertEquals( + GetIngestionJobRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).ingestionJobId(JOB_ID).build(), + GetIngestionJobRequest.builder().applyMutation(getIngestionJobRequestCaptor.getValue()).build() + ); + } + + @Test + public void testHandleWithKnowledgeBaseNotFound() throws InterruptedException { + when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev"); + when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101"); + + when(mockBedrockAgentClient.listKnowledgeBases(any(ListKnowledgeBasesRequest.class))).thenReturn( + ListKnowledgeBasesResponse.builder().knowledgeBaseSummaries(Collections.emptyList()).build() + ); + + when(mockBedrockAgentClient.listKnowledgeBasesPaginator(listKnowledgeBasesRequestCaptor.capture())).thenReturn( + new ListKnowledgeBasesIterable(mockBedrockAgentClient, ListKnowledgeBasesRequest.builder().build()) + ); + + assertThrows(NoSuchElementException.class, () -> { + // Call under test + handler.handle(mockStackEvent); + }); + + assertEquals( + ListKnowledgeBasesRequest.builder().build(), + ListKnowledgeBasesRequest.builder().applyMutation(listKnowledgeBasesRequestCaptor.getValue()).build() + ); + + verifyNoMoreInteractions(mockBedrockAgentClient); + } + + @Test + public void testHandleWithDataSourceNotFound() throws InterruptedException { + when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev"); + when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101"); + + String knowledgeBaseName = "dev-101-synhelp-knowledge-base"; + + when(mockBedrockAgentClient.listKnowledgeBases(any(ListKnowledgeBasesRequest.class))).thenReturn( + ListKnowledgeBasesResponse.builder().knowledgeBaseSummaries(List.of( + KnowledgeBaseSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).name(knowledgeBaseName).build() + )).build() + ); + + when(mockBedrockAgentClient.listKnowledgeBasesPaginator(listKnowledgeBasesRequestCaptor.capture())).thenReturn( + new ListKnowledgeBasesIterable(mockBedrockAgentClient, ListKnowledgeBasesRequest.builder().build()) + ); + + when(mockBedrockAgentClient.listDataSources(any(ListDataSourcesRequest.class))).thenReturn( + ListDataSourcesResponse.builder().dataSourceSummaries(Collections.emptyList()).build() + ); + + when(mockBedrockAgentClient.listDataSourcesPaginator(listDataSourceRequestCaptor.capture())).thenReturn( + new ListDataSourcesIterable(mockBedrockAgentClient, ListDataSourcesRequest.builder().build()) + ); + + assertThrows(NoSuchElementException.class, () -> { + // Call under test + handler.handle(mockStackEvent); + }); + + assertEquals( + ListKnowledgeBasesRequest.builder().build(), + ListKnowledgeBasesRequest.builder().applyMutation(listKnowledgeBasesRequestCaptor.getValue()).build() + ); + + assertEquals( + ListDataSourcesRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).build(), + ListDataSourcesRequest.builder().applyMutation(listDataSourceRequestCaptor.getValue()).build() + ); + + verifyNoMoreInteractions(mockBedrockAgentClient); + } + + @Test + public void testHandleWithJobAlreadyStarted() throws InterruptedException { + when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev"); + when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101"); + + String knowledgeBaseName = "dev-101-synhelp-knowledge-base"; + + when(mockBedrockAgentClient.listKnowledgeBases(any(ListKnowledgeBasesRequest.class))).thenReturn( + ListKnowledgeBasesResponse.builder().knowledgeBaseSummaries(List.of( + KnowledgeBaseSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).name(knowledgeBaseName).build() + )).build() + ); + + when(mockBedrockAgentClient.listKnowledgeBasesPaginator(listKnowledgeBasesRequestCaptor.capture())).thenReturn( + new ListKnowledgeBasesIterable(mockBedrockAgentClient, ListKnowledgeBasesRequest.builder().build()) + ); + + String dataSourceName = "dev-101-synhelp-datasource"; + + when(mockBedrockAgentClient.listDataSources(any(ListDataSourcesRequest.class))).thenReturn( + ListDataSourcesResponse.builder().dataSourceSummaries(List.of( + DataSourceSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).name(dataSourceName).build() + )).build() + ); + + when(mockBedrockAgentClient.listDataSourcesPaginator(listDataSourceRequestCaptor.capture())).thenReturn( + new ListDataSourcesIterable(mockBedrockAgentClient, ListDataSourcesRequest.builder().build()) + ); + + when(mockBedrockAgentClient.listIngestionJobs(listIngestionRequestCaptor.capture())).thenReturn( + ListIngestionJobsResponse.builder().ingestionJobSummaries(List.of( + IngestionJobSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).ingestionJobId(JOB_ID).build() + )).build() + ); + + // Call under test + assertEquals(Optional.of("sync-started"), handler.handle(mockStackEvent)); + + assertEquals( + ListKnowledgeBasesRequest.builder().build(), + ListKnowledgeBasesRequest.builder().applyMutation(listKnowledgeBasesRequestCaptor.getValue()).build() + ); + + assertEquals( + ListDataSourcesRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).build(), + ListDataSourcesRequest.builder().applyMutation(listDataSourceRequestCaptor.getValue()).build() + ); + + assertEquals( + ListIngestionJobsRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).maxResults(1).build(), + ListIngestionJobsRequest.builder().applyMutation(listIngestionRequestCaptor.getValue()).build() + ); + + verifyNoMoreInteractions(mockBedrockAgentClient); + } + + @Test + public void testHandleWithJobFailure() throws InterruptedException { + when(mockConfig.getProperty(Constants.PROPERTY_KEY_STACK)).thenReturn("dev"); + when(mockConfig.getProperty(Constants.PROPERTY_KEY_INSTANCE)).thenReturn("101"); + + String knowledgeBaseName = "dev-101-synhelp-knowledge-base"; + + when(mockBedrockAgentClient.listKnowledgeBases(any(ListKnowledgeBasesRequest.class))).thenReturn( + ListKnowledgeBasesResponse.builder().knowledgeBaseSummaries(List.of( + KnowledgeBaseSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).name(knowledgeBaseName).build() + )).build() + ); + + when(mockBedrockAgentClient.listKnowledgeBasesPaginator(listKnowledgeBasesRequestCaptor.capture())).thenReturn( + new ListKnowledgeBasesIterable(mockBedrockAgentClient, ListKnowledgeBasesRequest.builder().build()) + ); + + String dataSourceName = "dev-101-synhelp-datasource"; + + when(mockBedrockAgentClient.listDataSources(any(ListDataSourcesRequest.class))).thenReturn( + ListDataSourcesResponse.builder().dataSourceSummaries(List.of( + DataSourceSummary.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).name(dataSourceName).build() + )).build() + ); + + when(mockBedrockAgentClient.listDataSourcesPaginator(listDataSourceRequestCaptor.capture())).thenReturn( + new ListDataSourcesIterable(mockBedrockAgentClient, ListDataSourcesRequest.builder().build()) + ); + + when(mockBedrockAgentClient.listIngestionJobs(listIngestionRequestCaptor.capture())).thenReturn(ListIngestionJobsResponse.builder().build()); + + IngestionJob job = IngestionJob.builder() + .knowledgeBaseId(KNOWLEDGE_BASE_ID) + .dataSourceId(DATA_SOURCE_ID) + .ingestionJobId(JOB_ID) + .status(IngestionJobStatus.STARTING) + .build(); + + when(mockBedrockAgentClient.startIngestionJob(startIngestionJobRequestCaptor.capture())).thenReturn( + StartIngestionJobResponse.builder().ingestionJob(job).build() + ); + + when(mockBedrockAgentClient.getIngestionJob(getIngestionJobRequestCaptor.capture())).thenReturn( + GetIngestionJobResponse.builder().ingestionJob(job).build(), + GetIngestionJobResponse.builder().ingestionJob(job.copy(b -> b + .status(IngestionJobStatus.FAILED) + .failureReasons("Some failure") + )).build() + ); + + assertEquals("Sync job job-id failed (Status: FAILED, Failures: [Some failure])", assertThrows(IllegalStateException.class, () -> { + // Call under test + handler.handle(mockStackEvent); + }).getMessage()); + + verify(mockBedrockAgentClient, times(2)).getIngestionJob(getIngestionJobRequestCaptor.capture()); + + assertEquals( + ListKnowledgeBasesRequest.builder().build(), + ListKnowledgeBasesRequest.builder().applyMutation(listKnowledgeBasesRequestCaptor.getValue()).build() + ); + + assertEquals( + ListDataSourcesRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).build(), + ListDataSourcesRequest.builder().applyMutation(listDataSourceRequestCaptor.getValue()).build() + ); + + assertEquals( + ListIngestionJobsRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).maxResults(1).build(), + ListIngestionJobsRequest.builder().applyMutation(listIngestionRequestCaptor.getValue()).build() + ); + + assertEquals( + StartIngestionJobRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).clientToken(JOB_HASH).build(), + StartIngestionJobRequest.builder().applyMutation(startIngestionJobRequestCaptor.getValue()).build() + ); + + assertEquals( + GetIngestionJobRequest.builder().knowledgeBaseId(KNOWLEDGE_BASE_ID).dataSourceId(DATA_SOURCE_ID).ingestionJobId(JOB_ID).build(), + GetIngestionJobRequest.builder().applyMutation(getIngestionJobRequestCaptor.getValue()).build() + ); + } +}