Skip to content

Commit

Permalink
Merge branch 'staging' of https://github.com/weni-ai/nexus-ai into st…
Browse files Browse the repository at this point in the history
…aging
  • Loading branch information
AlisoSouza committed Nov 21, 2024
2 parents 0bd9899 + 2c778cd commit e6a70a2
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
26 changes: 24 additions & 2 deletions router/classifiers/chatgpt_function.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from openai import OpenAI

from typing import List
from typing import List, Dict

from django.conf import settings

Expand Down Expand Up @@ -35,13 +35,34 @@ class ChatGPTFunctionClassifier(Classifier):

def __init__(
self,
agent_goal: str,
client: OpenAIClientInterface = OpenAIClient(settings.OPENAI_API_KEY),
chatgpt_model: str = settings.FUNCTION_CALLING_CHATGPT_MODEL,
):
self.chatgpt_model = chatgpt_model
self.client = client
self.prompt = settings.CHATGPT_CONTEXT_PROMPT
self.flow_name_mapping = {}
self.agent_goal = agent_goal

def replace_vars(self, prompt: str, replace_variables: Dict) -> str:
for key in replace_variables.keys():
replace_str = "{{" + key + "}}"
value = replace_variables.get(key)
if not isinstance(value, str):
value = str(value)
prompt = prompt.replace(replace_str, value)
return prompt

def get_prompt(self):
variable = {
"agent_goal": "".join(self.agent_goal),
}

return self.replace_vars(
prompt=self.prompt,
replace_variables=variable
)

def tools(
self,
Expand Down Expand Up @@ -71,10 +92,11 @@ def predict(
) -> str:

print(f"[+ ChatGPT message function classification: {message} ({language}) +]")
formated_prompt = self.get_prompt()
msg = [
{
"role": "system",
"content": self.prompt
"content": formated_prompt
},
{
"role": "user",
Expand Down
3 changes: 1 addition & 2 deletions router/classifiers/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,10 @@ def setUp(self) -> None:
choice = MockChoice(message)
self.mock_response = MockResponse([choice])
self.mock_client = MockOpenAIClient(response=self.mock_response)
self.chatgpt_model = "gpt-3.5-turbo"

self.classifier = ChatGPTFunctionClassifier(
client=self.mock_client,
chatgpt_model=self.chatgpt_model,
agent_goal="Answer user questions"
)

def test_predict(self):
Expand Down
2 changes: 1 addition & 1 deletion router/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def start_route(

# classifier = ZeroshotClassifier(chatbot_goal=agent.goal)

classifier = ChatGPTFunctionClassifier()
classifier = ChatGPTFunctionClassifier(agent_goal=agent.goal)
classification = classification_handler.custom_actions(
classifier=classifier,
language=llm_config.language
Expand Down

0 comments on commit e6a70a2

Please sign in to comment.