Skip to content

Commit

Permalink
Merge branch 'feature/remove-flows-zeroshot' into staging
Browse files Browse the repository at this point in the history
  • Loading branch information
zMardone committed Dec 26, 2024
2 parents c0f3ee5 + 09bbe31 commit 9fd7eec
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 8 deletions.
3 changes: 3 additions & 0 deletions nexus/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,9 @@
FUNCTION_CALLING_CHATGPT_MODEL = env.str("FUNCTION_CALLING_CHATGPT_MODEL", "gpt-4o-mini")
FUNCTION_CALLING_CHATGPT_PROMPT = env.str("FUNCTION_CALLING_CHATGPT_PROMPT", "")

# Classification data
DEFAULT_CLASSIFICATION_MODEL = env.str("DEFAULT_CLASSIFICATION_MODEL", ZEROSHOT_MODEL_BACKEND)

# Reflection data
GROUNDEDNESS_MODEL = env.str("GROUNDEDNESS_MODEL", "gpt-4o-mini")
GROUNDEDNESS_SYSTEM_PROMPT = env.str("GROUNDEDNESS_SYSTEM_PROMPT", "")
Expand Down
6 changes: 5 additions & 1 deletion nexus/zeroshot/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from nexus.zeroshot.api.permissions import ZeroshotTokenPermission
from nexus.zeroshot.models import ZeroshotLogs

from django.conf import settings


logger = logging.getLogger(__name__)

Expand All @@ -21,6 +23,7 @@ class ZeroShotFastPredictAPIView(APIView):
def post(self, request):
data = request.data
try:

invoke_model = InvokeModel(data)
response = invoke_model.invoke()

Expand All @@ -30,7 +33,8 @@ def post(self, request):
other=response["output"].get("other", False),
options=data.get("options"),
nlp_log=str(json.dumps(response)),
language=data.get("language")
language=data.get("language"),
model=settings.DEFAULT_CLASSIFICATION_MODEL
)

return Response(status=200, data=response if response.get("output") else {"error": response})
Expand Down
32 changes: 32 additions & 0 deletions nexus/zeroshot/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from nexus.zeroshot.format_classification import FormatClassification
from nexus.zeroshot.format_prompt import FormatPrompt

from router.classifiers.chatgpt_function import ChatGPTFunctionClassifier
from router.entities.flow import FlowDTO


class InvokeModel:
def __init__(
Expand Down Expand Up @@ -112,8 +115,37 @@ def _invoke_zeroshot(self, model_backend: str):
"bedrock": self._invoke_bedrock
}.get(model_backend)

def _invoke_function_calling(self):

classifier = ChatGPTFunctionClassifier(
agent_goal=self.zeroshot_data.get("context"),
)

flow_dto_list = []
options = self.zeroshot_data.get("options", [])
for option in options:
flow_dto_list.append(FlowDTO(name=option.get("class"), prompt=option.get("context")))

prediction: str = classifier.predict(
message=self.zeroshot_data.get("text"),
flows=flow_dto_list,
language=self.zeroshot_data.get("language")
)

formated_prediction = {
"output": prediction
}

classification_formater = FormatClassification(formated_prediction)
formatted_classification = classification_formater.get_classification(self.zeroshot_data)

response = {"output": formatted_classification}
return response

def invoke(self):
prompt = self._get_prompt(self.zeroshot_data)
if settings.DEFAULT_CLASSIFICATION_MODEL != "zeroshot":
return self._invoke_function_calling()
invoke_zeroshot = self._invoke_zeroshot(self.model_backend)
if invoke_zeroshot:
return invoke_zeroshot(prompt)
Expand Down
16 changes: 16 additions & 0 deletions nexus/zeroshot/format_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ def __init__(self, classification_data: dict):
self.classification_data = classification_data

def get_classification(self, zeroshot_data):
if settings.DEFAULT_CLASSIFICATION_MODEL != "zeroshot":
return self._get_function_calling_classification(zeroshot_data)

if self.model_backend == "runpod":
return self._get_runpod_classification(zeroshot_data)
elif self.model_backend == "bedrock":
Expand Down Expand Up @@ -55,3 +58,16 @@ def _get_runpod_classification(self, zeroshot_data):
def _get_bedrock_classification(self, zeroshot_data):
output_text = self.classification_data.get("outputs")[0].get("text").strip()
return self._get_formatted_output(output_text, zeroshot_data)

def _get_function_calling_classification(self, zeroshot_data):
output_text = self.classification_data.get("output")
classification = {"other": True, "classification": self._get_data_none_class()}

