Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General: Track token usage of LLM service requests #9455

Merged
merged 34 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a5dfdbb
Define LLM token usage model
alexjoham Sep 28, 2024
d0bdae3
Update table, save data recieved from Pyris Exercise chat pipeline
alexjoham Oct 11, 2024
2a08cb2
Implement competency generation tracking, update enum
alexjoham Oct 11, 2024
f85cf46
Add comments to LLMTokenUsageService
alexjoham Oct 11, 2024
65fb259
Fix server test failures by checking if tokens received
alexjoham Oct 12, 2024
188ff22
Update database for cost tracking and trace_id functionality
alexjoham Oct 12, 2024
be85a3b
Update database, add information to competency gen, change traceId calc
alexjoham Oct 12, 2024
e974d59
Implement server Integration tests for token tracking and saving
alexjoham Oct 13, 2024
6337162
Update code based on code-rabbit feedback, fix tests
alexjoham Oct 14, 2024
84a60dc
minor comment changes, remove tokens from frontend
alexjoham Oct 14, 2024
5b0ab48
Merge branch 'develop' into feature/track-usage-of-iris-requests
alexjoham Oct 14, 2024
62dad8b
Fix github test fails
alexjoham Oct 14, 2024
897d643
Change servicetype to type String to prevent failures
alexjoham Oct 14, 2024
1d10860
Change servicetype to type String to prevent failures
alexjoham Oct 14, 2024
8b27861
Merge remote-tracking branch 'origin/feature/track-usage-of-iris-requ…
alexjoham Oct 14, 2024
86294c1
Fix test failure by removing @SpyBean
alexjoham Oct 15, 2024
56b20e7
Update database to safe only IDs, fix competency Integration Test user
alexjoham Oct 15, 2024
8a29c82
Implement builder pattern based on feedback
alexjoham Oct 16, 2024
abbd28f
Update database migration with foreign keys and on delete null
alexjoham Oct 16, 2024
8d34428
Rework database, update saveLLMTokens method
alexjoham Oct 18, 2024
52bf023
Implement new service in all Pipelines, update database, update test
alexjoham Oct 19, 2024
b8f5cca
fix server tests
krusche Oct 20, 2024
82fb76d
fix function naming
FelixTJDietrich Oct 21, 2024
6d3037a
replace ArraySet
FelixTJDietrich Oct 21, 2024
c785417
Merge branch 'develop' into feature/track-usage-of-iris-requests
FelixTJDietrich Oct 21, 2024
9f4cccd
Refactored token usage tracking and improved session-based job handling
bassner Oct 21, 2024
e437c71
Update tests to work with new changes
alexjoham Oct 21, 2024
cc127af
Athena: Add LLM token usage tracking (#9554)
FelixTJDietrich Oct 22, 2024
aded0ee
Implement feedback and fix server tests
alexjoham Oct 22, 2024
8a7bc2e
add foreign keys with onDelete=SET NULL to all ids in LLMTokenUsageTrace
alexjoham Oct 22, 2024
c7e7db6
Correct wrong Long type, update comment
alexjoham Oct 22, 2024
6a90887
Merge branch 'develop' into feature/track-usage-of-iris-requests
FelixTJDietrich Oct 23, 2024
79b1b88
Make LLM token tracking of chat suggestions multi-node compatible
bassner Oct 23, 2024
5a2e92a
Added return statements to handleStatusUpdate methods
bassner Oct 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package de.tum.cit.aet.artemis.core.domain;

public record LLMRequest(String model, int numInputTokens, float costPerMillionInputToken, int numOutputTokens, float costPerMillionOutputToken, String pipelineId) {
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
}
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package de.tum.cit.aet.artemis.core.domain;

/**
* Enum representing different types of LLM (Large Language Model) services used in the system.
*/
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
public enum LLMServiceType {
IRIS, ATHENA
}
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package de.tum.cit.aet.artemis.core.domain;

import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.ManyToOne;
import jakarta.persistence.Table;

import org.hibernate.annotations.Cache;
import org.hibernate.annotations.CacheConcurrencyStrategy;

import com.fasterxml.jackson.annotation.JsonInclude;

@Entity
@Table(name = "llm_token_usage_request")
@Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE)
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public class LLMTokenUsageRequest extends DomainObject {
alexjoham marked this conversation as resolved.
Show resolved Hide resolved

@Column(name = "model")
private String model;

@Column(name = "service_pipeline_id")
private String servicePipelineId;

@Column(name = "num_input_tokens")
private int numInputTokens;

@Column(name = "cost_per_million_input_tokens")
private float costPerMillionInputTokens;

@Column(name = "num_output_tokens")
private int numOutputTokens;

@Column(name = "cost_per_million_output_tokens")
private float costPerMillionOutputTokens;

@ManyToOne
private LLMTokenUsageTrace trace;
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved

public String getModel() {
return model;
}

public void setModel(String model) {
this.model = model;
}

public String getServicePipelineId() {
return servicePipelineId;
}

public void setServicePipelineId(String servicePipelineId) {
this.servicePipelineId = servicePipelineId;
}

public float getCostPerMillionInputTokens() {
return costPerMillionInputTokens;
}

public void setCostPerMillionInputTokens(float costPerMillionInputToken) {
this.costPerMillionInputTokens = costPerMillionInputToken;
}

public float getCostPerMillionOutputTokens() {
return costPerMillionOutputTokens;
}

public void setCostPerMillionOutputTokens(float costPerMillionOutputToken) {
this.costPerMillionOutputTokens = costPerMillionOutputToken;
}
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved

public int getNumInputTokens() {
return numInputTokens;
}

public void setNumInputTokens(int numInputTokens) {
this.numInputTokens = numInputTokens;
}

public int getNumOutputTokens() {
return numOutputTokens;
}

public void setNumOutputTokens(int numOutputTokens) {
this.numOutputTokens = numOutputTokens;
}

public LLMTokenUsageTrace getTrace() {
return trace;
}

public void setTrace(LLMTokenUsageTrace trace) {
this.trace = trace;
}
}
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package de.tum.cit.aet.artemis.core.domain;

