diff --git a/lib/ruby_llm/chat.rb b/lib/ruby_llm/chat.rb index afc572342..352dee9f6 100644 --- a/lib/ruby_llm/chat.rb +++ b/lib/ruby_llm/chat.rb @@ -32,20 +32,20 @@ def initialize(model: nil, provider: nil, assume_model_exists: false, context: n } end - def ask(message = nil, with: nil, &) - add_message role: :user, content: build_content(message, with) + def ask(message = nil, with: nil, cache_point: false, &) + add_message role: :user, content: build_content(message, with), cache_point: cache_point complete(&) end alias say ask - def with_instructions(instructions, append: false, replace: nil) + def with_instructions(instructions, append: false, replace: nil, cache_point: false) append ||= (replace == false) unless replace.nil? if append - append_system_instruction(instructions) + append_system_instruction(instructions, cache_point: cache_point) else - replace_system_instruction(instructions) + replace_system_instruction(instructions, cache_point: cache_point) end self @@ -329,21 +329,16 @@ def content_like?(object) object.is_a?(Content) || object.is_a?(Content::Raw) end - def append_system_instruction(instructions) + def append_system_instruction(instructions, cache_point: false) system_messages, non_system_messages = @messages.partition { |msg| msg.role == :system } - system_messages << Message.new(role: :system, content: instructions) + system_messages << Message.new(role: :system, content: instructions, cache_point: cache_point) @messages = system_messages + non_system_messages end - def replace_system_instruction(instructions) - system_messages, non_system_messages = @messages.partition { |msg| msg.role == :system } + def replace_system_instruction(instructions, cache_point: false) + _, non_system_messages = @messages.partition { |msg| msg.role == :system } - if system_messages.empty? - system_messages = [Message.new(role: :system, content: instructions)] - else - system_messages.first.content = instructions - system_messages = [system_messages.first] - end + system_messages = [Message.new(role: :system, content: instructions, cache_point: cache_point)] @messages = system_messages + non_system_messages end diff --git a/lib/ruby_llm/message.rb b/lib/ruby_llm/message.rb index eefb93e55..dc88da741 100644 --- a/lib/ruby_llm/message.rb +++ b/lib/ruby_llm/message.rb @@ -5,7 +5,8 @@ module RubyLLM class Message ROLES = %i[system user assistant tool].freeze - attr_reader :role, :model_id, :tool_calls, :tool_call_id, :raw, :thinking, :tokens + attr_reader :role, :model_id, :tool_calls, :tool_call_id, :raw, :thinking, :tokens, :cache_point + alias cache_point? cache_point attr_writer :content def initialize(options = {}) @@ -24,6 +25,7 @@ def initialize(options = {}) ) @raw = options[:raw] @thinking = options[:thinking] + @cache_point = options.fetch(:cache_point, false) ensure_valid_role end @@ -80,7 +82,8 @@ def to_h tool_calls: tool_calls, tool_call_id: tool_call_id, thinking: thinking&.text, - thinking_signature: thinking&.signature + thinking_signature: thinking&.signature, + cache_point: @cache_point || nil }.merge(tokens ? tokens.to_h : {}).compact end diff --git a/lib/ruby_llm/provider.rb b/lib/ruby_llm/provider.rb index a434ac114..db5e77447 100644 --- a/lib/ruby_llm/provider.rb +++ b/lib/ruby_llm/provider.rb @@ -257,7 +257,7 @@ def maybe_normalize_temperature(temperature, _model) def sync_response(connection, payload, additional_headers = {}) response = connection.post completion_url, payload do |req| - req.headers = additional_headers.merge(req.headers) unless additional_headers.empty? + req.headers.merge!(additional_headers) unless additional_headers.empty? end parse_completion_response response end diff --git a/lib/ruby_llm/providers/anthropic.rb b/lib/ruby_llm/providers/anthropic.rb index a0686036f..f7e85ff32 100644 --- a/lib/ruby_llm/providers/anthropic.rb +++ b/lib/ruby_llm/providers/anthropic.rb @@ -22,6 +22,13 @@ def headers } end + def complete(messages, headers: {}, **kwargs, &block) + headers = headers.merge('anthropic-beta' => 'prompt-caching-2024-07-31') if messages.any?(&:cache_point?) + + super(messages, headers: headers, **kwargs, &block) # rubocop:disable Style/SuperArguments + # Ignoring as we're modifying headers before calling super. We need to call super with modified headers. + end + class << self def capabilities Anthropic::Capabilities diff --git a/lib/ruby_llm/providers/anthropic/chat.rb b/lib/ruby_llm/providers/anthropic/chat.rb index 9926fe98b..24f009a48 100644 --- a/lib/ruby_llm/providers/anthropic/chat.rb +++ b/lib/ruby_llm/providers/anthropic/chat.rb @@ -41,11 +41,13 @@ def build_system_content(system_messages) system_messages.flat_map do |msg| content = msg.content - if content.is_a?(RubyLLM::Content::Raw) - content.value - else - Media.format_content(content) - end + blocks = if content.is_a?(RubyLLM::Content::Raw) + Array(content.value) + else + Array(Media.format_content(content)) + end + + msg.cache_point? ? inject_cache_control(blocks) : blocks end end @@ -159,6 +161,7 @@ def format_basic_message_with_thinking(msg, thinking_enabled) end append_formatted_content(content_blocks, msg.content) + inject_cache_control(content_blocks) if msg.cache_point? { role: convert_role(msg.role), @@ -228,6 +231,17 @@ def append_formatted_content(content_blocks, content) end end + def inject_cache_control(blocks) + return blocks if blocks.empty? + + last = blocks.last + # Don't duplicate if already present (e.g. Content::Raw with cache_control) + return blocks if last.is_a?(Hash) && last[:cache_control] + + blocks[-1] = last.merge(cache_control: { type: 'ephemeral' }) + blocks + end + def convert_role(role) case role when :tool, :user then 'user' diff --git a/lib/ruby_llm/providers/bedrock/chat.rb b/lib/ruby_llm/providers/bedrock/chat.rb index c39fa2942..a2dbfa415 100644 --- a/lib/ruby_llm/providers/bedrock/chat.rb +++ b/lib/ruby_llm/providers/bedrock/chat.rb @@ -114,21 +114,26 @@ def render_message_content(msg) text_and_media_blocks = Media.render_content(msg.content, used_document_names: @used_document_names) blocks.concat(text_and_media_blocks) if text_and_media_blocks - if msg.tool_call? - msg.tool_calls.each_value do |tool_call| - blocks << { - toolUse: { - toolUseId: tool_call.id, - name: tool_call.name, - input: tool_call.arguments - } - } - end - end + append_tool_use_blocks(blocks, msg) + blocks << { cachePoint: { type: 'default' } } if msg.cache_point? blocks end + def append_tool_use_blocks(blocks, msg) + return unless msg.tool_call? + + msg.tool_calls.each_value do |tool_call| + blocks << { + toolUse: { + toolUseId: tool_call.id, + name: tool_call.name, + input: tool_call.arguments + } + } + end + end + def render_raw_content(content) value = content.value value.is_a?(Array) ? value : [value] @@ -200,7 +205,10 @@ def render_role(role) end def render_system(messages) - messages.flat_map { |msg| Media.render_content(msg.content, used_document_names: @used_document_names) } + messages.flat_map do |msg| + blocks = Media.render_content(msg.content, used_document_names: @used_document_names) + msg.cache_point? ? blocks + [{ cachePoint: { type: 'default' } }] : blocks + end end def render_inference_config(_model, temperature) diff --git a/lib/ruby_llm/streaming.rb b/lib/ruby_llm/streaming.rb index a671f9cca..df7e82566 100644 --- a/lib/ruby_llm/streaming.rb +++ b/lib/ruby_llm/streaming.rb @@ -9,7 +9,7 @@ def stream_response(connection, payload, additional_headers = {}, &block) accumulator = StreamAccumulator.new response = connection.post stream_url, payload do |req| - req.headers = additional_headers.merge(req.headers) unless additional_headers.empty? + req.headers.merge!(additional_headers) unless additional_headers.empty? if faraday_1? req.options[:on_data] = handle_stream do |chunk| accumulator.add chunk diff --git a/spec/ruby_llm/chat_cache_point_spec.rb b/spec/ruby_llm/chat_cache_point_spec.rb new file mode 100644 index 000000000..2f730c1a4 --- /dev/null +++ b/spec/ruby_llm/chat_cache_point_spec.rb @@ -0,0 +1,58 @@ +# frozen_string_literal: true + +require 'spec_helper' + +RSpec.describe RubyLLM::Chat do + include_context 'with configured RubyLLM' + + describe 'cache_point forwarding' do + let(:chat) { RubyLLM.chat } + + shared_examples 'a method that supports cache_point' do |message_finder| + it 'sets cache_point? true when cache_point: true' do + action.call(cache_point: true) + message = message_finder.call(chat) + expect(message).not_to be_nil + expect(message.cache_point?).to be true + end + + it 'sets cache_point? false when cache_point is omitted' do + action.call + message = message_finder.call(chat) + expect(message).not_to be_nil + expect(message.cache_point?).to be false + end + end + + describe '#with_instructions' do + let(:action) { ->(opts = {}) { chat.with_instructions('Be helpful', **opts) } } + let(:finder) { ->(c) { c.messages.find { |m| m.role == :system } } } + + it_behaves_like 'a method that supports cache_point', ->(c) { c.messages.find { |m| m.role == :system } } + + it 'sets cache_point? true on appended message only' do + chat.with_instructions('First instruction') + chat.with_instructions('Second instruction', append: true, cache_point: true) + system_msgs = chat.messages.select { |m| m.role == :system } + expect(system_msgs.last.cache_point?).to be true + expect(system_msgs.first.cache_point?).to be false + end + + it 'preserves cache_point: true when replacing' do + chat.with_instructions('Old instruction', cache_point: false) + chat.with_instructions('New instruction', replace: true, cache_point: true) + system_msgs = chat.messages.select { |m| m.role == :system } + expect(system_msgs.size).to eq(1) + expect(system_msgs.first.cache_point?).to be true + end + end + + describe '#ask' do + before { allow(chat).to receive(:complete) } + + let(:action) { ->(opts = {}) { chat.ask('Hello', **opts) } } + + it_behaves_like 'a method that supports cache_point', ->(c) { c.messages.find { |m| m.role == :user } } + end + end +end diff --git a/spec/ruby_llm/message_spec.rb b/spec/ruby_llm/message_spec.rb index a787fd0e7..2672f911f 100644 --- a/spec/ruby_llm/message_spec.rb +++ b/spec/ruby_llm/message_spec.rb @@ -3,6 +3,30 @@ require 'spec_helper' RSpec.describe RubyLLM::Message do + describe '#cache_point?' do + it 'returns false by default' do + message = described_class.new(role: :user, content: 'hello') + expect(message.cache_point?).to be false + end + + it 'returns true when constructed with cache_point: true' do + message = described_class.new(role: :user, content: 'hello', cache_point: true) + expect(message.cache_point?).to be true + end + end + + describe '#to_h' do + it 'omits cache_point key when false' do + message = described_class.new(role: :user, content: 'hello') + expect(message.to_h).not_to have_key(:cache_point) + end + + it 'includes cache_point: true when set' do + message = described_class.new(role: :user, content: 'hello', cache_point: true) + expect(message.to_h[:cache_point]).to be true + end + end + describe '#content' do it 'normalizes nil content to empty string for assistant tool-call messages' do tool_call = RubyLLM::ToolCall.new(id: 'call_1', name: 'weather', arguments: {}) diff --git a/spec/ruby_llm/providers/anthropic/chat_cache_control_spec.rb b/spec/ruby_llm/providers/anthropic/chat_cache_control_spec.rb new file mode 100644 index 000000000..1fa15c8f5 --- /dev/null +++ b/spec/ruby_llm/providers/anthropic/chat_cache_control_spec.rb @@ -0,0 +1,70 @@ +# frozen_string_literal: true + +require 'spec_helper' + +RSpec.describe RubyLLM::Providers::Anthropic::Chat do + let(:model) { instance_double(RubyLLM::Model::Info, id: 'claude-sonnet-4-5', max_tokens: nil) } + + def render(messages) + described_class.render_payload( + messages, + tools: {}, + temperature: nil, + model: model, + stream: false, + schema: nil + ) + end + + describe 'cache_control injection' do + context 'with a system message where cache_point is true' do + it 'adds cache_control to the last system block' do + msg = RubyLLM::Message.new(role: :system, content: 'You are helpful.', cache_point: true) + payload = render([msg, RubyLLM::Message.new(role: :user, content: 'Hi')]) + + last_block = payload[:system].last + expect(last_block[:cache_control]).to eq(type: 'ephemeral') + end + + it 'does not add cache_control when cache_point is false' do + msg = RubyLLM::Message.new(role: :system, content: 'You are helpful.') + payload = render([msg, RubyLLM::Message.new(role: :user, content: 'Hi')]) + + payload[:system].each do |block| + expect(block).not_to have_key(:cache_control) + end + end + end + + context 'with a user message where cache_point is true' do + it 'adds cache_control to the last content block' do + msg = RubyLLM::Message.new(role: :user, content: 'Tell me a story.', cache_point: true) + payload = render([msg]) + + last_block = payload[:messages].first[:content].last + expect(last_block[:cache_control]).to eq(type: 'ephemeral') + end + + it 'does not add cache_control when cache_point is false' do + msg = RubyLLM::Message.new(role: :user, content: 'Tell me a story.') + payload = render([msg]) + + payload[:messages].first[:content].each do |block| + expect(block).not_to have_key(:cache_control) + end + end + end + + context 'when a Content::Raw block already contains cache_control' do + it 'does not duplicate when cache_control' do + raw = RubyLLM::Providers::Anthropic::Content.new('Cached system', cache: true) + msg = RubyLLM::Message.new(role: :system, content: raw, cache_point: true) + payload = render([msg, RubyLLM::Message.new(role: :user, content: 'Hi')]) + + blocks_with_cache = payload[:system].select { |b| b[:cache_control] } + expect(blocks_with_cache.length).to eq(1) + expect(blocks_with_cache.first[:cache_control]).to eq(type: 'ephemeral') + end + end + end +end diff --git a/spec/ruby_llm/providers/bedrock/chat_cache_point_spec.rb b/spec/ruby_llm/providers/bedrock/chat_cache_point_spec.rb new file mode 100644 index 000000000..7f562657f --- /dev/null +++ b/spec/ruby_llm/providers/bedrock/chat_cache_point_spec.rb @@ -0,0 +1,77 @@ +# frozen_string_literal: true + +require 'spec_helper' + +RSpec.describe RubyLLM::Providers::Bedrock::Chat do + let(:model) do + instance_double(RubyLLM::Model::Info, + id: 'anthropic.claude-haiku-4-5-20251001-v1:0', + max_tokens: nil, + metadata: {}) + end + + let(:base_args) do + { tools: {}, temperature: nil, model: model, stream: false } + end + + def render(messages) + described_class.render_payload(messages, **base_args) + end + + def msg(role, content, cache_point: false) + RubyLLM::Message.new(role: role, content: content, cache_point: cache_point) + end + + describe 'cache_point injection' do + context 'with a system message where cache_point is true' do + it 'appends a cachePoint block to the system content' do + payload = render([msg(:system, 'You are helpful.', cache_point: true), + msg(:user, 'Hi')]) + + last_block = payload[:system].last + expect(last_block).to eq(cachePoint: { type: 'default' }) + end + + it 'does not append cachePoint when cache_point is false' do + payload = render([msg(:system, 'You are helpful.'), msg(:user, 'Hi')]) + + expect(payload[:system]).not_to include(cachePoint: { type: 'default' }) + end + end + + context 'with a user message where cache_point is true' do + it 'appends a cachePoint block to the message content' do + payload = render([msg(:user, 'Tell me a story.', cache_point: true)]) + + last_block = payload[:messages].first[:content].last + expect(last_block).to eq(cachePoint: { type: 'default' }) + end + + it 'does not append cachePoint when cache_point is false' do + payload = render([msg(:user, 'Tell me a story.')]) + + content = payload[:messages].first[:content] + expect(content).not_to include(cachePoint: { type: 'default' }) + end + end + + context 'when multiple messages have cache_point: true' do + it 'appends cachePoint to each cache-pointed message' do + payload = render([ + msg(:system, 'System prompt', cache_point: true), + msg(:user, 'User context', cache_point: true), + msg(:user, 'Dynamic question') + ]) + + system_has_cache = payload[:system].last == { cachePoint: { type: 'default' } } + user_messages = payload[:messages] + first_user_has_cache = user_messages.first[:content].last == { cachePoint: { type: 'default' } } + last_user_no_cache = user_messages.last[:content].last != { cachePoint: { type: 'default' } } + + expect(system_has_cache).to be true + expect(first_user_has_cache).to be true + expect(last_user_no_cache).to be true + end + end + end +end