if output_text:
response_prepared = output_text.strip().strip(".").strip("\n").strip("'").lower()
all_classes = [option.get("class").lower() for option in zeroshot_data.get("options", [])]

if response_prepared in all_classes:
classification["other"] = False
classification["classification"] = response_prepared
return classification
18 changes: 18 additions & 0 deletions nexus/zeroshot/migrations/0002_zeroshotlogs_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Generated by Django 4.2.10 on 2024-12-18 18:24

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('zeroshot', '0001_initial'),
]

operations = [
migrations.AddField(
model_name='zeroshotlogs',
name='model',
field=models.CharField(default='zeroshot', max_length=64, verbose_name='Model'),
),
]
1 change: 1 addition & 0 deletions nexus/zeroshot/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ class Meta:
nlp_log = models.TextField(blank=True)
created_at = models.DateTimeField("created at", auto_now_add=True)
language = models.CharField(verbose_name="Language", max_length=64, null=True, blank=True)
model = models.CharField(verbose_name="Model", max_length=64, default="zeroshot")
53 changes: 51 additions & 2 deletions nexus/zeroshot/tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from django.test import TestCase
from django.test import TestCase, override_settings
from nexus.zeroshot.client import InvokeModel
from unittest.mock import patch
from unittest.mock import patch, ANY

from router.classifiers.chatgpt_function import ChatGPTFunctionClassifier


class TestClient(TestCase):
Expand All @@ -12,19 +14,66 @@ def setUp(self) -> None:
'options': []
}

@override_settings(DEFAULT_CLASSIFICATION_MODEL="zeroshot")
@patch("nexus.zeroshot.client.InvokeModel._invoke_bedrock")
def test_call_bedrock(self, mock):
invoke_model = InvokeModel(self.zeroshot_data, model_backend="bedrock")
invoke_model.invoke()
self.assertTrue(mock.called)

@override_settings(DEFAULT_CLASSIFICATION_MODEL="zeroshot")
@patch("nexus.zeroshot.client.InvokeModel._invoke_runpod")
def test_call_runpod(self, mock):
invoke_model = InvokeModel(self.zeroshot_data, model_backend="runpod")
invoke_model.invoke()
self.assertTrue(mock.called)

@override_settings(DEFAULT_CLASSIFICATION_MODEL="zeroshot")
def test_value_error(self):
invoke_model = InvokeModel(self.zeroshot_data, model_backend="err")
with self.assertRaises(ValueError):
invoke_model.invoke()

@override_settings(DEFAULT_CLASSIFICATION_MODEL="function_calling")
@patch("nexus.zeroshot.client.InvokeModel._invoke_function_calling")
def test_call_function_calling(self, mock):
invoke_model = InvokeModel(self.zeroshot_data, model_backend="zeroshot")
invoke_model.invoke()
self.assertTrue(mock.called)


class TestFunctionCalling(TestCase):

@override_settings(DEFAULT_CLASSIFICATION_MODEL='function_calling')
def test_invoke_function_calling_calls_correct_methods(self):
zeroshot_data = {
"context": "This is the agent goal.",
"language": "eng",
"text": "User message to classify.",
"options": [
{"class": "Class1", "context": "Context for class 1"},
{"class": "Class2", "context": "Context for class 2"},
]
}

with patch.object(ChatGPTFunctionClassifier, '__init__', return_value=None) as mock_init, \
patch.object(ChatGPTFunctionClassifier, 'predict', return_value='Class1') as mock_predict:

invoke_model = InvokeModel(zeroshot_data)
response = invoke_model.invoke()

mock_init.assert_called_with(agent_goal='This is the agent goal.')

mock_predict.assert_called_with(
message='User message to classify.',
flows=ANY,
language='eng',
)

expected_response = {
'output': {
'other': False,
'classification': 'class1'
}
}
self.assertEqual(response, expected_response)
7 changes: 6 additions & 1 deletion router/classifiers/chatgpt_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,16 @@ def predict(
self,
message: str,
flows: List[FlowDTO],
custom_prompt: str = None,
language: str = "por"
) -> str:

print(f"[+ ChatGPT message function classification: {message} ({language}) +]")
formated_prompt = self.get_prompt()

formated_prompt = custom_prompt
if not custom_prompt:
formated_prompt = self.get_prompt()

msg = [
{
"role": "system",
Expand Down
8 changes: 4 additions & 4 deletions router/entities/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

@dataclass
class FlowDTO:
pk: str
uuid: str
name: str
prompt: str
fallback: str
content_base_uuid: str
pk: str = None
uuid: str = None
fallback: str = None
content_base_uuid: str = None
send_to_llm: bool = False

0 comments on commit 9fd7eec

Please sign in to comment.