Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fixes 3737: Add support for AWS Bedrock for ml procedures #388

Closed
wants to merge 12 commits into from
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ subprojects {

ext {
// NB: due to version.json generation by parsing this file, the next line must not have any if/then/else logic
neo4jVersion = "5.14.0"
neo4jVersion = "5.12.0"
// instead we apply the override logic here
neo4jVersionEffective = project.hasProperty("neo4jVersionOverride") ? project.getProperty("neo4jVersionOverride") : neo4jVersion
testContainersVersion = '1.18.3'
Expand Down
65 changes: 65 additions & 0 deletions docs/asciidoc/modules/ROOT/pages/ml/bedrock.adoc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
[[aws-bedrock]]
= AWS Bedrock procedures


These procedures leverage the https://aws.amazon.com/bedrock/[Amazon Bedrock API].


Here is a list of all available Aws Bedrock procedures:


[opts=header, cols="1, 4", separator="|"]
|===
|name|description
|apoc.ml.bedrock.list($config)|To get the list of foundation or custom models
|apoc.ml.bedrock.jurassic(body, $config)|To create an API call to `Jurassic-2` model
|apoc.ml.bedrock.anthropic.claude(body, $config)|To create an API call to ``Titan Embedding``s model
|apoc.ml.bedrock.titan.embedding(body, $config)|To create an API call to `Claude` model
|apoc.ml.bedrock.stability(body, $config)|To create an API call to `Stable Diffusion` model
|apoc.ml.bedrock.custom(body, $config)|To create a customizable Bedrock API call
|===

All the procedures, leverage the `apoc.ml.bedrock.custom` procedures,
and support the same config parameter, but unlike the `custom` one,
they have some different default parameters and model id.
Moreover the return data is consistent with the called API,
instead of returning a generic `Object` as a result


== Config

.Config parameters
[opts=header, cols="1,1,1,5"]
|===
| name | type | default | description
| secretKey | String | null | The AWS key ID. We can also evaluate it via `apoc.conf`, with the key `apoc.aws.key.id`. As an alternative to the pair keyId-secretKey, we can directly pass the https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html[aws V4 signature] via the `headers` config
| keyId | String | null | The AWS secret access key. We can also evaluate it via `apoc.conf`, with the key `apoc.aws.secret.id`. As an alternative to the pair keyId-secretKey, we can directly pass the https://docs.aws.amazon.com/AmazonS3/latest/API/sig-v4-header-based-auth.html[aws V4 signature] via the `headers` config
| region | String | the one calculated from `endpoint` config | The AWS region
| endpoint | String | see below | The AWS endpoint.
| method | String | `"POST"` (or `"GET"` with the `apoc.ml.bedrock.list` procedure) | TODO
| headers | Map<String, Object> | `{}` | The additional headers
| modelId | String | see below | The modelId. If created,
|===

modelId and endpoint - TODO

== Usage examples

TODO

=== Custom AWS API Call




- endpoint customizable
- keyId e secretKey, oppure passare direttamente nell'header


// todo - document that modelId goto endpoint --> https://bedrock-runtime.us-east-1.amazonaws.com/model/<MODEL_ID>/invoke
// e che endpoint ha la priorità


// todo - all procs accept same configs.

// todo - apocConfig().setProperty(APOC_IMPORT_FILE_ENABLED, true);
1 change: 1 addition & 0 deletions docs/asciidoc/modules/ROOT/pages/ml/index.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ This section includes:

* xref::ml/vertexai.adoc[]
* xref::ml/openai.adoc[]
* xref::ml/bedrock.adoc[]
2 changes: 2 additions & 0 deletions extended/src/main/java/apoc/ExtendedApocConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ public class ExtendedApocConfig extends LifecycleAdapter
public static final String APOC_UUID_ENABLED_DB = "apoc.uuid.enabled.%s";
public static final String APOC_UUID_FORMAT = "apoc.uuid.format";
public static final String APOC_OPENAI_KEY = "apoc.openai.key";
public static final String APOC_AWS_KEY_ID = "apoc.aws.key.id";
public static final String APOC_AWS_SECRET_KEY = "apoc.aws.secret.key";
public enum UuidFormatType { hex, base64 }

// These were earlier added via the Neo4j config using the ApocSettings.java class
Expand Down
152 changes: 152 additions & 0 deletions extended/src/main/java/apoc/ml/bedrock/AwsSignatureV4Generator.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package apoc.ml.bedrock;

import org.apache.commons.lang3.tuple.Pair;

import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.time.ZoneOffset;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.stream.Collectors;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;


public class AwsSignatureV4Generator {

public static final String AWS_SERVICE_NAME = "bedrock";

/**
* Generates signing headers for HTTP request in accordance with Amazon AWS API Signature version 4 process.
* <p>
* Following steps outlined here: <a href="https://docs.aws.amazon.com/general/latest/gr/signature-version-4.html">docs.aws.amazon.com</a>
* <p>
* @param conf - The {@link BedrockConfig config}
* @param headers - The HTTP headers
* @param body - The HTTP body in bytes
*/
public static Map<String, Object> calculateAuthorizationHeaders(
BedrockConfig conf,
Map<String, Object> headers,
byte[] body
) throws MalformedURLException {
headers = new HashMap<>(headers);

String isoDateTime = DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'").format(ZonedDateTime.now(ZoneOffset.UTC));

URL url = new URL(conf.getEndpoint());

String host = url.getHost();
String path = url.getPath();
String query = url.getQuery();

String bodySha256 = hex(toSha256(body));
String isoDateOnly = isoDateTime.substring(0, 8);

headers.put("Host", host);
headers.put("X-Amz-Date", isoDateTime);

Pair<String, String> pairSignedHeaderAndCanonicalHash = createCanonicalRequest(conf.getMethod(), headers, path, query, bodySha256);

Pair<String, String> pairCredentialAndStringSign = createStringToSign(conf.getRegion(), isoDateTime, isoDateOnly, pairSignedHeaderAndCanonicalHash);

String signature = calculateSignature(conf.getSecretKey(), conf.getRegion(), isoDateOnly, pairCredentialAndStringSign.getRight());

createAuthorizationHeader(conf, headers, pairSignedHeaderAndCanonicalHash, pairCredentialAndStringSign, signature);

return headers;
}

private static void createAuthorizationHeader(BedrockConfig conf, Map<String, Object> headers, Pair<String, String> pairSignedHeaderAndCanonicalHash, Pair<String, String> pairCredentialAndStringSign, String signature) {
String authStringParameter = "AWS4-HMAC-SHA256 Credential=" + conf.getKeyId() + "/" + pairCredentialAndStringSign.getLeft()
+ ", SignedHeaders=" + pairSignedHeaderAndCanonicalHash.getLeft()
+ ", Signature=" + signature;

headers.put("Authorization", authStringParameter);
}

/**
* Based on <a href="https://docs.aws.amazon.com/general/latest/gr/sigv4-create-string-to-sign.html">sigv4-create-string-to-sign</a>
*/
private static Pair<String, String> createStringToSign(String awsRegion, String isoDateTime, String isoJustDate, Pair<String, String> pairSignedHeaderCanonicalHash) {
List<String> stringToSignLines = new ArrayList<>();
stringToSignLines.add("AWS4-HMAC-SHA256");
stringToSignLines.add(isoDateTime);
String credentialScope = isoJustDate + "/" + awsRegion + "/" + AWS_SERVICE_NAME + "/aws4_request";
stringToSignLines.add(credentialScope);
stringToSignLines.add(pairSignedHeaderCanonicalHash.getRight());
String stringToSign = String.join("\n", stringToSignLines);
return Pair.of(credentialScope, stringToSign);
}

/**
* Based on <a href="https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html">sigv4-create-canonical-request</a>
*/
private static Pair<String, String> createCanonicalRequest(String method, Map<String, Object> headers, String path, String query, String bodySha256) {
List<String> canonicalRequestLines = new ArrayList<>();
canonicalRequestLines.add(method);
canonicalRequestLines.add(path);
canonicalRequestLines.add(query);
List<String> hashedHeaders = new ArrayList<>();
List<String> headerKeysSorted = headers.keySet().stream().sorted(Comparator.comparing(e -> e.toLowerCase(Locale.US))).toList();
for (String key : headerKeysSorted) {
hashedHeaders.add(key.toLowerCase(Locale.US));
canonicalRequestLines.add(key.toLowerCase(Locale.US) + ":" + normalizeSpaces((String) headers.get(key)));
}
canonicalRequestLines.add(null); // new line required after headers
String signedHeaders = String.join(";", hashedHeaders);
canonicalRequestLines.add(signedHeaders);
canonicalRequestLines.add(bodySha256);
String canonicalRequestBody = canonicalRequestLines.stream().map(line -> line == null ? "" : line).collect(Collectors.joining("\n"));
String canonicalRequestHash = hex(toSha256(canonicalRequestBody.getBytes(StandardCharsets.UTF_8)));
return Pair.of(signedHeaders, canonicalRequestHash);
}

/**
* Based on <a href="https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html">sigv4-calculate-signature</a>
*/
private static String calculateSignature(String awsSecret, String awsRegion, String isoJustDate, String stringToSign) {
byte[] kDate = toHmac(("AWS4" + awsSecret).getBytes(StandardCharsets.UTF_8), isoJustDate);
byte[] kRegion = toHmac(kDate, awsRegion);
byte[] kService = toHmac(kRegion, AWS_SERVICE_NAME);
byte[] kSigning = toHmac(kService, "aws4_request");
return hex(toHmac(kSigning, stringToSign));
}

private static String normalizeSpaces(String value) {
return value.replaceAll("\\s+", " ").trim();
}

public static String hex(byte[] a) {
StringBuilder sb = new StringBuilder(a.length * 2);
for(byte b: a) {
sb.append(String.format("%02x", b));
}
return sb.toString();
}

private static byte[] toSha256(byte[] bytes) {
try {
MessageDigest digest = MessageDigest.getInstance("SHA-256");
digest.update(bytes);
return digest.digest();
} catch (Exception e) {
throw new RuntimeException(e);
}
}

public static byte[] toHmac(byte[] key, String msg) {
try {
Mac mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(key, "HmacSHA256"));
return mac.doFinal(msg.getBytes(StandardCharsets.UTF_8));
} catch (Exception e) {
throw new RuntimeException(e);
}
}

}
127 changes: 127 additions & 0 deletions extended/src/main/java/apoc/ml/bedrock/Bedrock.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
package apoc.ml.bedrock;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;

import apoc.Description;
import apoc.result.ObjectResult;
import apoc.util.ExtendedUtil;
import apoc.util.Util;
import com.fasterxml.jackson.core.JsonProcessingException;
import org.apache.http.impl.client.CloseableHttpClient;

import org.apache.http.impl.client.HttpClientBuilder;
import org.neo4j.procedure.Name;
import org.neo4j.procedure.Procedure;

import static apoc.ml.bedrock.AwsSignatureV4Generator.calculateAuthorizationHeaders;
import static apoc.ml.bedrock.BedrockInvokeConfig.MODEL_ID;
import static apoc.util.JsonUtil.OBJECT_MAPPER;
import static apoc.ml.bedrock.BedrockInvokeResult.*;
import static apoc.ml.bedrock.BedrockUtil.ModelId.*;
import static apoc.ml.bedrock.BedrockUtil.ALL;
import static apoc.ml.bedrock.BedrockUtil.JSON;


public class Bedrock {

@Procedure
public Stream<ModelItemResult> list(@Name(value = "config", defaultValue = "{}") Map<String, Object> config) {

BedrockConfig conf = new BedrockGetModelsConfig(config);

return executeRequestCommon(null, "modelSummaries[*]", conf)
.flatMap(i -> ((List<Map<String, Object>>) i).stream())
.map(ModelItemResult::new);
}

@Procedure
@Description("To create a customizable bedrock call")
public Stream<ObjectResult> custom(@Name(value = "body") Object body,
@Name(value = "config", defaultValue = "{}") Map<String, Object> config) {

return executeCustomRequest(body, config, null)
.map(ObjectResult::new);
}

@Procedure
public Stream<Jurassic> jurassic(@Name(value = "body") Object body,
@Name(value = "config", defaultValue = "{}") Map<String, Object> config) {

config.putIfAbsent(MODEL_ID, JURASSIC_2_ULTRA.id());

return executeRequestReturningMap(body, config, null)
.map(Jurassic::from);
}

@Procedure("apoc.ml.bedrock.anthropic.claude")
public Stream<AnthropicClaude> anthropicClaude(@Name(value = "body") Object body,
@Name(value = "config", defaultValue = "{}") Map<String, Object> config) {
config.putIfAbsent(MODEL_ID, CLAUDE_V2.id());

return executeRequestReturningMap(body, config, null)
.map(AnthropicClaude::from);
}

@Procedure("apoc.ml.bedrock.titan.embedding")
public Stream<TitanEmbedding> titanEmbedding(@Name(value = "body") Object body,
@Name(value = "config", defaultValue = "{}") Map<String, Object> config) {
config.putIfAbsent(MODEL_ID, TITAN_EMBEDDING_G1.id());

return executeRequestReturningMap(body, config, null)
.map(TitanEmbedding::from);
}

@Procedure("apoc.ml.bedrock.stability")
public Stream<StabilityAi> stability(@Name(value = "body") Object body,
@Name(value = "config", defaultValue = "{}") Map<String, Object> config) {
config.putIfAbsent(MODEL_ID, STABLE_DIFFUSION_XL.id());

return executeRequestReturningMap(body, config, "$.artifacts[0]")
.map(StabilityAi::from);
}


private Stream<Map<String, Object>> executeRequestReturningMap(Object body, Map<String, Object> config, String path) {
return executeCustomRequest(body, config, path)
.map(i -> (Map<String, Object>) i);
}

private Stream<Object> executeCustomRequest(Object body, Map<String, Object> config, String path) {
BedrockConfig conf = new BedrockInvokeConfig(config);

return executeRequestCommon(body, path, conf);
}

private Stream<Object> executeRequestCommon(Object body, String path, BedrockConfig conf) {
try {
String bodyString = getBodyAsString(body);
Map<String, Object> headers = new HashMap<>(conf.getHeaders());
headers.putIfAbsent("Content-Type", JSON);
headers.putIfAbsent("accept", ALL);

headers = calculateAuthorizationHeaders(conf, headers, bodyString.getBytes());

CloseableHttpClient httpClient = HttpClientBuilder.create().build();

return ExtendedUtil.getModelItemResultStream(conf.getMethod(), httpClient, bodyString, headers, conf.getEndpoint(), path, List.of())
.onClose(() -> Util.close(httpClient));
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private String getBodyAsString(Object body) throws JsonProcessingException {
if (body == null) {
return "";
}
if (body instanceof String bodyString) {
return bodyString;
}
return OBJECT_MAPPER.writeValueAsString(body);
}

}
Loading