diff --git a/pkg/providers/octoml/chat.go b/pkg/providers/octoml/chat.go index 5dedf596..f4978333 100644 --- a/pkg/providers/octoml/chat.go +++ b/pkg/providers/octoml/chat.go @@ -69,7 +69,7 @@ func (c *Client) Chat(ctx context.Context, request *schemas.UnifiedChatRequest) return nil, err } - if len(chatResponse.ProviderResponse.Message.Content) == 0 { + if len(chatResponse.ModelResponse.Message.Content) == 0 { return nil, ErrEmptyResponse } @@ -139,47 +139,37 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Parse the response JSON - var responseJSON map[string]interface{} + var openAICompletion schemas.OpenAIChatCompletion // Octo uses the same response schema as OpenAI - err = json.Unmarshal(bodyBytes, &responseJSON) + err = json.Unmarshal(bodyBytes, &openAICompletion) if err != nil { - c.telemetry.Logger.Error("failed to parse octoml chat response", zap.Error(err)) + c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err)) return nil, err } - // Parse response - var response schemas.UnifiedChatResponse - - var responsePayload schemas.ProviderResponse - - var tokenCount schemas.TokenCount - - message := responseJSON["choices"].([]interface{})[0].(map[string]interface{})["message"].(map[string]interface{}) - messageStruct := schemas.ChatMessage{ - Role: message["role"].(string), - Content: message["content"].(string), - } - - tokenCount = schemas.TokenCount{ - PromptTokens: responseJSON["usage"].(map[string]interface{})["prompt_tokens"].(float64), - ResponseTokens: responseJSON["usage"].(map[string]interface{})["completion_tokens"].(float64), - TotalTokens: responseJSON["usage"].(map[string]interface{})["total_tokens"].(float64), - } - - responsePayload = schemas.ProviderResponse{ - ResponseID: map[string]string{"system_fingerprint": "none"}, - Message: messageStruct, - TokenCount: tokenCount, - } - - response = schemas.UnifiedChatResponse{ - ID: responseJSON["id"].(string), - Created: responseJSON["created"].(float64), - Provider: "octoml", - Router: "chat", // TODO: Update this with actual router - Model: responseJSON["model"].(string), - Cached: false, - ProviderResponse: responsePayload, + // Map response to UnifiedChatResponse schema + response := schemas.UnifiedChatResponse{ + ID: openAICompletion.ID, + Created: openAICompletion.Created, + Provider: providerName, + Router: "chat", // TODO: this will be the router used + Model: openAICompletion.Model, + Cached: false, + ModelResponse: schemas.ProviderResponse{ + ResponseID: map[string]string{ + "system_fingerprint": openAICompletion.SystemFingerprint, + }, + Message: schemas.ChatMessage{ + Role: openAICompletion.Choices[0].Message.Role, + Content: openAICompletion.Choices[0].Message.Content, + Name: "", + }, + TokenCount: schemas.TokenCount{ + PromptTokens: openAICompletion.Usage.PromptTokens, + ResponseTokens: openAICompletion.Usage.CompletionTokens, + TotalTokens: openAICompletion.Usage.TotalTokens, + }, + }, } return &response, nil