Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion scripts/generate_tiny_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
Gemma2Config,
Gemma2ForCausalLM,
Gemma3ForConditionalGeneration,
Gemma4ForConditionalGeneration,
GemmaConfig,
GemmaForCausalLM,
GenerationConfig,
Expand Down Expand Up @@ -316,6 +317,7 @@ def init_weights_tiny_model(model):
# Vision Language Models
for model_id, model_class, dtype in [
("google/gemma-3-4b-it", Gemma3ForConditionalGeneration, torch.bfloat16),
("google/gemma-4-E2B-it", Gemma4ForConditionalGeneration, torch.bfloat16),
("google/paligemma-3b-pt-224", PaliGemmaForConditionalGeneration, torch.float32),
("HuggingFaceM4/idefics2-8b", Idefics2ForConditionalGeneration, torch.float32),
("HuggingFaceM4/Idefics3-8B-Llama3", Idefics3ForConditionalGeneration, torch.bfloat16),
Expand Down Expand Up @@ -394,7 +396,20 @@ def init_weights_tiny_model(model):
# See https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf/discussions/46
text_config["dtype"] = None

config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs)
if model_class is Gemma4ForConditionalGeneration:
# Gemma4 rope validation fails when passing text_config as a dict, so we mutate the config directly.
config = AutoConfig.from_pretrained(model_id)
for k, v in text_config.items():
setattr(config.text_config, k, v)
for k, v in vision_config.items():
setattr(config.vision_config, k, v)
config.text_config.layer_types = ["sliding_attention", "full_attention"]
config.text_config.num_kv_shared_layers = 0
config.text_config.global_head_dim = 8
config.text_config.hidden_size_per_layer_input = 16
config.audio_config = None
else:
config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs)
model = model_class(config).to(dtype=dtype)

