Skip to content

Commit

Permalink
add google vertex ai
Browse files Browse the repository at this point in the history
  • Loading branch information
Maleehak committed Apr 15, 2024
1 parent 8349295 commit fb9696f
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 1 deletion.
5 changes: 5 additions & 0 deletions backend/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
<groupId>com.google.cloud</groupId>
<artifactId>google-cloud-speech</artifactId>
</dependency>
<dependency>
<groupId>com.google.cloud</groupId>
<artifactId>google-cloud-aiplatform</artifactId>
<version>3.42.0</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
package com.example.BriefMe;

import com.example.BriefMe.service.impl.PredictTextSummarizationSample;
import java.io.IOException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;


@SpringBootApplication
@Slf4j
public class BriefMeApplication {
public static void main(String[] args) {
public static void main(String[] args) throws IOException {

SpringApplication.run(BriefMeApplication.class, args);
PredictTextSummarizationSample predictTextSummarization = new PredictTextSummarizationSample();
predictTextSummarization.predict();

}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package com.example.BriefMe.request;

import lombok.AllArgsConstructor;
import lombok.Data;

@Data
@AllArgsConstructor
public class VertexAIParameters {
Double temperature;
Integer maxOutputTokens;
Double topP;
Integer topK;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.example.BriefMe.request;

import lombok.AllArgsConstructor;
import lombok.Data;

@Data
@AllArgsConstructor
public class VertexAIPrompt {
String prompt;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package com.example.BriefMe.service.impl;

import com.example.BriefMe.request.VertexAIParameters;
import com.example.BriefMe.request.VertexAIPrompt;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.IOException;
import org.springframework.stereotype.Service;
import com.google.cloud.aiplatform.v1beta1.EndpointName;
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
import com.google.protobuf.Value;
import com.google.protobuf.util.JsonFormat;
import java.util.ArrayList;
import java.util.List;


@Service
public class PredictTextSummarizationSample {
public void predict() throws IOException {
String text ="";
VertexAIPrompt vertexAIPrompt = new VertexAIPrompt("Provide a short summary in five numeric bullet points:" + text);
ObjectMapper objectMapper = new ObjectMapper();
String vertexAIRequestString = objectMapper.writeValueAsString(vertexAIPrompt);

VertexAIParameters vertexAIParameters = new VertexAIParameters(0.2, 256, 0.95, 40);
String vertexAIParamsString = objectMapper.writeValueAsString(vertexAIParameters);


String project = "artful-lane-419217";
String location = "us-central1";
String publisher = "google";
String model = "text-bison@001";

predictTextSummarization(vertexAIRequestString, vertexAIParamsString, project, location, publisher, model);
}

// Get summarization from a supported text model
public void predictTextSummarization(
String instance,
String parameters,
String project,
String location,
String publisher,
String model)
throws IOException {
String endpoint = String.format("%s-aiplatform.googleapis.com:443", location);

PredictionServiceSettings predictionServiceSettings =
PredictionServiceSettings.newBuilder()
.setEndpoint(endpoint)
.build();

// Initialize client that will be used to send requests. This client only needs to be created
// once, and can be reused for multiple requests.
try (PredictionServiceClient predictionServiceClient =
PredictionServiceClient.create(predictionServiceSettings)) {
final EndpointName endpointName =
EndpointName.ofProjectLocationPublisherModelName(project, location, publisher, model);

// Use Value.Builder to convert instance to a dynamically typed value that can be
// processed by the service.
Value.Builder instanceValue = Value.newBuilder();
JsonFormat.parser().merge(instance, instanceValue);
List<Value> instances = new ArrayList<>();
instances.add(instanceValue.build());

// Use Value.Builder to convert parameter to a dynamically typed value that can be
// processed by the service.
Value.Builder parameterValueBuilder = Value.newBuilder();
JsonFormat.parser().merge(parameters, parameterValueBuilder);
Value parameterValue = parameterValueBuilder.build();

PredictResponse predictResponse =
predictionServiceClient.predict(endpointName, instances, parameterValue);

System.out.println("Predict Response");
System.out.println(predictResponse);
}
}

}

0 comments on commit fb9696f

Please sign in to comment.