Skip to content

Commit

Permalink
refactor predict class
Browse files Browse the repository at this point in the history
  • Loading branch information
Maleehak committed Apr 15, 2024
1 parent fb9696f commit 9ecd302
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,5 @@ public class BriefMeApplication {
public static void main(String[] args) throws IOException {

SpringApplication.run(BriefMeApplication.class, args);
PredictTextSummarizationSample predictTextSummarization = new PredictTextSummarizationSample();
predictTextSummarization.predict();

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.example.BriefMe.properties;

import lombok.Getter;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

@Component
@Getter
public class VertexAIProperties {

@Value("${google.vertex-ai.project}")
String project;

@Value("${google.vertex-ai.location}")
String location;

@Value("${google.vertex-ai.publisher}")
String publisher;

@Value("${google.vertex-ai.model}")
String model;

@Value("${google.vertex-ai.model.temperature}")
Double temperature;

@Value("${google.vertex-ai.model.max-output-token}")
Integer maxOutputToken;

@Value("${google.vertex-ai.model.top-p}")
Double topP;

@Value("${google.vertex-ai.model.top-k}")
Integer topK;
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@

@Data
@AllArgsConstructor
public class VertexAIPrompt {
public class VertexAIData {
String prompt;
}
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package com.example.BriefMe.service.impl;

import com.example.BriefMe.properties.VertexAIProperties;
import com.example.BriefMe.request.VertexAIParameters;
import com.example.BriefMe.request.VertexAIPrompt;
import com.example.BriefMe.request.VertexAIData;
import com.example.BriefMe.service.client.TextSummarizer;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import javax.print.DocFlavor.STRING;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import com.google.cloud.aiplatform.v1beta1.EndpointName;
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
Expand All @@ -16,67 +22,83 @@


@Service
public class PredictTextSummarizationSample {
public void predict() throws IOException {
String text ="";
VertexAIPrompt vertexAIPrompt = new VertexAIPrompt("Provide a short summary in five numeric bullet points:" + text);
ObjectMapper objectMapper = new ObjectMapper();
String vertexAIRequestString = objectMapper.writeValueAsString(vertexAIPrompt);
@Slf4j
public class PredictTextSummarizationSample implements TextSummarizer {

VertexAIParameters vertexAIParameters = new VertexAIParameters(0.2, 256, 0.95, 40);
String vertexAIParamsString = objectMapper.writeValueAsString(vertexAIParameters);
//TODO: Set vaules in application.properties
@Autowired
VertexAIProperties vertexAIProperties;
@Override
public String generateSummary(String text, int numberOfLines) {
try{
String prompt = createPromptString(text, numberOfLines);
String parameters= createParametersString();


String project = "artful-lane-419217";
String location = "us-central1";
String publisher = "google";
String model = "text-bison@001";
String endpoint = String.format("%s-aiplatform.googleapis.com:443", vertexAIProperties.getLocation());
PredictionServiceSettings predictionServiceSettings =
PredictionServiceSettings.newBuilder()
.setEndpoint(endpoint)
.build();

predictTextSummarization(vertexAIRequestString, vertexAIParamsString, project, location, publisher, model);
}

// Get summarization from a supported text model
public void predictTextSummarization(
String instance,
String parameters,
String project,
String location,
String publisher,
String model)
throws IOException {
String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);
// Initialize client
try (PredictionServiceClient predictionServiceClient =
PredictionServiceClient.create(predictionServiceSettings)) {
final EndpointName endpointName =
EndpointName.ofProjectLocationPublisherModelName(
vertexAIProperties.getProject(),
vertexAIProperties.getLocation(),
vertexAIProperties.getPublisher(),
vertexAIProperties.getModel());

PredictionServiceSettings predictionServiceSettings =
PredictionServiceSettings.newBuilder()
.setEndpoint(endpoint)
.build();
// Use Value.Builder to convert prompt to a dynamically typed value
Value.Builder instanceValue = Value.newBuilder();
JsonFormat.parser().merge(prompt, instanceValue);
List<Value> instances = new ArrayList<>();
instances.add(instanceValue.build());

// Initialize client that will be used to send requests. This client only needs to be created
// once, and can be reused for multiple requests.
try (PredictionServiceClient predictionServiceClient =
PredictionServiceClient.create(predictionServiceSettings)) {
final EndpointName endpointName =
EndpointName.ofProjectLocationPublisherModelName(project, location, publisher, model);
// Use Value.Builder to convert parameter to a dynamically typed value
Value.Builder parameterValueBuilder = Value.newBuilder();
JsonFormat.parser().merge(parameters, parameterValueBuilder);
Value parameterValue = parameterValueBuilder.build();

// Use Value.Builder to convert instance to a dynamically typed value that can be
// processed by the service.
Value.Builder instanceValue = Value.newBuilder();
JsonFormat.parser().merge(instance, instanceValue);
List<Value> instances = new ArrayList<>();
instances.add(instanceValue.build());
PredictResponse predictResponse =
predictionServiceClient.predict(endpointName, instances, parameterValue);

// Use Value.Builder to convert parameter to a dynamically typed value that can be
// processed by the service.
Value.Builder parameterValueBuilder = Value.newBuilder();
JsonFormat.parser().merge(parameters, parameterValueBuilder);
Value parameterValue = parameterValueBuilder.build();
//TODO: Fetch and return data
System.out.println("Predict Response");
System.out.println(predictResponse);
}

PredictResponse predictResponse =
predictionServiceClient.predict(endpointName, instances, parameterValue);
} catch (Exception e) {
throw new RuntimeException(e);
}
return "Nothing to summarize";
}

System.out.println("Predict Response");
System.out.println(predictResponse);
private String createPromptString(String text, int numberOfLines){
try{
String prompt = "Provide a short summary in "+ numberOfLines +" numeric bullet points:" + text;
VertexAIData vertexAIData = new VertexAIData(prompt);
ObjectMapper objectMapper = new ObjectMapper();
return objectMapper.writeValueAsString(vertexAIData);
} catch (JsonProcessingException e) {
log.error("Exception occurred while creating prompt string {}", e.getMessage());
throw new RuntimeException(e);
}
}

private String createParametersString(){
try{
ObjectMapper objectMapper = new ObjectMapper();
VertexAIParameters vertexAIParameters = new VertexAIParameters(
vertexAIProperties.getTemperature(),
vertexAIProperties.getMaxOutputToken(),
vertexAIProperties.getTopP(),
vertexAIProperties.getTopK());
return objectMapper.writeValueAsString(vertexAIParameters);
} catch (JsonProcessingException e) {
log.error("Exception occurred while creating parameters string {}", e.getMessage());
throw new RuntimeException(e);
}
}
}

0 comments on commit 9ecd302

Please sign in to comment.