@@ -855,3 +855,112 @@ async def test_activity_tool_supports_complex_inputs_via_adk(client: Client):
855855 ),
856856 "annotate_trip" : "SFO->LAX:3" ,
857857 }
858+
859+
860+ def litellm_agent (model_name : str ) -> Agent :
861+ return Agent (
862+ name = "litellm_test_agent" ,
863+ model = TemporalModel (model_name ),
864+ )
865+
866+
867+ @workflow .defn
868+ class LiteLlmWorkflow :
869+ @workflow .run
870+ async def run (self , prompt : str , model_name : str ) -> Event | None :
871+ agent = litellm_agent (model_name )
872+
873+ runner = InMemoryRunner (
874+ agent = agent ,
875+ app_name = "litellm_test_app" ,
876+ )
877+
878+ session = await runner .session_service .create_session (
879+ app_name = "litellm_test_app" , user_id = "test"
880+ )
881+
882+ last_event = None
883+ async with Aclosing (
884+ runner .run_async (
885+ user_id = "test" ,
886+ session_id = session .id ,
887+ new_message = types .Content (role = "user" , parts = [types .Part (text = prompt )]),
888+ )
889+ ) as agen :
890+ async for event in agen :
891+ last_event = event
892+
893+ return last_event
894+
895+
896+ @pytest .mark .asyncio
897+ async def test_litellm_model (client : Client ):
898+ """Test that a litellm-backed model works with TemporalModel through a full Temporal workflow."""
899+ import litellm as litellm_module
900+ from google .adk .models .lite_llm import LiteLlm
901+ from litellm import ModelResponse
902+ from litellm .llms .custom_llm import CustomLLM
903+
904+ class FakeLiteLlmProvider (CustomLLM ):
905+ """A fake litellm provider that returns canned responses locally."""
906+
907+ def _make_response (self , model : str ) -> ModelResponse :
908+ return ModelResponse (
909+ choices = [
910+ {
911+ "message" : {
912+ "content" : "hello from litellm" ,
913+ "role" : "assistant" ,
914+ },
915+ "index" : 0 ,
916+ "finish_reason" : "stop" ,
917+ }
918+ ],
919+ model = model ,
920+ )
921+
922+ def completion (self , * args : Any , ** kwargs : Any ) -> ModelResponse :
923+ model = args [0 ] if args else kwargs .get ("model" , "unknown" )
924+ return self ._make_response (model )
925+
926+ async def acompletion (self , * args : Any , ** kwargs : Any ) -> ModelResponse :
927+ model = args [0 ] if args else kwargs .get ("model" , "unknown" )
928+ return self ._make_response (model )
929+
930+ class FakeLiteLlm (LiteLlm ):
931+ """LiteLlm subclass that handles fake/.* model names for testing."""
932+
933+ @classmethod
934+ def supported_models (cls ) -> list [str ]:
935+ return [r"fake/.*" ]
936+
937+ # Register our fake provider with litellm
938+ litellm_module .custom_provider_map = [
939+ {"provider" : "fake" , "custom_handler" : FakeLiteLlmProvider ()}
940+ ]
941+
942+ LLMRegistry .register (FakeLiteLlm )
943+
944+ new_config = client .config ()
945+ new_config ["plugins" ] = [GoogleAdkPlugin ()]
946+ client = Client (** new_config )
947+
948+ async with Worker (
949+ client ,
950+ task_queue = "adk-task-queue-litellm" ,
951+ workflows = [LiteLlmWorkflow ],
952+ max_cached_workflows = 0 ,
953+ ):
954+ handle = await client .start_workflow (
955+ LiteLlmWorkflow .run ,
956+ args = ["Say hello" , "fake/test-model" ],
957+ id = f"litellm-agent-workflow-{ uuid .uuid4 ()} " ,
958+ task_queue = "adk-task-queue-litellm" ,
959+ execution_timeout = timedelta (seconds = 60 ),
960+ )
961+ result = await handle .result ()
962+
963+ assert result is not None
964+ assert result .content is not None
965+ assert result .content .parts is not None
966+ assert result .content .parts [0 ].text == "hello from litellm"
0 commit comments