Skip to content

Commit

Permalink
Added OpenAI Test Cases (#143)
Browse files Browse the repository at this point in the history
* Added OpenAI Test Cases

* Update BuildAndRun.yml
  • Loading branch information
EmadHanif01 authored Jul 14, 2023
1 parent acfa542 commit 975a938
Show file tree
Hide file tree
Showing 10 changed files with 140 additions and 116 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/BuildAndRun.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ jobs:
working-directory: ./FlySpring/edgechain-app
run: mvn clean package -DskipTests

# - name: Run edgechain testcases
# working-directory: ./FlySpring/edgechain-app
# run: mvn test
- name: Run edgechain testcases
working-directory: ./FlySpring/edgechain-app
run: mvn test

- name: Copy edgechain-app JAR to Examples folder
run: cp ./FlySpring/edgechain-app/target/edgechain-app-1.0.0.jar ./BuildOutput/
Expand Down
19 changes: 0 additions & 19 deletions FlySpring/edgechain-app/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -315,23 +315,4 @@
</plugins>
</build>

<profiles>

<profile>
<id>app</id>
<properties>
<spring.boot.mainclass>com.edgechain.EdgeChainAppRunner</spring.boot.mainclass>
</properties>
</profile>

<profile>
<id>service</id>
<properties>
<spring.boot.mainclass>com.edgechain.service.EdgeChainServiceRunner</spring.boot.mainclass>
</properties>
</profile>


</profiles>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,23 @@ public OpenAiEndpoint(
this.stream = stream;
}

public OpenAiEndpoint(
String url,
String apiKey,
String orgId,
String model,
String role,
Double temperature,
Boolean stream) {
super(url, apiKey, null);
this.orgId = orgId;
this.model = model;
this.role = role;
this.temperature = temperature;
this.stream = stream;
}


public OpenAiEndpoint(
String url,
String apiKey,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import com.edgechain.lib.openai.response.ChatCompletionResponse;
import com.edgechain.lib.openai.response.CompletionResponse;
import com.edgechain.lib.rxjava.transformer.observable.EdgeChain;
import com.edgechain.lib.utils.JsonUtils;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.reactivex.rxjava3.core.Observable;
import org.slf4j.Logger;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public ExponentialDelay() {}

public ExponentialDelay(long firstDelay, int maxRetries, int factor, TimeUnit unit) {
this.firstDelay = firstDelay;
this.maxRetries = maxRetries + 1;
this.maxRetries = maxRetries;
this.factor = factor;
this.unit = unit;
this.retryCount = 0;
Expand All @@ -44,11 +44,11 @@ public Observable<?> apply(Observable<? extends Throwable> observable) throws Th
long compute = compute(firstDelay, retryCount, factor, unit);

if (++retryCount < maxRetries) {
logger.info("Retrying it.... " + throwable.getMessage());
logger.info(String.format("Retrying: Attempt: %s, Max Retries: %s ~ %s", retryCount, maxRetries, throwable.getMessage()));
return Observable.timer(compute, TimeUnit.MILLISECONDS);
}

logger.error(throwable.getMessage());
logger.error(String.format("Error Occurred: Attempt: %s, Max Retries: %s ~ %s", retryCount, maxRetries, throwable.getMessage()));
return Observable.error(throwable);
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class FixedDelay extends RetryPolicy {
public FixedDelay() {}

public FixedDelay(int maxRetries, int retryDelay, TimeUnit unit) {
this.maxRetries = maxRetries + 1;
this.maxRetries = maxRetries;
this.retryDelay = retryDelay;
this.unit = unit;
this.retryCount = 0;
Expand All @@ -41,13 +41,12 @@ public Observable<?> apply(final Observable<? extends Throwable> attempts) {
return Observable.empty();

if (++retryCount < maxRetries) {
// Unsubscribe the original observable & resubscribed it.
logger.info("Retrying it.... " + throwable.getMessage());
logger.info(String.format("Retrying: Attempt: %s, Max Retries: %s ~ %s", retryCount, maxRetries, throwable.getMessage()));
return Observable.timer(unit.toMillis(retryDelay), TimeUnit.MILLISECONDS);
}

// Once, max-retries hit, emit an error.
logger.error(throwable.getMessage());
logger.error(String.format("Error Occurred: Attempt: %s, Max Retries: %s ~ %s", retryCount, maxRetries, throwable.getMessage()));
return Observable.error(throwable);
});
}
Expand All @@ -62,6 +61,10 @@ public String toString() {
return sb.toString();
}

public int getRetryCount() {
return retryCount;
}

public int getMaxRetries() {
return maxRetries;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,46 +1,100 @@
package com.edgechain.openai;

import com.edgechain.lib.embeddings.WordEmbeddings;
import com.edgechain.lib.endpoint.impl.OpenAiEndpoint;
import com.edgechain.lib.openai.request.ChatCompletionRequest;
import com.edgechain.lib.openai.request.ChatMessage;
import com.edgechain.lib.openai.response.ChatCompletionResponse;
import com.edgechain.lib.rxjava.retry.impl.ExponentialDelay;
;
import com.edgechain.lib.utils.JsonUtils;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.reactivex.rxjava3.observers.TestObserver;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.*;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.server.LocalServerPort;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
import java.util.concurrent.TimeUnit;

import static com.edgechain.lib.constants.EndpointConstants.*;
import static com.edgechain.lib.constants.EndpointConstants.OPENAI_CHAT_COMPLETION_API;
import static org.junit.jupiter.api.Assertions.*;

@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
public class OpenAiClientTest {

@LocalServerPort int randomServerPort;

private Logger logger = LoggerFactory.getLogger(this.getClass());
private final Logger logger = LoggerFactory.getLogger(this.getClass());

@BeforeEach
public void setup() {
System.setProperty("server.port", "" + randomServerPort);
}

@ParameterizedTest
@CsvSource({
"Write 10 unique sentences on Java Language",
"Can you explain Ant Bee Colony Optimization Algorithm?"
})
public void testOpenAiEndpoint_ChatCompletionShouldAssertNoErrors(String prompt)
@ValueSource(classes = {ChatCompletionRequest.class})
@DisplayName("Test ChatCompletionRequest Json Request")
@Order(1)
public void testOpenAiClient_ChatCompletionRequest_ShouldMatchRequestBody(Class<?> clazz)
throws IOException {

ObjectMapper mapper = new ObjectMapper();
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);

ChatCompletionRequest chatCompletionRequest =
ChatCompletionRequest.builder()
.model("gpt-3.5-turbo")
.temperature(0.7)
.messages(
List.of(
new ChatMessage(
"user", "Can you write two unique sentences on Java Language?")))
.stream(false)
.build();

byte[] bytes =
Files.readAllBytes(Paths.get("src/test/java/resources/" + clazz.getSimpleName() + ".json"));
String originalJson = new String(bytes);

assertEquals(
mapper.readTree(JsonUtils.convertToString(chatCompletionRequest)),
mapper.readTree(originalJson));
}

@ParameterizedTest
@ValueSource(classes = {ChatCompletionResponse.class})
@DisplayName("Test ChatCompletionResponse POJO")
@Order(2)
public void testOpenAiClient_ChatCompletionResponse_ShouldMappedToPOJO(Class<?> clazz) {
assertDoesNotThrow(
() -> {
byte[] bytes =
Files.readAllBytes(
Paths.get("src/test/java/resources/" + clazz.getSimpleName() + ".json"));
String json = new String(bytes);

ChatCompletionResponse chatCompletionResponse =
JsonUtils.convertToObject(json, ChatCompletionResponse.class);
logger.info("" + chatCompletionResponse); // Printing the object
});
}

@Test
@DisplayName("Test OpenAiEndpoint With Retry Mechanism")
@Order(3)
public void testOpenAiClient_WithRetryMechanism_ShouldThrowExceptionWithRetry(TestInfo testInfo)
throws InterruptedException {

// Step 1 : Create OpenAi Endpoint
System.out.println("======== " + testInfo.getDisplayName() +" ========");

OpenAiEndpoint endpoint =
new OpenAiEndpoint(
OPENAI_CHAT_COMPLETION_API,
Expand All @@ -52,26 +106,24 @@ public void testOpenAiEndpoint_ChatCompletionShouldAssertNoErrors(String prompt)
false,
new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS));

TestObserver<ChatCompletionResponse> test = endpoint.getChatCompletion(prompt).test();
TestObserver<ChatCompletionResponse> test =
endpoint.getChatCompletion("Can you write two unique sentences on Java Language?").test();

// Step 4: To act & assert
test.await();

logger.info(test.values().toString());

// Assert
test.assertNoErrors();
test.assertError(Exception.class);
}

@ParameterizedTest
@CsvSource({
"Write 10 unique sentences on Java Language",
"Can you explain Ant Bee Colony Optimization Algorithm?"
})
@DisplayName("Test OpenAI ChatCompletion Stream")
public void testOpenAiEndpoint_ChatCompletionStreamResponseShouldAssertNoErrors(String prompt)
@Test
@DisplayName("Test OpenAiEndpoint With No Retry Mechanism")
@Order(4)
public void testOpenAiClient_WithNoRetryMechanism_ShouldThrowExceptionWithNoRetry(TestInfo testInfo)
throws InterruptedException {

System.out.println("======== " + testInfo.getDisplayName() +" ========");

// Step 1 : Create OpenAi Endpoint
OpenAiEndpoint endpoint =
new OpenAiEndpoint(
Expand All @@ -81,46 +133,15 @@ public void testOpenAiEndpoint_ChatCompletionStreamResponseShouldAssertNoErrors(
"gpt-3.5-turbo",
"user",
0.7,
true,
new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS));

TestObserver<ChatCompletionResponse> test = endpoint.getChatCompletion(prompt).test();

// Step 4: To act & assert
test.await();

logger.info(test.values().toString());

// Assert
test.assertNoErrors();
}

@Test
@DisplayName("Test OpenAI Embeddings")
public void testOpenAiEndpoint_EmbeddingsShouldAssertNoErrors() throws InterruptedException {

String input = "Hey, we are building LLMs using Spring and Java";
false);

// Step 1 : Create OpenAi Endpoint
OpenAiEndpoint endpoint =
new OpenAiEndpoint(
OPENAI_EMBEDDINGS_API,
"", // apiKey
"", // orgId
"text-embedding-ada-002", // model
null,
null,
null,
new ExponentialDelay(3, 3, 2, TimeUnit.SECONDS));

TestObserver<WordEmbeddings> test = endpoint.getEmbeddings(input).test();
TestObserver<ChatCompletionResponse> test =
endpoint.getChatCompletion("Can you write two unique sentences on Java Language?").test();

// Step 4: To act & assert
test.await();

logger.info(test.values().toString());

// Assert
test.assertNoErrors();
test.assertError(Exception.class);
}
}
Original file line number Diff line number Diff line change
@@ -1,34 +1,8 @@
package com.edgechain.redis;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.web.server.LocalServerPort;

@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT)
@SpringBootTest
public class RedisClientTest {

@LocalServerPort int randomServerPort;

private Logger logger = LoggerFactory.getLogger(this.getClass());

@BeforeEach
public void setup() {
System.setProperty("server.port", "" + randomServerPort);
}

@Test
@DisplayName("Test Redis Endpoint Upsert With OpenAI")
public void testRedisEndpoint_UpsertWithOpenAiShouldAssertNoErrors() {}

@Test
@DisplayName("Test Redis Endpoint Similarity Search With OpenAI")
public void testRedisEndpoint_SimilaritySearchWithOpenAiShouldAssertNoErrors() {}

@Test
@DisplayName("Test Redis Endpoint Delete By Pattern")
public void testRedisEndpoint_DeleteByPatternShouldAssertNoErrors() {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"model" : "gpt-3.5-turbo",
"temperature" : 0.7,
"messages" : [ {
"role" : "user",
"content" : "Can you write two unique sentences on Java Language?"
} ],
"stream" : false
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"id" : "chatcmpl-7cCBttNGLwPkAMswN8074tMYdd3Zm",
"object" : "chat.completion",
"created" : 1689337681,
"model" : "gpt-3.5-turbo-0613",
"choices" : [ {
"index" : 0,
"message" : {
"role" : "assistant",
"content" : "1. Java language is a versatile, object-oriented programming language used for developing various applications such as mobile applications, desktop applications, and web services.\n2. Known for its write once, run anywhere feature, Java offers robust security and cross-platform compatibility making it a popular choice among developers."
},
"finish_reason" : "stop"
} ],
"usage" : {
"prompt_tokens" : 880,
"total_tokens" : 1396
}
}

0 comments on commit 975a938

Please sign in to comment.