Skip to content

Commit

Permalink
edit start_route task to work with preview
Browse files Browse the repository at this point in the history
  • Loading branch information
AlisoSouza committed Dec 11, 2024
1 parent fa5f482 commit 70bc0b4
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 163 deletions.
152 changes: 9 additions & 143 deletions nexus/actions/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,165 +253,31 @@ def destroy(self, request, *args, **kwargs):
except IntelligencePermissionDenied:
return Response(status=status.HTTP_401_UNAUTHORIZED)


from router.entities import Message as UserMessage
from router.tasks.tasks import start_route
class MessagePreviewView(APIView):

def post(self, request, *args, **kwargs):
try:
flows_user_email = os.environ.get("FLOW_USER_EMAIL")
flow_start = SimulateFlowStart(
os.environ.get(
'FLOWS_REST_ENDPOINT'
),
os.environ.get(
'FLOWS_INTERNAL_TOKEN'
)
)
broadcast = SimulateBroadcast(
os.environ.get(
'FLOWS_REST_ENDPOINT'
),
os.environ.get(
'FLOWS_INTERNAL_TOKEN'
),
get_file_info
)

content_base_repository = ContentBaseORMRepository()
message_logs_repository = MessageLogsRepository()

data = request.data

project_uuid = kwargs.get("project_uuid")
text = data.get("text")
contact_urn = data.get("contact_urn")
attachments = data.get("attachments", [])
metadata = data.get("metadata", {})

project = projects.get_project_by_uuid(project_uuid)
indexer = projects.ProjectsUseCase().get_indexer_database_by_project(
project
)

has_project_permission(
user=request.user,
project=project,
method="post"
)

log_usecase = CreateLogUsecase()

message = Message(
data = request.data
message = UserMessage(
project_uuid=project_uuid,
text=text,
contact_urn=contact_urn,
attachments=attachments,
metadata=metadata
)

print(
f"[+ Message: {message.text} - Contact: {message.contact_urn} - Project: {message.project_uuid} +]"
)

project_uuid: str = message.project_uuid

flows_repository = FlowsORMRepository(project_uuid=project_uuid)

content_base: ContentBaseDTO = content_base_repository.get_content_base_by_project(
message.project_uuid
)

agent: AgentDTO = content_base_repository.get_agent(
content_base.uuid
)
agent = agent.set_default_if_null()

llm_model = get_llm_by_project_uuid(project_uuid)

llm_config = LLMSetupDTO(
model=llm_model.model.lower(),
model_version=llm_model.setup.get("version"),
temperature=llm_model.setup.get("temperature"),
top_k=llm_model.setup.get("top_k"),
top_p=llm_model.setup.get("top_p"),
token=llm_model.setup.get("token"),
max_length=llm_model.setup.get("max_length"),
max_tokens=llm_model.setup.get("max_tokens"),
language=llm_model.setup.get(
"language", settings.WENIGPT_DEFAULT_LANGUAGE)
)

print(
f"[+ LLM model: {llm_config.model}:{llm_config.model_version} +]"
)

pre_classification = PreClassification(
flows_repository=flows_repository,
message=message,
msg_event={},
flow_start=flow_start,
user_email=flows_user_email
)

pre_classification_response = pre_classification.pre_classification_preview()
if pre_classification_response:
return Response(pre_classification_response)

classification_handler = Classification(
flows_repository=flows_repository,
message=message,
msg_event={},
flow_start=flow_start,
user_email=flows_user_email
)

started_flow = classification_handler.non_custom_actions_preview()
if started_flow:
return Response(started_flow)

message_log = log_usecase.create_message_log(
text=text,
contact_urn=contact_urn,
source="preview"
)

if project_uuid == os.environ.get("DEMO_FUNC_CALLING_PROJECT_UUID"):
classifier = ChatGPTFunctionClassifier(agent_goal=agent.goal)
else:
classifier = ZeroshotClassifier(chatbot_goal=agent.goal)

classification = classification_handler.custom_actions(
classifier=classifier,
language=llm_config.language
)

llm_client = LLMClient.get_by_type(llm_config.model)
llm_client: LLMClient = list(llm_client)[0](
model_version=llm_config.model_version
)

if llm_config.model.lower() != "wenigpt":
llm_client.api_key = llm_config.token

print(f"[+ Classfication: {classification} +]")

response: dict = route(
classification=classification,
message=message,
content_base_repository=content_base_repository,
flows_repository=flows_repository,
message_logs_repository=message_logs_repository,
indexer=indexer(),
llm_client=llm_client,
direct_message=broadcast,
flow_start=flow_start,
llm_config=llm_config,
flows_user_email=flows_user_email,
log_usecase=log_usecase,
message_log=message_log
text=data.get("text"),
contact_urn=data.get("contact_urn"),
attachments=data.get("attachments", []),
metadata=data.get("metadata", {})
)

log_usecase.update_status("S")
response = start_route(message, preview=True)

return Response(data=response)
except IntelligencePermissionDenied:
Expand Down
72 changes: 52 additions & 20 deletions router/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,53 @@
MessageLogsRepository
)
from nexus.usecases.projects.projects_use_case import ProjectsUseCase

