-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add question rephrasing to Claude pipeline
This adds a question rephrasing step to the Claude pipeline which mirrors the existing OpenAI rephrasing. Rather than just copying the existing implementation into the `Claude` namespace and updating the API calls, instead we keep the existing `Pipeline::QuestionRephraser` class and instead makes the changes in that class to call provider-specific implementations instead. The reason for this is that there's some logic in this step which is applied before we actually make the API calls; namely that of: * returning early if this is the first question in the conversation * plucking out the last 5 relevant questions from the conversation to include in the prompt So it makes sense to perform this logic in the `Pipeline::QuestionRephraser` class before calling the provider APIs. To that end, we now do the early return logic and find the relevant `Question` records, then pass those records into provider-specific classes. Aside from that the implementation is relatively straightforward. We use the same prompt and interpolate the same values into it. In terms of the tests, I've added a new method: `bedrock_claude_text_response`. This simply stubs a call to Bedrock with a specified text response. We didn't have this previously as the structured answer call stubs the tool call and tool call response, rather than a non-tool call text response. Additionally, it optionally asserts on the "user" messages within the `messages` array we pass to the Bedrock `converse` method. If the `user_messages` argument is present (be it a string or a Regex), then an element in the `messages` array must have a role of `"user"` and must match the value in the argument, otherwise an error is raised. This is similar to how the current OpenAI stubbing works, in that we make assertions on the request body as well. It's trickier to make these same assertions with the Bedrock library because we don't have the same control over the request body. With question rephrasing it's quite important that we make some assertions on the request body (i.e. the messages) as some specs are testing that the rephrased question is used in the prompt.
- Loading branch information
Showing
11 changed files
with
534 additions
and
237 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
module AnswerComposition::Pipeline::Claude | ||
def self.prompt_config | ||
Rails.configuration.govuk_chat_private.llm_prompts.claude | ||
end | ||
end |
83 changes: 83 additions & 0 deletions
83
lib/answer_composition/pipeline/claude/question_rephraser.rb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
module AnswerComposition::Pipeline | ||
module Claude | ||
class QuestionRephraser | ||
# TODO: change this to a more basic model | ||
BEDROCK_MODEL = "eu.anthropic.claude-3-5-sonnet-20240620-v1:0".freeze | ||
|
||
def self.call(...) = new(...).call | ||
|
||
def initialize(question_message, message_records) | ||
@question_message = question_message | ||
@message_records = message_records | ||
end | ||
|
||
def call | ||
response = bedrock_client.converse( | ||
system: [{ text: config[:system_prompt] }], | ||
model_id: BEDROCK_MODEL, | ||
messages:, | ||
inference_config:, | ||
) | ||
|
||
AnswerComposition::Pipeline::QuestionRephraser::Result.new( | ||
llm_response: response.to_h, | ||
rephrased_question: response.dig("output", "message", "content", 0, "text"), | ||
metrics: build_metrics(response), | ||
) | ||
end | ||
|
||
private | ||
|
||
attr_reader :question_message, :message_records | ||
|
||
def bedrock_client | ||
@bedrock_client ||= Aws::BedrockRuntime::Client.new | ||
end | ||
|
||
def build_metrics(response) | ||
{ | ||
llm_prompt_tokens: response.dig("usage", "input_tokens"), | ||
llm_completion_tokens: response.dig("usage", "output_tokens"), | ||
} | ||
end | ||
|
||
def config | ||
Claude.prompt_config[:question_rephraser] | ||
end | ||
|
||
def inference_config | ||
{ | ||
max_tokens: 4096, | ||
temperature: 0.0, | ||
} | ||
end | ||
|
||
def user_prompt | ||
config[:user_prompt] | ||
.sub("{question}", question_message) | ||
.sub("{message_history}", message_history) | ||
end | ||
|
||
def messages | ||
[{ role: "user", content: [{ text: user_prompt }] }] | ||
end | ||
|
||
def message_history | ||
message_records.flat_map(&method(:map_question)).join("\n") | ||
end | ||
|
||
def map_question(question) | ||
question_message = question.answer.rephrased_question || question.message | ||
|
||
[ | ||
format_messsage("user", question_message), | ||
format_messsage("assistant", question.answer.message), | ||
] | ||
end | ||
|
||
def format_messsage(actor, message) | ||
["#{actor}:", '"""', message, '"""'].join("\n") | ||
end | ||
end | ||
end | ||
end |
84 changes: 84 additions & 0 deletions
84
lib/answer_composition/pipeline/openai/question_rephraser.rb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
module AnswerComposition::Pipeline::OpenAI | ||
class QuestionRephraser | ||
OPENAI_MODEL = "gpt-4o-mini".freeze | ||
|
||
def self.call(...) = new(...).call | ||
|
||
def initialize(question_message, question_records) | ||
@question_message = question_message | ||
@question_records = question_records | ||
end | ||
|
||
def call | ||
AnswerComposition::Pipeline::QuestionRephraser::Result.new( | ||
llm_response: openai_response_choice, | ||
rephrased_question: openai_response_choice.dig("message", "content"), | ||
metrics:, | ||
) | ||
end | ||
|
||
private | ||
|
||
attr_reader :question_message, :question_records | ||
|
||
def openai_response | ||
@openai_response ||= openai_client.chat( | ||
parameters: { | ||
model: OPENAI_MODEL, | ||
messages:, | ||
temperature: 0.0, | ||
}, | ||
) | ||
end | ||
|
||
def openai_response_choice | ||
@openai_response_choice ||= openai_response.dig("choices", 0) | ||
end | ||
|
||
def messages | ||
[ | ||
{ role: "system", content: config[:system_prompt] }, | ||
{ role: "user", content: user_prompt }, | ||
] | ||
end | ||
|
||
def user_prompt | ||
config[:user_prompt] | ||
.sub("{question}", question_message) | ||
.sub("{message_history}", message_history) | ||
end | ||
|
||
def openai_client | ||
@openai_client ||= OpenAIClient.build | ||
end | ||
|
||
def message_history | ||
question_records.flat_map(&method(:map_question)).join("\n") | ||
end | ||
|
||
def map_question(question) | ||
question_message = question.answer.rephrased_question || question.message | ||
|
||
[ | ||
format_messsage("user", question_message), | ||
format_messsage("assistant", question.answer.message), | ||
] | ||
end | ||
|
||
def config | ||
Rails.configuration.govuk_chat_private.llm_prompts.openai.question_rephraser | ||
end | ||
|
||
def format_messsage(actor, message) | ||
["#{actor}:", '"""', message, '"""'].join("\n") | ||
end | ||
|
||
def metrics | ||
{ | ||
llm_prompt_tokens: openai_response.dig("usage", "prompt_tokens"), | ||
llm_completion_tokens: openai_response.dig("usage", "completion_tokens"), | ||
llm_cached_tokens: openai_response.dig("usage", "prompt_tokens_details", "cached_tokens"), | ||
} | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,100 +1,48 @@ | ||
module AnswerComposition | ||
module Pipeline | ||
class QuestionRephraser | ||
OPENAI_MODEL = "gpt-4o-mini".freeze | ||
Result = Data.define(:llm_response, :rephrased_question, :metrics) | ||
|
||
delegate :question, to: :context | ||
|
||
def self.call(...) = new(...).call | ||
|
||
def initialize(context) | ||
@context = context | ||
def initialize(llm_provider:) | ||
@llm_provider = llm_provider | ||
end | ||
|
||
def call | ||
return if first_question? | ||
|
||
start_time = Clock.monotonic_time | ||
|
||
context.question_message = openai_response_choice.dig("message", "content") | ||
|
||
context.answer.assign_llm_response("question_rephrasing", openai_response_choice) | ||
def call(context) | ||
records = message_records(context.question.conversation) | ||
|
||
context.answer.assign_metrics("question_rephrasing", { | ||
duration: Clock.monotonic_time - start_time, | ||
llm_prompt_tokens: openai_response.dig("usage", "prompt_tokens"), | ||
llm_completion_tokens: openai_response.dig("usage", "completion_tokens"), | ||
llm_cached_tokens: openai_response.dig("usage", "prompt_tokens_details", "cached_tokens"), | ||
}) | ||
end | ||
return if records.blank? # First question in a conversation | ||
|
||
private | ||
|
||
attr_reader :context | ||
|
||
def openai_response | ||
@openai_response ||= openai_client.chat( | ||
parameters: { | ||
model: OPENAI_MODEL, | ||
messages:, | ||
temperature: 0.0, | ||
}, | ||
start_time = Clock.monotonic_time | ||
klass = case llm_provider | ||
when :openai | ||
Pipeline::OpenAI::QuestionRephraser | ||
when :claude | ||
Pipeline::Claude::QuestionRephraser | ||
else | ||
raise "Unknown llm provider: #{llm_provider}" | ||
end | ||
|
||
result = klass.call(context.question.message, records) | ||
|
||
context.answer.assign_llm_response("question_rephrasing", result.llm_response) | ||
context.question_message = result.rephrased_question | ||
context.answer.assign_metrics( | ||
"question_rephrasing", | ||
{ duration: Clock.monotonic_time - start_time }.merge(result.metrics), | ||
) | ||
end | ||
|
||
def openai_response_choice | ||
@openai_response_choice ||= openai_response.dig("choices", 0) | ||
end | ||
private | ||
|
||
def messages | ||
[ | ||
{ role: "system", content: config[:system_prompt] }, | ||
{ role: "user", content: user_prompt }, | ||
] | ||
end | ||
attr_reader :llm_provider | ||
|
||
def message_records | ||
@message_records ||= Question.where(conversation: question.conversation) | ||
def message_records(conversation) | ||
@message_records ||= Question.where(conversation:) | ||
.includes(:answer) | ||
.joins(:answer) | ||
.last(5) | ||
.select(&:use_in_rephrasing?) | ||
end | ||
|
||
def message_history | ||
message_records.flat_map(&method(:map_question)).join("\n") | ||
end | ||
|
||
def map_question(question) | ||
question_message = question.answer.rephrased_question || question.message | ||
|
||
[ | ||
format_messsage("user", question_message), | ||
format_messsage("assistant", question.answer.message), | ||
] | ||
end | ||
|
||
def first_question? | ||
message_records.blank? | ||
end | ||
|
||
def config | ||
Rails.configuration.govuk_chat_private.llm_prompts.openai.question_rephraser | ||
end | ||
|
||
def user_prompt | ||
config[:user_prompt] | ||
.sub("{question}", question.message) | ||
.sub("{message_history}", message_history) | ||
end | ||
|
||
def openai_client | ||
@openai_client ||= OpenAIClient.build | ||
end | ||
|
||
def format_messsage(actor, message) | ||
["#{actor}:", '"""', message, '"""'].join("\n") | ||
end | ||
end | ||
end | ||
end |
Oops, something went wrong.