Skip to content

Commit

Permalink
Fix router bugs on max_new_tokens and dataprep gaudi yaml file (#273)
Browse files Browse the repository at this point in the history
* fix router bugs on max_new_tokens and dataprep gaudi yaml file.
Signed-off-by: zhlsunshine <[email protected]>

* change based on comments.
Signed-off-by: zhlsunshine <[email protected]>

* change the yaml file for data-prep.
Signed-off-by: zhlsunshine <[email protected]>
  • Loading branch information
zhlsunshine authored Aug 7, 2024
1 parent 4319660 commit 5735dd3
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 80 deletions.
52 changes: 50 additions & 2 deletions microservices-connector/cmd/router/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ const (
ServiceURL = "serviceUrl"
ServiceNode = "node"
DataPrep = "DataPrep"
Parameters = "parameters"
)

type EnsembleStepOutput struct {
Expand Down Expand Up @@ -198,6 +199,32 @@ func executeStep(
return callService(step, serviceURL, input, headers)
}

func mergeRequests(respReq []byte, initReqData map[string]interface{}) []byte {
var respReqData map[string]interface{}

if _, exists := initReqData[Parameters]; exists {
if err := json.Unmarshal(respReq, &respReqData); err != nil {
log.Error(err, "Error unmarshaling respReqData:")
return nil
}
// Merge init request into respReq
for key, value := range initReqData[Parameters].(map[string]interface{}) {
/*if _, exists := respReqData[key]; !exists {
respReqData[key] = value
}*/
// overwrite the respReq by initial request
respReqData[key] = value
}
mergedBytes, err := json.Marshal(respReqData)
if err != nil {
log.Error(err, "Error marshaling merged data:")
return nil
}
return mergedBytes
}
return respReq
}

func handleSwitchNode(
route *mcv1alpha3.Step,
graph mcv1alpha3.GMConnector,
Expand Down Expand Up @@ -239,6 +266,13 @@ func handleSwitchPipeline(nodeName string,
var statusCode int
var responseBytes []byte
var err error

initReqData := make(map[string]interface{})
if err = json.Unmarshal(initInput, &initReqData); err != nil {
log.Error(err, "Error unmarshaling initReqData:")
return nil, 500, err
}

for index, route := range currentNode.Steps {
if route.InternalService.IsDownstreamService {
log.Info(
Expand All @@ -252,9 +286,11 @@ func handleSwitchPipeline(nodeName string,
}
log.Info("Current Step Information", "Node Name", nodeName, "Step Index", index)
request := input
log.Info("Print Original Request Bytes", "Request Bytes", request)
if route.Data == "$response" && index > 0 {
request = responseBytes
request = mergeRequests(responseBytes, initReqData)
}
log.Info("Print New Request Bytes", "Request Bytes", request)
if route.Condition == "" {
responseBytes, statusCode, err = handleSwitchNode(&route, graph, initInput, request, headers)
if err != nil {
Expand Down Expand Up @@ -348,6 +384,12 @@ func handleSequencePipeline(nodeName string,
var statusCode int
var responseBytes []byte
var err error

initReqData := make(map[string]interface{})
if err = json.Unmarshal(initInput, &initReqData); err != nil {
log.Error(err, "Error unmarshaling initReqData:")
return nil, 500, err
}
for i := range currentNode.Steps {
step := &currentNode.Steps[i]
stepType := ServiceURL
Expand All @@ -366,9 +408,11 @@ func handleSequencePipeline(nodeName string,
}
log.Info("Starting execution of step", "type", stepType, "stepName", step.StepName)
request := input
log.Info("Print Original Request Bytes", "Request Bytes", request)
if step.Data == "$response" && i > 0 {
request = responseBytes
request = mergeRequests(responseBytes, initReqData)
}
log.Info("Print New Request Bytes", "Request Bytes", request)
if step.Condition != "" {
if !gjson.ValidBytes(responseBytes) {
return nil, 500, fmt.Errorf("invalid response")
Expand Down Expand Up @@ -467,6 +511,10 @@ func mcGraphHandler(w http.ResponseWriter, req *http.Request) {
log.Error(err, "failed to write mcGraphHandler response")
return
}

if err := writer.Flush(); err != nil {
log.Error(err, "error flushing writer when processing response")
}
}
}()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,5 @@ spec:
config:
endpoint: /v1/dataprep
REDIS_URL: redis-vector-db
INDEX_NAME: data-prep
TEI_ENDPOINT: tei-embedding-svc
TEI_ENDPOINT: tei-embedding-gaudi-svc
isDownstreamService: true
151 changes: 75 additions & 76 deletions microservices-connector/config/samples/chatQnA_dataprep_xeon.yaml
Original file line number Diff line number Diff line change
@@ -1,77 +1,76 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

apiVersion: gmc.opea.io/v1alpha3
kind: GMConnector
metadata:
labels:
app.kubernetes.io/name: gmconnector
app.kubernetes.io/managed-by: kustomize
gmc/platform: xeon
name: chatqa
namespace: chatqa
spec:
routerConfig:
name: router
serviceName: router-service
nodes:
root:
routerType: Sequence
steps:
- name: Embedding
internalService:
serviceName: embedding-svc
config:
endpoint: /v1/embeddings
TEI_EMBEDDING_ENDPOINT: tei-embedding-svc
- name: TeiEmbedding
internalService:
serviceName: tei-embedding-svc
isDownstreamService: true
- name: Retriever
data: $response
internalService:
serviceName: retriever-svc
config:
endpoint: /v1/retrieval
REDIS_URL: redis-vector-db
TEI_EMBEDDING_ENDPOINT: tei-embedding-svc
- name: VectorDB
internalService:
serviceName: redis-vector-db
isDownstreamService: true
- name: Reranking
data: $response
internalService:
serviceName: reranking-svc
config:
endpoint: /v1/reranking
TEI_RERANKING_ENDPOINT: tei-reranking-svc
- name: TeiReranking
internalService:
serviceName: tei-reranking-svc
config:
endpoint: /rerank
isDownstreamService: true
- name: Llm
data: $response
internalService:
serviceName: llm-svc
config:
endpoint: /v1/chat/completions
TGI_LLM_ENDPOINT: tgi-service-m
- name: Tgi
internalService:
serviceName: tgi-service-m
config:
endpoint: /generate
isDownstreamService: true
- name: DataPrep
internalService:
serviceName: data-prep-svc
config:
endpoint: /v1/dataprep
REDIS_URL: redis-vector-db
INDEX_NAME: data-prep
TEI_ENDPOINT: tei-embedding-svc
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

apiVersion: gmc.opea.io/v1alpha3
kind: GMConnector
metadata:
labels:
app.kubernetes.io/name: gmconnector
app.kubernetes.io/managed-by: kustomize
gmc/platform: xeon
name: chatqa
namespace: chatqa
spec:
routerConfig:
name: router
serviceName: router-service
nodes:
root:
routerType: Sequence
steps:
- name: Embedding
internalService:
serviceName: embedding-svc
config:
endpoint: /v1/embeddings
TEI_EMBEDDING_ENDPOINT: tei-embedding-svc
- name: TeiEmbedding
internalService:
serviceName: tei-embedding-svc
isDownstreamService: true
- name: Retriever
data: $response
internalService:
serviceName: retriever-svc
config:
endpoint: /v1/retrieval
REDIS_URL: redis-vector-db
TEI_EMBEDDING_ENDPOINT: tei-embedding-svc
- name: VectorDB
internalService:
serviceName: redis-vector-db
isDownstreamService: true
- name: Reranking
data: $response
internalService:
serviceName: reranking-svc
config:
endpoint: /v1/reranking
TEI_RERANKING_ENDPOINT: tei-reranking-svc
- name: TeiReranking
internalService:
serviceName: tei-reranking-svc
config:
endpoint: /rerank
isDownstreamService: true
- name: Llm
data: $response
internalService:
serviceName: llm-svc
config:
endpoint: /v1/chat/completions
TGI_LLM_ENDPOINT: tgi-service-m
- name: Tgi
internalService:
serviceName: tgi-service-m
config:
endpoint: /generate
isDownstreamService: true
- name: DataPrep
internalService:
serviceName: data-prep-svc
config:
endpoint: /v1/dataprep
REDIS_URL: redis-vector-db
TEI_ENDPOINT: tei-embedding-svc
isDownstreamService: true

0 comments on commit 5735dd3

Please sign in to comment.