Skip to content

Commit

Permalink
one package
Browse files Browse the repository at this point in the history
  • Loading branch information
swuecho committed May 21, 2024
1 parent 6f786c8 commit c31b706
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 47 deletions.
32 changes: 3 additions & 29 deletions api/chat_main_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"github.com/rotisserie/eris"
"github.com/samber/lo"
openai "github.com/sashabaranov/go-openai"
gemini "github.com/swuecho/chat_backend/llm/gemini"
"github.com/swuecho/chat_backend/sqlc_queries"

"github.com/gorilla/mux"
Expand Down Expand Up @@ -1217,35 +1216,10 @@ func constructChatCompletionStreamReponse(answer_id string, answer string) opena
// "parts":[{
// "text": "Write a story about a magic backpack."}]}]}' 2> /dev/null

func genGemminPayload(chat_compeletion_messages []Message) ([]byte, error) {
payload := gemini.Payload{
Contents: make([]gemini.GeminiMessage, len(chat_compeletion_messages)),
}
for i, message := range chat_compeletion_messages {
geminiMessage := gemini.GeminiMessage{
Role: message.Role,
Parts: []gemini.Part{
{Text: message.Content},
},
}
if message.Role == "assistant" {
geminiMessage.Role = "model"
} else if message.Role == "system" {
geminiMessage.Role = "user"
}
payload.Contents[i] = geminiMessage
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
fmt.Println("Error marshalling payload:", err)
// handle err
return nil, err
}
return payloadBytes, nil
}


func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_queries.ChatSession, chat_compeletion_messages []Message, chatUuid string, regenerate bool) (string, string, bool) {
payloadBytes, err := genGemminPayload(chat_compeletion_messages)
payloadBytes, err := GenGemminPayload(chat_compeletion_messages)
if err != nil {
RespondWithError(w, http.StatusInternalServerError, eris.Wrap(err, "Error generating gemmi payload").Error(), err)
return "", "", true
Expand Down Expand Up @@ -1314,7 +1288,7 @@ func (h *ChatHandler) chatStreamGemini(w http.ResponseWriter, chatSession sqlc_q
}
line = bytes.TrimPrefix(line, headerData)
if len(line) > 0 {
answer = gemini.ParseRespLine(line, answer)
answer = ParseRespLine(line, answer)
data, _ := json.Marshal(constructChatCompletionStreamReponse(answer_id, answer))
fmt.Fprintf(w, "data: %v\n\n", string(data))
flusher.Flush()
Expand Down
11 changes: 3 additions & 8 deletions api/llm/gemini/gemini.go → api/llm_gemini.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package gemmi
package main

import (
"encoding/json"
Expand All @@ -15,7 +15,7 @@ type GeminiMessage struct {
Parts []Part `json:"parts"`
}

type Payload struct {
type GeminPayload struct {
Contents []GeminiMessage `json:"contents"`
}

Expand Down Expand Up @@ -47,11 +47,6 @@ type ResponseBody struct {
PromptFeedback PromptFeedback `json:"promptFeedback"`
}

type Message struct {
Role string `json:"role"`
Content string `json:"content"`
tokenCount int32
}
func ParseRespLine(line []byte, answer string) string {
var resp ResponseBody
if err := json.Unmarshal(line, &resp); err != nil {
Expand All @@ -68,7 +63,7 @@ func ParseRespLine(line []byte, answer string) string {
}

func GenGemminPayload(chat_compeletion_messages []Message) ([]byte, error) {
payload := Payload{
payload := GeminPayload{
Contents: make([]GeminiMessage, len(chat_compeletion_messages)),
}
for i, message := range chat_compeletion_messages {
Expand Down
10 changes: 0 additions & 10 deletions api/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,6 @@ type TokenResult struct {
ExpiresIn int `json:"expiresIn"`
}

type MultiMessage struct {
Role string `json:"role"`
Parts []Part `json:"parts"`
}

type Part struct {
Content string `json:"content"`
Type string `json:"type"`
}

type ConversationRequest struct {
UUID string `json:"uuid,omitempty"`
ConversationID string `json:"conversationId,omitempty"`
Expand Down

0 comments on commit c31b706

Please sign in to comment.