import java.time.ZonedDateTime;
import java.util.HashSet;
import java.util.Set;

import jakarta.annotation.Nullable;
import jakarta.persistence.CascadeType;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.EnumType;
import jakarta.persistence.Enumerated;
import jakarta.persistence.FetchType;
import jakarta.persistence.OneToMany;
import jakarta.persistence.Table;

import org.hibernate.annotations.Cache;
import org.hibernate.annotations.CacheConcurrencyStrategy;

import com.fasterxml.jackson.annotation.JsonInclude;

@Entity
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
@Table(name = "llm_token_usage_trace")
@Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE)
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public class LLMTokenUsageTrace extends DomainObject {

@Column(name = "service")
@Enumerated(EnumType.STRING)
private LLMServiceType serviceType;

@Nullable
@Column(name = "course_id")
private Long courseId;

@Nullable
@Column(name = "exercise_id")
private Long exerciseId;

@Column(name = "user_id")
private Long userId;
krusche marked this conversation as resolved.
Show resolved Hide resolved

@Column(name = "time")
private ZonedDateTime time = ZonedDateTime.now();

@Nullable
@Column(name = "iris_message_id")
private Long irisMessageId;
alexjoham marked this conversation as resolved.
Show resolved Hide resolved

@OneToMany(mappedBy = "trace", fetch = FetchType.LAZY, cascade = CascadeType.ALL, orphanRemoval = true)
private Set<LLMTokenUsageRequest> llmRequests = new HashSet<>();

public LLMServiceType getServiceType() {
return serviceType;
}

public void setServiceType(LLMServiceType serviceType) {
this.serviceType = serviceType;
}

public Long getCourseId() {
return courseId;
}

public void setCourseId(Long courseId) {
this.courseId = courseId;
}

public Long getExerciseId() {
return exerciseId;
}

public void setExerciseId(Long exerciseId) {
this.exerciseId = exerciseId;
}

public long getUserId() {
return userId;
}

public void setUserId(long userId) {
this.userId = userId;
}

public ZonedDateTime getTime() {
return time;
}

public void setTime(ZonedDateTime time) {
this.time = time;
}

public Set<LLMTokenUsageRequest> getLLMRequests() {
return llmRequests;
}

public void setLlmRequests(Set<LLMTokenUsageRequest> llmRequests) {
this.llmRequests = llmRequests;
}
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved

public Long getIrisMessageId() {
return irisMessageId;
}

public void setIrisMessageId(Long messageId) {
this.irisMessageId = messageId;
}
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package de.tum.cit.aet.artemis.core.repository;

import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE;

import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Repository;

import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageRequest;
import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository;

@Profile(PROFILE_CORE)
@Repository
public interface LLMTokenUsageRequestRepository extends ArtemisJpaRepository<LLMTokenUsageRequest, Long> {
}
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package de.tum.cit.aet.artemis.core.repository;

import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE;

import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Repository;

import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace;
import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository;

@Profile(PROFILE_CORE)
@Repository
public interface LLMTokenUsageTraceRepository extends ArtemisJpaRepository<LLMTokenUsageTrace, Long> {
}
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package de.tum.cit.aet.artemis.core.service;

import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_CORE;

import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Service;

import de.tum.cit.aet.artemis.core.domain.LLMRequest;
import de.tum.cit.aet.artemis.core.domain.LLMServiceType;
import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageRequest;
import de.tum.cit.aet.artemis.core.domain.LLMTokenUsageTrace;
import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageRequestRepository;
import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageTraceRepository;

/**
* Service for managing the LLMTokenUsage by all LLMs in Artemis
*/
@Profile(PROFILE_CORE)
@Service
public class LLMTokenUsageService {

private final LLMTokenUsageTraceRepository llmTokenUsageTraceRepository;

private final LLMTokenUsageRequestRepository llmTokenUsageRequestRepository;

public LLMTokenUsageService(LLMTokenUsageTraceRepository llmTokenUsageTraceRepository, LLMTokenUsageRequestRepository llmTokenUsageRequestRepository) {
this.llmTokenUsageTraceRepository = llmTokenUsageTraceRepository;
this.llmTokenUsageRequestRepository = llmTokenUsageRequestRepository;
}

/**
* Saves the token usage to the database.
* This method records the usage of tokens by various LLM services in the system.
*
* @param llmRequests List of LLM requests containing details about the token usage.
* @param serviceType Type of the LLM service (e.g., IRIS, GPT-3).
* @param builderFunction A function that takes an LLMTokenUsageBuilder and returns a modified LLMTokenUsageBuilder.
* This function is used to set additional properties on the LLMTokenUsageTrace object, such as
* the course ID, user ID, exercise ID, and Iris message ID.
* Example usage:
* builder -> builder.withCourse(courseId).withUser(userId)
* @return The saved LLMTokenUsageTrace object, which includes the details of the token usage.
*/
// TODO: this should ideally be done Async
public LLMTokenUsageTrace saveLLMTokenUsage(List<LLMRequest> llmRequests, LLMServiceType serviceType, Function<LLMTokenUsageBuilder, LLMTokenUsageBuilder> builderFunction) {
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
LLMTokenUsageTrace llmTokenUsageTrace = new LLMTokenUsageTrace();
llmTokenUsageTrace.setServiceType(serviceType);

LLMTokenUsageBuilder builder = builderFunction.apply(new LLMTokenUsageBuilder());
builder.getIrisMessageID().ifPresent(llmTokenUsageTrace::setIrisMessageId);
builder.getCourseID().ifPresent(llmTokenUsageTrace::setCourseId);
builder.getExerciseID().ifPresent(llmTokenUsageTrace::setExerciseId);
builder.getUserID().ifPresent(llmTokenUsageTrace::setUserId);

llmTokenUsageTrace.setLlmRequests(llmRequests.stream().map(LLMTokenUsageService::convertLLMRequestToLLMTokenUsageRequest)
.peek(llmTokenUsageRequest -> llmTokenUsageRequest.setTrace(llmTokenUsageTrace)).collect(Collectors.toSet()));

return llmTokenUsageTraceRepository.save(llmTokenUsageTrace);
}
bassner marked this conversation as resolved.
Show resolved Hide resolved

private static LLMTokenUsageRequest convertLLMRequestToLLMTokenUsageRequest(LLMRequest llmRequest) {
LLMTokenUsageRequest llmTokenUsageRequest = new LLMTokenUsageRequest();
llmTokenUsageRequest.setModel(llmRequest.model());
llmTokenUsageRequest.setNumInputTokens(llmRequest.numInputTokens());
llmTokenUsageRequest.setNumOutputTokens(llmRequest.numOutputTokens());
llmTokenUsageRequest.setCostPerMillionInputTokens(llmRequest.costPerMillionInputToken());
llmTokenUsageRequest.setCostPerMillionOutputTokens(llmRequest.costPerMillionOutputToken());
llmTokenUsageRequest.setServicePipelineId(llmRequest.pipelineId());
return llmTokenUsageRequest;
}

// TODO: this should ideally be done Async
public void appendRequestsToTrace(List<LLMRequest> requests, LLMTokenUsageTrace trace) {
var requestSet = requests.stream().map(LLMTokenUsageService::convertLLMRequestToLLMTokenUsageRequest).peek(llmTokenUsageRequest -> llmTokenUsageRequest.setTrace(trace))
.collect(Collectors.toSet());
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
llmTokenUsageRequestRepository.saveAll(requestSet);
}
bassner marked this conversation as resolved.
Show resolved Hide resolved

/**
* Class LLMTokenUsageBuilder to be used for saveLLMTokenUsage()
*/
public static class LLMTokenUsageBuilder {

private Optional<Long> courseID = Optional.empty();

private Optional<Long> irisMessageID = Optional.empty();
bassner marked this conversation as resolved.
Show resolved Hide resolved

private Optional<Long> exerciseID = Optional.empty();

private Optional<Long> userID = Optional.empty();

public LLMTokenUsageBuilder withCourse(Long courseID) {
this.courseID = Optional.ofNullable(courseID);
return this;
}

public LLMTokenUsageBuilder withIrisMessageID(Long irisMessageID) {
this.irisMessageID = Optional.ofNullable(irisMessageID);
return this;
}

public LLMTokenUsageBuilder withExercise(Long exerciseID) {
this.exerciseID = Optional.ofNullable(exerciseID);
return this;
}

public LLMTokenUsageBuilder withUser(Long userID) {
this.userID = Optional.ofNullable(userID);
return this;
}

public Optional<Long> getCourseID() {
return courseID;
}

public Optional<Long> getIrisMessageID() {
return irisMessageID;
}

public Optional<Long> getExerciseID() {
return exerciseID;
}

public Optional<Long> getUserID() {
return userID;
}
}
bassner marked this conversation as resolved.
Show resolved Hide resolved
}
Loading
Loading