Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 49 additions & 11 deletions src/google/adk/models/interactions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,41 @@
_NEW_LINE = '\n'


def _extract_event_id_from_interaction_event(
event: 'InteractionSSEEvent',
) -> Optional[str]:
"""Extract the SDK event identifier from an interactions SSE event."""
event_id = getattr(event, 'event_id', None)
if isinstance(event_id, str):
return event_id

legacy_event_id = getattr(event, 'id', None)
if isinstance(legacy_event_id, str):
return legacy_event_id

return None


def _extract_interaction_id_from_event(
event: 'InteractionSSEEvent',
) -> Optional[str]:
"""Extract the interaction chain identifier from an SSE event."""
interaction = getattr(event, 'interaction', None)
interaction_id = getattr(interaction, 'id', None)
if isinstance(interaction_id, str):
return interaction_id

event_interaction_id = getattr(event, 'interaction_id', None)
if isinstance(event_interaction_id, str):
return event_interaction_id

legacy_interaction_id = getattr(event, 'id', None)
if isinstance(legacy_interaction_id, str):
return legacy_interaction_id

return None


def convert_part_to_interaction_content(part: types.Part) -> Optional[dict]:
"""Convert a types.Part to an interaction content dict.

Expand Down Expand Up @@ -154,12 +189,12 @@ def convert_part_to_interaction_content(part: types.Part) -> Optional[dict]:
elif part.thought:
# part.thought is a boolean indicating this is a thought part
# ThoughtContentParam expects 'signature' (base64 encoded bytes)
result: dict[str, Any] = {'type': 'thought'}
thought_content: dict[str, Any] = {'type': 'thought'}
if part.thought_signature is not None:
result['signature'] = base64.b64encode(part.thought_signature).decode(
'utf-8'
)
return result
thought_content['signature'] = base64.b64encode(
part.thought_signature
).decode('utf-8')
return thought_content
elif part.code_execution_result is not None:
is_error = part.code_execution_result.outcome in (
types.Outcome.OUTCOME_FAILED,
Expand Down Expand Up @@ -487,6 +522,7 @@ def convert_interaction_event_to_llm_response(
from .llm_response import LlmResponse

event_type = getattr(event, 'event_type', None)
interaction_id = interaction_id or _extract_interaction_id_from_event(event)

if event_type == 'content.delta':
delta = event.delta
Expand Down Expand Up @@ -565,9 +601,10 @@ def convert_interaction_event_to_llm_response(
interaction_id=interaction_id,
)

elif event_type == 'interaction':
# Final interaction event with complete data
return convert_interaction_to_llm_response(event)
elif event_type in ('interaction.complete', 'interaction'):
# Final interaction event with complete data.
interaction = getattr(event, 'interaction', event)
return convert_interaction_to_llm_response(interaction)

elif event_type == 'interaction.status_update':
status = getattr(event, 'status', None)
Expand Down Expand Up @@ -834,7 +871,7 @@ def build_interactions_event_log(event: InteractionSSEEvent) -> str:
A formatted log string describing the event.
"""
event_type = getattr(event, 'event_type', 'unknown')
event_id = getattr(event, 'id', None)
event_id = _extract_event_id_from_interaction_event(event)

details = []

Expand Down Expand Up @@ -1014,8 +1051,9 @@ async def generate_content_via_interactions(
logger.debug(build_interactions_event_log(event))

# Extract interaction ID from event if available
if hasattr(event, 'id') and event.id:
current_interaction_id = event.id
current_interaction_id = (
_extract_interaction_id_from_event(event) or current_interaction_id
)
llm_response = convert_interaction_event_to_llm_response(
event, aggregated_parts, current_interaction_id
)
Expand Down
122 changes: 122 additions & 0 deletions tests/unittests/models/test_interactions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,36 @@

import base64
import json
from unittest.mock import AsyncMock
from unittest.mock import MagicMock

from google.adk.models import interactions_utils
from google.adk.models.llm_request import LlmRequest
from google.genai import types
from google.genai._interactions.types import ContentDelta
from google.genai._interactions.types import ContentStop
from google.genai._interactions.types import Interaction
from google.genai._interactions.types import InteractionCompleteEvent
from google.genai._interactions.types import InteractionStartEvent
from google.genai._interactions.types import InteractionStatusUpdate
from google.genai._interactions.types.content_delta import DeltaFunctionCall
import pytest


class _MockAsyncIterator:
"""Simple async iterator for streaming test events."""

def __init__(self, values):
self._iterator = iter(values)

def __aiter__(self):
return self

async def __anext__(self):
try:
return next(self._iterator)
except StopIteration as exc:
raise StopAsyncIteration from exc


class TestConvertPartToInteractionContent:
Expand Down Expand Up @@ -955,3 +980,100 @@ def test_unknown_event_type_returns_none(self):

assert result is None
assert not aggregated_parts

def test_interaction_complete_event(self):
"""Test converting an interaction.complete event."""
interaction = Interaction(
id='int_complete',
created='2026-04-07T00:00:00Z',
updated='2026-04-07T00:00:01Z',
status='completed',
outputs=[{'type': 'text', 'text': 'Done'}],
)
event = InteractionCompleteEvent(
event_type='interaction.complete',
interaction=interaction,
)

result = interactions_utils.convert_interaction_event_to_llm_response(
event, aggregated_parts=[]
)

assert result is not None
assert result.interaction_id == 'int_complete'
assert result.content.parts[0].text == 'Done'
assert result.turn_complete is True


class TestGenerateContentViaInteractions:
"""Tests for generate_content_via_interactions."""

@pytest.mark.asyncio
async def test_stream_uses_interaction_start_id_for_function_calls(self):
"""Test that streaming function calls retain the interaction chain ID."""
interaction = Interaction(
id='int_stream_123',
created='2026-04-07T00:00:00Z',
updated='2026-04-07T00:00:01Z',
status='requires_action',
)
stream_events = [
InteractionStartEvent(
event_type='interaction.start',
interaction=interaction,
),
ContentDelta(
event_type='content.delta',
index=0,
delta=DeltaFunctionCall(
type='function_call',
id='fc_123',
name='get_weather',
arguments={'city': 'Tokyo'},
),
),
ContentStop(event_type='content.stop', index=0),
InteractionStatusUpdate(
event_type='interaction.status_update',
interaction_id='int_stream_123',
status='requires_action',
),
]
api_client = MagicMock()
api_client.aio.interactions.create = AsyncMock(
return_value=_MockAsyncIterator(stream_events)
)
llm_request = LlmRequest(
model='gemini-2.5-flash',
contents=[
types.Content(
role='user',
parts=[types.Part.from_text(text='Weather in Tokyo?')],
)
],
config=types.GenerateContentConfig(),
)

responses = [
response
async for response in (
interactions_utils.generate_content_via_interactions(
api_client=api_client,
llm_request=llm_request,
stream=True,
)
)
]

function_call_response = next(
response
for response in responses
if response.content
and response.content.parts
and response.content.parts[0].function_call
)

assert function_call_response.interaction_id == 'int_stream_123'
assert function_call_response.content.parts[0].function_call.name == (
'get_weather'
)