From 2458e2f1ec7f7e383429a54047814347e18c363d Mon Sep 17 00:00:00 2001 From: lvliang-intel Date: Mon, 15 Jul 2024 17:09:48 +0800 Subject: [PATCH] Allow the Ollama microservice to be configurable with different models (#280) Signed-off-by: lvliang-intel Co-authored-by: Sihan Chen <39623753+Spycsh@users.noreply.github.com> --- comps/cores/proto/docarray.py | 1 + comps/llms/text-generation/ollama/README.md | 2 +- comps/llms/text-generation/ollama/llm.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index 040d4035a..819cb11c8 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -77,6 +77,7 @@ class RerankedDoc(BaseDoc): class LLMParamsDoc(BaseDoc): + model: Optional[str] = None # for openai and ollama query: str max_new_tokens: int = 1024 top_k: int = 10 diff --git a/comps/llms/text-generation/ollama/README.md b/comps/llms/text-generation/ollama/README.md index a5bd486d6..1ad636098 100644 --- a/comps/llms/text-generation/ollama/README.md +++ b/comps/llms/text-generation/ollama/README.md @@ -62,5 +62,5 @@ docker run --network host opea/llm-ollama:latest # Consume the Ollama Microservice ```bash -curl http://127.0.0.1:9000/v1/chat/completions -X POST -d '{"query":"What is Deep Learning?","max_new_tokens":32,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' -H 'Content-Type: application/json' +curl http://127.0.0.1:9000/v1/chat/completions -X POST -d '{"model": "llama3", "query":"What is Deep Learning?","max_new_tokens":32,"top_k":10,"top_p":0.95,"typical_p":0.95,"temperature":0.01,"repetition_penalty":1.03,"streaming":true}' -H 'Content-Type: application/json' ``` diff --git a/comps/llms/text-generation/ollama/llm.py b/comps/llms/text-generation/ollama/llm.py index 2cec1beac..5374cfa69 100644 --- a/comps/llms/text-generation/ollama/llm.py +++ b/comps/llms/text-generation/ollama/llm.py @@ -21,7 +21,7 @@ def llm_generate(input: LLMParamsDoc): ollama = Ollama( base_url=ollama_endpoint, - model="llama3", + model=input.model, num_predict=input.max_new_tokens, top_k=input.top_k, top_p=input.top_p,