Skip to content

Commit

Permalink
OPIK-720 Manually set span cost
Browse files Browse the repository at this point in the history
  • Loading branch information
Borys Tkachenko authored and Borys Tkachenko committed Jan 9, 2025
1 parent 9ae31b9 commit 20a738c
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 49 deletions.
6 changes: 4 additions & 2 deletions apps/opik-backend/src/main/java/com/comet/opik/api/Span.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.DecimalMin;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Pattern;
Expand Down Expand Up @@ -51,8 +52,9 @@ public record Span(
@JsonView({Span.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) String lastUpdatedBy,
@JsonView({
Span.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) List<FeedbackScore> feedbackScores,
@JsonView({
Span.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) BigDecimal totalEstimatedCost,
@JsonView({Span.View.Public.class,
Span.View.Write.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY) @DecimalMin("0.0") BigDecimal totalEstimatedCost,
String totalEstimatedCostVersion,
@JsonView({
Span.View.Public.class}) @Schema(accessMode = Schema.AccessMode.READ_ONLY, description = "Duration in milliseconds as a decimal number to support sub-millisecond precision") Double duration){

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import com.fasterxml.jackson.databind.PropertyNamingStrategies;
import com.fasterxml.jackson.databind.annotation.JsonNaming;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.DecimalMin;
import jakarta.validation.constraints.NotNull;
import jakarta.validation.constraints.Pattern;
import lombok.Builder;

import java.math.BigDecimal;
import java.time.Instant;
import java.util.Map;
import java.util.Set;
Expand All @@ -32,5 +34,6 @@ public record SpanUpdate(
String provider,
Set<String> tags,
Map<String, Integer> usage,
@DecimalMin("0.0") BigDecimal totalEstimatedCost,
ErrorInfo errorInfo) {
}
67 changes: 46 additions & 21 deletions apps/opik-backend/src/main/java/com/comet/opik/domain/SpanDAO.java
Original file line number Diff line number Diff line change
Expand Up @@ -747,8 +747,6 @@ private Publisher<? extends Result> insert(List<Span> spans, Connection connecti
int i = 0;
for (Span span : spans) {

BigDecimal estimatedCost = calculateCost(span);

statement.bind("id" + i, span.id())
.bind("project_id" + i, span.projectId())
.bind("trace_id" + i, span.traceId())
Expand All @@ -761,15 +759,23 @@ private Publisher<? extends Result> insert(List<Span> spans, Connection connecti
.bind("metadata" + i, span.metadata() != null ? span.metadata().toString() : "")
.bind("model" + i, span.model() != null ? span.model() : "")
.bind("provider" + i, span.provider() != null ? span.provider() : "")
.bind("total_estimated_cost" + i, estimatedCost.toString())
.bind("total_estimated_cost_version" + i,
estimatedCost.compareTo(BigDecimal.ZERO) > 0 ? ESTIMATED_COST_VERSION : "")
.bind("tags" + i, span.tags() != null ? span.tags().toArray(String[]::new) : new String[]{})
.bind("error_info" + i,
span.errorInfo() != null ? JsonUtils.readTree(span.errorInfo()).toString() : "")
.bind("created_by" + i, userName)
.bind("last_updated_by" + i, userName);

if (span.totalEstimatedCost() != null) {
// Cost is set manually by the user
statement.bind("total_estimated_cost" + i, span.totalEstimatedCost().toString());
statement.bind("total_estimated_cost_version" + i, "");
} else {
BigDecimal estimatedCost = calculateCost(span);
statement.bind("total_estimated_cost" + i, estimatedCost.toString());
statement.bind("total_estimated_cost_version" + i,
estimatedCost.compareTo(BigDecimal.ZERO) > 0 ? ESTIMATED_COST_VERSION : "");
}

if (span.endTime() != null) {
statement.bind("end_time" + i, span.endTime().toString());
} else {
Expand Down Expand Up @@ -849,12 +855,15 @@ private Publisher<? extends Result> insert(Span span, Connection connection) {
statement.bind("provider", "");
}

BigDecimal estimatedCost = calculateCost(span);
statement.bind("total_estimated_cost", estimatedCost.toString());
if (estimatedCost.compareTo(BigDecimal.ZERO) > 0) {
statement.bind("total_estimated_cost_version", ESTIMATED_COST_VERSION);
} else {
if (span.totalEstimatedCost() != null) {
// Cost is set manually by the user
statement.bind("total_estimated_cost", span.totalEstimatedCost().toString());
statement.bind("total_estimated_cost_version", "");
} else {
BigDecimal estimatedCost = calculateCost(span);
statement.bind("total_estimated_cost", estimatedCost.toString());
statement.bind("total_estimated_cost_version",
estimatedCost.compareTo(BigDecimal.ZERO) > 0 ? ESTIMATED_COST_VERSION : "");
}

if (span.tags() != null) {
Expand Down Expand Up @@ -912,7 +921,7 @@ public Mono<Long> update(@NonNull UUID id, @NonNull SpanUpdate spanUpdate, Span
public Mono<Long> partialInsert(@NonNull UUID id, @NonNull UUID projectId, @NonNull SpanUpdate spanUpdate) {
return Mono.from(connectionFactory.create())
.flatMapMany(connection -> {
ST template = newUpdateTemplate(spanUpdate, PARTIAL_INSERT);
ST template = newUpdateTemplate(spanUpdate, PARTIAL_INSERT, false);

var statement = connection.createStatement(template.render());

Expand All @@ -926,7 +935,7 @@ public Mono<Long> partialInsert(@NonNull UUID id, @NonNull UUID projectId, @NonN
statement.bind("parent_span_id", "");
}

bindUpdateParams(spanUpdate, statement);
bindUpdateParams(spanUpdate, statement, false);

Segment segment = startSegment("spans", "Clickhouse", "partial_insert");

Expand All @@ -946,19 +955,19 @@ private Publisher<? extends Result> update(UUID id, SpanUpdate spanUpdate, Conne
.build();
}

var template = newUpdateTemplate(spanUpdate, UPDATE);
var template = newUpdateTemplate(spanUpdate, UPDATE, isManualCost(existingSpan));
var statement = connection.createStatement(template.render());
statement.bind("id", id);

bindUpdateParams(spanUpdate, statement);
bindUpdateParams(spanUpdate, statement, isManualCost(existingSpan));

Segment segment = startSegment("spans", "Clickhouse", "update");

return makeFluxContextAware(bindUserNameAndWorkspaceContextToStream(statement))
.doFinally(signalType -> endSegment(segment));
}

private void bindUpdateParams(SpanUpdate spanUpdate, Statement statement) {
private void bindUpdateParams(SpanUpdate spanUpdate, Statement statement, boolean isManualCostExist) {
Optional.ofNullable(spanUpdate.input())
.ifPresent(input -> statement.bind("input", input.toString()));
Optional.ofNullable(spanUpdate.output())
Expand Down Expand Up @@ -988,14 +997,22 @@ private void bindUpdateParams(SpanUpdate spanUpdate, Statement statement) {
Optional.ofNullable(spanUpdate.errorInfo())
.ifPresent(errorInfo -> statement.bind("error_info", JsonUtils.readTree(errorInfo).toString()));

if (StringUtils.isNotBlank(spanUpdate.model()) && Objects.nonNull(spanUpdate.usage())) {
statement.bind("total_estimated_cost",
ModelPrice.fromString(spanUpdate.model()).calculateCost(spanUpdate.usage()).toString());
statement.bind("total_estimated_cost_version", ESTIMATED_COST_VERSION);
if (Objects.nonNull(spanUpdate.totalEstimatedCost())) {
// Update with new manually set cost
statement.bind("total_estimated_cost", spanUpdate.totalEstimatedCost().toString());
statement.bind("total_estimated_cost_version", "");
} else {
// Calculate estimated cost only in case Span doesn't have manually set cost
if (!isManualCostExist && StringUtils.isNotBlank(spanUpdate.model())
&& Objects.nonNull(spanUpdate.usage())) {
statement.bind("total_estimated_cost",
ModelPrice.fromString(spanUpdate.model()).calculateCost(spanUpdate.usage()).toString());
statement.bind("total_estimated_cost_version", ESTIMATED_COST_VERSION);
}
}
}

private ST newUpdateTemplate(SpanUpdate spanUpdate, String sql) {
private ST newUpdateTemplate(SpanUpdate spanUpdate, String sql, boolean isManualCostExist) {
var template = new ST(sql);
Optional.ofNullable(spanUpdate.input())
.ifPresent(input -> template.add("input", input.toString()));
Expand All @@ -1015,7 +1032,9 @@ private ST newUpdateTemplate(SpanUpdate spanUpdate, String sql) {
.ifPresent(usage -> template.add("usage", usage.toString()));
Optional.ofNullable(spanUpdate.errorInfo())
.ifPresent(errorInfo -> template.add("error_info", JsonUtils.readTree(errorInfo).toString()));
if (StringUtils.isNotBlank(spanUpdate.model()) && Objects.nonNull(spanUpdate.usage())) {
// If we have manual cost in update OR if we can calculate it and user didn't set manual cost before
if ((!isManualCostExist && StringUtils.isNotBlank(spanUpdate.model()) && Objects.nonNull(spanUpdate.usage()))
|| Objects.nonNull(spanUpdate.totalEstimatedCost())) {
template.add("total_estimated_cost", "total_estimated_cost");
template.add("total_estimated_cost_version", "total_estimated_cost_version");
}
Expand Down Expand Up @@ -1097,6 +1116,9 @@ private Publisher<Span> mapToDto(Result result) {
row.get("total_estimated_cost", BigDecimal.class).compareTo(BigDecimal.ZERO) == 0
? null
: row.get("total_estimated_cost", BigDecimal.class))
.totalEstimatedCostVersion(row.getMetadata().contains("total_estimated_cost_version")
? row.get("total_estimated_cost_version", String.class)
: null)
.tags(Optional.of(Arrays.stream(row.get("tags", String[].class)).collect(Collectors.toSet()))
.filter(set -> !set.isEmpty())
.orElse(null))
Expand Down Expand Up @@ -1281,4 +1303,7 @@ public Mono<ProjectStats> getStats(@NonNull SpanSearchCriteria searchCriteria) {
.singleOrEmpty();
}

private boolean isManualCost(Span span) {
return span.totalEstimatedCost() != null && StringUtils.isBlank(span.totalEstimatedCostVersion());
}
}
Loading

0 comments on commit 20a738c

Please sign in to comment.