Skip to content

Commit

Permalink
Add question rephrasing to Claude pipeline
Browse files Browse the repository at this point in the history
This adds a question rephrasing step to the Claude pipeline which
mirrors the existing OpenAI rephrasing.

Rather than just copying the existing implementation into the
`Claude` namespace and updating the API calls, instead we keep the
existing `Pipeline::QuestionRephraser` class and instead makes the
changes in that class to call provider-specific implementations instead.

The reason for this is that there's some logic in this step which is
applied before we actually make the API calls; namely that of:

* returning early if this is the first question in the conversation
* plucking out the last 5 relevant questions from the conversation to
  include in the prompt

So it makes sense to perform this logic in the
`Pipeline::QuestionRephraser` class before calling the provider APIs. To
that end, we now do the early return logic and find the relevant
`Question` records, then pass those records into provider-specific
classes.

Aside from that the implementation is relatively straightforward. We use
the same prompt and interpolate the same values into it.

In terms of the tests, I've added a new method: `bedrock_claude_text_response`. This
simply stubs a call to Bedrock with a specified text response. We didn't
have this previously as the structured answer call stubs the tool call
and tool call response, rather than a non-tool call text response.

Additionally, it optionally asserts on the "user" messages within the
`messages` array we pass to the Bedrock `converse` method. If the
`user_messages` argument is present (be it a string or a Regex), then an
element in the `messages` array must have a role of `"user"` and must
match the value in the argument, otherwise an error is raised.

This is similar to how the current OpenAI stubbing works, in that we
make assertions on the request body as well. It's trickier to make these
same assertions with the Bedrock library because we don't have the same
control over the request body.

With question rephrasing it's quite important that we make some
assertions on the request body (i.e. the messages) as some specs are
testing that the rephrased question is used in the prompt.
  • Loading branch information
jackbot committed Feb 11, 2025
1 parent 71a0acd commit da8964c
Show file tree
Hide file tree
Showing 11 changed files with 534 additions and 237 deletions.
3 changes: 2 additions & 1 deletion lib/answer_composition/composer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def compose_answer
when "openai_structured_answer"
PipelineRunner.call(question:, pipeline: [
Pipeline::JailbreakGuardrails,
Pipeline::QuestionRephraser,
Pipeline::QuestionRephraser.new(llm_provider: :openai),
Pipeline::QuestionRouter,
Pipeline::QuestionRoutingGuardrails,
Pipeline::SearchResultFetcher,
Expand All @@ -54,6 +54,7 @@ def compose_answer
])
when "claude_structured_answer"
PipelineRunner.call(question:, pipeline: [
Pipeline::QuestionRephraser.new(llm_provider: :claude),
Pipeline::Claude::StructuredAnswerComposer,
])
else
Expand Down
5 changes: 5 additions & 0 deletions lib/answer_composition/pipeline/claude.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module AnswerComposition::Pipeline::Claude
def self.prompt_config
Rails.configuration.govuk_chat_private.llm_prompts.claude
end
end
83 changes: 83 additions & 0 deletions lib/answer_composition/pipeline/claude/question_rephraser.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
module AnswerComposition::Pipeline
module Claude
class QuestionRephraser
# TODO: change this to a more basic model
BEDROCK_MODEL = "eu.anthropic.claude-3-5-sonnet-20240620-v1:0".freeze

def self.call(...) = new(...).call

def initialize(question_message, message_records)
@question_message = question_message
@message_records = message_records
end

def call
response = bedrock_client.converse(
system: [{ text: config[:system_prompt] }],
model_id: BEDROCK_MODEL,
messages:,
inference_config:,
)

AnswerComposition::Pipeline::QuestionRephraser::Result.new(
llm_response: response.to_h,
rephrased_question: response.dig("output", "message", "content", 0, "text"),
metrics: build_metrics(response),
)
end

private

attr_reader :question_message, :message_records

def bedrock_client
@bedrock_client ||= Aws::BedrockRuntime::Client.new
end

def build_metrics(response)
{
llm_prompt_tokens: response.dig("usage", "input_tokens"),
llm_completion_tokens: response.dig("usage", "output_tokens"),
}
end

def config
Claude.prompt_config[:question_rephraser]
end

def inference_config
{
max_tokens: 4096,
temperature: 0.0,
}
end

def user_prompt
config[:user_prompt]
.sub("{question}", question_message)
.sub("{message_history}", message_history)
end

def messages
[{ role: "user", content: [{ text: user_prompt }] }]
end

def message_history
message_records.flat_map(&method(:map_question)).join("\n")
end

def map_question(question)
question_message = question.answer.rephrased_question || question.message

[
format_messsage("user", question_message),
format_messsage("assistant", question.answer.message),
]
end

def format_messsage(actor, message)
["#{actor}:", '"""', message, '"""'].join("\n")
end
end
end
end
84 changes: 84 additions & 0 deletions lib/answer_composition/pipeline/openai/question_rephraser.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
module AnswerComposition::Pipeline::OpenAI
class QuestionRephraser
OPENAI_MODEL = "gpt-4o-mini".freeze

