diff --git a/langextract/providers/gemini.py b/langextract/providers/gemini.py index a82afe1e..08b804e4 100644 --- a/langextract/providers/gemini.py +++ b/langextract/providers/gemini.py @@ -45,6 +45,7 @@ 'tools', 'stop_sequences', 'candidate_count', + 'thinking_config', } @@ -125,8 +126,9 @@ def __init__( Gemini handles this based on schema). **kwargs: Additional Gemini API parameters. Only allowlisted keys are forwarded to the API (response_schema, response_mime_type, tools, - safety_settings, stop_sequences, candidate_count, system_instruction). - See https://ai.google.dev/api/generate-content for details. + safety_settings, stop_sequences, candidate_count, + system_instruction, thinking_config). See + https://ai.google.dev/api/generate-content for details. """ try: # pylint: disable=import-outside-toplevel diff --git a/tests/inference_test.py b/tests/inference_test.py index 92433b04..045b6f7d 100644 --- a/tests/inference_test.py +++ b/tests/inference_test.py @@ -363,6 +363,7 @@ def test_gemini_allowlist_filtering(self, mock_client_class): tools=["tool1", "tool2"], stop_sequences=["\n\n"], system_instruction="Be helpful", + thinking_config={"thinking_level": "minimal"}, # Unknown parameters to test filtering unknown_param="should_be_ignored", another_unknown="also_ignored", @@ -372,6 +373,7 @@ def test_gemini_allowlist_filtering(self, mock_client_class): "tools": ["tool1", "tool2"], "stop_sequences": ["\n\n"], "system_instruction": "Be helpful", + "thinking_config": {"thinking_level": "minimal"}, } self.assertEqual( expected_extra_kwargs, @@ -386,7 +388,12 @@ def test_gemini_allowlist_filtering(self, mock_client_class): call_args = mock_client.models.generate_content.call_args config = call_args.kwargs["config"] - for key in ["tools", "stop_sequences", "system_instruction"]: + for key in [ + "tools", + "stop_sequences", + "system_instruction", + "thinking_config", + ]: self.assertIn(key, config, f"Expected {key} to be in API config") self.assertEqual( expected_extra_kwargs[key], @@ -415,6 +422,7 @@ def test_gemini_runtime_kwargs_filtered(self, mock_client_class): prompts, candidate_count=2, safety_settings={"HARM_CATEGORY_DANGEROUS": "BLOCK_NONE"}, + thinking_config={"thinking_level": "minimal"}, unknown_runtime_param="ignored", ) ) @@ -432,6 +440,11 @@ def test_gemini_runtime_kwargs_filtered(self, mock_client_class): config.get("safety_settings"), "safety_settings should be passed through to API", ) + self.assertEqual( + {"thinking_level": "minimal"}, + config.get("thinking_config"), + "thinking_config should be passed through to API", + ) self.assertNotIn( "unknown_runtime_param", config, "Unknown kwargs should be filtered out" )