from nexus.usecases.intelligences.retrieve import get_file_info
from router.clients.preview.simulator.broadcast import SimulateBroadcast
from router.clients.preview.simulator.flow_start import SimulateFlowStart

@celery_app.task(bind=True)
def start_route(self, message: Dict) -> bool: # pragma: no cover
def start_route(self, message: Dict, preview: bool = False) -> bool: # pragma: no cover
def get_action_clients(preview: bool = False):
if preview:
flow_start = SimulateFlowStart(
os.environ.get(
'FLOWS_REST_ENDPOINT'
),
os.environ.get(
'FLOWS_INTERNAL_TOKEN'
)
)
broadcast = SimulateBroadcast(
os.environ.get(
'FLOWS_REST_ENDPOINT'
),
os.environ.get(
'FLOWS_INTERNAL_TOKEN'
),
get_file_info
)
return broadcast, flow_start

broadcast = SendMessageHTTPClient(
os.environ.get(
'FLOWS_REST_ENDPOINT'
),
os.environ.get(
'FLOWS_SEND_MESSAGE_INTERNAL_TOKEN'
)
)
flow_start = FlowStartHTTPClient(
os.environ.get(
'FLOWS_REST_ENDPOINT'
),
os.environ.get(
'FLOWS_INTERNAL_TOKEN'
)
)
return broadcast, flow_start

source = "preview" if preview else "router"
print(f"[+ Message from: {source} +]")

# Initialize Redis client using the REDIS_URL from settings
redis_client = Redis.from_url(settings.REDIS_URL)
Expand All @@ -48,27 +91,14 @@ def start_route(self, message: Dict) -> bool: # pragma: no cover
mailroom_msg_event['metadata'] = mailroom_msg_event.get('metadata') or {}

log_usecase = CreateLogUsecase()

try:
project_uuid: str = message.project_uuid
indexer = ProjectsUseCase().get_indexer_database_by_uuid(project_uuid)
flows_repository = FlowsORMRepository(project_uuid=project_uuid)

broadcast = SendMessageHTTPClient(
os.environ.get(
'FLOWS_REST_ENDPOINT'
),
os.environ.get(
'FLOWS_SEND_MESSAGE_INTERNAL_TOKEN'
)
)
flow_start = FlowStartHTTPClient(
os.environ.get(
'FLOWS_REST_ENDPOINT'
),
os.environ.get(
'FLOWS_INTERNAL_TOKEN'
)
)
broadcast, flow_start = get_action_clients(preview)

flows_user_email = os.environ.get("FLOW_USER_EMAIL")

content_base: ContentBaseDTO = content_base_repository.get_content_base_by_project(
Expand Down Expand Up @@ -102,7 +132,7 @@ def start_route(self, message: Dict) -> bool: # pragma: no cover
message_log = log_usecase.create_message_log(
text=message.text,
contact_urn=message.contact_urn,
source="router",
source=source,
)

llm_model = get_llm_by_project_uuid(project_uuid)
Expand Down Expand Up @@ -163,7 +193,7 @@ def start_route(self, message: Dict) -> bool: # pragma: no cover
redis_client.set(pending_task_key, self.request.id)

# Generate response for the concatenated message
route(
response: dict = route(
classification=classification,
message=message,
content_base_repository=content_base_repository,
Expand All @@ -184,6 +214,8 @@ def start_route(self, message: Dict) -> bool: # pragma: no cover
redis_client.delete(pending_task_key)

log_usecase.update_status("S")
return response

except Exception as e:
print(f"[- START ROUTE - Error: {e} -]")
if message.text:
Expand Down

0 comments on commit 70bc0b4

Please sign in to comment.