def self.call(...) = new(...).call

def initialize(question_message, question_records)
@question_message = question_message
@question_records = question_records
end

def call
AnswerComposition::Pipeline::QuestionRephraser::Result.new(
llm_response: openai_response_choice,
rephrased_question: openai_response_choice.dig("message", "content"),
metrics:,
)
end

private

attr_reader :question_message, :question_records

def openai_response
@openai_response ||= openai_client.chat(
parameters: {
model: OPENAI_MODEL,
messages:,
temperature: 0.0,
},
)
end

def openai_response_choice
@openai_response_choice ||= openai_response.dig("choices", 0)
end

def messages
[
{ role: "system", content: config[:system_prompt] },
{ role: "user", content: user_prompt },
]
end

def user_prompt
config[:user_prompt]
.sub("{question}", question_message)
.sub("{message_history}", message_history)
end

def openai_client
@openai_client ||= OpenAIClient.build
end

def message_history
question_records.flat_map(&method(:map_question)).join("\n")
end

def map_question(question)
question_message = question.answer.rephrased_question || question.message

[
format_messsage("user", question_message),
format_messsage("assistant", question.answer.message),
]
end

def config
Rails.configuration.govuk_chat_private.llm_prompts.openai.question_rephraser
end

def format_messsage(actor, message)
["#{actor}:", '"""', message, '"""'].join("\n")
end

def metrics
{
llm_prompt_tokens: openai_response.dig("usage", "prompt_tokens"),
llm_completion_tokens: openai_response.dig("usage", "completion_tokens"),
llm_cached_tokens: openai_response.dig("usage", "prompt_tokens_details", "cached_tokens"),
}
end
end
end
106 changes: 27 additions & 79 deletions lib/answer_composition/pipeline/question_rephraser.rb
Original file line number Diff line number Diff line change
@@ -1,100 +1,48 @@
module AnswerComposition
module Pipeline
class QuestionRephraser
OPENAI_MODEL = "gpt-4o-mini".freeze
Result = Data.define(:llm_response, :rephrased_question, :metrics)

delegate :question, to: :context

def self.call(...) = new(...).call

def initialize(context)
@context = context
def initialize(llm_provider:)
@llm_provider = llm_provider
end

def call
return if first_question?

start_time = Clock.monotonic_time

context.question_message = openai_response_choice.dig("message", "content")

context.answer.assign_llm_response("question_rephrasing", openai_response_choice)
def call(context)
records = message_records(context.question.conversation)

context.answer.assign_metrics("question_rephrasing", {
duration: Clock.monotonic_time - start_time,
llm_prompt_tokens: openai_response.dig("usage", "prompt_tokens"),
llm_completion_tokens: openai_response.dig("usage", "completion_tokens"),
llm_cached_tokens: openai_response.dig("usage", "prompt_tokens_details", "cached_tokens"),
})
end
return if records.blank? # First question in a conversation

private

attr_reader :context

def openai_response
@openai_response ||= openai_client.chat(
parameters: {
model: OPENAI_MODEL,
messages:,
temperature: 0.0,
},
start_time = Clock.monotonic_time
klass = case llm_provider
when :openai
Pipeline::OpenAI::QuestionRephraser
when :claude
Pipeline::Claude::QuestionRephraser
else
raise "Unknown llm provider: #{llm_provider}"
end

result = klass.call(context.question.message, records)

context.answer.assign_llm_response("question_rephrasing", result.llm_response)
context.question_message = result.rephrased_question
context.answer.assign_metrics(
"question_rephrasing",
{ duration: Clock.monotonic_time - start_time }.merge(result.metrics),
)
end

def openai_response_choice
@openai_response_choice ||= openai_response.dig("choices", 0)
end
private

def messages
[
{ role: "system", content: config[:system_prompt] },
{ role: "user", content: user_prompt },
]
end
attr_reader :llm_provider

def message_records
@message_records ||= Question.where(conversation: question.conversation)
def message_records(conversation)
@message_records ||= Question.where(conversation:)
.includes(:answer)
.joins(:answer)
.last(5)
.select(&:use_in_rephrasing?)
end

def message_history
message_records.flat_map(&method(:map_question)).join("\n")
end

def map_question(question)
question_message = question.answer.rephrased_question || question.message

[
format_messsage("user", question_message),
format_messsage("assistant", question.answer.message),
]
end

def first_question?
message_records.blank?
end

def config
Rails.configuration.govuk_chat_private.llm_prompts.openai.question_rephraser
end

def user_prompt
config[:user_prompt]
.sub("{question}", question.message)
.sub("{message_history}", message_history)
end

def openai_client
@openai_client ||= OpenAIClient.build
end

def format_messsage(actor, message)
["#{actor}:", '"""', message, '"""'].join("\n")
end
end
end
end
Loading

0 comments on commit da8964c

Please sign in to comment.