Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

General: Track token usage of LLM service requests #9455

Merged
merged 34 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a5dfdbb
Define LLM token usage model
alexjoham Sep 28, 2024
d0bdae3
Update table, save data recieved from Pyris Exercise chat pipeline
alexjoham Oct 11, 2024
2a08cb2
Implement competency generation tracking, update enum
alexjoham Oct 11, 2024
f85cf46
Add comments to LLMTokenUsageService
alexjoham Oct 11, 2024
65fb259
Fix server test failures by checking if tokens received
alexjoham Oct 12, 2024
188ff22
Update database for cost tracking and trace_id functionality
alexjoham Oct 12, 2024
be85a3b
Update database, add information to competency gen, change traceId calc
alexjoham Oct 12, 2024
e974d59
Implement server Integration tests for token tracking and saving
alexjoham Oct 13, 2024
6337162
Update code based on code-rabbit feedback, fix tests
alexjoham Oct 14, 2024
84a60dc
minor comment changes, remove tokens from frontend
alexjoham Oct 14, 2024
5b0ab48
Merge branch 'develop' into feature/track-usage-of-iris-requests
alexjoham Oct 14, 2024
62dad8b
Fix github test fails
alexjoham Oct 14, 2024
897d643
Change servicetype to type String to prevent failures
alexjoham Oct 14, 2024
1d10860
Change servicetype to type String to prevent failures
alexjoham Oct 14, 2024
8b27861
Merge remote-tracking branch 'origin/feature/track-usage-of-iris-requ…
alexjoham Oct 14, 2024
86294c1
Fix test failure by removing @SpyBean
alexjoham Oct 15, 2024
56b20e7
Update database to safe only IDs, fix competency Integration Test user
alexjoham Oct 15, 2024
8a29c82
Implement builder pattern based on feedback
alexjoham Oct 16, 2024
abbd28f
Update database migration with foreign keys and on delete null
alexjoham Oct 16, 2024
8d34428
Rework database, update saveLLMTokens method
alexjoham Oct 18, 2024
52bf023
Implement new service in all Pipelines, update database, update test
alexjoham Oct 19, 2024
b8f5cca
fix server tests
krusche Oct 20, 2024
82fb76d
fix function naming
FelixTJDietrich Oct 21, 2024
6d3037a
replace ArraySet
FelixTJDietrich Oct 21, 2024
c785417
Merge branch 'develop' into feature/track-usage-of-iris-requests
FelixTJDietrich Oct 21, 2024
9f4cccd
Refactored token usage tracking and improved session-based job handling
bassner Oct 21, 2024
e437c71
Update tests to work with new changes
alexjoham Oct 21, 2024
cc127af
Athena: Add LLM token usage tracking (#9554)
FelixTJDietrich Oct 22, 2024
aded0ee
Implement feedback and fix server tests
alexjoham Oct 22, 2024
8a7bc2e
add foreign keys with onDelete=SET NULL to all ids in LLMTokenUsageTrace
alexjoham Oct 22, 2024
c7e7db6
Correct wrong Long type, update comment
alexjoham Oct 22, 2024
6a90887
Merge branch 'develop' into feature/track-usage-of-iris-requests
FelixTJDietrich Oct 23, 2024
79b1b88
Make LLM token tracking of chat suggestions multi-node compatible
bassner Oct 23, 2024
5a2e92a
Added return statements to handleStatusUpdate methods
bassner Oct 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package de.tum.cit.aet.artemis.core.domain;

public enum LLMServiceType {
ATHENA_PRELIMINARY_FEEDBACK, ATHENA_FEEDBACK_SUGGESTION, IRIS_CODE_FEEDBACK, IRIS_CHAT_COURSE_MESSAGE, IRIS_CHAT_EXERCISE_MESSAGE, IRIS_INTERACTION_SUGGESTION,
IRIS_CHAT_LECTURE_MESSAGE, IRIS_COMPETENCY_GENERATION, IRIS_CITATION_PIPELINE, NOT_SET
}
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
183 changes: 183 additions & 0 deletions src/main/java/de/tum/cit/aet/artemis/core/domain/LLMTokenUsage.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package de.tum.cit.aet.artemis.core.domain;

import java.time.ZonedDateTime;

import jakarta.annotation.Nullable;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.EnumType;
import jakarta.persistence.Enumerated;
import jakarta.persistence.Inheritance;
import jakarta.persistence.InheritanceType;
import jakarta.persistence.JoinColumn;
import jakarta.persistence.ManyToOne;
import jakarta.persistence.Table;

import org.hibernate.annotations.Cache;
import org.hibernate.annotations.CacheConcurrencyStrategy;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonTypeInfo;

import de.tum.cit.aet.artemis.exercise.domain.Exercise;
import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage;

@Entity
@Table(name = "llm_token_usage")
@Inheritance(strategy = InheritanceType.SINGLE_TABLE)
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
@Cache(usage = CacheConcurrencyStrategy.NONSTRICT_READ_WRITE)
@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type")
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public class LLMTokenUsage extends DomainObject {
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved

@Column(name = "service")
@Enumerated(EnumType.STRING)
private LLMServiceType serviceType;

@Column(name = "model")
private String model;
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved

@Column(name = "num_input_tokens")
private int num_input_tokens;

@Column(name = "cost_per_input_token")
private float cost_per_input_token;
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
alexjoham marked this conversation as resolved.
Show resolved Hide resolved

@Column(name = "num_output_tokens")
private int num_output_tokens;

@Column(name = "cost_per_output_token")
private float cost_per_output_token;

alexjoham marked this conversation as resolved.
Show resolved Hide resolved
@Nullable
@ManyToOne
@JsonIgnore
@JoinColumn(name = "course_id")
private Course course;

@Nullable
@ManyToOne
@JsonIgnore
@JoinColumn(name = "exercise_id")
private Exercise exercise;

@Column(name = "user_id")
private long userId;
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved

@Nullable
@Column(name = "timestamp")
private ZonedDateTime timestamp = ZonedDateTime.now();

@Column(name = "trace_id")
private Long traceId;
alexjoham marked this conversation as resolved.
Show resolved Hide resolved

@Nullable
@ManyToOne
@JsonIgnore
@JoinColumn(name = "iris_message_id")
IrisMessage irisMessage;
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
alexjoham marked this conversation as resolved.
Show resolved Hide resolved

public LLMServiceType getServiceType() {
return serviceType;
}

public void setServiceType(LLMServiceType serviceType) {
this.serviceType = serviceType;
}

public String getModel() {
return model;
}

public void setModel(String model) {
this.model = model;
}

public float getCost_per_input_token() {
return cost_per_input_token;
}

public void setCost_per_input_token(float cost_per_input_token) {
this.cost_per_input_token = cost_per_input_token;
}

public float getCost_per_output_token() {
return cost_per_output_token;
}

public void setCost_per_output_token(float cost_per_output_token) {
this.cost_per_output_token = cost_per_output_token;
}

public int getNum_input_tokens() {
return num_input_tokens;
}

public void setNum_input_tokens(int num_input_tokens) {
this.num_input_tokens = num_input_tokens;
}

public int getNum_output_tokens() {
return num_output_tokens;
}

public void setNum_output_tokens(int num_output_tokens) {
this.num_output_tokens = num_output_tokens;
}

public Course getCourse() {
return course;
}

public void setCourse(Course course) {
this.course = course;
}

public Exercise getExercise() {
return exercise;
}

public void setExercise(Exercise exercise) {
this.exercise = exercise;
}

public long getUserId() {
return userId;
}

public void setUserId(long userId) {
this.userId = userId;
}

public ZonedDateTime getTimestamp() {
return timestamp;
}

public void setTimestamp(ZonedDateTime timestamp) {
this.timestamp = timestamp;
}

public Long getTraceId() {
return traceId;
}

public void setTraceId(Long traceId) {
this.traceId = traceId;
}

public IrisMessage getIrisMessage() {
return irisMessage;
}

public void setIrisMessage(IrisMessage message) {
this.irisMessage = message;
}

@Override
public String toString() {
return "LLMTokenUsage{" + "serviceType=" + serviceType + ", model=" + model + ", num_input_tokens=" + num_input_tokens + ", cost_per_input_token=" + cost_per_input_token
+ ", num_output_tokens=" + num_output_tokens + ", cost_per_output_token=" + cost_per_output_token + ", course=" + course + ", exercise=" + exercise + ", userId="
+ userId + ", timestamp=" + timestamp + ", trace_id=" + traceId + ", irisMessage=" + irisMessage + '}';
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package de.tum.cit.aet.artemis.core.repository;

import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_IRIS;

import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Repository;

import de.tum.cit.aet.artemis.core.domain.LLMTokenUsage;
import de.tum.cit.aet.artemis.core.repository.base.ArtemisJpaRepository;

@Repository
@Profile(PROFILE_IRIS)
public interface LLMTokenUsageRepository extends ArtemisJpaRepository<LLMTokenUsage, Long> {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package de.tum.cit.aet.artemis.core.service;

import static de.tum.cit.aet.artemis.core.config.Constants.PROFILE_IRIS;

import java.util.ArrayList;
import java.util.List;
import java.util.UUID;

import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Service;

import de.tum.cit.aet.artemis.core.domain.Course;
import de.tum.cit.aet.artemis.core.domain.LLMTokenUsage;
import de.tum.cit.aet.artemis.core.domain.User;
import de.tum.cit.aet.artemis.core.repository.LLMTokenUsageRepository;
import de.tum.cit.aet.artemis.exercise.domain.Exercise;
import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage;
import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO;

/**
* Service for managing the LLMTokenUsage by all LLMs in Artemis
*/
@Service
@Profile(PROFILE_IRIS)
public class LLMTokenUsageService {

private final LLMTokenUsageRepository llmTokenUsageRepository;

public LLMTokenUsageService(LLMTokenUsageRepository llmTokenUsageRepository) {
this.llmTokenUsageRepository = llmTokenUsageRepository;
}

/**
* saves the tokens used for a specific IrisMessage or Athena call
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
* in case of an Athena call IrisMessage can be null and the
* LLMServiceType in tokens has to by Athena
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
*
* @param message IrisMessage related to the TokenUsage
* @param exercise Exercise in which the request was made
* @param user User that made the request
* @param course Course in which the request was made
* @param tokens List with Tokens of the PyrisLLMCostDTO Mdel
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
* @return List of the created LLMTokenUsage entries
*/
alexjoham marked this conversation as resolved.
Show resolved Hide resolved

public List<LLMTokenUsage> saveTokenUsage(IrisMessage message, Exercise exercise, User user, Course course, List<PyrisLLMCostDTO> tokens) {
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
List<LLMTokenUsage> tokenUsages = new ArrayList<>();

// Combine current time and UUID to create a unique traceId
long timestamp = System.currentTimeMillis();
long uuidComponent = UUID.randomUUID().getLeastSignificantBits() & Long.MAX_VALUE;
Long traceId = timestamp + uuidComponent;
alexjoham marked this conversation as resolved.
Show resolved Hide resolved

alexjoham marked this conversation as resolved.
Show resolved Hide resolved
for (PyrisLLMCostDTO cost : tokens) {
LLMTokenUsage llmTokenUsage = new LLMTokenUsage();
if (message != null) {
llmTokenUsage.setIrisMessage(message);
llmTokenUsage.setTimestamp(message.getSentAt());
}
llmTokenUsage.setServiceType(cost.pipeline());
llmTokenUsage.setExercise(exercise);
if (user != null) {
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
llmTokenUsage.setUserId(user.getId());
}
llmTokenUsage.setCourse(course);
FelixTJDietrich marked this conversation as resolved.
Show resolved Hide resolved
llmTokenUsage.setNum_input_tokens(cost.num_input_tokens());
llmTokenUsage.setCost_per_input_token(cost.cost_per_input_token());
llmTokenUsage.setNum_output_tokens(cost.num_output_tokens());
llmTokenUsage.setCost_per_output_token(cost.cost_per_output_token());
llmTokenUsage.setModel(cost.model_info());
llmTokenUsage.setTraceId(traceId);
tokenUsages.add(llmTokenUsageRepository.save(llmTokenUsage));
}
return tokenUsages;
}
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import de.tum.cit.aet.artemis.iris.domain.message.IrisMessage;
import de.tum.cit.aet.artemis.iris.service.IrisRateLimitService;
import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO;
import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO;

/**
Expand All @@ -21,7 +22,7 @@
*/
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public record IrisChatWebsocketDTO(IrisWebsocketMessageType type, IrisMessage message, IrisRateLimitService.IrisRateLimitInformation rateLimitInfo, List<PyrisStageDTO> stages,
List<String> suggestions) {
List<String> suggestions, List<PyrisLLMCostDTO> tokens) {

/**
* Creates a new IrisWebsocketDTO instance with the given parameters
Expand All @@ -31,8 +32,9 @@ public record IrisChatWebsocketDTO(IrisWebsocketMessageType type, IrisMessage me
* @param rateLimitInfo the rate limit information
* @param stages the stages of the Pyris pipeline
*/
public IrisChatWebsocketDTO(@Nullable IrisMessage message, IrisRateLimitService.IrisRateLimitInformation rateLimitInfo, List<PyrisStageDTO> stages, List<String> suggestions) {
this(determineType(message), message, rateLimitInfo, stages, suggestions);
public IrisChatWebsocketDTO(@Nullable IrisMessage message, IrisRateLimitService.IrisRateLimitInformation rateLimitInfo, List<PyrisStageDTO> stages, List<String> suggestions,
List<PyrisLLMCostDTO> tokens) {
this(determineType(message), message, rateLimitInfo, stages, suggestions, tokens);
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import de.tum.cit.aet.artemis.atlas.domain.competency.CompetencyTaxonomy;
import de.tum.cit.aet.artemis.core.domain.Course;
import de.tum.cit.aet.artemis.core.domain.User;
import de.tum.cit.aet.artemis.core.service.LLMTokenUsageService;
import de.tum.cit.aet.artemis.iris.service.pyris.PyrisJobService;
import de.tum.cit.aet.artemis.iris.service.pyris.PyrisPipelineService;
import de.tum.cit.aet.artemis.iris.service.pyris.dto.competency.PyrisCompetencyExtractionPipelineExecutionDTO;
Expand All @@ -25,14 +26,18 @@ public class IrisCompetencyGenerationService {

private final PyrisPipelineService pyrisPipelineService;

private final LLMTokenUsageService llmTokenUsageService;

private final IrisWebsocketService websocketService;

private final PyrisJobService pyrisJobService;

public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, IrisWebsocketService websocketService, PyrisJobService pyrisJobService) {
public IrisCompetencyGenerationService(PyrisPipelineService pyrisPipelineService, LLMTokenUsageService llmTokenUsageService, IrisWebsocketService websocketService,
PyrisJobService pyrisJobService) {
this.pyrisPipelineService = pyrisPipelineService;
this.websocketService = websocketService;
this.pyrisJobService = pyrisJobService;
this.llmTokenUsageService = llmTokenUsageService;
}

/**
Expand All @@ -50,7 +55,7 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String
"default",
pyrisJobService.createTokenForJob(token -> new CompetencyExtractionJob(token, course.getId(), user.getLogin())),
executionDto -> new PyrisCompetencyExtractionPipelineExecutionDTO(executionDto, courseDescription, currentCompetencies, CompetencyTaxonomy.values(), 5),
stages -> websocketService.send(user.getLogin(), websocketTopic(course.getId()), new PyrisCompetencyStatusUpdateDTO(stages, null))
stages -> websocketService.send(user.getLogin(), websocketTopic(course.getId()), new PyrisCompetencyStatusUpdateDTO(stages, null, null))
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
);
// @formatter:on
}
Expand All @@ -63,6 +68,9 @@ public void executeCompetencyExtractionPipeline(User user, Course course, String
* @param statusUpdate the status update containing the new competency recommendations
*/
public void handleStatusUpdate(String userLogin, long courseId, PyrisCompetencyStatusUpdateDTO statusUpdate) {
if (statusUpdate.tokens() != null) {
var tokenUsages = llmTokenUsageService.saveTokenUsage(null, null, null, null, statusUpdate.tokens());
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
}
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
websocketService.send(userLogin, websocketTopic(courseId), statusUpdate);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import com.fasterxml.jackson.annotation.JsonInclude;

import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO;
import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO;

@JsonInclude(JsonInclude.Include.NON_EMPTY)
public record PyrisChatStatusUpdateDTO(String result, List<PyrisStageDTO> stages, List<String> suggestions) {
public record PyrisChatStatusUpdateDTO(String result, List<PyrisStageDTO> stages, List<String> suggestions, List<PyrisLLMCostDTO> tokens) {
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import com.fasterxml.jackson.annotation.JsonInclude;

import de.tum.cit.aet.artemis.iris.service.pyris.dto.data.PyrisLLMCostDTO;
import de.tum.cit.aet.artemis.iris.service.pyris.dto.status.PyrisStageDTO;

/**
Expand All @@ -15,5 +16,5 @@
* @param result List of competencies recommendations that have been generated so far
*/
@JsonInclude(JsonInclude.Include.NON_EMPTY)
public record PyrisCompetencyStatusUpdateDTO(List<PyrisStageDTO> stages, List<PyrisCompetencyRecommendationDTO> result) {
public record PyrisCompetencyStatusUpdateDTO(List<PyrisStageDTO> stages, List<PyrisCompetencyRecommendationDTO> result, List<PyrisLLMCostDTO> tokens) {
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package de.tum.cit.aet.artemis.iris.service.pyris.dto.data;

import de.tum.cit.aet.artemis.core.domain.LLMServiceType;

public record PyrisLLMCostDTO(String model_info, int num_input_tokens, float cost_per_input_token, int num_output_tokens, float cost_per_output_token, LLMServiceType pipeline) {
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
}
alexjoham marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading