Skip to content

Commit

Permalink
fix: ai service race (#768)
Browse files Browse the repository at this point in the history
  • Loading branch information
fanweixiao authored Mar 25, 2024
1 parent 6ed7c5d commit 39a969b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 30 deletions.
9 changes: 0 additions & 9 deletions pkg/bridge/ai/api_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"fmt"
"net"
"net/http"
"sync"
"time"

gonanoid "github.com/matoous/go-nanoid/v2"
Expand Down Expand Up @@ -127,14 +126,6 @@ func HandleInvoke(w http.ResponseWriter, r *http.Request) {
return
}

ci := &CacheItem{
wg: &sync.WaitGroup{},
ResponseWriter: w,
}
if _, ok := service.cache[reqID]; !ok {
service.cache[reqID] = ci
}

var req ai.InvokeRequest
req.ReqID = reqID

Expand Down
38 changes: 17 additions & 21 deletions pkg/bridge/ai/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package ai

import (
"fmt"
"net/http"
"sync"
"time"

Expand All @@ -24,27 +23,19 @@ var (
services *expirable.LRU[string, *Service]
)

// CacheItem cache the http.ResponseWriter, which is used for writing response from reducer.
// TODO: http.ResponseWriter is from the SimpleRestfulServer interface, should be decoupled
// from here.
type CacheItem struct {
ResponseWriter http.ResponseWriter
wg *sync.WaitGroup
mu sync.Mutex
}

// Service is used to invoke LLM Provider to get the functions to be executed,
// then, use source to send arguments which returned by llm provider to target
// function. Finally, use reducer to aggregate all the results, and write the
// result by the http.ResponseWriter.
type Service struct {
credential string
zipperAddr string
md metadata.M
source yomo.Source
reducer yomo.StreamFunction
cache map[string]*CacheItem
credential string
zipperAddr string
md metadata.M
source yomo.Source
reducer yomo.StreamFunction
// cache map[string]*CacheItem
sfnCallCache map[string]*sfnAsyncCall
muCallCache sync.Mutex
LLMProvider
}

Expand Down Expand Up @@ -72,9 +63,9 @@ func DefaultExchangeMetadataFunc(credential string) (metadata.M, error) {

func newService(credential string, zipperAddr string, aiProvider LLMProvider, exFn ExchangeMetadataFunc) (*Service, error) {
s := &Service{
credential: credential,
zipperAddr: zipperAddr,
cache: make(map[string]*CacheItem),
credential: credential,
zipperAddr: zipperAddr,
// cache: make(map[string]*CacheItem),
LLMProvider: aiProvider,
sfnCallCache: make(map[string]*sfnAsyncCall),
}
Expand Down Expand Up @@ -116,7 +107,6 @@ func (s *Service) Release() {
if s.reducer != nil {
s.reducer.Close()
}
clear(s.cache)
}

func (s *Service) createSource() (yomo.Source, error) {
Expand Down Expand Up @@ -156,7 +146,9 @@ func (s *Service) createReducer() (yomo.StreamFunction, error) {
reqID := invoke.ReqID

// write parallel function calling results to cache, after all the results are written, the reducer will be done
s.muCallCache.Lock()
c, ok := s.sfnCallCache[reqID]
s.muCallCache.Unlock()
if !ok {
ylog.Error("[sfn-reducer] req_id not found", "req_id", reqID)
return
Expand Down Expand Up @@ -242,7 +234,9 @@ func (s *Service) runFunctionCalls(fns map[uint32][]*ai.ToolCall, reqID string)
wg: &sync.WaitGroup{},
val: make(map[string]ai.ToolMessage),
}
s.muCallCache.Lock()
s.sfnCallCache[reqID] = asyncCall
s.muCallCache.Unlock()

for tag, tcs := range fns {
ylog.Debug("+++invoke toolCalls", "tag", tag, "len(toolCalls)", len(tcs), "reqID", reqID)
Expand All @@ -262,11 +256,13 @@ func (s *Service) runFunctionCalls(fns map[uint32][]*ai.ToolCall, reqID string)

arr := make([]ai.ToolMessage, 0)

asyncCall.mu.RLock()
for _, call := range asyncCall.val {
ylog.Debug("---invoke done", "id", call.ToolCallId, "content", call.Content)
call.Role = "tool"
arr = append(arr, call)
}
asyncCall.mu.RUnlock()

return arr, nil
}
Expand Down Expand Up @@ -307,6 +303,6 @@ func init() {

type sfnAsyncCall struct {
wg *sync.WaitGroup
mu sync.Mutex
mu sync.RWMutex
val map[string]ai.ToolMessage
}

0 comments on commit 39a969b

Please sign in to comment.