Skip to content

Commit

Permalink
Use ContentRetriever
Browse files Browse the repository at this point in the history
  • Loading branch information
sberyozkin committed May 20, 2024
1 parent cf7bfeb commit dc36dbf
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,4 @@ public class Customer {
public String name;
public String email;
public int transactionLimit;
public int distanceLimit;
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,4 @@ public int getTransactionLimit(String customerName, String customerEmail) {
customerName,
customerEmail).firstResult().transactionLimit;
}

@Tool("Get the distance limit for a given customer")
public int getDistanceLimit(String customerName, String customerEmail) {
return find("name = ?1 and email = ?2",
customerName,
customerEmail).firstResult().distanceLimit;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,19 @@
import dev.langchain4j.service.UserMessage;
import io.quarkiverse.langchain4j.RegisterAiService;

@RegisterAiService(tools = { TransactionRepository.class, CustomerRepository.class })
@RegisterAiService(retrievalAugmentor = FraudDetectionRetrievalAugmentor.class)
public interface FraudDetectionAi {

@SystemMessage("""
You are a bank account fraud detection AI. You have to detect frauds in transactions.
""")
@UserMessage("""
Your task is to detect whether a fraud was committed for the customer {{customerName}}.
To detect a fraud, perform the following actions:
1 - Retrieve the transaction limit for the customer {{customerName}} with the {{customerEmail}} email address.
2 - Retrieve the transactions for the customer {{customerName}} with the {{customerEmail}} email address for the last 15 minutes.
3 - Sum the amount of all of these transactions. Make sure the sum is correct.
4 - If the amount is greater than the transaction limit for this customer, a fraud is detected.
Your task is to detect whether a fraud was committed for the customer {{customerName}} with the {{customerEmail}} email.
Answer with a **single** JSON document containing:
- the customer name in the 'customer-name' key
- the 'returning-customer' key set to a boolean value indicating if the same query was already issued before
- the transaction limit in the 'transaction-limit' key
- the computed sum in the 'total' key
- the computed sum of all transactions in the 'total' key
- the 'fraud' key set to a boolean value indicating if a fraud was detected
- the 'transactions' key containing the list of transaction amounts
- the 'explanation' key containing a explanation of your answer, including how the sum is computed.
Expand All @@ -37,37 +30,4 @@ public interface FraudDetectionAi {
""")
@Timeout(value = 2, unit = ChronoUnit.MINUTES)
String detectAmountFraudForCustomer(String customerName, String customerEmail);

@SystemMessage("""
You are a bank account fraud detection AI. You have to detect frauds in transactions.
""")
@UserMessage("""
Detect frauds based on the distance between two transactions for the customer: {{customerName}}.
To detect a fraud, perform the following actions:
1 - Retrieve the distance limit in kilometers for the customer {{customerName}} with the {{customerEmail}} email address.
2 - Retrieve the transactions for the customer {{customerName}} with the {{customerEmail}} email address for the last 15 minutes.
3 - Retrieve the city for each transaction.
4 - Check if the distance between 2 cities is greater than the distance limit, if so, a fraud is detected.
5 - If a fraud is detected, find the two transactions associated with these cities.
Answer with a **single** JSON document containing:
- the customer name in the 'customer-name' key
- the distance limit in the 'distance-limit' key
- the amount of the first transaction in the 'first-amount' key
- the amount of the second transaction in the 'second-amount' key
- the city of the first transaction in the 'first-city' key
- the city of the second transaction in the 'second-city' key
- the 'fraud' key set to a boolean value indicating if a fraud was detected (so the distance is greater than the distance limit)
- the 'distance' key set to the distance between the two cities
- the 'explanation' key containing a explanation of your answer.
- the 'cities' key containing all the cities for the transactions for the customer {{customerName}} in the last 15 minutes.
- if there is a fraud, the 'email' key containing an email to the customer {{customerName}} to warn about the fraud.
Your response must be just the raw JSON document, without ```json, ``` or anything else. Do not use null JSON properties.
""")
@Timeout(value = 2, unit = ChronoUnit.MINUTES)
String detectDistanceFraudForCustomer(String customerName, String customerEmail);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package io.quarkiverse.langchain4j.sample;

import java.util.List;

import org.eclipse.microprofile.jwt.JsonWebToken;
import org.jboss.logging.Logger;

import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;

@ApplicationScoped
public class FraudDetectionContentRetriever implements ContentRetriever {
private static final Logger log = Logger.getLogger(FraudDetectionContentRetriever.class);

@Inject
TransactionRepository transactionRepository;

@Inject
CustomerRepository customerRepository;

@Override
public List<Content> retrieve(Query query) {
JsonWebToken idToken = (JsonWebToken)query.metadata().chatMemoryId();
// User the customer name and email to retrive the content and return it as JSON
return List.of();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,4 @@ public FraudDetectionResource(FraudDetectionAi service) {
public String detectBaseOnAmount() {
return service.detectAmountFraudForCustomer(idToken.getName(), idToken.getClaim(Claims.email));
}

@GET
@Path("/distance")
public String detectBasedOnDistance() {
return service.detectDistanceFraudForCustomer(idToken.getName(), idToken.getClaim(Claims.email));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package io.quarkiverse.langchain4j.sample;

import java.util.function.Supplier;

import dev.langchain4j.rag.DefaultRetrievalAugmentor;
import dev.langchain4j.rag.RetrievalAugmentor;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.inject.Inject;

@ApplicationScoped
public class FraudDetectionRetrievalAugmentor implements Supplier<RetrievalAugmentor> {

@Inject
FraudDetectionContentRetriever contentRetriever;

@Override
public RetrievalAugmentor get() {
return DefaultRetrievalAugmentor.builder()
.contentRetriever(contentRetriever)
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import java.util.List;
import java.util.Random;

import io.quarkus.runtime.StartupEvent;
import jakarta.enterprise.context.ApplicationScoped;
import jakarta.enterprise.event.Observes;
import jakarta.transaction.Transactional;

import io.quarkus.runtime.StartupEvent;

@ApplicationScoped
public class Setup {

Expand All @@ -29,21 +28,18 @@ public void init(@Observes StartupEvent ev, CustomerRepository customers, Transa
customer1.name = "Clement Escofier";
customer1.email = "[email protected]";
customer1.transactionLimit = 10000;
customer1.distanceLimit = 500;
customers.persist(customer1);

var customer2 = new Customer();
customer2.name = "Georgios Andrianakis";
customer2.email = "[email protected]";
customer2.transactionLimit = 1000;
customer1.distanceLimit = 300;
customers.persist(customer2);

var customer3 = new Customer();
customer3.name = "Sergey Beryozkin";
customer3.email = "[email protected]";
customer1.transactionLimit = 500;
customer1.distanceLimit = 100;
customers.persist(customer3);
}

Expand Down

0 comments on commit dc36dbf

Please sign in to comment.