Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions lib/ruby_llm/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,20 @@ def initialize(model: nil, provider: nil, assume_model_exists: false, context: n
}
end

def ask(message = nil, with: nil, &)
add_message role: :user, content: build_content(message, with)
def ask(message = nil, with: nil, cache_point: false, &)
add_message role: :user, content: build_content(message, with), cache_point: cache_point
complete(&)
end

alias say ask

def with_instructions(instructions, append: false, replace: nil)
def with_instructions(instructions, append: false, replace: nil, cache_point: false)
append ||= (replace == false) unless replace.nil?

if append
append_system_instruction(instructions)
append_system_instruction(instructions, cache_point: cache_point)
else
replace_system_instruction(instructions)
replace_system_instruction(instructions, cache_point: cache_point)
end

self
Expand Down Expand Up @@ -329,21 +329,16 @@ def content_like?(object)
object.is_a?(Content) || object.is_a?(Content::Raw)
end

def append_system_instruction(instructions)
def append_system_instruction(instructions, cache_point: false)
system_messages, non_system_messages = @messages.partition { |msg| msg.role == :system }
system_messages << Message.new(role: :system, content: instructions)
system_messages << Message.new(role: :system, content: instructions, cache_point: cache_point)
@messages = system_messages + non_system_messages
end

def replace_system_instruction(instructions)
system_messages, non_system_messages = @messages.partition { |msg| msg.role == :system }
def replace_system_instruction(instructions, cache_point: false)
_, non_system_messages = @messages.partition { |msg| msg.role == :system }

if system_messages.empty?
system_messages = [Message.new(role: :system, content: instructions)]
else
system_messages.first.content = instructions
system_messages = [system_messages.first]
end
system_messages = [Message.new(role: :system, content: instructions, cache_point: cache_point)]

@messages = system_messages + non_system_messages
end
Expand Down
7 changes: 5 additions & 2 deletions lib/ruby_llm/message.rb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ module RubyLLM
class Message
ROLES = %i[system user assistant tool].freeze

attr_reader :role, :model_id, :tool_calls, :tool_call_id, :raw, :thinking, :tokens
attr_reader :role, :model_id, :tool_calls, :tool_call_id, :raw, :thinking, :tokens, :cache_point
alias cache_point? cache_point
attr_writer :content

def initialize(options = {})
Expand All @@ -24,6 +25,7 @@ def initialize(options = {})
)
@raw = options[:raw]
@thinking = options[:thinking]
@cache_point = options.fetch(:cache_point, false)

ensure_valid_role
end
Expand Down Expand Up @@ -80,7 +82,8 @@ def to_h
tool_calls: tool_calls,
tool_call_id: tool_call_id,
thinking: thinking&.text,
thinking_signature: thinking&.signature
thinking_signature: thinking&.signature,
cache_point: @cache_point || nil
}.merge(tokens ? tokens.to_h : {}).compact
end

Expand Down
2 changes: 1 addition & 1 deletion lib/ruby_llm/provider.rb
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def maybe_normalize_temperature(temperature, _model)

def sync_response(connection, payload, additional_headers = {})
response = connection.post completion_url, payload do |req|
req.headers = additional_headers.merge(req.headers) unless additional_headers.empty?
req.headers.merge!(additional_headers) unless additional_headers.empty?
end
parse_completion_response response
end
Expand Down
7 changes: 7 additions & 0 deletions lib/ruby_llm/providers/anthropic.rb
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ def headers
}
end

def complete(messages, headers: {}, **kwargs, &block)
headers = headers.merge('anthropic-beta' => 'prompt-caching-2024-07-31') if messages.any?(&:cache_point?)

super(messages, headers: headers, **kwargs, &block) # rubocop:disable Style/SuperArguments
# Ignoring as we're modifying headers before calling super. We need to call super with modified headers.
end

class << self
def capabilities
Anthropic::Capabilities
Expand Down
24 changes: 19 additions & 5 deletions lib/ruby_llm/providers/anthropic/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@ def build_system_content(system_messages)
system_messages.flat_map do |msg|
content = msg.content

if content.is_a?(RubyLLM::Content::Raw)
content.value
else
Media.format_content(content)
end
blocks = if content.is_a?(RubyLLM::Content::Raw)
Array(content.value)
else
Array(Media.format_content(content))
end

