Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PLFM-8614: Add KB to the synapse agent with help docs #681

Merged
merged 35 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
e9f0ba2
PLFM-8614: Add open search collection for synapse help
marcomarasca Dec 13, 2024
b177a09
PLFM-8614: Create a dedicated stack for the bedrock agent
marcomarasca Dec 14, 2024
6219402
PLFM-8614: Testing check for collection index
marcomarasca Dec 14, 2024
c6a0478
PLFM-8614: Try to create an index
marcomarasca Dec 14, 2024
e13944a
PLFM-8614: Testing index creation
marcomarasca Dec 14, 2024
c8f0d62
PLFM-8614: Test wait condition
marcomarasca Dec 15, 2024
b9da81f
Revert "PLFM-8614: Create a dedicated stack for the bedrock agent"
marcomarasca Dec 15, 2024
e14a7e7
PLFM-8614: Add dep on policies
marcomarasca Dec 15, 2024
d5a57ad
PLFM-8614: A stack in creation doesn't have outputs yet
marcomarasca Dec 15, 2024
c94d1aa
PLFM-8614: Fix null pointer
marcomarasca Dec 15, 2024
0871ffd
PLFM-8614: Fix property name
marcomarasca Dec 16, 2024
f0a9ab9
PLFM-8614: Only process latest event
marcomarasca Dec 16, 2024
5c31954
PLFM-8614: Fix vector index params
marcomarasca Dec 16, 2024
f46f94c
PLFM-8614: Trigger new wait condition
marcomarasca Dec 16, 2024
7ff145a
PLFM-8614: Fix logging messages
marcomarasca Dec 16, 2024
987aac0
PLFM-8614: Add KB resources
marcomarasca Dec 16, 2024
d841dd4
PLFM-8614: Sync datasource on creation
marcomarasca Dec 16, 2024
0f3a5fc
PLFM-8614: Try without creating index
marcomarasca Dec 16, 2024
42d7df4
PLFM-8614: The index needs to exists before the KB can be created
marcomarasca Dec 16, 2024
fbff200
PLFM-8614: Simplify data access policy
marcomarasca Dec 16, 2024
652bb9e
PLFM-8614: Needs separate access policies
marcomarasca Dec 16, 2024
c603eba
PLFM-8614: Try to create the access policy before the collection
marcomarasca Dec 16, 2024
6d3358a
PLFM-8614: Minor refactor
marcomarasca Dec 16, 2024
9c1771d
PLFM-8614: Split role policy from role
marcomarasca Dec 16, 2024
4b929d2
PLFM-8614: Tweek OSS access policies
marcomarasca Dec 16, 2024
8fa033a
PLFM-8614: Fix OSS permissions
marcomarasca Dec 16, 2024
c429f46
PLFM-8614: Avoid re-processing wait conditions
marcomarasca Dec 17, 2024
6fafa99
PLFM-8614: Add unit test for wait condition handlers
marcomarasca Dec 17, 2024
f780cd2
PLFM-8614: Modify agent template for optional knowledge base
marcomarasca Dec 17, 2024
e02cd43
PLFM-8614: Oh well, let's hack the cloudformation template because re…
marcomarasca Dec 17, 2024
0598882
PLFM-8614: Of course the KB Arn attribute is named differently
marcomarasca Dec 17, 2024
1ddb797
PLFM-8614: Try without the KB
marcomarasca Dec 17, 2024
114f528
PLFM-8614: It works!!
marcomarasca Dec 17, 2024
0f00053
PLFM-8614: Typo
marcomarasca Dec 17, 2024
52f9c87
PLFM-8614: Allow response streaming
marcomarasca Dec 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@
</url>
</repository>
</repositories>

<dependencyManagement>
<dependencies>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bom</artifactId>
<version>${amazon.sdk.v2.version}</version>
<type>pom</type>
<scope>import</scope>
</dependency>
</dependencies>
</dependencyManagement>

