Skip to content

Commit

Permalink
fix(openrons-ai-server, opentrons-ai-client): predict method is divid…
Browse files Browse the repository at this point in the history
…ed into create and update

Previously, users end up with the same chatcompletion endpoint no matter which method it is used
  • Loading branch information
Elyorcv committed Nov 25, 2024
1 parent 6e407e4 commit 81156e9
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ export function generateChatPrompt(
.join('\n')
: `- ${t(values.instruments.pipettes)}`
const flexGripper =
values.instruments.flexGripper === FLEX_GRIPPER
values.instruments.flexGripper === FLEX_GRIPPER &&
values.instruments.robot === OPENTRONS_FLEX
? `\n- ${t('with_flex_gripper')}`
: ''
const modules = values.modules
Expand Down
76 changes: 41 additions & 35 deletions opentrons-ai-server/api/domain/anthropic_predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Literal

import requests
import structlog
Expand All @@ -23,7 +23,7 @@ def __init__(self, settings: Settings) -> None:
self.model_name: str = settings.anthropic_model_name
self.system_prompt: str = SYSTEM_PROMPT
self.path_docs: Path = ROOT_PATH / "api" / "storage" / "docs"
self._messages: List[MessageParam] = [
self.cashed_docs: List[MessageParam] = [
{
"role": "user",
"content": [
Expand Down Expand Up @@ -77,19 +77,26 @@ def get_docs(self) -> str:
return "\n".join(xml_output)

@tracer.wrap()
def generate_message(self, max_tokens: int = 4096) -> Message:
def _process_message(
self, user_id: str, messages: List[MessageParam], message_type: Literal["create", "update"], max_tokens: int = 4096
) -> Message:
"""
Internal method to handle message processing with different system prompts.
For now, system prompt is the same.
"""

response = self.client.messages.create(
response: Message = self.client.messages.create(
model=self.model_name,
system=self.system_prompt,
max_tokens=max_tokens,
messages=self._messages,
messages=messages,
tools=self.tools, # type: ignore
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
metadata={"user_id": user_id},
)

logger.info(
"Token usage",
f"Token usage: {message_type.capitalize()}",
extra={
"input_tokens": response.usage.input_tokens,
"output_tokens": response.usage.output_tokens,
Expand All @@ -100,15 +107,23 @@ def generate_message(self, max_tokens: int = 4096) -> Message:
return response

@tracer.wrap()
def predict(self, prompt: str) -> str | None:
def process_message(
self, user_id: str, prompt: str, history: List[MessageParam] | None = None, message_type: Literal["create", "update"] = "create"
) -> str | None:
"""Unified method for creating and updating messages"""
try:
self._messages.append({"role": "user", "content": PROMPT.format(USER_PROMPT=prompt)})
response = self.generate_message()
messages: List[MessageParam] = self.cashed_docs.copy()
if history:
messages += history

messages.append({"role": "user", "content": PROMPT.format(USER_PROMPT=prompt)})
response = self._process_message(user_id=user_id, messages=messages, message_type=message_type)

if response.content[-1].type == "tool_use":
tool_use = response.content[-1]
self._messages.append({"role": "assistant", "content": response.content})
messages.append({"role": "assistant", "content": response.content})
result = self.handle_tool_use(tool_use.name, tool_use.input) # type: ignore
self._messages.append(
messages.append(
{
"role": "user",
"content": [
Expand All @@ -120,25 +135,26 @@ def predict(self, prompt: str) -> str | None:
],
}
)
follow_up = self.generate_message()
response_text = follow_up.content[0].text # type: ignore
self._messages.append({"role": "assistant", "content": response_text})
return response_text
follow_up = self._process_message(user_id=user_id, messages=messages, message_type=message_type)
return follow_up.content[0].text # type: ignore

elif response.content[0].type == "text":
response_text = response.content[0].text
self._messages.append({"role": "assistant", "content": response_text})
return response_text
return response.content[0].text

logger.error("Unexpected response type")
return None
except IndexError as e:
logger.error("Invalid response format", extra={"error": str(e)})
return None
except Exception as e:
logger.error("Error in predict method", extra={"error": str(e)})
logger.error(f"Error in {message_type} method", extra={"error": str(e)})
return None

@tracer.wrap()
def create(self, user_id: str, prompt: str, history: List[MessageParam] | None = None) -> str | None:
return self.process_message(user_id, prompt, history, "create")

@tracer.wrap()
def update(self, user_id: str, prompt: str, history: List[MessageParam] | None = None) -> str | None:
return self.process_message(user_id, prompt, history, "update")

@tracer.wrap()
def handle_tool_use(self, func_name: str, func_params: Dict[str, Any]) -> str:
if func_name == "simulate_protocol":
Expand All @@ -148,17 +164,6 @@ def handle_tool_use(self, func_name: str, func_params: Dict[str, Any]) -> str:
logger.error("Unknown tool", extra={"tool": func_name})
raise ValueError(f"Unknown tool: {func_name}")

@tracer.wrap()
def reset(self) -> None:
self._messages = [
{
"role": "user",
"content": [
{"type": "text", "text": DOCUMENTS.format(doc_content=self.get_docs()), "cache_control": {"type": "ephemeral"}} # type: ignore
],
}
]

@tracer.wrap()
def simulate_protocol(self, protocol: str) -> str:
url = "https://Opentrons-simulator.hf.space/protocol"
Expand Down Expand Up @@ -197,8 +202,9 @@ def main() -> None:

settings = Settings()
llm = AnthropicPredict(settings)
prompt = Prompt.ask("Type a prompt to send to the Anthropic API:")
completion = llm.predict(prompt)
Prompt.ask("Type a prompt to send to the Anthropic API:")

completion = llm.create(user_id="1", prompt="hi", history=None)
print(completion)


Expand Down
34 changes: 18 additions & 16 deletions opentrons-ai-server/api/domain/config_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
4. Flag potential safety or compatibility issues
5. Suggest protocol optimizations when appropriate
Call protocol simulation tool to validate the code - only when it is called explicitly by the user.
For all other queries, provide direct responses.
Important guidelines:
- Always verify labware compatibility before generating protocols
- Include appropriate error handling in generated code
Expand All @@ -28,26 +25,25 @@
"""

PROMPT = """
Here are the inputs you will work with:
<user_prompt>
{USER_PROMPT}
</user_prompt>
Follow these instructions to handle the user's prompt:
1. Analyze the user's prompt to determine if it's:
1. <Analyze the user's prompt to determine if it's>:
a) A request to generate a protocol
b) A question about the Opentrons Python API v2
b) A question about the Opentrons Python API v2 or about details of protocol
c) A common task (e.g., value changes, OT-2 to Flex conversion, slot correction)
d) An unrelated or unclear request
e) A tool calling. If a user calls simulate protocol explicity, then call.
f) A greeting. Respond kindly.
2. If the prompt is unrelated or unclear, ask the user for clarification. For example:
I apologize, but your prompt seems unclear. Could you please provide more details?
Note: when you respond you dont need mention the category or the type.
2. If the prompt is unrelated or unclear, ask the user for clarification.
I'm sorry, but your prompt seems unclear. Could you please provide more details?
You dont need to mention
3. If the prompt is a question about the API, answer it using only the information
3. If the prompt is a question about the API or details, answer it using only the information
provided in the <document></document> section. Provide references and place them under the <References> tag.
Format your response like this:
API answer:
Expand Down Expand Up @@ -86,8 +82,8 @@
}}
requirements = {{
'robotType': '[Robot type based on user prompt, OT-2 or Flex, default is OT-2]',
'apiLevel': '[apiLevel, default is 2.19 ]'
'robotType': '[Robot type: OT-2(default) for Opentrons OT-2, Flex for Opentrons Flex]',
'apiLevel': '[apiLevel, default: 2.19]'
}}
def run(protocol: protocol_api.ProtocolContext):
Expand Down Expand Up @@ -214,4 +210,10 @@ def run(protocol: protocol_api.ProtocolContext):
as a reference to generate a basic protocol.
Remember to use only the information provided in the <document></document>. Do not introduce any external information or assumptions.
Here are the inputs you will work with:
<user_prompt>
{USER_PROMPT}
</user_prompt>
"""
61 changes: 35 additions & 26 deletions opentrons-ai-server/api/handler/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,19 @@ async def create_chat_completion(
return ChatResponse(reply="Default fake response. ", fake=body.fake)

response: Optional[str] = None

if "Write a protocol using" in body.history[0]["content"]: # type: ignore
protocol_option = "create"
else:
protocol_option = "update"

if "openai" in settings.model.lower():
response = openai.predict(prompt=body.message, chat_completion_message_params=body.history)
else:
response = claude.predict(prompt=body.message)
if protocol_option == "create":
response = claude.create(user_id=str(user.sub), prompt=body.message, history=body.history) # type: ignore
else:
response = claude.update(user_id=str(user.sub), prompt=body.message, history=body.history) # type: ignore

if response is None or response == "":
return ChatResponse(reply="No response was generated", fake=bool(body.fake))
Expand All @@ -218,88 +227,88 @@ async def create_chat_completion(

@tracer.wrap()
@app.post(
"/api/chat/updateProtocol",
"/api/chat/createProtocol",
response_model=Union[ChatResponse, ErrorResponse],
summary="Updates protocol",
description="Generate a chat response based on the provided prompt that will update an existing protocol with the required changes.",
summary="Creates protocol",
description="Generate a chat response based on the provided prompt that will create a new protocol with the required changes.",
)
async def update_protocol(
body: UpdateProtocol, user: Annotated[User, Security(auth.verify)]
async def create_protocol(
body: CreateProtocol, user: Annotated[User, Security(auth.verify)]
) -> Union[ChatResponse, ErrorResponse]: # noqa: B008
"""
Generate an updated protocol using LLM.
- **request**: The HTTP request containing the existing protocol and other relevant parameters.
- **request**: The HTTP request containing the chat message.
- **returns**: A chat response or an error message.
"""
logger.info("POST /api/chat/updateProtocol", extra={"body": body.model_dump(), "user": user})
logger.info("POST /api/chat/createProtocol", extra={"body": body.model_dump(), "user": user})
try:
if not body.protocol_text or body.protocol_text == "":

if not body.prompt or body.prompt == "":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=EmptyRequestError(message="Request body is empty").model_dump()
)

if body.fake:
return ChatResponse(reply="Fake response", fake=bool(body.fake))
return ChatResponse(reply="Fake response", fake=body.fake)

response: Optional[str] = None
if "openai" in settings.model.lower():
response = openai.predict(prompt=body.prompt, chat_completion_message_params=None)
response = openai.predict(prompt=str(body.model_dump()), chat_completion_message_params=None)
else:
response = claude.predict(prompt=body.prompt)
response = claude.create(user_id=str(user.sub), prompt=body.prompt, history=None)

if response is None or response == "":
return ChatResponse(reply="No response was generated", fake=bool(body.fake))

return ChatResponse(reply=response, fake=bool(body.fake))

except Exception as e:
logger.exception("Error processing protocol update")
logger.exception("Error processing protocol creation")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=InternalServerError(exception_object=e).model_dump()
) from e


@tracer.wrap()
@app.post(
"/api/chat/createProtocol",
"/api/chat/updateProtocol",
response_model=Union[ChatResponse, ErrorResponse],
summary="Creates protocol",
description="Generate a chat response based on the provided prompt that will create a new protocol with the required changes.",
summary="Updates protocol",
description="Generate a chat response based on the provided prompt that will update an existing protocol with the required changes.",
)
async def create_protocol(
body: CreateProtocol, user: Annotated[User, Security(auth.verify)]
async def update_protocol(
body: UpdateProtocol, user: Annotated[User, Security(auth.verify)]
) -> Union[ChatResponse, ErrorResponse]: # noqa: B008
"""
Generate an updated protocol using LLM.
- **request**: The HTTP request containing the chat message.
- **request**: The HTTP request containing the existing protocol and other relevant parameters.
- **returns**: A chat response or an error message.
"""
logger.info("POST /api/chat/createProtocol", extra={"body": body.model_dump(), "user": user})
logger.info("POST /api/chat/updateProtocol", extra={"body": body.model_dump(), "user": user})
try:

if not body.prompt or body.prompt == "":
if not body.protocol_text or body.protocol_text == "":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=EmptyRequestError(message="Request body is empty").model_dump()
)

if body.fake:
return ChatResponse(reply="Fake response", fake=body.fake)
return ChatResponse(reply="Fake response", fake=bool(body.fake))

response: Optional[str] = None
if "openai" in settings.model.lower():
response = openai.predict(prompt=str(body.model_dump()), chat_completion_message_params=None)
response = openai.predict(prompt=body.prompt, chat_completion_message_params=None)
else:
response = claude.predict(prompt=str(body.model_dump()))
response = claude.update(user_id=str(user.sub), prompt=body.prompt, history=None)

if response is None or response == "":
return ChatResponse(reply="No response was generated", fake=bool(body.fake))

return ChatResponse(reply=response, fake=bool(body.fake))

except Exception as e:
logger.exception("Error processing protocol creation")
logger.exception("Error processing protocol update")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=InternalServerError(exception_object=e).model_dump()
) from e
Expand Down
4 changes: 4 additions & 0 deletions opentrons-ai-server/api/models/chat_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ class Chat(BaseModel):
Field(None, description="Chat history in the form of a list of messages. Type is from OpenAI's ChatCompletionMessageParam"),
]

ChatOptions = Literal["update", "create"]
ChatOptionsType = Annotated[Optional[ChatOptions], Field("create", description="which chat pathway did the user enter: create or update")]


class ChatRequest(BaseModel):
message: str = Field(..., description="The latest message to be processed.")
history: HistoryType
fake: bool = Field(True, description="When set to true, the response will be a fake. OpenAI API is not used.")
fake_key: FakeKeyType
chat_options: ChatOptionsType
2 changes: 1 addition & 1 deletion opentrons-ai-server/tests/helpers/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_health(self) -> Response:
@timeit
def get_chat_completion(self, message: str, fake: bool = True, fake_key: Optional[FakeKeys] = None, bad_auth: bool = False) -> Response:
"""Call the /chat/completion endpoint and return the response."""
request = ChatRequest(message=message, fake=fake, fake_key=fake_key, history=None)
request = ChatRequest(message=message, fake=fake, fake_key=fake_key, history=None, chat_options=None)
headers = self.standard_headers if not bad_auth else self.invalid_auth_headers
return self.httpx.post("/chat/completion", headers=headers, json=request.model_dump())

Expand Down

0 comments on commit 81156e9

Please sign in to comment.