diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index 51e0444e6a..bf8c39e13a 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -36,6 +36,7 @@ Gemma2Config, Gemma2ForCausalLM, Gemma3ForConditionalGeneration, + Gemma4ForConditionalGeneration, GemmaConfig, GemmaForCausalLM, GenerationConfig, @@ -316,6 +317,7 @@ def init_weights_tiny_model(model): # Vision Language Models for model_id, model_class, dtype in [ ("google/gemma-3-4b-it", Gemma3ForConditionalGeneration, torch.bfloat16), + ("google/gemma-4-E2B-it", Gemma4ForConditionalGeneration, torch.bfloat16), ("google/paligemma-3b-pt-224", PaliGemmaForConditionalGeneration, torch.float32), ("HuggingFaceM4/idefics2-8b", Idefics2ForConditionalGeneration, torch.float32), ("HuggingFaceM4/Idefics3-8B-Llama3", Idefics3ForConditionalGeneration, torch.bfloat16), @@ -394,7 +396,20 @@ def init_weights_tiny_model(model): # See https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf/discussions/46 text_config["dtype"] = None - config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs) + if model_class is Gemma4ForConditionalGeneration: + # Gemma4 rope validation fails when passing text_config as a dict, so we mutate the config directly. + config = AutoConfig.from_pretrained(model_id) + for k, v in text_config.items(): + setattr(config.text_config, k, v) + for k, v in vision_config.items(): + setattr(config.vision_config, k, v) + config.text_config.layer_types = ["sliding_attention", "full_attention"] + config.text_config.num_kv_shared_layers = 0 + config.text_config.global_head_dim = 8 + config.text_config.hidden_size_per_layer_input = 16 + config.audio_config = None + else: + config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs) model = model_class(config).to(dtype=dtype) if issubclass(model_class.config_class, Qwen3_5Config): diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 192afeccb2..f2451d9b0a 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1021,6 +1021,13 @@ def test_tag_added_peft(self): "model_id", [ "trl-internal-testing/tiny-Gemma3ForConditionalGeneration", + pytest.param( + "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.5.0"), + reason="Gemma4 models were introduced in transformers-5.5.0", + ), + ), # "trl-internal-testing/tiny-Idefics2ForConditionalGeneration", high memory peak, skipped for now # "trl-internal-testing/tiny-Idefics3ForConditionalGeneration", high memory peak, skipped for now "trl-internal-testing/tiny-LlavaForConditionalGeneration", diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index e193e3bb58..199148a13c 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1923,6 +1923,13 @@ def test_prepare_input_called_with_correct_data(self): "model_id", [ "trl-internal-testing/tiny-Gemma3ForConditionalGeneration", + pytest.param( + "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.5.0"), + reason="Gemma4 models were introduced in transformers-5.5.0", + ), + ), "trl-internal-testing/tiny-LlavaNextForConditionalGeneration", "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", "trl-internal-testing/tiny-Qwen2VLForConditionalGeneration", @@ -2215,6 +2222,13 @@ def reward_func(completions, **kwargs): [ "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", "trl-internal-testing/tiny-Gemma3ForConditionalGeneration", + pytest.param( + "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.5.0"), + reason="Gemma4 models were introduced in transformers-5.5.0", + ), + ), ], ) @require_vision diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 391bf11c8a..39f766ed5a 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1314,6 +1314,13 @@ def test_prepare_input_called_with_correct_data(self): "model_id", [ "trl-internal-testing/tiny-Gemma3ForConditionalGeneration", + pytest.param( + "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.5.0"), + reason="Gemma4 models were introduced in transformers-5.5.0", + ), + ), "trl-internal-testing/tiny-LlavaNextForConditionalGeneration", "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", "trl-internal-testing/tiny-Qwen2VLForConditionalGeneration", @@ -1506,6 +1513,13 @@ def reward_func(completions, **kwargs): [ "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", "trl-internal-testing/tiny-Gemma3ForConditionalGeneration", + pytest.param( + "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.5.0"), + reason="Gemma4 models were introduced in transformers-5.5.0", + ), + ), ], ) @require_vision diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 5363262d4e..c67382bc21 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -1630,6 +1630,13 @@ def test_tag_added_peft(self): "model_id", [ "trl-internal-testing/tiny-Gemma3ForConditionalGeneration", + pytest.param( + "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.5.0"), + reason="Gemma4 models were introduced in transformers-5.5.0", + ), + ), # "trl-internal-testing/tiny-Idefics2ForConditionalGeneration", high memory peak, skipped for now # "trl-internal-testing/tiny-Idefics3ForConditionalGeneration", high memory peak, skipped for now "trl-internal-testing/tiny-LlavaForConditionalGeneration", @@ -1746,6 +1753,13 @@ def test_train_vlm_multi_image(self, model_id): "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", # Special case for Gemma, as it uses token_type_ids, and we need to ensure they are properly in the collator: "trl-internal-testing/tiny-Gemma3ForConditionalGeneration", + pytest.param( + "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.5.0"), + reason="Gemma4 models were introduced in transformers-5.5.0", + ), + ), ], ) @require_vision diff --git a/tests/test_utils.py b/tests/test_utils.py index 297729bc9b..63540f33da 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1052,6 +1052,13 @@ def test_llm(self, model_id): "model_id", [ "trl-internal-testing/tiny-Gemma3ForConditionalGeneration", + pytest.param( + "trl-internal-testing/tiny-Gemma4ForConditionalGeneration", + marks=pytest.mark.skipif( + Version(transformers.__version__) < Version("5.5.0"), + reason="Gemma4 models were introduced in transformers-5.5.0", + ), + ), "trl-internal-testing/tiny-Idefics2ForConditionalGeneration", "trl-internal-testing/tiny-Idefics3ForConditionalGeneration", "trl-internal-testing/tiny-LlavaForConditionalGeneration",