@@ -664,6 +664,15 @@ async def test_mcp_toolset_outside_workflow_no_not_in_workflow_toolset():
664664 await toolset .get_tools ()
665665
666666
667+ try :
668+ import litellm as _litellm_check
669+
670+ del _litellm_check
671+ _has_litellm = True
672+ except ImportError :
673+ _has_litellm = False
674+
675+
667676complex_activity_inputs_seen : dict [str , object ] = {}
668677
669678
@@ -855,3 +864,116 @@ async def test_activity_tool_supports_complex_inputs_via_adk(client: Client):
855864 ),
856865 "annotate_trip" : "SFO->LAX:3" ,
857866 }
867+
868+
869+ def litellm_agent (model_name : str ) -> Agent :
870+ return Agent (
871+ name = "litellm_test_agent" ,
872+ model = TemporalModel (model_name ),
873+ )
874+
875+
876+ @workflow .defn
877+ class LiteLlmAgent :
878+ @workflow .run
879+ async def run (self , prompt : str , model_name : str ) -> Event | None :
880+ agent = litellm_agent (model_name )
881+
882+ runner = InMemoryRunner (
883+ agent = agent ,
884+ app_name = "litellm_test_app" ,
885+ )
886+
887+ session = await runner .session_service .create_session (
888+ app_name = "litellm_test_app" , user_id = "test"
889+ )
890+
891+ last_event = None
892+ async with Aclosing (
893+ runner .run_async (
894+ user_id = "test" ,
895+ session_id = session .id ,
896+ new_message = types .Content (role = "user" , parts = [types .Part (text = prompt )]),
897+ )
898+ ) as agen :
899+ async for event in agen :
900+ last_event = event
901+
902+ return last_event
903+
904+
905+ @pytest .mark .asyncio
906+ @pytest .mark .skipif (not _has_litellm , reason = "litellm not installed" )
907+ async def test_litellm_model (client : Client ):
908+ """Test that ADK's LiteLlm class works with TemporalModel through a full Temporal workflow."""
909+ # Import inside the test so the module loads cleanly when litellm is not installed.
910+ import litellm as litellm_module
911+ from google .adk .models .lite_llm import LiteLlm
912+ from google .adk .models .registry import _llm_registry_dict
913+ from litellm import ModelResponse
914+ from litellm .llms .custom_llm import CustomLLM
915+
916+ class FakeLiteLlmProvider (CustomLLM ):
917+ """A fake litellm provider that returns canned responses locally."""
918+
919+ def _make_response (self , model : str ) -> ModelResponse :
920+ return ModelResponse (
921+ choices = [
922+ {
923+ "message" : {
924+ "content" : "hello from litellm" ,
925+ "role" : "assistant" ,
926+ },
927+ "index" : 0 ,
928+ "finish_reason" : "stop" ,
929+ }
930+ ],
931+ model = model ,
932+ )
933+
934+ def completion (self , * args : Any , ** kwargs : Any ) -> ModelResponse :
935+ model = args [0 ] if args else kwargs .get ("model" , "unknown" )
936+ return self ._make_response (model )
937+
938+ async def acompletion (self , * args : Any , ** kwargs : Any ) -> ModelResponse :
939+ model = args [0 ] if args else kwargs .get ("model" , "unknown" )
940+ return self ._make_response (model )
941+
942+ # Register our fake provider with litellm
943+ litellm_module .custom_provider_map = [
944+ {"provider" : "fake" , "custom_handler" : FakeLiteLlmProvider ()}
945+ ]
946+
947+ try :
948+ # Register LiteLlm to handle "fake/.*" model names via ADK's LLMRegistry
949+ _llm_registry_dict [r"fake/.*" ] = LiteLlm
950+ # Clear the resolve LRU cache so the new pattern is picked up
951+ LLMRegistry .resolve .cache_clear ()
952+ new_config = client .config ()
953+ new_config ["plugins" ] = [GoogleAdkPlugin ()]
954+ client = Client (** new_config )
955+
956+ async with Worker (
957+ client ,
958+ task_queue = "adk-task-queue-litellm" ,
959+ workflows = [LiteLlmAgent ],
960+ max_cached_workflows = 0 ,
961+ ):
962+ handle = await client .start_workflow (
963+ LiteLlmAgent .run ,
964+ args = ["Say hello" , "fake/test-model" ],
965+ id = f"litellm-agent-workflow-{ uuid .uuid4 ()} " ,
966+ task_queue = "adk-task-queue-litellm" ,
967+ execution_timeout = timedelta (seconds = 60 ),
968+ )
969+ result = await handle .result ()
970+
971+ assert result is not None
972+ assert result .content is not None
973+ assert result .content .parts is not None
974+ assert result .content .parts [0 ].text == "hello from litellm"
975+ finally :
976+ # Clean up registry state
977+ _llm_registry_dict .pop (r"fake/.*" , None )
978+ LLMRegistry .resolve .cache_clear ()
979+ litellm_module .custom_provider_map = []
0 commit comments