Skip to content

Commit

Permalink
0.4.0 Anthropic prompt caching support
Browse files Browse the repository at this point in the history
  • Loading branch information
obie committed Oct 19, 2024
1 parent ad7e44a commit 66a20b6
Show file tree
Hide file tree
Showing 10 changed files with 381 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .rubocop.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Style/StringLiteralsInInterpolation:
EnforcedStyle: double_quotes

Layout/LineLength:
Max: 120
Max: 180

Metrics/BlockLength:
Enabled: false
Expand Down
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@
- adds `ChatCompletion` module
- adds `PromptDeclarations` module
- adds `FunctionDispatch` module

## [0.3.2] - 2024-06-29
- adds support for streaming

## [0.4.0] - 2024-10-18
- adds support for Anthropic-style prompt caching
- defaults to `max_completion_tokens` when using OpenAI directly
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,30 @@ transcript << { role: "user", content: "What is the meaning of life?" }

One of the advantages of OpenRouter and the reason that it is used by default by this library is that it handles mapping message formats from the OpenAI standard to whatever other model you're wanting to use (Anthropic, Cohere, etc.)

### Prompt Caching

Raix supports [Anthropic-style prompt caching](https://openrouter.ai/docs/prompt-caching#anthropic-claude) when using Anthropic's Claud family of models. You can specify a `cache_at` parameter when doing a chat completion. If the character count for the content of a particular message is longer than the cache_at parameter, it will be sent to Anthropic as a multipart message with a cache control "breakpoint" set to "ephemeral".

Note that there is a limit of four breakpoints, and the cache will expire within five minutes. Therefore, it is recommended to reserve the cache breakpoints for large bodies of text, such as character cards, CSV data, RAG data, book chapters, etc. Raix does not enforce a limit on the number of breakpoints, which means that you might get an error if you try to cache too many messages.

```ruby
>> my_class.chat_completion(params: { cache_at: 1000 })
=> {
"messages": [
{
"role": "system",
"content": [
{
"type": "text",
"text": "HUGE TEXT BODY LONGER THAN 1000 CHARACTERS",
"cache_control": {
"type": "ephemeral"
}
}
]
},
```

### Use of Tools/Functions

The second (optional) module that you can add to your Ruby classes after `ChatCompletion` is `FunctionDispatch`. It lets you declare and implement functions to be called at the AI's discretion as part of a chat completion "loop" in a declarative, Rails-like "DSL" fashion.
Expand Down
5 changes: 5 additions & 0 deletions lib/raix.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class Configuration
# The max_tokens option determines the maximum number of tokens to generate.
attr_accessor :max_tokens

# The max_completion_tokens option determines the maximum number of tokens to generate.
attr_accessor :max_completion_tokens

# The model option determines the model to use for text generation. This option
# is normally set in each class that includes the ChatCompletion module.
attr_accessor :model
Expand All @@ -27,12 +30,14 @@ class Configuration
attr_accessor :openai_client

DEFAULT_MAX_TOKENS = 1000
DEFAULT_MAX_COMPLETION_TOKENS = 16_384
DEFAULT_MODEL = "meta-llama/llama-3-8b-instruct:free"
DEFAULT_TEMPERATURE = 0.0

# Initializes a new instance of the Configuration class with default values.
def initialize
self.temperature = DEFAULT_TEMPERATURE
self.max_completion_tokens = DEFAULT_MAX_COMPLETION_TOKENS
self.max_tokens = DEFAULT_MAX_TOKENS
self.model = DEFAULT_MODEL
end
Expand Down
64 changes: 34 additions & 30 deletions lib/raix/chat_completion.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

require "active_support/concern"
require "active_support/core_ext/object/blank"
require "raix/message_adapters/base"
require "open_router"
require "openai"

Expand All @@ -17,9 +18,9 @@ module Raix
module ChatCompletion
extend ActiveSupport::Concern

attr_accessor :frequency_penalty, :logit_bias, :logprobs, :loop, :min_p, :model, :presence_penalty,
:repetition_penalty, :response_format, :stream, :temperature, :max_tokens, :seed, :stop, :top_a,
:top_k, :top_logprobs, :top_p, :tools, :tool_choice, :provider
attr_accessor :cache_at, :frequency_penalty, :logit_bias, :logprobs, :loop, :min_p, :model, :presence_penalty,
:repetition_penalty, :response_format, :stream, :temperature, :max_completion_tokens,
:max_tokens, :seed, :stop, :top_a, :top_k, :top_logprobs, :top_p, :tools, :tool_choice, :provider

# This method performs chat completion based on the provided transcript and parameters.
#
Expand All @@ -30,16 +31,12 @@ module ChatCompletion
# @option params [Boolean] :raw (false) Whether to return the raw response or dig the text content.
# @return [String|Hash] The completed chat response.
def chat_completion(params: {}, loop: false, json: false, raw: false, openai: false)
messages = transcript.flatten.compact.map { |msg| transform_message_format(msg) }
raise "Can't complete an empty transcript" if messages.blank?

# used by FunctionDispatch
self.loop = loop

# set params to default values if not provided
params[:cache_at] ||= cache_at.presence
params[:frequency_penalty] ||= frequency_penalty.presence
params[:logit_bias] ||= logit_bias.presence
params[:logprobs] ||= logprobs.presence
params[:max_completion_tokens] ||= max_completion_tokens.presence || Raix.configuration.max_completion_tokens
params[:max_tokens] ||= max_tokens.presence || Raix.configuration.max_tokens
params[:min_p] ||= min_p.presence
params[:presence_penalty] ||= presence_penalty.presence
Expand All @@ -57,23 +54,29 @@ def chat_completion(params: {}, loop: false, json: false, raw: false, openai: fa
params[:top_p] ||= top_p.presence

if json
params[:provider] ||= {}
params[:provider][:require_parameters] = true
unless openai
params[:provider] ||= {}
params[:provider][:require_parameters] = true
end
params[:response_format] ||= {}
params[:response_format][:type] = "json_object"
end

# used by FunctionDispatch
self.loop = loop

# set the model to the default if not provided
self.model ||= Raix.configuration.model

adapter = MessageAdapters::Base.new(self)
messages = transcript.flatten.compact.map { |msg| adapter.transform(msg) }
raise "Can't complete an empty transcript" if messages.blank?

begin
response = if openai
openai_request(params:, model: openai,
messages:)
openai_request(params:, model: openai, messages:)
else
openrouter_request(
params:, model:, messages:
)
openrouter_request(params:, model:, messages:)
end
retry_count = 0
content = nil
Expand Down Expand Up @@ -115,8 +118,8 @@ def chat_completion(params: {}, loop: false, json: false, raw: false, openai: fa
raise e # just fail if we can't get content after 3 attempts
end

# attempt to fix the JSON
JsonFixer.new.call(content, e.message)
puts "Bad JSON received!!!!!!: #{content}"
raise e
rescue Faraday::BadRequestError => e
# make sure we see the actual error message on console or Honeybadger
puts "Chat completion failed!!!!!!!!!!!!!!!!: #{e.response[:body]}"
Expand All @@ -132,6 +135,9 @@ def chat_completion(params: {}, loop: false, json: false, raw: false, openai: fa
# { user: "Hey what time is it?" },
# { assistant: "Sorry, pumpkins do not wear watches" }
#
# to add a function call use the following format:
# { function: { name: 'fancy_pants_function', arguments: { param: 'value' } } }
#
# to add a function result use the following format:
# { function: result, name: 'fancy_pants_function' }
#
Expand All @@ -143,11 +149,21 @@ def transcript
private

def openai_request(params:, model:, messages:)
# deprecated in favor of max_completion_tokens
params.delete(:max_tokens)

params[:stream] ||= stream.presence
params[:stream_options] = { include_usage: true } if params[:stream]

params.delete(:temperature) if model == "o1-preview"

Raix.configuration.openai_client.chat(parameters: params.compact.merge(model:, messages:))
end

def openrouter_request(params:, model:, messages:)
# max_completion_tokens is not supported by OpenRouter
params.delete(:max_completion_tokens)

retry_count = 0

begin
Expand All @@ -163,17 +179,5 @@ def openrouter_request(params:, model:, messages:)
raise e
end
end

def transform_message_format(message)
return message if message[:role].present?

if message[:function].present?
{ role: "assistant", name: message.dig(:function, :name), content: message.dig(:function, :arguments).to_json }
elsif message[:result].present?
{ role: "function", name: message[:name], content: message[:result] }
else
{ role: message.first.first, content: message.first.last }
end
end
end
end
50 changes: 50 additions & 0 deletions lib/raix/message_adapters/base.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# frozen_string_literal: true

require "active_support/core_ext/module/delegation"

module Raix
module MessageAdapters
# Transforms messages into the format expected by the OpenAI API
class Base
attr_accessor :context

delegate :cache_at, :model, to: :context

def initialize(context)
@context = context
end

def transform(message)
return message if message[:role].present?

if message[:function].present?
{ role: "assistant", name: message.dig(:function, :name), content: message.dig(:function, :arguments).to_json }
elsif message[:result].present?
{ role: "function", name: message[:name], content: message[:result] }
else
content(message)
end
end

protected

def content(message)
case message
in { system: content }
{ role: "system", content: }
in { user: content }
{ role: "user", content: }
in { assistant: content }
{ role: "assistant", content: }
else
raise ArgumentError, "Invalid message format: #{message.inspect}"
end.tap do |msg|
# convert to anthropic multipart format if model is claude-3 and cache_at is set
if model["anthropic/claude-3"] && cache_at && msg[:content].length > cache_at.to_i
msg[:content] = [{ type: "text", text: msg[:content], cache_control: { type: "ephemeral" } }]
end
end
end
end
end
end
2 changes: 1 addition & 1 deletion lib/raix/version.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# frozen_string_literal: true

module Raix
VERSION = "0.3.2"
VERSION = "0.4.0"
end
Loading

0 comments on commit 66a20b6

Please sign in to comment.