From 9ecd30225b47264b1e013b2e9273a0b6c4f24a82 Mon Sep 17 00:00:00 2001 From: Maleehak Date: Tue, 16 Apr 2024 00:06:18 +0500 Subject: [PATCH] refactor predict class --- .../example/BriefMe/BriefMeApplication.java | 3 - .../properties/VertexAIProperties.java | 34 +++++ ...{VertexAIPrompt.java => VertexAIData.java} | 2 +- .../impl/PredictTextSummarizationSample.java | 126 ++++++++++-------- 4 files changed, 109 insertions(+), 56 deletions(-) create mode 100644 backend/src/main/java/com/example/BriefMe/properties/VertexAIProperties.java rename backend/src/main/java/com/example/BriefMe/request/{VertexAIPrompt.java => VertexAIData.java} (82%) diff --git a/backend/src/main/java/com/example/BriefMe/BriefMeApplication.java b/backend/src/main/java/com/example/BriefMe/BriefMeApplication.java index 7270623..fdafa74 100644 --- a/backend/src/main/java/com/example/BriefMe/BriefMeApplication.java +++ b/backend/src/main/java/com/example/BriefMe/BriefMeApplication.java @@ -13,8 +13,5 @@ public class BriefMeApplication { public static void main(String[] args) throws IOException { SpringApplication.run(BriefMeApplication.class, args); - PredictTextSummarizationSample predictTextSummarization = new PredictTextSummarizationSample(); - predictTextSummarization.predict(); - } } diff --git a/backend/src/main/java/com/example/BriefMe/properties/VertexAIProperties.java b/backend/src/main/java/com/example/BriefMe/properties/VertexAIProperties.java new file mode 100644 index 0000000..ff74546 --- /dev/null +++ b/backend/src/main/java/com/example/BriefMe/properties/VertexAIProperties.java @@ -0,0 +1,34 @@ +package com.example.BriefMe.properties; + +import lombok.Getter; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + +@Component +@Getter +public class VertexAIProperties { + + @Value("${google.vertex-ai.project}") + String project; + + @Value("${google.vertex-ai.location}") + String location; + + @Value("${google.vertex-ai.publisher}") + String publisher; + + @Value("${google.vertex-ai.model}") + String model; + + @Value("${google.vertex-ai.model.temperature}") + Double temperature; + + @Value("${google.vertex-ai.model.max-output-token}") + Integer maxOutputToken; + + @Value("${google.vertex-ai.model.top-p}") + Double topP; + + @Value("${google.vertex-ai.model.top-k}") + Integer topK; +} diff --git a/backend/src/main/java/com/example/BriefMe/request/VertexAIPrompt.java b/backend/src/main/java/com/example/BriefMe/request/VertexAIData.java similarity index 82% rename from backend/src/main/java/com/example/BriefMe/request/VertexAIPrompt.java rename to backend/src/main/java/com/example/BriefMe/request/VertexAIData.java index 6005616..4b17051 100644 --- a/backend/src/main/java/com/example/BriefMe/request/VertexAIPrompt.java +++ b/backend/src/main/java/com/example/BriefMe/request/VertexAIData.java @@ -5,6 +5,6 @@ @Data @AllArgsConstructor -public class VertexAIPrompt { +public class VertexAIData { String prompt; } diff --git a/backend/src/main/java/com/example/BriefMe/service/impl/PredictTextSummarizationSample.java b/backend/src/main/java/com/example/BriefMe/service/impl/PredictTextSummarizationSample.java index 316ec76..5fde3c2 100644 --- a/backend/src/main/java/com/example/BriefMe/service/impl/PredictTextSummarizationSample.java +++ b/backend/src/main/java/com/example/BriefMe/service/impl/PredictTextSummarizationSample.java @@ -1,9 +1,15 @@ package com.example.BriefMe.service.impl; +import com.example.BriefMe.properties.VertexAIProperties; import com.example.BriefMe.request.VertexAIParameters; -import com.example.BriefMe.request.VertexAIPrompt; +import com.example.BriefMe.request.VertexAIData; +import com.example.BriefMe.service.client.TextSummarizer; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.IOException; +import javax.print.DocFlavor.STRING; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.stereotype.Service; import com.google.cloud.aiplatform.v1beta1.EndpointName; import com.google.cloud.aiplatform.v1beta1.PredictResponse; @@ -16,67 +22,83 @@ @Service -public class PredictTextSummarizationSample { - public void predict() throws IOException { - String text =""; - VertexAIPrompt vertexAIPrompt = new VertexAIPrompt("Provide a short summary in five numeric bullet points:" + text); - ObjectMapper objectMapper = new ObjectMapper(); - String vertexAIRequestString = objectMapper.writeValueAsString(vertexAIPrompt); +@Slf4j +public class PredictTextSummarizationSample implements TextSummarizer { - VertexAIParameters vertexAIParameters = new VertexAIParameters(0.2, 256, 0.95, 40); - String vertexAIParamsString = objectMapper.writeValueAsString(vertexAIParameters); + //TODO: Set vaules in application.properties + @Autowired + VertexAIProperties vertexAIProperties; + @Override + public String generateSummary(String text, int numberOfLines) { + try{ + String prompt = createPromptString(text, numberOfLines); + String parameters= createParametersString(); - - String project = "artful-lane-419217"; - String location = "us-central1"; - String publisher = "google"; - String model = "text-bison@001"; + String endpoint = String.format("%s-aiplatform.googleapis.com:443", vertexAIProperties.getLocation()); + PredictionServiceSettings predictionServiceSettings = + PredictionServiceSettings.newBuilder() + .setEndpoint(endpoint) + .build(); - predictTextSummarization(vertexAIRequestString, vertexAIParamsString, project, location, publisher, model); - } - - // Get summarization from a supported text model - public void predictTextSummarization( - String instance, - String parameters, - String project, - String location, - String publisher, - String model) - throws IOException { - String endpoint = String.format("%s-aiplatform.googleapis.com:443", location); + // Initialize client + try (PredictionServiceClient predictionServiceClient = + PredictionServiceClient.create(predictionServiceSettings)) { + final EndpointName endpointName = + EndpointName.ofProjectLocationPublisherModelName( + vertexAIProperties.getProject(), + vertexAIProperties.getLocation(), + vertexAIProperties.getPublisher(), + vertexAIProperties.getModel()); - PredictionServiceSettings predictionServiceSettings = - PredictionServiceSettings.newBuilder() - .setEndpoint(endpoint) - .build(); + // Use Value.Builder to convert prompt to a dynamically typed value + Value.Builder instanceValue = Value.newBuilder(); + JsonFormat.parser().merge(prompt, instanceValue); + List instances = new ArrayList<>(); + instances.add(instanceValue.build()); - // Initialize client that will be used to send requests. This client only needs to be created - // once, and can be reused for multiple requests. - try (PredictionServiceClient predictionServiceClient = - PredictionServiceClient.create(predictionServiceSettings)) { - final EndpointName endpointName = - EndpointName.ofProjectLocationPublisherModelName(project, location, publisher, model); + // Use Value.Builder to convert parameter to a dynamically typed value + Value.Builder parameterValueBuilder = Value.newBuilder(); + JsonFormat.parser().merge(parameters, parameterValueBuilder); + Value parameterValue = parameterValueBuilder.build(); - // Use Value.Builder to convert instance to a dynamically typed value that can be - // processed by the service. - Value.Builder instanceValue = Value.newBuilder(); - JsonFormat.parser().merge(instance, instanceValue); - List instances = new ArrayList<>(); - instances.add(instanceValue.build()); + PredictResponse predictResponse = + predictionServiceClient.predict(endpointName, instances, parameterValue); - // Use Value.Builder to convert parameter to a dynamically typed value that can be - // processed by the service. - Value.Builder parameterValueBuilder = Value.newBuilder(); - JsonFormat.parser().merge(parameters, parameterValueBuilder); - Value parameterValue = parameterValueBuilder.build(); + //TODO: Fetch and return data + System.out.println("Predict Response"); + System.out.println(predictResponse); + } - PredictResponse predictResponse = - predictionServiceClient.predict(endpointName, instances, parameterValue); + } catch (Exception e) { + throw new RuntimeException(e); + } + return "Nothing to summarize"; + } - System.out.println("Predict Response"); - System.out.println(predictResponse); + private String createPromptString(String text, int numberOfLines){ + try{ + String prompt = "Provide a short summary in "+ numberOfLines +" numeric bullet points:" + text; + VertexAIData vertexAIData = new VertexAIData(prompt); + ObjectMapper objectMapper = new ObjectMapper(); + return objectMapper.writeValueAsString(vertexAIData); + } catch (JsonProcessingException e) { + log.error("Exception occurred while creating prompt string {}", e.getMessage()); + throw new RuntimeException(e); } } + private String createParametersString(){ + try{ + ObjectMapper objectMapper = new ObjectMapper(); + VertexAIParameters vertexAIParameters = new VertexAIParameters( + vertexAIProperties.getTemperature(), + vertexAIProperties.getMaxOutputToken(), + vertexAIProperties.getTopP(), + vertexAIProperties.getTopK()); + return objectMapper.writeValueAsString(vertexAIParameters); + } catch (JsonProcessingException e) { + log.error("Exception occurred while creating parameters string {}", e.getMessage()); + throw new RuntimeException(e); + } + } }