Skip to content

Commit 50f99e2

Browse files
xumapleclaude
andcommitted
Add LiteLlm + TemporalModel integration test and sandbox fix
Add litellm and httpx to the GoogleAdkPlugin sandbox passthrough modules. Without this, any LiteLlm-backed model crashes inside the workflow sandbox because litellm transitively imports httpx which fails sandbox restrictions. Add an integration test proving LiteLlm works with TemporalModel through the full Temporal workflow path, using a fake litellm custom provider that requires no API key. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4f3e320 commit 50f99e2

2 files changed

Lines changed: 124 additions & 2 deletions

File tree

temporalio/contrib/google_adk_agents/_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class GoogleAdkPlugin(SimplePlugin):
6262
6363
This plugin configures:
6464
- Pydantic Payload Converter (required for ADK objects).
65-
- Sandbox Passthrough for google.adk and google.genai modules.
65+
- Sandbox Passthrough for ADK and its dependencies.
6666
"""
6767

6868
def __init__(
@@ -89,7 +89,7 @@ def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner:
8989
return dataclasses.replace(
9090
runner,
9191
restrictions=runner.restrictions.with_passthrough_modules(
92-
"google.adk", "google.genai", "mcp"
92+
"google.adk", "google.genai", "mcp", "litellm", "httpx"
9393
),
9494
)
9595
return runner

tests/contrib/google_adk_agents/test_google_adk_agents.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,15 @@ async def test_mcp_toolset_outside_workflow_no_not_in_workflow_toolset():
664664
await toolset.get_tools()
665665

666666

667+
try:
668+
import litellm as _litellm_check
669+
670+
del _litellm_check
671+
_has_litellm = True
672+
except ImportError:
673+
_has_litellm = False
674+
675+
667676
complex_activity_inputs_seen: dict[str, object] = {}
668677

669678

@@ -855,3 +864,116 @@ async def test_activity_tool_supports_complex_inputs_via_adk(client: Client):
855864
),
856865
"annotate_trip": "SFO->LAX:3",
857866
}
867+
868+
869+
def litellm_agent(model_name: str) -> Agent:
870+
return Agent(
871+
name="litellm_test_agent",
872+
model=TemporalModel(model_name),
873+
)
874+
875+
876+
@workflow.defn
877+
class LiteLlmAgent:
878+
@workflow.run
879+
async def run(self, prompt: str, model_name: str) -> Event | None:
880+
agent = litellm_agent(model_name)
881+
882+
runner = InMemoryRunner(
883+
agent=agent,
884+
app_name="litellm_test_app",
885+
)
886+
887+
session = await runner.session_service.create_session(
888+
app_name="litellm_test_app", user_id="test"
889+
)
890+
891+
last_event = None
892+
async with Aclosing(
893+
runner.run_async(
894+
user_id="test",
895+
session_id=session.id,
896+
new_message=types.Content(role="user", parts=[types.Part(text=prompt)]),
897+
)
898+
) as agen:
899+
async for event in agen:
900+
last_event = event
901+
902+
return last_event
903+
904+
905+
@pytest.mark.asyncio
906+
@pytest.mark.skipif(not _has_litellm, reason="litellm not installed")
907+
async def test_litellm_model(client: Client):
908+
"""Test that ADK's LiteLlm class works with TemporalModel through a full Temporal workflow."""
909+
# Import inside the test so the module loads cleanly when litellm is not installed.
910+
import litellm as litellm_module
911+
from google.adk.models.lite_llm import LiteLlm
912+
from google.adk.models.registry import _llm_registry_dict
913+
from litellm import ModelResponse
914+
from litellm.llms.custom_llm import CustomLLM
915+
916+
class FakeLiteLlmProvider(CustomLLM):
917+
"""A fake litellm provider that returns canned responses locally."""
918+
919+
def _make_response(self, model: str) -> ModelResponse:
920+
return ModelResponse(
921+
choices=[
922+
{
923+
"message": {
924+
"content": "hello from litellm",
925+
"role": "assistant",
926+
},
927+
"index": 0,
928+
"finish_reason": "stop",
929+
}
930+
],
931+
model=model,
932+
)
933+
934+
def completion(self, *args: Any, **kwargs: Any) -> ModelResponse:
935+
model = args[0] if args else kwargs.get("model", "unknown")
936+
return self._make_response(model)
937+
938+
async def acompletion(self, *args: Any, **kwargs: Any) -> ModelResponse:
939+
model = args[0] if args else kwargs.get("model", "unknown")
940+
return self._make_response(model)
941+
942+
# Register our fake provider with litellm
943+
litellm_module.custom_provider_map = [
944+
{"provider": "fake", "custom_handler": FakeLiteLlmProvider()}
945+
]
946+
947+
try:
948+
# Register LiteLlm to handle "fake/.*" model names via ADK's LLMRegistry
949+
_llm_registry_dict[r"fake/.*"] = LiteLlm
950+
# Clear the resolve LRU cache so the new pattern is picked up
951+
LLMRegistry.resolve.cache_clear()
952+
new_config = client.config()
953+
new_config["plugins"] = [GoogleAdkPlugin()]
954+
client = Client(**new_config)
955+
956+
async with Worker(
957+
client,
958+
task_queue="adk-task-queue-litellm",
959+
workflows=[LiteLlmAgent],
960+
max_cached_workflows=0,
961+
):
962+
handle = await client.start_workflow(
963+
LiteLlmAgent.run,
964+
args=["Say hello", "fake/test-model"],
965+
id=f"litellm-agent-workflow-{uuid.uuid4()}",
966+
task_queue="adk-task-queue-litellm",
967+
execution_timeout=timedelta(seconds=60),
968+
)
969+
result = await handle.result()
970+
971+
assert result is not None
972+
assert result.content is not None
973+
assert result.content.parts is not None
974+
assert result.content.parts[0].text == "hello from litellm"
975+
finally:
976+
# Clean up registry state
977+
_llm_registry_dict.pop(r"fake/.*", None)
978+
LLMRegistry.resolve.cache_clear()
979+
litellm_module.custom_provider_map = []

0 commit comments

Comments
 (0)