if issubclass(model_class.config_class, Qwen3_5Config):
Expand Down
7 changes: 7 additions & 0 deletions tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,13 @@ def test_tag_added_peft(self):
"model_id",
[
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
# "trl-internal-testing/tiny-Idefics2ForConditionalGeneration", high memory peak, skipped for now
# "trl-internal-testing/tiny-Idefics3ForConditionalGeneration", high memory peak, skipped for now
"trl-internal-testing/tiny-LlavaForConditionalGeneration",
Expand Down
14 changes: 14 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1923,6 +1923,13 @@ def test_prepare_input_called_with_correct_data(self):
"model_id",
[
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
"trl-internal-testing/tiny-LlavaNextForConditionalGeneration",
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
"trl-internal-testing/tiny-Qwen2VLForConditionalGeneration",
Expand Down Expand Up @@ -2215,6 +2222,13 @@ def reward_func(completions, **kwargs):
[
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
],
)
@require_vision
Expand Down
14 changes: 14 additions & 0 deletions tests/test_rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,13 @@ def test_prepare_input_called_with_correct_data(self):
"model_id",
[
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
"trl-internal-testing/tiny-LlavaNextForConditionalGeneration",
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
"trl-internal-testing/tiny-Qwen2VLForConditionalGeneration",
Expand Down Expand Up @@ -1506,6 +1513,13 @@ def reward_func(completions, **kwargs):
[
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
],
)
@require_vision
Expand Down
14 changes: 14 additions & 0 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,6 +1629,13 @@ def test_tag_added_peft(self):
"model_id",
[
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
# "trl-internal-testing/tiny-Idefics2ForConditionalGeneration", high memory peak, skipped for now
# "trl-internal-testing/tiny-Idefics3ForConditionalGeneration", high memory peak, skipped for now
"trl-internal-testing/tiny-LlavaForConditionalGeneration",
Expand Down Expand Up @@ -1745,6 +1752,13 @@ def test_train_vlm_multi_image(self, model_id):
"trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
# Special case for Gemma, as it uses token_type_ids, and we need to ensure they are properly in the collator:
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
],
)
@require_vision
Expand Down
7 changes: 7 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,13 @@ def test_llm(self, model_id):
"model_id",
[
"trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
pytest.param(
"trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
marks=pytest.mark.skipif(
Version(transformers.__version__) < Version("5.5.0"),
reason="Gemma4 models were introduced in transformers-5.5.0",
),
),
"trl-internal-testing/tiny-Idefics2ForConditionalGeneration",
"trl-internal-testing/tiny-Idefics3ForConditionalGeneration",
"trl-internal-testing/tiny-LlavaForConditionalGeneration",
Expand Down
6 changes: 3 additions & 3 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class DataCollatorForVisionPreference(DataCollatorMixin):
- `"completion_mask"`: Tensor indicating which tokens correspond to completions.
- `"pixel_values"`: Tensor representing image pixel values.
Additional keys may be present depending on the processor, such as `"image_grid_thw"` or `"pixel_position_ids"`.
Additional keys may be present depending on the processor, such as `"image_grid_thw"` or `"image_position_ids"`.
Args:
processor ([`~transformers.ProcessorMixin`]):
Expand Down Expand Up @@ -1035,7 +1035,7 @@ def compute_ref_log_probs(self, inputs):
"pixel_attention_mask",
"image_grid_thw",
"image_sizes",
"pixel_position_ids",
"image_position_ids",
):
if key in inputs:
model_kwargs[key] = inputs[key]
Expand Down Expand Up @@ -1161,7 +1161,7 @@ def _compute_loss(self, model, inputs, return_outputs):
"pixel_attention_mask",
"image_grid_thw",
"image_sizes",
"pixel_position_ids",
"image_position_ids",
):
if key in inputs:
model_kwargs[key] = inputs[key]
Expand Down
40 changes: 10 additions & 30 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,6 @@ def _get_last_hidden_state(
image_grid_thw=None,
pixel_attention_mask=None,
image_sizes=None,
pixel_position_ids=None,
image_position_ids=None,
):
if is_peft_model(unwrapped_model):
Expand All @@ -956,8 +955,6 @@ def _get_last_hidden_state(
# For LLaVa-Next
if image_sizes is not None:
model_inputs["image_sizes"] = image_sizes
if pixel_position_ids is not None:
model_inputs["pixel_position_ids"] = pixel_position_ids
if image_position_ids is not None:
model_inputs["image_position_ids"] = image_position_ids

Expand Down Expand Up @@ -1029,7 +1026,6 @@ def _get_per_token_logps_and_entropies(
image_sizes=None,
token_type_ids=None,
mm_token_type_ids=None,
pixel_position_ids=None,
image_position_ids=None,
) -> dict[str, torch.Tensor | None]:
"""Compute log-probs and (optionally) entropies for each token."""
Expand All @@ -1040,7 +1036,7 @@ def _get_per_token_logps_and_entropies(
input_ids_batch = input_ids[start : start + batch_size]
attention_mask_batch = attention_mask[start : start + batch_size]

# Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't)
# Build model inputs
model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch}
if image_grid_thw is not None and pixel_values is not None:
rows_per_image = image_grid_thw.prod(dim=-1)
Expand All @@ -1052,17 +1048,13 @@ def _get_per_token_logps_and_entropies(
cum_imgs = torch.tensor([0] + num_images).cumsum(0)
img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size]
model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end]
elif image_position_ids is not None and pixel_values is not None:
cum_imgs = torch.tensor([0] + num_images).cumsum(0)
img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size]
model_inputs["pixel_values"] = pixel_values[img_start:img_end]
model_inputs["image_position_ids"] = image_position_ids[img_start:img_end]
elif pixel_values is not None:
if num_images is not None:
# pixel_values is indexed by image, not by sample — slice using cumulative image counts
num_images_list = num_images.tolist() if isinstance(num_images, torch.Tensor) else list(num_images)
cum_imgs = [0]
for n in num_images_list:
cum_imgs.append(cum_imgs[-1] + int(n))
img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size]
model_inputs["pixel_values"] = pixel_values[img_start:img_end]
else:
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
model_inputs["pixel_values"] = pixel_values[start : start + batch_size]
if pixel_attention_mask is not None:
model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size]
if image_sizes is not None:
Expand All @@ -1071,14 +1063,6 @@ def _get_per_token_logps_and_entropies(
model_inputs["token_type_ids"] = token_type_ids[start : start + batch_size]
if mm_token_type_ids is not None:
model_inputs["mm_token_type_ids"] = mm_token_type_ids[start : start + batch_size]
if pixel_position_ids is not None:
model_inputs["pixel_position_ids"] = pixel_position_ids[start : start + batch_size]
if image_position_ids is not None:
# image_position_ids is indexed by image (like pixel_values), not by sample
if num_images is not None and pixel_values is not None:
model_inputs["image_position_ids"] = image_position_ids[img_start:img_end]
else:
model_inputs["image_position_ids"] = image_position_ids[start : start + batch_size]

# Only add logits_to_keep if the model supports it
if "logits_to_keep" in self.model_kwarg_keys:
Expand Down Expand Up @@ -2203,7 +2187,7 @@ def _generate_and_score_completions(
logits_to_keep,
batch_size,
num_images=num_images,
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, pixel_position_ids, image_position_ids
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, image_position_ids
)
else:
old_per_token_logps = None
Expand Down Expand Up @@ -2249,7 +2233,7 @@ def _generate_and_score_completions(
logits_to_keep,
batch_size=batch_size,
num_images=num_images,
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, pixel_position_ids, image_position_ids
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, image_position_ids
)
else:
# When training a PEFT adapter, how we obtain the reference depends on the setup:
Expand All @@ -2264,7 +2248,7 @@ def _generate_and_score_completions(
logits_to_keep,
batch_size=batch_size,
num_images=num_images,
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, pixel_position_ids, image_position_ids
**forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask, image_sizes, image_position_ids
)
else:
ref_per_token_logps = None
Expand Down Expand Up @@ -2443,8 +2427,6 @@ def _generate_and_score_completions(
output["token_type_ids"] = forward_kwargs["token_type_ids"]
if "mm_token_type_ids" in forward_kwargs:
output["mm_token_type_ids"] = forward_kwargs["mm_token_type_ids"]
if "pixel_position_ids" in forward_kwargs:
output["pixel_position_ids"] = forward_kwargs["pixel_position_ids"]
if "image_position_ids" in forward_kwargs:
output["image_position_ids"] = forward_kwargs["image_position_ids"]
if images is not None:
Expand All @@ -2471,7 +2453,6 @@ def compute_liger_loss(self, unwrapped_model, inputs):
inputs.get("image_grid_thw"),
inputs.get("pixel_attention_mask"),
inputs.get("image_sizes"),
inputs.get("pixel_position_ids"),
inputs.get("image_position_ids"),
)

Expand Down Expand Up @@ -2608,7 +2589,6 @@ def _compute_loss(self, model, inputs):
image_sizes=inputs.get("image_sizes"),
token_type_ids=inputs.get("token_type_ids"),
mm_token_type_ids=inputs.get("mm_token_type_ids"),
pixel_position_ids=inputs.get("pixel_position_ids"),
image_position_ids=inputs.get("image_position_ids"),
)

Expand Down
Loading
Loading