diff --git a/docs/asciidoc/modules/ROOT/pages/ml/bedrock.adoc b/docs/asciidoc/modules/ROOT/pages/ml/bedrock.adoc new file mode 100644 index 0000000000..602411e102 --- /dev/null +++ b/docs/asciidoc/modules/ROOT/pages/ml/bedrock.adoc @@ -0,0 +1,11 @@ +TODO: + +- endpoint customizable +- keyId e secretKey, oppure passare direttamente nell'header + + +// todo - documentare che con modelId va su endpoint --> https://bedrock-runtime.us-east-1.amazonaws.com/model//invoke +// e che endpoint ha la priorità + + +// todo - dire che tutte le procedure accettano gli stessi config. \ No newline at end of file diff --git a/docs/asciidoc/modules/ROOT/pages/ml/index.adoc b/docs/asciidoc/modules/ROOT/pages/ml/index.adoc index 2a004be27b..35d0ddf14c 100644 --- a/docs/asciidoc/modules/ROOT/pages/ml/index.adoc +++ b/docs/asciidoc/modules/ROOT/pages/ml/index.adoc @@ -9,3 +9,4 @@ This section includes: * xref::ml/vertexai.adoc[] * xref::ml/openai.adoc[] +* xref::ml/bedrock.adoc[] diff --git a/extended/src/main/java/apoc/ml/bedrock/AmazonRequestSignatureV4Utils.java b/extended/src/main/java/apoc/ml/bedrock/AwsRequestSignatureV4Converter.java similarity index 68% rename from extended/src/main/java/apoc/ml/bedrock/AmazonRequestSignatureV4Utils.java rename to extended/src/main/java/apoc/ml/bedrock/AwsRequestSignatureV4Converter.java index d18826dda2..fa22fc7b49 100644 --- a/extended/src/main/java/apoc/ml/bedrock/AmazonRequestSignatureV4Utils.java +++ b/extended/src/main/java/apoc/ml/bedrock/AwsRequestSignatureV4Converter.java @@ -1,8 +1,8 @@ package apoc.ml.bedrock; import org.apache.commons.lang3.tuple.Pair; -import org.jetbrains.annotations.NotNull; +import java.net.MalformedURLException; import java.net.URL; import java.nio.charset.StandardCharsets; import java.security.MessageDigest; @@ -16,48 +16,38 @@ import javax.crypto.spec.SecretKeySpec; -public class AmazonRequestSignatureV4Utils { +public class AwsRequestSignatureV4Converter { + + public static final String AWS_SERVICE_NAME = "bedrock"; /** * Generates signing headers for HTTP request in accordance with Amazon AWS API Signature version 4 process. *

* Following steps outlined here: docs.aws.amazon.com - * - * This method takes many arguments as read-only, but adds necessary headers to @{code headers} argument, which is a map. - * The caller should make sure those parameters are copied to the actual request object. - *

