Skip to content

Commit

Permalink
fix(openrons-ai-server, opentrons-ai-client): predict method (#16967)
Browse files Browse the repository at this point in the history
<!--
Thanks for taking the time to open a Pull Request (PR)! Please make sure
you've read the "Opening Pull Requests" section of our Contributing
Guide:


https://github.com/Opentrons/opentrons/blob/edge/CONTRIBUTING.md#opening-pull-requests

GitHub provides robust markdown to format your PR. Links, diagrams,
pictures, and videos along with text formatting make it possible to
create a rich and informative PR. For more information on GitHub
markdown, see:


https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax

To ensure your code is reviewed quickly and thoroughly, please fill out
the sections below to the best of your ability!
-->

# Overview
Two changed
- **Backend**: Previously, two endpoints end up using the same model,
therefore confused during message processing. Present, two endpoints use
separate models.
- **Frontend**: Flex gripper is concerned with only Flex robot 
   - opentrons-ai-client/src/resources/utils/createProtocolUtils.tsx 

Closes AUTH-1076
<!--
Describe your PR at a high level. State acceptance criteria and how this
PR fits into other work. Link issues, PRs, and other relevant resources.
-->

## Test Plan and Hands on Testing
- Visit `opentrons.ai` 
- Click `Update an existing protocol` and follow instructions
- Click `Create a new protocol` and follow instructions. Once you
complete providing labware and other information. Click Submit then it
will take you chat window where you need to see the generated protocol.
It should not start like 'Simulation is successful'.


 
<!--
Describe your testing of the PR. Emphasize testing not reflected in the
code. Attach protocols, logs, screenshots and any other assets that
support your testing.
-->

## Changelog

<!--
List changes introduced by this PR considering future developers and the
end user. Give careful thought and clear documentation to breaking
changes.
-->

## Review requests
All tests are passing 
<!--
- What do you need from reviewers to feel confident this PR is ready to
merge?
- Ask questions.
-->

## Risk assessment
Low

<!--
- Indicate the level of attention this PR needs.
- Provide context to guide reviewers.
- Discuss trade-offs, coupling, and side effects.
- Look for the possibility, even if you think it's small, that your
change may affect some other part of the system.
- For instance, changing return tip behavior may also change the
behavior of labware calibration.
- How do your unit tests and on hands on testing mitigate this PR's
risks and the risk of future regressions?
- Especially in high risk PRs, explain how you know your testing is
enough.
-->
  • Loading branch information
Elyorcv authored Nov 25, 2024
1 parent b6e29e9 commit 378a1f2
Show file tree
Hide file tree
Showing 8 changed files with 701 additions and 237 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ export function generateChatPrompt(
.join('\n')
: `- ${t(values.instruments.pipettes)}`
const flexGripper =
values.instruments.flexGripper === FLEX_GRIPPER
values.instruments.flexGripper === FLEX_GRIPPER &&
values.instruments.robot === OPENTRONS_FLEX
? `\n- ${t('with_flex_gripper')}`
: ''
const modules = values.modules
Expand Down
76 changes: 41 additions & 35 deletions opentrons-ai-server/api/domain/anthropic_predict.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import uuid
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Literal

import requests
import structlog
Expand All @@ -23,7 +23,7 @@ def __init__(self, settings: Settings) -> None:
self.model_name: str = settings.anthropic_model_name
self.system_prompt: str = SYSTEM_PROMPT
self.path_docs: Path = ROOT_PATH / "api" / "storage" / "docs"
self._messages: List[MessageParam] = [
self.cached_docs: List[MessageParam] = [
{
"role": "user",
"content": [
Expand Down Expand Up @@ -77,19 +77,26 @@ def get_docs(self) -> str:
return "\n".join(xml_output)

@tracer.wrap()
def generate_message(self, max_tokens: int = 4096) -> Message:
def _process_message(
self, user_id: str, messages: List[MessageParam], message_type: Literal["create", "update"], max_tokens: int = 4096
) -> Message:
"""
Internal method to handle message processing with different system prompts.
For now, system prompt is the same.
"""

response = self.client.messages.create(
response: Message = self.client.messages.create(
model=self.model_name,
system=self.system_prompt,
max_tokens=max_tokens,
messages=self._messages,
messages=messages,
tools=self.tools, # type: ignore
extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
metadata={"user_id": user_id},
)

logger.info(
"Token usage",
f"Token usage: {message_type.capitalize()}",
extra={
"input_tokens": response.usage.input_tokens,
"output_tokens": response.usage.output_tokens,
Expand All @@ -100,15 +107,23 @@ def generate_message(self, max_tokens: int = 4096) -> Message:
return response

@tracer.wrap()
def predict(self, prompt: str) -> str | None:
def process_message(
self, user_id: str, prompt: str, history: List[MessageParam] | None = None, message_type: Literal["create", "update"] = "create"
) -> str | None:
"""Unified method for creating and updating messages"""
try:
self._messages.append({"role": "user", "content": PROMPT.format(USER_PROMPT=prompt)})
response = self.generate_message()
messages: List[MessageParam] = self.cached_docs.copy()
if history:
messages += history

messages.append({"role": "user", "content": PROMPT.format(USER_PROMPT=prompt)})
response = self._process_message(user_id=user_id, messages=messages, message_type=message_type)

if response.content[-1].type == "tool_use":
tool_use = response.content[-1]
self._messages.append({"role": "assistant", "content": response.content})
messages.append({"role": "assistant", "content": response.content})
result = self.handle_tool_use(tool_use.name, tool_use.input) # type: ignore
self._messages.append(
messages.append(
{
"role": "user",
"content": [
Expand All @@ -120,25 +135,26 @@ def predict(self, prompt: str) -> str | None:
],
}
)
follow_up = self.generate_message()
response_text = follow_up.content[0].text # type: ignore
self._messages.append({"role": "assistant", "content": response_text})
return response_text
follow_up = self._process_message(user_id=user_id, messages=messages, message_type=message_type)
return follow_up.content[0].text # type: ignore

elif response.content[0].type == "text":
response_text = response.content[0].text
self._messages.append({"role": "assistant", "content": response_text})
return response_text
return response.content[0].text

logger.error("Unexpected response type")
return None
except IndexError as e:
logger.error("Invalid response format", extra={"error": str(e)})
return None
except Exception as e:
logger.error("Error in predict method", extra={"error": str(e)})
logger.error(f"Error in {message_type} method", extra={"error": str(e)})
return None

@tracer.wrap()
def create(self, user_id: str, prompt: str, history: List[MessageParam] | None = None) -> str | None:
return self.process_message(user_id, prompt, history, "create")

@tracer.wrap()
def update(self, user_id: str, prompt: str, history: List[MessageParam] | None = None) -> str | None:
return self.process_message(user_id, prompt, history, "update")

@tracer.wrap()
def handle_tool_use(self, func_name: str, func_params: Dict[str, Any]) -> str:
if func_name == "simulate_protocol":
Expand All @@ -148,17 +164,6 @@ def handle_tool_use(self, func_name: str, func_params: Dict[str, Any]) -> str:
logger.error("Unknown tool", extra={"tool": func_name})
raise ValueError(f"Unknown tool: {func_name}")

@tracer.wrap()
def reset(self) -> None:
self._messages = [
{
"role": "user",
"content": [
{"type": "text", "text": DOCUMENTS.format(doc_content=self.get_docs()), "cache_control": {"type": "ephemeral"}} # type: ignore
],
}
]

@tracer.wrap()
def simulate_protocol(self, protocol: str) -> str:
url = "https://Opentrons-simulator.hf.space/protocol"
Expand Down Expand Up @@ -197,8 +202,9 @@ def main() -> None:

settings = Settings()
llm = AnthropicPredict(settings)
prompt = Prompt.ask("Type a prompt to send to the Anthropic API:")
completion = llm.predict(prompt)
Prompt.ask("Type a prompt to send to the Anthropic API:")

completion = llm.create(user_id="1", prompt="hi", history=None)
print(completion)


Expand Down
36 changes: 19 additions & 17 deletions opentrons-ai-server/api/domain/config_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
Your key responsibilities:
1. Welcome scientists warmly and understand their protocol needs
2. Generate accurate Python protocols using standard Opentrons labware
2. Generate accurate Python protocols using standard Opentrons labware (see <source> standard-loadname-info.md </source> in <document>)
3. Provide clear explanations and documentation
4. Flag potential safety or compatibility issues
5. Suggest protocol optimizations when appropriate
Call protocol simulation tool to validate the code - only when it is called explicitly by the user.
For all other queries, provide direct responses.
Important guidelines:
- Always verify labware compatibility before generating protocols
- Include appropriate error handling in generated code
Expand All @@ -28,26 +25,25 @@
"""

PROMPT = """
Here are the inputs you will work with:
<user_prompt>
{USER_PROMPT}
</user_prompt>
Follow these instructions to handle the user's prompt:
1. Analyze the user's prompt to determine if it's:
1. <Analyze the user's prompt to determine if it's>:
a) A request to generate a protocol
b) A question about the Opentrons Python API v2
b) A question about the Opentrons Python API v2 or about details of protocol
c) A common task (e.g., value changes, OT-2 to Flex conversion, slot correction)
d) An unrelated or unclear request
e) A tool calling. If a user calls simulate protocol explicity, then call.
f) A greeting. Respond kindly.
2. If the prompt is unrelated or unclear, ask the user for clarification. For example:
I apologize, but your prompt seems unclear. Could you please provide more details?
Note: when you respond you dont need mention the category or the type.
2. If the prompt is unrelated or unclear, ask the user for clarification.
I'm sorry, but your prompt seems unclear. Could you please provide more details?
You dont need to mention
3. If the prompt is a question about the API, answer it using only the information
3. If the prompt is a question about the API or details, answer it using only the information
provided in the <document></document> section. Provide references and place them under the <References> tag.
Format your response like this:
API answer:
Expand Down Expand Up @@ -86,8 +82,8 @@
}}
requirements = {{
'robotType': '[Robot type based on user prompt, OT-2 or Flex, default is OT-2]',
'apiLevel': '[apiLevel, default is 2.19 ]'
'robotType': '[Robot type: OT-2(default) for Opentrons OT-2, Flex for Opentrons Flex]',
'apiLevel': '[apiLevel, default: 2.19]'
}}
def run(protocol: protocol_api.ProtocolContext):
Expand Down Expand Up @@ -214,4 +210,10 @@ def run(protocol: protocol_api.ProtocolContext):
as a reference to generate a basic protocol.
Remember to use only the information provided in the <document></document>. Do not introduce any external information or assumptions.
Here are the inputs you will work with:
<user_prompt>
{USER_PROMPT}
</user_prompt>
"""
61 changes: 35 additions & 26 deletions opentrons-ai-server/api/handler/fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,19 @@ async def create_chat_completion(
return ChatResponse(reply="Default fake response. ", fake=body.fake)

response: Optional[str] = None

if body.history and body.history[0].get("content") and "Write a protocol using" in body.history[0]["content"]: # type: ignore
protocol_option = "create"
else:
protocol_option = "update"

if "openai" in settings.model.lower():
response = openai.predict(prompt=body.message, chat_completion_message_params=body.history)
else:
response = claude.predict(prompt=body.message)
if protocol_option == "create":
response = claude.create(user_id=str(user.sub), prompt=body.message, history=body.history) # type: ignore
else:
response = claude.update(user_id=str(user.sub), prompt=body.message, history=body.history) # type: ignore

if response is None or response == "":
return ChatResponse(reply="No response was generated", fake=bool(body.fake))
Expand All @@ -218,88 +227,88 @@ async def create_chat_completion(

@tracer.wrap()
@app.post(
"/api/chat/updateProtocol",
"/api/chat/createProtocol",
response_model=Union[ChatResponse, ErrorResponse],
summary="Updates protocol",
description="Generate a chat response based on the provided prompt that will update an existing protocol with the required changes.",
summary="Creates protocol",
description="Generate a chat response based on the provided prompt that will create a new protocol with the required changes.",
)
async def update_protocol(
body: UpdateProtocol, user: Annotated[User, Security(auth.verify)]
async def create_protocol(
body: CreateProtocol, user: Annotated[User, Security(auth.verify)]
) -> Union[ChatResponse, ErrorResponse]: # noqa: B008
"""
Generate an updated protocol using LLM.
- **request**: The HTTP request containing the existing protocol and other relevant parameters.
- **request**: The HTTP request containing the chat message.
- **returns**: A chat response or an error message.
"""
logger.info("POST /api/chat/updateProtocol", extra={"body": body.model_dump(), "user": user})
logger.info("POST /api/chat/createProtocol", extra={"body": body.model_dump(), "user": user})
try:
if not body.protocol_text or body.protocol_text == "":

if not body.prompt or body.prompt == "":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=EmptyRequestError(message="Request body is empty").model_dump()
)

if body.fake:
return ChatResponse(reply="Fake response", fake=bool(body.fake))
return ChatResponse(reply="Fake response", fake=body.fake)

response: Optional[str] = None
if "openai" in settings.model.lower():
response = openai.predict(prompt=body.prompt, chat_completion_message_params=None)
response = openai.predict(prompt=str(body.model_dump()), chat_completion_message_params=None)
else:
response = claude.predict(prompt=body.prompt)
response = claude.create(user_id=str(user.sub), prompt=body.prompt, history=None)

if response is None or response == "":
return ChatResponse(reply="No response was generated", fake=bool(body.fake))

return ChatResponse(reply=response, fake=bool(body.fake))

except Exception as e:
logger.exception("Error processing protocol update")
logger.exception("Error processing protocol creation")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=InternalServerError(exception_object=e).model_dump()
) from e


@tracer.wrap()
@app.post(
"/api/chat/createProtocol",
"/api/chat/updateProtocol",
response_model=Union[ChatResponse, ErrorResponse],
summary="Creates protocol",
description="Generate a chat response based on the provided prompt that will create a new protocol with the required changes.",
summary="Updates protocol",
description="Generate a chat response based on the provided prompt that will update an existing protocol with the required changes.",
)
async def create_protocol(
body: CreateProtocol, user: Annotated[User, Security(auth.verify)]
async def update_protocol(
body: UpdateProtocol, user: Annotated[User, Security(auth.verify)]
) -> Union[ChatResponse, ErrorResponse]: # noqa: B008
"""
Generate an updated protocol using LLM.
- **request**: The HTTP request containing the chat message.
- **request**: The HTTP request containing the existing protocol and other relevant parameters.
- **returns**: A chat response or an error message.
"""
logger.info("POST /api/chat/createProtocol", extra={"body": body.model_dump(), "user": user})
logger.info("POST /api/chat/updateProtocol", extra={"body": body.model_dump(), "user": user})
try:

if not body.prompt or body.prompt == "":
if not body.protocol_text or body.protocol_text == "":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=EmptyRequestError(message="Request body is empty").model_dump()
)

if body.fake:
return ChatResponse(reply="Fake response", fake=body.fake)
return ChatResponse(reply="Fake response", fake=bool(body.fake))

response: Optional[str] = None
if "openai" in settings.model.lower():
response = openai.predict(prompt=str(body.model_dump()), chat_completion_message_params=None)
response = openai.predict(prompt=body.prompt, chat_completion_message_params=None)
else:
response = claude.predict(prompt=str(body.model_dump()))
response = claude.update(user_id=str(user.sub), prompt=body.prompt, history=None)

if response is None or response == "":
return ChatResponse(reply="No response was generated", fake=bool(body.fake))

return ChatResponse(reply=response, fake=bool(body.fake))

except Exception as e:
logger.exception("Error processing protocol creation")
logger.exception("Error processing protocol update")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=InternalServerError(exception_object=e).model_dump()
) from e
Expand Down
4 changes: 4 additions & 0 deletions opentrons-ai-server/api/models/chat_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ class Chat(BaseModel):
Field(None, description="Chat history in the form of a list of messages. Type is from OpenAI's ChatCompletionMessageParam"),
]

ChatOptions = Literal["update", "create"]
ChatOptionType = Annotated[Optional[ChatOptions], Field("create", description="which chat pathway did the user enter: create or update")]


class ChatRequest(BaseModel):
message: str = Field(..., description="The latest message to be processed.")
history: HistoryType
fake: bool = Field(True, description="When set to true, the response will be a fake. OpenAI API is not used.")
fake_key: FakeKeyType
chat_options: ChatOptionType
Loading

0 comments on commit 378a1f2

Please sign in to comment.