<dependencies>
<dependency>
Expand Down Expand Up @@ -225,6 +237,23 @@
<artifactId>guava</artifactId>
<version>31.1-jre</version>
</dependency>
<dependency>
<groupId>org.opensearch.client</groupId>
<artifactId>opensearch-java</artifactId>
<version>2.8.1</version>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>opensearchserverless</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>apache-client</artifactId>
</dependency>
<dependency>
<groupId>software.amazon.awssdk</groupId>
<artifactId>bedrockagent</artifactId>
</dependency>
</dependencies>
<build>
<resources>
Expand Down Expand Up @@ -275,6 +304,7 @@
<properties>
<bouncycastle.version>1.78.1</bouncycastle.version>
<amazon.sdk.version>1.12.296</amazon.sdk.version>
<amazon.sdk.v2.version>2.29.34</amazon.sdk.v2.version>
<log4j.version>2.17.1</log4j.version>
<junit.jupiter.version>5.4.1</junit.jupiter.version>
<junit.vintage.version>5.4.1</junit.vintage.version>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.sagebionetworks.template;

import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

import com.amazonaws.services.cloudformation.model.AmazonCloudFormationException;
Expand All @@ -20,7 +21,7 @@ public interface CloudFormationClient {
* @param stackName
* @return
*/
public boolean doesStackNameExist(String stackName);
boolean doesStackNameExist(String stackName);

/**
* Describe a stack given its name.
Expand All @@ -29,7 +30,7 @@ public interface CloudFormationClient {
* @return
* @throws AmazonCloudFormationException When the stack does not exist.
*/
public Optional<Stack> describeStack(String stackName);
Optional<Stack> describeStack(String stackName);

/**
* Update a stack with the given name using the provided template body.
Expand All @@ -38,7 +39,7 @@ public interface CloudFormationClient {
* @param templateBody
* @return StackId
*/
public void updateStack(CreateOrUpdateStackRequest request);
void updateStack(CreateOrUpdateStackRequest request);

/**
* Create a stack with the given name using the provided template body.
Expand All @@ -47,7 +48,7 @@ public interface CloudFormationClient {
* @param templateBody
* @return StackId
*/
public void createStack(CreateOrUpdateStackRequest request);
void createStack(CreateOrUpdateStackRequest request);

/**
* If a stack does not exist the stack will be created else the stack will be
Expand All @@ -57,33 +58,42 @@ public interface CloudFormationClient {
* @param templateBody
* @return StackId
*/
public void createOrUpdateStack(CreateOrUpdateStackRequest request);
void createOrUpdateStack(CreateOrUpdateStackRequest request);

/**
* Wait for the given stack to complete.
* @param stackName
* @return
* @throws InterruptedException
*/
public Optional<Stack> waitForStackToComplete(String stackName) throws InterruptedException;
Optional<Stack> waitForStackToComplete(String stackName) throws InterruptedException;

/**
* Wait for the given stack to complete and handles any wait condition that is provided in the map (where the key is the logical id of the wait condition)
* @param stackName
* @param waitConditionHandlers
* @return
* @throws InterruptedException
*/
Optional<Stack> waitForStackToComplete(String stackName, Set<WaitConditionHandler> waitConditionHandlers) throws InterruptedException;

/**
*
* @param stackName
* @return
*/
public String getOutput(String stackName, String outputKey);
String getOutput(String stackName, String outputKey);

/**
* Stream over all stacks.
* @return
*/
public Stream<Stack> streamOverAllStacks();
Stream<Stack> streamOverAllStacks();

/**
* Delete a stack by name
* @param stackName
*/
public void deleteStack(String stackName);
void deleteStack(String stackName);

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.Spliterator;
import java.util.Spliterators;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

Expand All @@ -21,11 +25,15 @@
import com.amazonaws.services.cloudformation.model.CreateStackRequest;
import com.amazonaws.services.cloudformation.model.CreateStackResult;
import com.amazonaws.services.cloudformation.model.DeleteStackRequest;
import com.amazonaws.services.cloudformation.model.DeleteStackResult;
import com.amazonaws.services.cloudformation.model.DescribeStackEventsRequest;
import com.amazonaws.services.cloudformation.model.DescribeStacksRequest;
import com.amazonaws.services.cloudformation.model.DescribeStacksResult;
import com.amazonaws.services.cloudformation.model.Output;
import com.amazonaws.services.cloudformation.model.ResourceSignalStatus;
import com.amazonaws.services.cloudformation.model.ResourceStatus;
import com.amazonaws.services.cloudformation.model.SignalResourceRequest;
import com.amazonaws.services.cloudformation.model.Stack;
import com.amazonaws.services.cloudformation.model.StackEvent;
import com.amazonaws.services.cloudformation.model.StackStatus;
import com.amazonaws.services.cloudformation.model.UpdateStackRequest;
import com.amazonaws.services.cloudformation.model.UpdateStackResult;
Expand Down Expand Up @@ -225,7 +233,19 @@ public boolean isStartedInUpdateRollbackComplete(String stackName) {

@Override
public Optional<Stack> waitForStackToComplete(String stackName) throws InterruptedException {
return waitForStackToComplete(stackName, Collections.emptySet());
}

@Override
public Optional<Stack> waitForStackToComplete(String stackName, Set<WaitConditionHandler> waitConditionHandlers) throws InterruptedException {
boolean startedInUpdateRollbackComplete = isStartedInUpdateRollbackComplete(stackName); // Initial state

Map<String, WaitConditionHandler> waitConditionHandlerMap = waitConditionHandlers.stream()
.collect(Collectors.toMap(WaitConditionHandler::getWaitConditionId, Function.identity()));

// To avoid re-processing the same wait condition multiple times we need to keep track of them
Set<String> processedWaitConditionSet = new HashSet<>();

long start = threadProvider.currentTimeMillis();
while (true) {
long elapse = threadProvider.currentTimeMillis() - start;
Expand All @@ -246,10 +266,10 @@ public Optional<Stack> waitForStackToComplete(String stackName) throws Interrupt
return optional;
case CREATE_IN_PROGRESS:
case UPDATE_IN_PROGRESS:
handleWaitConditions(stackName, waitConditionHandlerMap, processedWaitConditionSet);
case DELETE_IN_PROGRESS:
case UPDATE_COMPLETE_CLEANUP_IN_PROGRESS:
logger.info("Waiting for stack: '" + stackName + "' to complete. Current status: " + status.name()
+ "...");
logger.info("Waiting for stack: '" + stackName + "' to complete. Current status: " + status.name() + "...");
threadProvider.sleep(SLEEP_TIME);
break;
case UPDATE_ROLLBACK_COMPLETE:
Expand All @@ -262,6 +282,84 @@ public Optional<Stack> waitForStackToComplete(String stackName) throws Interrupt
}
}
}

void handleWaitConditions(String stackName, Map<String, WaitConditionHandler> waitConditionHandlers, Set<String> processedWaitConditionSet) {
if (waitConditionHandlers.isEmpty()) {
return;
}

Set<String> waitConditionEventIds = new HashSet<>();

List<StackEvent> waitConditionEvents = cloudFormationClient.describeStackEvents(
new DescribeStackEventsRequest().withStackName(stackName)
)
.getStackEvents()
.stream()
.filter(event -> "AWS::CloudFormation::WaitCondition".equals(event.getResourceType()))
// We only need the latest event for each wait condition
.filter(event -> waitConditionEventIds.add(event.getLogicalResourceId()))
.filter(event -> ResourceStatus.CREATE_IN_PROGRESS.equals(ResourceStatus.fromValue(event.getResourceStatus())))
.collect(Collectors.toList());

for (StackEvent waitConditionEvent : waitConditionEvents) {
String waitConditionId = waitConditionEvent.getLogicalResourceId();

if (processedWaitConditionSet.contains(waitConditionId)) {
logger.warn("Wait condition {} already processed, skipping.", waitConditionId);
continue;
}

logger.info("Processing wait condition {} (Status: {}, Reason: {})...", waitConditionId, waitConditionEvent.getResourceStatus(), waitConditionEvent.getResourceStatusReason());

WaitConditionHandler waitConditionHandler = waitConditionHandlers.get(waitConditionId);

if (waitConditionHandler == null) {

cloudFormationClient.signalResource(new SignalResourceRequest()
.withStackName(stackName)
.withLogicalResourceId(waitConditionId)
.withStatus(ResourceSignalStatus.FAILURE)
.withUniqueId("handler-not-found")
);

throw new IllegalStateException("Processing wait condition " + waitConditionId + " failed: could not find an handler.");

} else {
logger.info("Processing wait condition {} started...", waitConditionId);

try {
waitConditionHandler.handle(waitConditionEvent).ifPresentOrElse(signalId -> {
logger.info("Processing wait condition {} completed with signal {}.", waitConditionId, signalId);

cloudFormationClient.signalResource(new SignalResourceRequest()
.withStackName(stackName)
.withLogicalResourceId(waitConditionId)
.withStatus(ResourceSignalStatus.SUCCESS)
.withUniqueId(signalId)
);

processedWaitConditionSet.add(waitConditionId);
}, () -> {
logger.info("Processing wait condition {} didn't return a signal, will process later.", waitConditionId);
});

} catch (Exception e) {
logger.error("Processing wait condition {} failed exceptionally: ", waitConditionId, e);

cloudFormationClient.signalResource(new SignalResourceRequest()
.withStackName(stackName)
.withLogicalResourceId(waitConditionId)
.withStatus(ResourceSignalStatus.FAILURE)
.withUniqueId("handler-failed")
);

throw new IllegalStateException("Processing wait condition " + waitConditionId + " failed.", e);
}
}

}

}

@Override
public String getOutput(String stackName, String outputKey) {
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/sagebionetworks/template/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ public class Constants {
public static final String HOSTED_ZONE = "hostedZone";
public static final String VPC_ENDPOINTS_COLOR = "vpcEndpointsColor";
public static final String VPC_ENDPOINTS_AZ = "vpcEndpointsAz";
public static final String IDENTITY_ARN = "identityArn";

public static final String CAPABILITY_NAMED_IAM = "CAPABILITY_NAMED_IAM";
public static final String OUTPUT_NAME_SUFFIX_REPOSITORY_DB_ENDPOINT = "RepositoryDBEndpoint";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package org.sagebionetworks.template;

import org.opensearch.client.opensearch.indices.OpenSearchIndicesClient;

public interface OpenSearchClientFactory {

OpenSearchIndicesClient getIndicesClient(String collectionEndpoint);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package org.sagebionetworks.template;

import org.opensearch.client.opensearch.indices.OpenSearchIndicesClient;
import org.opensearch.client.transport.aws.AwsSdk2Transport;
import org.opensearch.client.transport.aws.AwsSdk2TransportOptions;

import com.google.inject.Inject;

import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.regions.Region;

public class OpenSearchClientFactoryImpl implements OpenSearchClientFactory {

private SdkHttpClient httpClient;

@Inject
public OpenSearchClientFactoryImpl(SdkHttpClient httpClient) {
this.httpClient = httpClient;
}

public OpenSearchIndicesClient getIndicesClient(String collectionEndpoint) {
return new OpenSearchIndicesClient(
new AwsSdk2Transport(
httpClient,
collectionEndpoint.replace("https://", ""),
"aoss",
Region.US_EAST_1,
AwsSdk2TransportOptions.builder().build()
)
);
}

}
Loading
Loading