- * The ISO8601 date parameter can be created by making a call to:
- * - {@code java.time.format.DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'").format(ZonedDateTime.now(ZoneOffset.UTC))}
- * or, if you prefer joda:
- * - {@code org.joda.time.format.ISODateTimeFormat.basicDateTimeNoMillis().print(DateTime.now().withZone(DateTimeZone.UTC))} * - * @param method - HTTP request method, (GET|POST|DELETE|PUT|...), e.g., {@link java.net.HttpURLConnection#getRequestMethod()} - * @param headers - HTTP request header map. This map is going to have entries added to it by this method. Initially populated with - * headers to be included in the signature. Like often compulsory 'Host' header. e.g., {@link java.net.HttpURLConnection#getRequestProperties()}. - * @param body - The binary request body, for requests like POST. - * @param awsIdentity - AWS Identity, e.g., "AKIAJTOUYS27JPVRDUYQ" - * @param awsSecret - AWS Secret Key, e.g., "I8Q2hY819e+7KzBnkXj66n1GI9piV+0p3dHglAzQ" - * @param awsRegion - AWS Region, e.g., "us-east-1" - * @param awsService - AWS Service, e.g., "route53" + * @param method - HTTP request method, (GET|POST|DELETE|PUT|...) + * @param conf - The {@link BedrockConfig} + * @param headers - The HTTP headers + * @param body - The HTTP payload in bytes */ public static Map calculateAuthorizationHeaders( - String method, - URL url, // String path, String query, + String method, + BedrockConfig conf, Map headers, - byte[] body, - String awsIdentity, String awsSecret, String awsRegion, String awsService - ) { + byte[] body + ) throws MalformedURLException { headers = new HashMap<>(headers); String isoDateTime = DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'").format(ZonedDateTime.now(ZoneOffset.UTC)); - + + URL url = new URL(conf.getEndpoint()); + String host = url.getHost(); String path = url.getPath(); String query = url.getQuery(); - -// try { String bodySha256 = hex(sha256(body)); - String isoJustDate = isoDateTime.substring(0, 8); // Cut the date portion of a string like '20150830T123600Z'; + // create a string like '20150830T123600Z'; + String isoDateOnly = isoDateTime.substring(0, 8); headers.put("Host", host); // headers.put("X-Amz-Content-Sha256", bodySha256); @@ -65,11 +55,11 @@ public static Map calculateAuthorizationHeaders( Pair pairSignedHeaderAndCanonicalHash = createCanonicalRequest(method, headers, path, query, bodySha256); - Pair pairCredentialAndStringSign = createStringToSign(awsRegion, awsService, isoDateTime, isoJustDate, pairSignedHeaderAndCanonicalHash); + Pair pairCredentialAndStringSign = createStringToSign(conf.getRegion(), isoDateTime, isoDateOnly, pairSignedHeaderAndCanonicalHash); - String signature = calculateSignature(awsSecret, awsRegion, awsService, isoJustDate, pairCredentialAndStringSign.getRight()); + String signature = calculateSignature(conf.getSecretKey(), conf.getRegion(), isoDateOnly, pairCredentialAndStringSign.getRight()); - String authParameter = "AWS4-HMAC-SHA256 Credential=" + awsIdentity + "/" + pairCredentialAndStringSign.getLeft() + ", SignedHeaders=" + pairSignedHeaderAndCanonicalHash.getLeft() + ", Signature=" + signature; + String authParameter = "AWS4-HMAC-SHA256 Credential=" + conf.getKeyId() + "/" + pairCredentialAndStringSign.getLeft() + ", SignedHeaders=" + pairSignedHeaderAndCanonicalHash.getLeft() + ", Signature=" + signature; headers.put("Authorization", authParameter); return headers; @@ -78,11 +68,11 @@ public static Map calculateAuthorizationHeaders( /** * Based on sigv4-create-string-to-sign */ - private static Pair createStringToSign(String awsRegion, String awsService, String isoDateTime, String isoJustDate, Pair pairSignedHeaderCanonicalHash) { + private static Pair createStringToSign(String awsRegion, String isoDateTime, String isoJustDate, Pair pairSignedHeaderCanonicalHash) { List stringToSignLines = new ArrayList<>(); stringToSignLines.add("AWS4-HMAC-SHA256"); stringToSignLines.add(isoDateTime); - String credentialScope = isoJustDate + "/" + awsRegion + "/" + awsService + "/aws4_request"; + String credentialScope = isoJustDate + "/" + awsRegion + "/" + AWS_SERVICE_NAME + "/aws4_request"; stringToSignLines.add(credentialScope); stringToSignLines.add(pairSignedHeaderCanonicalHash.getRight()); String stringToSign = String.join("\n", stringToSignLines); @@ -115,10 +105,10 @@ private static Pair createCanonicalRequest(String method, Mapsigv4-calculate-signature */ - private static String calculateSignature(String awsSecret, String awsRegion, String awsService, String isoJustDate, String stringToSign) { + private static String calculateSignature(String awsSecret, String awsRegion, String isoJustDate, String stringToSign) { byte[] kDate = hmac(("AWS4" + awsSecret).getBytes(StandardCharsets.UTF_8), isoJustDate); byte[] kRegion = hmac(kDate, awsRegion); - byte[] kService = hmac(kRegion, awsService); + byte[] kService = hmac(kRegion, AWS_SERVICE_NAME); byte[] kSigning = hmac(kService, "aws4_request"); return hex(hmac(kSigning, stringToSign)); } diff --git a/extended/src/main/java/apoc/ml/bedrock/Bedrock.java b/extended/src/main/java/apoc/ml/bedrock/Bedrock.java index acf1f6435e..0b0dcf0073 100644 --- a/extended/src/main/java/apoc/ml/bedrock/Bedrock.java +++ b/extended/src/main/java/apoc/ml/bedrock/Bedrock.java @@ -4,261 +4,168 @@ import java.io.InputStream; import java.net.HttpURLConnection; import java.net.URL; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; +import apoc.Description; import apoc.result.ObjectResult; +import apoc.util.ExtendedUtil; import apoc.util.JsonUtil; +import apoc.util.Util; import org.apache.http.HttpResponse; +import org.apache.http.client.HttpClient; import org.apache.http.client.methods.HttpGet; -import org.apache.http.impl.client.DefaultHttpClient; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.DefaultHttpClient; +import org.apache.http.impl.client.HttpClientBuilder; +import org.jetbrains.annotations.NotNull; +import org.neo4j.graphdb.Transaction; import org.neo4j.procedure.Name; import org.neo4j.procedure.Procedure; -import static apoc.ApocConfig.apocConfig; -import static apoc.ExtendedApocConfig.APOC_AWS_KEY_ID; -import static apoc.ExtendedApocConfig.APOC_AWS_SECRET_KEY; -import static apoc.ml.bedrock.AmazonRequestSignatureV4Utils.calculateAuthorizationHeaders; +import static apoc.ml.bedrock.AwsRequestSignatureV4Converter.calculateAuthorizationHeaders; +import static apoc.ml.bedrock.BedrockInvokeConfig.MODEL_ID; import static apoc.util.JsonUtil.OBJECT_MAPPER; import static apoc.util.JsonUtil.streamObjetsFromIStream; +import static apoc.ml.bedrock.BedrockInvokeResult.*; + -/* -TODO: -https://sdk.amazonaws.com/java/api/latest/software/amazon/awssdk/services/bedrock/BedrockClient.html#createProvisionedModelThroughput(software.amazon.awssdk.services.bedrock.model.CreateProvisionedModelThroughputRequest) - - */ public class Bedrock { public static final String ALL = "*/*"; public static final String JSON = "application/json"; - public static final String PNG = "image/png"; - // @Context -// public ApocConfig apocConfig; - - // todo - forse basta creare delle classi che estendono interfaccia e basta... - // farei tipo new CustomModel(...) - // todo - implement - enum ModelId { - JURASSIC_2_MID("ai21.j2-mid-v1", ALL, null), - JURASSIC_2_ULTRA("ai21.j2-ultra-v1", ALL, null), - - TITAN_EMBEDDING_G1("amazon.titan-embed-text-v1", ALL, null), - TITAN_TEXT_G1_EXPRESS("amazon.titan-text-express-v1", ALL, null), - - CLAUDE_V1("anthropic.claude-v1", JSON, null), - CLAUDE_V2("anthropic.claude-v2", JSON, null), - CLAUDE_INSTANT("anthropic.claude-instant-v1", JSON, null), - - STABLE_DIFFUSION_XL("stability.stable-diffusion-xl-v0", PNG, "$.artifacts[0]"); -// CUSTOM("idName", ALL, null); - - private final String id; - private final String acceptValue; - private final String jsonPath; - - ModelId(String id, String acceptValue, String jsonPath) { - this.id = id; - this.acceptValue = acceptValue; - this.jsonPath = jsonPath; - } + @Procedure("apoc.ml.bedrock.list") + public Stream list(@Name(value = "config", defaultValue = "{}") Map config) throws IOException { - public String getId() { - return id; - } + BedrockConfig conf = new BedrockModelsConfig(config); + Map headers = Map.of("Content-Type", JSON); - public String getAcceptValue() { - return acceptValue; - } + headers = calculateAuthorizationHeaders("GET", conf, headers, "".getBytes()); - public String getJsonPath() { - return jsonPath; - } - - // todo - forse inutile.. - public static ModelId from(String id) { - for (ModelId modelId: ModelId.values()) { - if (modelId.getId().equals(id)) { - return modelId; - } - } - return ModelId.TITAN_EMBEDDING_G1; - } - } - - enum GetModel { - CUSTOM("custom-models"), - FOUNDATION("foundation-models"); - - private final String path; - GetModel(String path) { - this.path = path; - } + String path = "modelSummaries[*]"; - public String getPath() { - return path; - } + CloseableHttpClient httpClient = HttpClientBuilder.create().build(); + return getModelItemResultStream(conf, httpClient,null, headers, path, + objectStream -> objectStream + .flatMap(i -> ((List>) i).stream()) + .map(ModelItemResult::new) + .onClose(() -> Util.close(httpClient)) + ); } - - - - + public static Stream getModelItemResultStream(BedrockConfig conf, HttpClient client, String payload, Map headers, String path, + Function, Stream> function) { + return ExtendedUtil.getModelItemResultStream(conf.getMethod(), client, payload, headers, conf.getEndpoint(), path, List.of(), function); +// +// HttpRequestBase request = ExtendedUtil.fromMethodName(conf.getMethod(), endpoint); +// +// headers.forEach((k, v) -> request.addHeader(k, v.toString())); +// +// try (CloseableHttpClient httpClient = HttpClientBuilder.create().build()) { +// HttpResponse response = httpClient.execute(request); +// +// InputStream stream = response.getEntity().getContent(); +// +// Stream objtream = streamObjetsFromIStream(stream, path, of); +// +// return function.apply(objtream); +//// return objectStream +//// .flatMap(i -> ((List>) i).stream()) +//// .map(ModelItemResult::new); +// } catch (Exception e) { +// throw new RuntimeException(e); +// } + } - // todo - generic function?? +// public T streamWithHttpClient(Function action) { +// try (Transaction tx = db.beginTx()) { +// T result = action.apply(tx); +// return result; +// } +// } - public record ModelItemResult(String modelId, String modelArn,String modelName, String providerName, Boolean responseStreamingSupported, - List customizationsSupported, List inferenceTypesSupported, List inputModalities, - List outputModalities) { + @Procedure + @Description("To create a customizabled bedrock call") + public Stream custom(@Name(value = "body") Object body, + @Name(value = "config", defaultValue = "{}") Map config) throws Exception { - public ModelItemResult(Map map) { - this((String) map.get("modelId"), - (String) map.get("modelArn"), - (String) map.get("modelName"), - (String) map.get("providerName"), - (Boolean) map.get("responseStreamingSupported"), - (List) map.get("customizationsSupported"), - (List) map.get("inferenceTypesSupported"), - (List) map.get("inputModalities"), - (List) map.get("outputModalities") - ); - } + return executeInvokeRequest(body, config, null) + .map(ObjectResult::new); } - // TODO - list models?? - @Procedure("apoc.ml.bedrock.list") - public Stream list(@Name(value = "type") String type, - @Name(value = "config", defaultValue = "{}") Map config) throws Exception { - GetModel getModel; - try { - getModel = GetModel.valueOf(type); - } catch (IllegalArgumentException e) { - throw new RuntimeException("The type config can be one of the following: " + Arrays.toString(GetModel.values())); - } - - URL url = new URL("https://bedrock.us-east-1.amazonaws.com/" + getModel.getPath()); - // todo - maybe remove "method", "GET" - Map headers = Map.of("Content-Type", "application/json", - "method", "GET"); - BedrockConfig conf = new BedrockConfig(config, url.toString()); - - String payload = ""; - headers = calculateAuthorizationHeaders("GET", url, headers, payload.getBytes(), - conf.getKeyId(), conf.getSecretKey(), "us-east-1", "bedrock"); - - - HttpGet request = new HttpGet(url.toString()); - headers.forEach((k,v) -> request.addHeader(k, v.toString())); - - try (DefaultHttpClient httpClient = new DefaultHttpClient()) { - HttpResponse response = httpClient.execute(request); - - InputStream stream = response.getEntity().getContent(); - - Stream objectStream = streamObjetsFromIStream(stream, "modelSummaries[*]", List.of()); -// Stream objectStream = JsonUtil.loadJson(url.toString(), headers, null, path, true, List.of()); - - return objectStream - .flatMap(i -> ((List>) i).stream()) - .map(ModelItemResult::new); - } + @Procedure("apoc.ml.bedrock.jurassic") + public Stream jurassic2(@Name(value = "body") Object body, + @Name(value = "config", defaultValue = "{}") Map config) throws Exception { + config.putIfAbsent(MODEL_ID, "ai21.j2-ultra-v1"); + return executeInvokeRequest(body, config, null) + .map(AnthropicClaude::from); } - - + @Procedure("apoc.ml.bedrock.anthropic.claude") + public Stream anthropicClaude(@Name(value = "body") Object body, + @Name(value = "config", defaultValue = "{}") Map config) throws Exception { + config.putIfAbsent(MODEL_ID, "anthropic.claude-v1"); - // --> TODO: remove final String accessKey, final String secretKey - - -// @Procedure -// public Stream stability(@Name(value = "modelId") String modelId, -// @Name(value = "payload") Object payload, -// @Name(value = "config", defaultValue = "{}") Map config) throws Exception { -// -// } - - @Procedure("apoc.ml.bedrock") - public Stream bedrock(@Name(value = "modelId") String modelId, // todo - remove modelId from generic proc... - @Name(value = "payload") Object payload, - @Name(value = "config", defaultValue = "{}") Map config) throws Exception { - - - // todo - validation: modelId dev'essere non nullo?? - - // todo - validation: config deve avere o session/key oppure nell'header - - ModelId modelId1 = ModelId.from(modelId); - - Stream objectStream = getObjectStream(payload, config, modelId1); - -// List objects = objectStream.toList(); -// System.out.println("objects = " + objects); - - return objectStream.map(ObjectResult::new); + return executeInvokeRequest(body, config, null) + .map(AnthropicClaude::from); } + + @Procedure("apoc.ml.bedrock.titan.embedding") + public Stream titanEmbedding(@Name(value = "body") Object body, + @Name(value = "config", defaultValue = "{}") Map config) throws Exception { + config.putIfAbsent(MODEL_ID, "amazon.titan-embed-text-v1"); - private Stream getObjectStream(Object payload, Map config, ModelId modelId1) throws IOException { - BedrockConfig conf = new BedrockConfig(config); - - // todo - endpoint customizable, document it - + return executeInvokeRequest(body, config, null) + .map(TitanEmbedding::from); + } - Map headers = new HashMap<>(conf.getHeaders()); + @Procedure("apoc.ml.bedrock.stability") + public Stream stability(@Name(value = "body") Object body, + @Name(value = "config", defaultValue = "{}") Map config) throws IOException { + config.putIfAbsent(MODEL_ID, "stability.stable-diffusion-xl-v0"); - headers.putIfAbsent("Content-Type", "application/json"); - headers.putIfAbsent("accept", ALL); -// Map headers = Map.of( -// "Content-Type", "application/json", -//// "Authorization", "AWS4-HMAC-SHA256 Credential=AKIASSO3M7CCVJ26AETR/20231017/us-east-1/bedrock/aws4_request, SignedHeaders=content-length;content-type;host;x-amz-date, Signature=013e29432a839f68f9de4934c76d7b478fec597e2980bbbea6faaa1ad6af0320", -// "accept", ALL -//// "x-amz-date", "20231017T135438Z" -//// "Content-Length", "21" -//// "authorization", "AWS4-HMAC-SHA256 Credential=AKIASSO3M7CCVJ26AETR/20231017/us-east-1/bedrock/aws4_request,SignedHeaders=content-length;content-type;host;x-amz-date,Signature=56f22872fe9a68d676bd6e2232a4ee8a22c507bfc2bbb79f789e6f9cdc88b67b" -// ); - String path = null; - - - URL url = new URL(conf.getEndpoint()); - - HttpURLConnection connection = (HttpURLConnection) url.openConnection(); - connection.setRequestMethod("POST"); - System.out.println(connection.getRequestMethod() + " " + url); - - // todo - get region - or customizable - String replace = url.getHost().replace(".amazonaws.com", ""); - String region = replace.substring(replace.lastIndexOf(".") + 1); + return executeInvokeRequest(body, config, "$.artifacts[0]") + .map(StabilityAi::from); + } + private Stream executeInvokeRequest(Object payload, Map config, String path) throws IOException { String payloadString = payload instanceof String ? (String) payload - : OBJECT_MAPPER.writeValueAsString(payload);// "{\"inputText\": \"Provona\"}";// JsonUtil.writeValueAsString("{\"inputText\": \"Provona\"}"); - System.out.println("payloadString = " + payloadString); - - // todo - directly BedrockConfig? - headers = calculateAuthorizationHeaders("POST", url, headers, payloadString.getBytes(), + : OBJECT_MAPPER.writeValueAsString(payload); + + BedrockConfig conf = new BedrockInvokeConfig(config); - conf.getKeyId(), conf.getSecretKey(), region, "bedrock"); -// -// // todo - path customizable, document it -// headers = aWSV4Auth.getHeaders(); - System.out.println("headers = " + headers.entrySet().stream().map(Object::toString).collect(Collectors.joining("\n"))); + Map headers = new HashMap<>(conf.getHeaders()); + headers.putIfAbsent("Content-Type", JSON); + headers.putIfAbsent("accept", ALL); + + headers = calculateAuthorizationHeaders("POST", conf, headers, payloadString.getBytes()); -// new HttpPost(url.toString()) -// .addHeader(new Header()); +// Stream objectStream = JsonUtil.loadJson(conf.getEndpoint(), headers, payloadString, path); +// return objectStream; - Stream objectStream = JsonUtil.loadJson(url.toString(), headers, payloadString, path); - return objectStream; + CloseableHttpClient httpClient = HttpClientBuilder.create().build(); +// List objects = getModelItemResultStream(conf, httpClient, payloadString, headers, path, objStream -> objStream) +// .toList(); + return getModelItemResultStream(conf, httpClient, payloadString, headers, path, objStream -> objStream) +// .stream(); + .onClose(() -> Util.close(httpClient)); } // basic_date // basic_date_time_no_millis + + } diff --git a/extended/src/main/java/apoc/ml/bedrock/BedrockConfig.java b/extended/src/main/java/apoc/ml/bedrock/BedrockConfig.java index 32c6b25c20..526ecfd9d5 100644 --- a/extended/src/main/java/apoc/ml/bedrock/BedrockConfig.java +++ b/extended/src/main/java/apoc/ml/bedrock/BedrockConfig.java @@ -7,54 +7,35 @@ import static apoc.ExtendedApocConfig.APOC_AWS_SECRET_KEY; // todo: as a record? -public class BedrockConfig { - public static final String MODEL_ID = "modelId"; +public abstract class BedrockConfig { + + abstract String getDefaultEndpoint(Map config); + abstract String getDefaultMethod(); + public static final String SECRET_KEY = "secretKey"; public static final String KEY_ID = "keyId"; - public static final String REGION = "region"; - public static final String ENDPOINT = "endpoint"; + public static final String REGION_KEY = "region"; + public static final String ENDPOINT_KEY = "endpoint"; + public static final String METHOD_KEY = "method"; private final String keyId; private final String secretKey; - // todo - documentare che con modelId va su endpoint --> https://bedrock-runtime.us-east-1.amazonaws.com/model//invoke - // e che endpoint ha la priorità - private final String endpoint; private final String region; + private final String endpoint; + private final String method; - // todo - maybe local, non viene richiamata all'esterno.. -// private final String modelId; + private final Map headers; - private Map headers; - - public BedrockConfig(Map config) { - this(config, null); - } - - public BedrockConfig(Map config, String defaultEndpoint) { + protected BedrockConfig(Map config) { config = config == null ? Map.of() : config; - - // todo - document it + this.keyId = apocConfig().getString(APOC_AWS_KEY_ID, (String) config.get(KEY_ID)); this.secretKey = apocConfig().getString(APOC_AWS_SECRET_KEY, (String) config.get(SECRET_KEY)); - - // todo - extract it? - if (defaultEndpoint == null) { - String modelId = (String) config.get(MODEL_ID); - if (modelId != null) { - defaultEndpoint = String.format("https://bedrock-runtime.us-east-1.amazonaws.com/model/%s/invoke", modelId); - } - } - - this.endpoint = getEndpoint(config, defaultEndpoint); - - - // todo - passo modelId: se è valorizzato metto String urlString = String.format("https://bedrock-runtime.us-east-1.amazonaws.com/model/%s/invoke", - // modelId1.getId()); - // se modelId NON è valorizzato faccio getEndpoint() - // se + this.endpoint = getEndpoint(config, getDefaultEndpoint(config)); - this.region = (String) config.getOrDefault(REGION, extractRegionFromEndpoint()); + this.region = (String) config.getOrDefault(REGION_KEY, extractRegionFromEndpoint()); + this.method = (String) config.getOrDefault(METHOD_KEY, getDefaultMethod()); this.headers = (Map) config.getOrDefault("headers", Map.of()); } @@ -67,16 +48,17 @@ private String extractRegionFromEndpoint() { private String getEndpoint(Map config, String defaultEndpoint) { - - - String endpointConfig = (String) config.get(ENDPOINT); + String endpointConfig = (String) config.get(ENDPOINT_KEY); if (endpointConfig != null) { return endpointConfig; } if (defaultEndpoint != null) { return defaultEndpoint; } - throw new RuntimeException("TODO.. ENDPOINT ERROR"); + String errMessage = String.format("An endpoint could not be retrieved.\n" + + "Explicit the %s config", + ENDPOINT_KEY); + throw new RuntimeException(errMessage); } public String getKeyId() { @@ -98,4 +80,8 @@ public String getRegion() { public Map getHeaders() { return headers; } + + public String getMethod() { + return method; + } } diff --git a/extended/src/main/java/apoc/ml/bedrock/BedrockInvokeConfig.java b/extended/src/main/java/apoc/ml/bedrock/BedrockInvokeConfig.java new file mode 100644 index 0000000000..a7ba3708b2 --- /dev/null +++ b/extended/src/main/java/apoc/ml/bedrock/BedrockInvokeConfig.java @@ -0,0 +1,24 @@ +package apoc.ml.bedrock; + +import java.util.Map; + +public class BedrockInvokeConfig extends BedrockConfig { + public static final String MODEL_ID = "modelId"; + + public BedrockInvokeConfig(Map config) { + super(config); + } + + @Override + String getDefaultEndpoint(Map config) { + String modelId = (String) config.get(MODEL_ID); + return modelId == null + ? null + : String.format("https://bedrock-runtime.us-east-1.amazonaws.com/model/%s/invoke", modelId); + } + + @Override + String getDefaultMethod() { + return "POST"; + } +} \ No newline at end of file diff --git a/extended/src/main/java/apoc/ml/bedrock/BedrockInvokeResult.java b/extended/src/main/java/apoc/ml/bedrock/BedrockInvokeResult.java new file mode 100644 index 0000000000..17c2f2faae --- /dev/null +++ b/extended/src/main/java/apoc/ml/bedrock/BedrockInvokeResult.java @@ -0,0 +1,58 @@ +package apoc.ml.bedrock; + +import java.util.List; +import java.util.Map; + +public class BedrockInvokeResult { + // todo + // todo 2 : https://stackoverflow.com/questions/8571501/how-to-check-whether-a-string-is-base64-encoded-or-not + public record StabilityAi(String base64Image) { + public static StabilityAi from(Object object) { + Map map = (Map) object; + + String base64 = (String) map.get("base64"); + + return new StabilityAi(base64); + } + } + + // todo + public record AnthropicClaude(String completion, String stopReason) { + public static AnthropicClaude from(Object object) { + Map map = (Map) object; + + String completion = (String) map.get("completion"); + String stopReason = (String) map.get("stopReason"); + + return new AnthropicClaude(completion, stopReason); + } + } + + // todo + public record TitanEmbedding(Long inputTextTokenCount, List embedding) { + public static TitanEmbedding from(Object object) { + Map map = (Map) object; + + Long inputTextTokenCount = (Long) map.get("inputTextTokenCount"); + List embedding = (List) map.get("embedding"); + + return new TitanEmbedding(inputTextTokenCount, embedding); + } + } + + // todo + public record Jurassic(Long id, List promptTokens, List completions) { + public static Jurassic from(Object object) { + Map map = (Map) object; + + Long id = (Long) map.get("id"); + + Map prompt = (Map) map.get("prompt"); + List promptTokens = (List) prompt.get("tokens"); + + List completions = (List) map.get("completions"); + + return new Jurassic(id, promptTokens, completions); + } + } +} diff --git a/extended/src/main/java/apoc/ml/bedrock/BedrockModelsConfig.java b/extended/src/main/java/apoc/ml/bedrock/BedrockModelsConfig.java new file mode 100644 index 0000000000..47e3ee84f3 --- /dev/null +++ b/extended/src/main/java/apoc/ml/bedrock/BedrockModelsConfig.java @@ -0,0 +1,42 @@ +package apoc.ml.bedrock; + +import java.util.Map; + +public class BedrockModelsConfig extends BedrockConfig { + enum TypeGet { + CUSTOM("custom-models"), + FOUNDATION("foundation-models"); + + private final String path; + + TypeGet(String path) { + this.path = path; + } + + public static String from(String value) { + for (TypeGet typeGet: TypeGet.values()) { + if (typeGet.name().equals(value)) { + return typeGet.path; + } + } + return TypeGet.FOUNDATION.path; + } + } + + public static final String TYPE_GET = "typeGet"; + + public BedrockModelsConfig(Map config) { + super(config); + } + + @Override + String getDefaultEndpoint(Map config) { + String typeGet = TypeGet.from((String) config.get(TYPE_GET)); + return "https://bedrock.us-east-1.amazonaws.com/" + typeGet; + } + + @Override + String getDefaultMethod() { + return "GET"; + } +} diff --git a/extended/src/main/java/apoc/ml/bedrock/BedrockResult.java b/extended/src/main/java/apoc/ml/bedrock/BedrockResult.java deleted file mode 100644 index 91033013e8..0000000000 --- a/extended/src/main/java/apoc/ml/bedrock/BedrockResult.java +++ /dev/null @@ -1,34 +0,0 @@ -package apoc.ml.bedrock; - -import java.util.List; -import java.util.Map; - -public class BedrockResult { - // todo - // todo 2 : https://stackoverflow.com/questions/8571501/how-to-check-whether-a-string-is-base64-encoded-or-not - record StabilityAiResult(String base64Image) {} - - // todo - record AnthropicClaudeResult(String completion, String stopReason) { - public AnthropicClaudeResult(Map map) { - this((String) map.get("completion"), (String) map.get("stopReason")); - } - } - - // todo - public record JurassicResult(Long id, List promptTokens, List completions) { - public static JurassicResult from(Map map) { - Long id = (Long) map.get("id"); - - Map prompt = (Map) map.get("prompt"); - List promptTokens = (List) prompt.get("tokens"); - - List completions = (List) map.get("completions"); - - return new JurassicResult(id, promptTokens, completions); - } - } - - // todo - record TitanEmbedding(Long inputTextTokenCount, List embedding) {} -} diff --git a/extended/src/main/java/apoc/ml/bedrock/ModelItemResult.java b/extended/src/main/java/apoc/ml/bedrock/ModelItemResult.java new file mode 100644 index 0000000000..80dac161b6 --- /dev/null +++ b/extended/src/main/java/apoc/ml/bedrock/ModelItemResult.java @@ -0,0 +1,31 @@ +package apoc.ml.bedrock; + +import java.util.List; +import java.util.Map; + +public class ModelItemResult { + public final String modelId; + public final String modelArn; + public final String modelName; + public final String providerName; + + public final Boolean responseStreamingSupported; + + public final List customizationsSupported; + public final List inferenceTypesSupported; + public final List inputModalities; + public final List outputModalities; + + public ModelItemResult(Map map) { + this.modelId = (String) map.get("modelId"); + this.modelArn = (String) map.get("modelArn"); + this.modelName = (String) map.get("modelName"); + this.providerName = (String) map.get("providerName"); + this.responseStreamingSupported = (Boolean) map.get("responseStreamingSupported"); + this.customizationsSupported = (List) map.get("customizationsSupported"); + this.inferenceTypesSupported = (List) map.get("inferenceTypesSupported"); + this.inputModalities = (List) map.get("inputModalities"); + this.outputModalities = (List) map.get("outputModalities"); + } + +} diff --git a/extended/src/main/java/apoc/util/ExtendedUtil.java b/extended/src/main/java/apoc/util/ExtendedUtil.java index 64a6c65a55..e645a29771 100644 --- a/extended/src/main/java/apoc/util/ExtendedUtil.java +++ b/extended/src/main/java/apoc/util/ExtendedUtil.java @@ -2,16 +2,103 @@ import static apoc.export.cypher.formatter.CypherFormatterUtils.formatProperties; import static apoc.export.cypher.formatter.CypherFormatterUtils.formatToString; +import static apoc.util.JsonUtil.streamObjetsFromIStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.UnsupportedEncodingException; import java.math.BigInteger; import java.time.Duration; import java.time.temporal.TemporalAccessor; +import java.util.List; import java.util.Map; +import java.util.function.Function; import java.util.stream.LongStream; +import java.util.stream.Stream; + +import org.apache.http.HttpResponse; +import org.apache.http.client.HttpClient; +import org.apache.http.client.methods.HttpDelete; +import org.apache.http.client.methods.HttpEntityEnclosingRequestBase; +import org.apache.http.client.methods.HttpHead; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpOptions; +import org.apache.http.client.methods.HttpPatch; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpPut; +import org.apache.http.client.methods.HttpRequestBase; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.DefaultHttpClient; +import org.apache.http.impl.client.HttpClientBuilder; import org.neo4j.graphdb.Entity; public class ExtendedUtil { + + /** + * Get the {@link HttpRequestBase} from the method name + * Similar to aws implementation + */ + public static HttpRequestBase fromMethodName(String method, String uri) { + return switch (method) { + case HttpHead.METHOD_NAME -> new HttpHead(uri); + case HttpGet.METHOD_NAME -> new HttpGet(uri); + case HttpDelete.METHOD_NAME -> new HttpDelete(uri); + case HttpOptions.METHOD_NAME -> new HttpOptions(uri); + case HttpPatch.METHOD_NAME -> new HttpPatch(uri); + case HttpPost.METHOD_NAME -> new HttpPost(uri); + case HttpPut.METHOD_NAME -> new HttpPut(uri); + default -> throw new RuntimeException("Unknown HTTP method name: " + method); + }; + } + + /** + * Similar to JsonUtil.loadJson(..) but works e.g. with GET method as well, + * for which it would return a FileNotFoundException + */ + public static Stream getModelItemResultStream(String method, HttpClient httpClient, String payloadString, Map headers, String endpoint, String path, List of, + Function, Stream> function) { + + try { + HttpRequestBase request = fromMethodName(method, endpoint); + + headers.forEach((k, v) -> request.setHeader(k, v.toString())); + + if (request instanceof HttpEntityEnclosingRequestBase entityRequest) { + try { + entityRequest.setEntity(new StringEntity(payloadString)); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + } +// try ( +// HttpClient httpClient = HttpClientBuilder.create().build();//) { + +// DefaultHttpClient httpClient = new DefaultHttpClient(); + HttpResponse response = httpClient.execute(request); + + InputStream stream = response.getEntity().getContent(); + + Stream objStream = streamObjetsFromIStream(stream, path, of); + + return function.apply(objStream); +// .onClose(() -> { +// try { +// httpClient.close(); +// } catch (IOException e) { +// throw new RuntimeException(e); +// } +// }); +// return objectStream +// .flatMap(i -> ((List>) i).stream()) +// .map(ModelItemResult::new); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + public static String dateFormat( TemporalAccessor value, String format){ return Util.getFormat(format).format(value); } diff --git a/extended/src/test/java/apoc/ml/bedrock/BedrockIT.java b/extended/src/test/java/apoc/ml/bedrock/BedrockIT.java index 2bf67050ed..f7fc7e1011 100644 --- a/extended/src/test/java/apoc/ml/bedrock/BedrockIT.java +++ b/extended/src/test/java/apoc/ml/bedrock/BedrockIT.java @@ -8,24 +8,77 @@ import org.neo4j.test.rule.DbmsRule; import org.neo4j.test.rule.ImpermanentDbmsRule; -import java.util.HashMap; import java.util.List; import java.util.Map; import static apoc.ApocConfig.apocConfig; import static apoc.ExtendedApocConfig.APOC_AWS_KEY_ID; import static apoc.ExtendedApocConfig.APOC_AWS_SECRET_KEY; -import static apoc.ml.bedrock.Bedrock.ModelId.CLAUDE_V1; -import static apoc.ml.bedrock.Bedrock.ModelId.JURASSIC_2_MID; -import static apoc.ml.bedrock.Bedrock.ModelId.STABLE_DIFFUSION_XL; -import static apoc.ml.bedrock.Bedrock.ModelId.TITAN_EMBEDDING_G1; -import static apoc.ml.bedrock.BedrockConfig.MODEL_ID; +import static apoc.ml.bedrock.BedrockConfig.METHOD_KEY; +import static apoc.ml.bedrock.BedrockIT.ModelId.*; +import static apoc.ml.bedrock.BedrockInvokeConfig.MODEL_ID; +import static apoc.util.TestUtil.testCall; import static apoc.util.TestUtil.testResult; +import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assume.assumeNotNull; +/** + * Todo: extractRegionFromEndpoint() + * TODO: test with wrong method; e.g. DELETE + */ public class BedrockIT { - private static final String BEDROCK_PROC = "call apoc.ml.bedrock($id, $payload, $conf)"; + + public static final Map STABILITY_AI_BODY = Map.of( + "text_prompts", List.of(Map.of("text", "picture of a bird", "weight", 1.0)), + "cfg_scale", 5, + "seed", 123, + "steps", 70, + "style_preset", "photographic" + ); + public static final Map JURASSIC_PAYLOAD = Map.of( + "prompt", "Review: Extremely old cabinets, phone was half broken and full of dust. Bathroom door was broken, bathroom floor was dirty and yellow. Bathroom tiles were falling off. Asked to change my room and the next room was in the same conditions. The most out of date and least maintained hotel i ever been on. Extracted sentiment:", + "maxTokens", 50, + "temperature", 0, + "topP", 1.0 + ); + public static final Map ANTHROPIC_CLAUDE = Map.of( + "prompt", "\n\nHuman: Hello world\n\nAssistant:", + "max_tokens_to_sample", 300, + "temperature", 0.5, + "top_k", 250, + "top_p", 1, + "stop_sequences", List.of("\\n\\nHuman:"), + "anthropic_version", "bedrock-2023-05-31" + ); + public static final Map TITAN_PAYLOAD = Map.of("inputText", "Test"); + + enum ModelId { + JURASSIC_2_MID("ai21.j2-mid-v1"), + JURASSIC_2_ULTRA("ai21.j2-ultra-v1"), + + TITAN_EMBEDDING_G1("amazon.titan-embed-text-v1"), + TITAN_TEXT_G1_EXPRESS("amazon.titan-text-express-v1"), + + CLAUDE_V1("anthropic.claude-v1"), + CLAUDE_V2("anthropic.claude-v2"), + CLAUDE_INSTANT("anthropic.claude-instant-v1"), + + STABLE_DIFFUSION_XL("stability.stable-diffusion-xl-v0"); + + private final String id; + + ModelId(String id) { + this.id = id; + } + + public String id() { + return id; + } + } + + + private static final String BEDROCK_PROC = "call apoc.ml.bedrock.custom($payload, $conf)"; private static String keyId; private static String secretKey; @@ -33,11 +86,6 @@ public class BedrockIT { @ClassRule public static DbmsRule db = new ImpermanentDbmsRule(); -// public BedrockIT() { -// this.keyId = keyId; -// this.secretKey = secretKey; -// } - @BeforeClass public static void setUp() throws Exception { @@ -55,24 +103,21 @@ public static void setUp() throws Exception { } @Test - public void testTitanEmbedding() { + public void testCustomWithTitanEmbedding() { String s = db.executeTransactionally(BEDROCK_PROC, - Map.of("id", TITAN_EMBEDDING_G1.getId(), - "payload", Map.of("inputText", "Test"), - "conf", Map.of(MODEL_ID, TITAN_EMBEDDING_G1.getId()) + Map.of("payload", TITAN_PAYLOAD, + "conf", Map.of(MODEL_ID, TITAN_EMBEDDING_G1.id()) ), Result::resultAsString); System.out.println("s = " + s); } - // todo - test payload as a string.. @Test public void testStringPayload() { String s = db.executeTransactionally(BEDROCK_PROC, - Map.of("id", TITAN_EMBEDDING_G1.getId(), - "payload", "{\"inputText\": \"Prova\" }", - "conf", Map.of(MODEL_ID, TITAN_EMBEDDING_G1.getId()) + Map.of("payload", "{\"inputText\": \"Prova\" }", + "conf", Map.of(MODEL_ID, TITAN_EMBEDDING_G1.id()) ), Result::resultAsString); System.out.println("s = " + s); @@ -80,44 +125,23 @@ public void testStringPayload() { // TODO - to delete... maybe... just to see the output and create custom procs in case, like OpenAI @Test - public void testAlls() { - HashMap objectObjectHashMap = new HashMap<>(); - - Map payload = Map.of( - "prompt", "Review: Extremely old cabinets, phone was half broken and full of dust. Bathroom door was broken, bathroom floor was dirty and yellow. Bathroom tiles were falling off. Asked to change my room and the next room was in the same conditions. The most out of date and least maintained hotel i ever been on. Extracted sentiment:", - "maxTokens", 50, - "temperature", 0, - "topP", 1.0 - ); + public void testCustomWithJurassic() { // objectObjectHashMap.put(JURASSIC_2_MID, payload); // TODO - prompt and completions for jurassic String s = db.executeTransactionally(BEDROCK_PROC, - Map.of("id", JURASSIC_2_MID.getId(), - "payload", payload, - "conf", Map.of(MODEL_ID, JURASSIC_2_MID.getId()) + Map.of("payload", JURASSIC_PAYLOAD, + "conf", Map.of(MODEL_ID, JURASSIC_2_MID.id()) ), Result::resultAsString); System.out.println("s = " + s); } - @Test public void testAlls23() { - Map payload = Map.of( - "prompt", "\n\nHuman: Hello world\n\nAssistant:", - "max_tokens_to_sample", 300, - "temperature", 0.5, - "top_k", 250, - "top_p", 1, - "stop_sequences", List.of("\\n\\nHuman:"), - "anthropic_version", "bedrock-2023-05-31" - ); - String s = db.executeTransactionally(BEDROCK_PROC, - Map.of("id", CLAUDE_V1.getId(), - "payload", payload, - "conf", Map.of(MODEL_ID, CLAUDE_V1.getId()) + Map.of("payload", ANTHROPIC_CLAUDE, + "conf", Map.of(MODEL_ID, CLAUDE_V1.id()) ), Result::resultAsString); System.out.println("s = " + s); @@ -126,17 +150,11 @@ public void testAlls23() { @Test public void testAlls2() { - Map payload = Map.of( - "prompt", "Review: Extremely old cabinets, phone was half broken and full of dust. Bathroom door was broken, bathroom floor was dirty and yellow. Bathroom tiles were falling off. Asked to change my room and the next room was in the same conditions. The most out of date and least maintained hotel i ever been on. Extracted sentiment:", - "maxTokens", 50, - "temperature", 0, - "topP", 1.0 - ); + Map payload = JURASSIC_PAYLOAD; String s = db.executeTransactionally(BEDROCK_PROC, - Map.of("id", JURASSIC_2_MID.getId(), - "payload", payload, - "conf", Map.of(MODEL_ID, JURASSIC_2_MID.getId()) + Map.of("payload", payload, + "conf", Map.of(MODEL_ID, JURASSIC_2_MID.id()) ), Result::resultAsString); System.out.println("s = " + s); @@ -145,26 +163,72 @@ public void testAlls2() { // todo - try another model id @Test public void testImage() { - Map payload = Map.of( - "text_prompts", List.of(Map.of("text", "picture of a bird", "weight", 1.0)), - "cfg_scale", 5, - "seed", 123, - "steps", 70, - "style_preset", "photographic" - ); String s = db.executeTransactionally(BEDROCK_PROC, - Map.of("id", STABLE_DIFFUSION_XL.getId(), - "payload", payload, - "conf", Map.of(MODEL_ID, STABLE_DIFFUSION_XL.getId()) + Map.of("payload", STABILITY_AI_BODY, + "conf", Map.of(MODEL_ID, STABLE_DIFFUSION_XL.id()) ), Result::resultAsString); System.out.println("s = " + s); } + + @Test + public void testStability() { + testCall(db, "call apoc.ml.bedrock.stability($payload)", + Map.of("payload", STABILITY_AI_BODY), + r -> { + Object base64Image = r.get("base64Image"); + System.out.println("base64Image = " + base64Image); + assertNotNull(base64Image); + }); +// String s = db.executeTransactionally("call apoc.ml.bedrock.stability($payload)", +// Map.of("payload", STABILITY_AI_BODY), +// Result::resultAsString); +// System.out.println("s = " + s); + } + + @Test + public void testJurassic() { + String s = db.executeTransactionally("call apoc.ml.bedrock.jurassic($payload)", + Map.of("payload", JURASSIC_PAYLOAD), + Result::resultAsString); + System.out.println("s = " + s); + } + + @Test + public void testAnthropicClaude() { + String s = db.executeTransactionally("call apoc.ml.bedrock.anthropic.claude($payload)", + Map.of("payload", ANTHROPIC_CLAUDE), + Result::resultAsString); + System.out.println("s = " + s); + } + + @Test + public void testTitanEmbedding() { + String s = db.executeTransactionally("call apoc.ml.bedrock.titan.embedding($payload)", + Map.of("payload", TITAN_PAYLOAD), + Result::resultAsString); + System.out.println("s = " + s); + } - - // todo - cohere.command-text-v14 + // TODO: provare questo: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_GetModelInvocationLoggingConfiguration.html + @Test + public void testGetModelInvocation() { + Map conf = Map.of("endpoint", "https://bedrock.us-east-1.amazonaws.com//logging/modelinvocations", + METHOD_KEY, "GET"); + String s = db.executeTransactionally("call apoc.ml.bedrock.custom('', $conf)", + Map.of("conf", conf), + Result::resultAsString); + System.out.println("s = " + s); + } + // todo - try another endpoind via custom... e.g. https://docs.aws.amazon.com/bedrock/latest/APIReference/API_DeleteCustomModel.html + // or - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_GetModelInvocationLoggingConfiguration.html + // todo - try with null value + + + // todo - try create + // todo - payload can be map or string?? @@ -193,19 +257,16 @@ public void testImage() { // https://bedrock.us-east-1.amazonaws.com/custom-models @Test public void testGetModel() { - String s = db.executeTransactionally("call apoc.ml.bedrock.list($type)", - Map.of("type", Bedrock.GetModel.FOUNDATION.name(), - "payload", "{\"inputText\": \"Prova\" }"), - Result::resultAsString); - - testResult(db, "call apoc.ml.bedrock.list($type)", - Map.of("type", Bedrock.GetModel.FOUNDATION.name(), - "payload", "{\"inputText\": \"Prova\" }"), - r -> { - r.forEachRemaining(row -> { - String modelArn = (String) row.get("modelArn"); - assertTrue(modelArn.contains("arn:aws:bedrock")); - }); - }); + for (BedrockModelsConfig.TypeGet model: BedrockModelsConfig.TypeGet.values()) { + testResult(db, "call apoc.ml.bedrock.list({typeGet: $type})", + Map.of("type", model.name()), + r -> { + r.forEachRemaining(row -> { + System.out.println("row = " + row); + String modelArn = (String) row.get("modelArn"); + assertTrue(modelArn.contains("arn:aws:bedrock")); + }); + }); + } } }