msg.cache_point? ? inject_cache_control(blocks) : blocks
end
end

Expand Down Expand Up @@ -159,6 +161,7 @@ def format_basic_message_with_thinking(msg, thinking_enabled)
end

append_formatted_content(content_blocks, msg.content)
inject_cache_control(content_blocks) if msg.cache_point?

{
role: convert_role(msg.role),
Expand Down Expand Up @@ -228,6 +231,17 @@ def append_formatted_content(content_blocks, content)
end
end

def inject_cache_control(blocks)
return blocks if blocks.empty?

last = blocks.last
# Don't duplicate if already present (e.g. Content::Raw with cache_control)
return blocks if last.is_a?(Hash) && last[:cache_control]

blocks[-1] = last.merge(cache_control: { type: 'ephemeral' })
blocks
end

def convert_role(role)
case role
when :tool, :user then 'user'
Expand Down
32 changes: 20 additions & 12 deletions lib/ruby_llm/providers/bedrock/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,26 @@ def render_message_content(msg)
text_and_media_blocks = Media.render_content(msg.content, used_document_names: @used_document_names)
blocks.concat(text_and_media_blocks) if text_and_media_blocks

if msg.tool_call?
msg.tool_calls.each_value do |tool_call|
blocks << {
toolUse: {
toolUseId: tool_call.id,
name: tool_call.name,
input: tool_call.arguments
}
}
end
end
append_tool_use_blocks(blocks, msg)
blocks << { cachePoint: { type: 'default' } } if msg.cache_point?

blocks
end

def append_tool_use_blocks(blocks, msg)
return unless msg.tool_call?

msg.tool_calls.each_value do |tool_call|
blocks << {
toolUse: {
toolUseId: tool_call.id,
name: tool_call.name,
input: tool_call.arguments
}
}
end
end

def render_raw_content(content)
value = content.value
value.is_a?(Array) ? value : [value]
Expand Down Expand Up @@ -200,7 +205,10 @@ def render_role(role)
end

def render_system(messages)
messages.flat_map { |msg| Media.render_content(msg.content, used_document_names: @used_document_names) }
messages.flat_map do |msg|
blocks = Media.render_content(msg.content, used_document_names: @used_document_names)
msg.cache_point? ? blocks + [{ cachePoint: { type: 'default' } }] : blocks
end
end

def render_inference_config(_model, temperature)
Expand Down
2 changes: 1 addition & 1 deletion lib/ruby_llm/streaming.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def stream_response(connection, payload, additional_headers = {}, &block)
accumulator = StreamAccumulator.new

response = connection.post stream_url, payload do |req|
req.headers = additional_headers.merge(req.headers) unless additional_headers.empty?
req.headers.merge!(additional_headers) unless additional_headers.empty?
if faraday_1?
req.options[:on_data] = handle_stream do |chunk|
accumulator.add chunk
Expand Down
58 changes: 58 additions & 0 deletions spec/ruby_llm/chat_cache_point_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# frozen_string_literal: true

require 'spec_helper'

RSpec.describe RubyLLM::Chat do
include_context 'with configured RubyLLM'

describe 'cache_point forwarding' do
let(:chat) { RubyLLM.chat }

shared_examples 'a method that supports cache_point' do |message_finder|
it 'sets cache_point? true when cache_point: true' do
action.call(cache_point: true)
message = message_finder.call(chat)
expect(message).not_to be_nil
expect(message.cache_point?).to be true
end

it 'sets cache_point? false when cache_point is omitted' do
action.call
message = message_finder.call(chat)
expect(message).not_to be_nil
expect(message.cache_point?).to be false
end
end

describe '#with_instructions' do
let(:action) { ->(opts = {}) { chat.with_instructions('Be helpful', **opts) } }
let(:finder) { ->(c) { c.messages.find { |m| m.role == :system } } }

it_behaves_like 'a method that supports cache_point', ->(c) { c.messages.find { |m| m.role == :system } }

it 'sets cache_point? true on appended message only' do
chat.with_instructions('First instruction')
chat.with_instructions('Second instruction', append: true, cache_point: true)
system_msgs = chat.messages.select { |m| m.role == :system }
expect(system_msgs.last.cache_point?).to be true
expect(system_msgs.first.cache_point?).to be false
end

it 'preserves cache_point: true when replacing' do
chat.with_instructions('Old instruction', cache_point: false)
chat.with_instructions('New instruction', replace: true, cache_point: true)
system_msgs = chat.messages.select { |m| m.role == :system }
expect(system_msgs.size).to eq(1)
expect(system_msgs.first.cache_point?).to be true
end
end

describe '#ask' do
before { allow(chat).to receive(:complete) }

let(:action) { ->(opts = {}) { chat.ask('Hello', **opts) } }

it_behaves_like 'a method that supports cache_point', ->(c) { c.messages.find { |m| m.role == :user } }
end
end
end
24 changes: 24 additions & 0 deletions spec/ruby_llm/message_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,30 @@
require 'spec_helper'

RSpec.describe RubyLLM::Message do
describe '#cache_point?' do
it 'returns false by default' do
message = described_class.new(role: :user, content: 'hello')
expect(message.cache_point?).to be false
end

it 'returns true when constructed with cache_point: true' do
message = described_class.new(role: :user, content: 'hello', cache_point: true)
expect(message.cache_point?).to be true
end
end

describe '#to_h' do
it 'omits cache_point key when false' do
message = described_class.new(role: :user, content: 'hello')
expect(message.to_h).not_to have_key(:cache_point)
end

it 'includes cache_point: true when set' do
message = described_class.new(role: :user, content: 'hello', cache_point: true)
expect(message.to_h[:cache_point]).to be true
end
end

describe '#content' do
it 'normalizes nil content to empty string for assistant tool-call messages' do
tool_call = RubyLLM::ToolCall.new(id: 'call_1', name: 'weather', arguments: {})
Expand Down
70 changes: 70 additions & 0 deletions spec/ruby_llm/providers/anthropic/chat_cache_control_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# frozen_string_literal: true

require 'spec_helper'

RSpec.describe RubyLLM::Providers::Anthropic::Chat do
let(:model) { instance_double(RubyLLM::Model::Info, id: 'claude-sonnet-4-5', max_tokens: nil) }

def render(messages)
described_class.render_payload(
messages,
tools: {},
temperature: nil,
model: model,
stream: false,
schema: nil
)
end

describe 'cache_control injection' do
context 'with a system message where cache_point is true' do
it 'adds cache_control to the last system block' do
msg = RubyLLM::Message.new(role: :system, content: 'You are helpful.', cache_point: true)
payload = render([msg, RubyLLM::Message.new(role: :user, content: 'Hi')])

last_block = payload[:system].last
expect(last_block[:cache_control]).to eq(type: 'ephemeral')
end

it 'does not add cache_control when cache_point is false' do
msg = RubyLLM::Message.new(role: :system, content: 'You are helpful.')
payload = render([msg, RubyLLM::Message.new(role: :user, content: 'Hi')])

payload[:system].each do |block|
expect(block).not_to have_key(:cache_control)
end
end
end

context 'with a user message where cache_point is true' do
it 'adds cache_control to the last content block' do
msg = RubyLLM::Message.new(role: :user, content: 'Tell me a story.', cache_point: true)
payload = render([msg])

last_block = payload[:messages].first[:content].last
expect(last_block[:cache_control]).to eq(type: 'ephemeral')
end

it 'does not add cache_control when cache_point is false' do
msg = RubyLLM::Message.new(role: :user, content: 'Tell me a story.')
payload = render([msg])

payload[:messages].first[:content].each do |block|
expect(block).not_to have_key(:cache_control)
end
end
end

context 'when a Content::Raw block already contains cache_control' do
it 'does not duplicate when cache_control' do
raw = RubyLLM::Providers::Anthropic::Content.new('Cached system', cache: true)
msg = RubyLLM::Message.new(role: :system, content: raw, cache_point: true)
payload = render([msg, RubyLLM::Message.new(role: :user, content: 'Hi')])

blocks_with_cache = payload[:system].select { |b| b[:cache_control] }
expect(blocks_with_cache.length).to eq(1)
expect(blocks_with_cache.first[:cache_control]).to eq(type: 'ephemeral')
end
end
end
end
Loading