Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion scripts/generate_tiny_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Gemma2Config,
Gemma2ForCausalLM,
Gemma3ForConditionalGeneration,
Gemma4ForConditionalGeneration,
GemmaConfig,
GemmaForCausalLM,
GenerationConfig,
Expand Down Expand Up @@ -316,6 +317,7 @@ def init_weights_tiny_model(model):
# Vision Language Models
for model_id, model_class, dtype in [
("google/gemma-3-4b-it", Gemma3ForConditionalGeneration, torch.bfloat16),
("google/gemma-4-E2B-it", Gemma4ForConditionalGeneration, torch.bfloat16),
("google/paligemma-3b-pt-224", PaliGemmaForConditionalGeneration, torch.float32),
("HuggingFaceM4/idefics2-8b", Idefics2ForConditionalGeneration, torch.float32),
("HuggingFaceM4/Idefics3-8B-Llama3", Idefics3ForConditionalGeneration, torch.bfloat16),
Expand Down Expand Up @@ -394,7 +396,20 @@ def init_weights_tiny_model(model):
# See https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf/discussions/46
text_config["dtype"] = None

config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs)
if model_class is Gemma4ForConditionalGeneration:
# Gemma4 rope validation fails when passing text_config as a dict, so we mutate the config directly.
config = AutoConfig.from_pretrained(model_id)
for k, v in text_config.items():
setattr(config.text_config, k, v)
for k, v in vision_config.items():
setattr(config.vision_config, k, v)
config.text_config.layer_types = ["sliding_attention", "full_attention"]
config.text_config.num_kv_shared_layers = 0
config.text_config.global_head_dim = 8
config.text_config.hidden_size_per_layer_input = 16
config.audio_config = None
else:
config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs)
model = model_class(config).to(dtype=dtype)

if issubclass(model_class.config_class, Qwen3_5Config):
Expand Down
7 changes: 7 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,13 @@ def test_tag_added_peft(self):
"model_id",
[
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
# "trl-internal-testing/tiny-Idefics2ForConditionalGeneration", high memory peak, skipped for now
# "trl-internal-testing/tiny-Idefics3ForConditionalGeneration", high memory peak, skipped for now
"trl-internal-testing/tiny-LlavaForConditionalGeneration",
Expand Down
14 changes: 14 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1923,6 +1923,13 @@ def test_prepare_input_called_with_correct_data(self):
"model_id",
[
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
"trl-internal-testing/tiny-LlavaNextForConditionalGeneration",
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
"trl-internal-testing/tiny-Qwen2VLForConditionalGeneration",
Expand Down Expand Up @@ -2215,6 +2222,13 @@ def reward_func(completions, **kwargs):
[
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
],
)
@require_vision
Expand Down
14 changes: 14 additions & 0 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,13 @@ def test_prepare_input_called_with_correct_data(self):
"model_id",
[
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
"trl-internal-testing/tiny-LlavaNextForConditionalGeneration",
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
"trl-internal-testing/tiny-Qwen2VLForConditionalGeneration",
Expand Down Expand Up @@ -1506,6 +1513,13 @@ def reward_func(completions, **kwargs):
[
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
],
)
@require_vision
Expand Down
14 changes: 14 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,6 +1630,13 @@ def test_tag_added_peft(self):
"model_id",
[
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
# "trl-internal-testing/tiny-Idefics2ForConditionalGeneration", high memory peak, skipped for now
# "trl-internal-testing/tiny-Idefics3ForConditionalGeneration", high memory peak, skipped for now
"trl-internal-testing/tiny-LlavaForConditionalGeneration",
Expand Down Expand Up @@ -1746,6 +1753,13 @@ def test_train_vlm_multi_image(self, model_id):
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
# Special case for Gemma, as it uses token_type_ids, and we need to ensure they are properly in the collator:
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
],
)
@require_vision
Expand Down
7 changes: 7 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,13 @@ def test_llm(self, model_id):
"model_id",
[
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
"trl-internal-testing/tiny-Idefics2ForConditionalGeneration",
"trl-internal-testing/tiny-Idefics3ForConditionalGeneration",
"trl-internal-testing/tiny-LlavaForConditionalGeneration",
Expand Down
Loading