Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,9 @@ class GRPOConfig(_BaseConfig):
Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed,
it prints the sample. If `wandb` and/or `trackio` logging is enabled, it logs it to `wandb` and/or
`trackio`.
log_multimodal (`bool`, *optional*, defaults to `True`):
Wether to log multimodal content (images, videos, etc.) together with completions. Disable this to reduce
log size when using high-resolution multimodal data.
num_completions_to_print (`int`, *optional*):
Number of completions to print with `rich`. If `None`, all completions are logged.
log_unique_prompts (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -844,6 +847,13 @@ class GRPOConfig(_BaseConfig):
"installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
},
)
log_multimodal: bool = field(
default=True,
metadata={
"help": "Whether to log multimodal content (images, videos, etc.) together with completions. Disable this "
"to reduce log size when using high-resolution multimodal data."
},
)
num_completions_to_print: int | None = field(
default=None,
metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."},
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,7 @@ def cast_outputs_to_original_dtype(module, args, output):
self._total_train_tokens = 0
self._current_train_step_time = 0.0
self.log_completions = args.log_completions
self.log_multimodal = args.log_multimodal
self.log_unique_prompts = args.log_unique_prompts
self.num_completions_to_print = args.num_completions_to_print
# Keep logs sized to the generation batch to record only outputs from the latest model update.
Expand Down Expand Up @@ -2049,7 +2050,7 @@ def _generate_and_score_completions(
self._metrics[mode][name].append(global_mean)
self._pending_metrics.clear()

if images is not None:
if images is not None and self.log_multimodal:
self._logs["images"].extend(gather_object(images))

if self.use_vllm and self.vllm_importance_sampling_correction:
Expand Down
10 changes: 10 additions & 0 deletions trl/trainer/rloo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ class RLOOConfig(_BaseConfig):
Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed,
it prints the sample. If `wandb` and/or `trackio` logging is enabled, it logs it to `wandb` and/or
`trackio`.
log_multimodal (`bool`, *optional*, defaults to `True`):
Wether to log multimodal content (images, videos, etc.) together with completions. Disable this to reduce
log size when using high-resolution multimodal data.
num_completions_to_print (`int`, *optional*):
Number of completions to print with `rich`. If `None`, all completions are logged.
log_unique_prompts (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -532,6 +535,13 @@ class RLOOConfig(_BaseConfig):
"installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`."
},
)
log_multimodal: bool = field(
default=True,
metadata={
"help": "Whether to log multimodal content (images, videos, etc.) together with completions. Disable this "
"to reduce log size when using high-resolution multimodal data."
},
)
num_completions_to_print: int | None = field(
default=None,
metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."},
Expand Down
3 changes: 2 additions & 1 deletion trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ def __init__(
self._total_train_tokens = 0
self._current_train_step_time = 0.0
self.log_completions = args.log_completions
self.log_multimodal = args.log_multimodal
self.log_unique_prompts = args.log_unique_prompts
self.num_completions_to_print = args.num_completions_to_print
# Keep logs sized to the generation batch to record only outputs from the latest model update.
Expand Down Expand Up @@ -1338,7 +1339,7 @@ def _generate_and_score_completions(
self._metrics[mode][name].append(global_mean)
self._pending_metrics.clear()

if images is not None:
if images is not None and self.log_multimodal:
self._logs["images"].extend(gather_object(images))

output = {
Expand Down