diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py b/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py index c1a83c045c..70abfc8961 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_bedrock.py @@ -318,13 +318,22 @@ async def generate_content_async( if self.additional_model_request_fields: kwargs["additionalModelRequestFields"] = self.additional_model_request_fields - def _run_converse_stream(**kw): - resp = client.converse_stream(**kw) - return list(resp.get("stream", [])) - try: if stream: - stream_body = await asyncio.to_thread(_run_converse_stream, **kwargs) + q: asyncio.Queue = asyncio.Queue() + loop = asyncio.get_running_loop() + + def _produce(): + try: + resp = client.converse_stream(**kwargs) + for event in resp.get("stream", []): + loop.call_soon_threadsafe(q.put_nowait, event) + except Exception as exc: + loop.call_soon_threadsafe(q.put_nowait, exc) + finally: + loop.call_soon_threadsafe(q.put_nowait, None) + + loop.run_in_executor(None, _produce) aggregated_text = "" tool_uses: dict[str, dict] = {} # toolUseId -> {name, input_json} @@ -332,7 +341,9 @@ def _run_converse_stream(**kw): stop_reason = "end_turn" usage_metadata: Optional[types.GenerateContentResponseUsageMetadata] = None - for event in stream_body: + while (event := await q.get()) is not None: + if isinstance(event, Exception): + raise event if "contentBlockStart" in event: start = event["contentBlockStart"].get("start", {}) if "toolUse" in start: