diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 57c529b299..7a50deeda6 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -319,6 +319,9 @@ class GRPOConfig(_BaseConfig): Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed, it prints the sample. If `wandb` and/or `trackio` logging is enabled, it logs it to `wandb` and/or `trackio`. + log_multimodal (`bool`, *optional*, defaults to `True`): + Wether to log multimodal content (images, videos, etc.) together with completions. Disable this to reduce + log size when using high-resolution multimodal data. num_completions_to_print (`int`, *optional*): Number of completions to print with `rich`. If `None`, all completions are logged. log_unique_prompts (`bool`, *optional*, defaults to `False`): @@ -844,6 +847,13 @@ class GRPOConfig(_BaseConfig): "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`." }, ) + log_multimodal: bool = field( + default=True, + metadata={ + "help": "Whether to log multimodal content (images, videos, etc.) together with completions. Disable this " + "to reduce log size when using high-resolution multimodal data." + }, + ) num_completions_to_print: int | None = field( default=None, metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."}, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c5eed09419..b7272e3592 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -699,6 +699,7 @@ def cast_outputs_to_original_dtype(module, args, output): self._total_train_tokens = 0 self._current_train_step_time = 0.0 self.log_completions = args.log_completions + self.log_multimodal = args.log_multimodal self.log_unique_prompts = args.log_unique_prompts self.num_completions_to_print = args.num_completions_to_print # Keep logs sized to the generation batch to record only outputs from the latest model update. @@ -2049,7 +2050,7 @@ def _generate_and_score_completions( self._metrics[mode][name].append(global_mean) self._pending_metrics.clear() - if images is not None: + if images is not None and self.log_multimodal: self._logs["images"].extend(gather_object(images)) if self.use_vllm and self.vllm_importance_sampling_correction: diff --git a/trl/trainer/rloo_config.py b/trl/trainer/rloo_config.py index 8cdb0335a5..271dc4df15 100644 --- a/trl/trainer/rloo_config.py +++ b/trl/trainer/rloo_config.py @@ -200,6 +200,9 @@ class RLOOConfig(_BaseConfig): Whether to log a sample of (prompt, completion) pairs every `logging_steps` steps. If `rich` is installed, it prints the sample. If `wandb` and/or `trackio` logging is enabled, it logs it to `wandb` and/or `trackio`. + log_multimodal (`bool`, *optional*, defaults to `True`): + Wether to log multimodal content (images, videos, etc.) together with completions. Disable this to reduce + log size when using high-resolution multimodal data. num_completions_to_print (`int`, *optional*): Number of completions to print with `rich`. If `None`, all completions are logged. log_unique_prompts (`bool`, *optional*, defaults to `False`): @@ -532,6 +535,13 @@ class RLOOConfig(_BaseConfig): "installed, it prints the sample. If `wandb` logging is enabled, it logs it to `wandb`." }, ) + log_multimodal: bool = field( + default=True, + metadata={ + "help": "Whether to log multimodal content (images, videos, etc.) together with completions. Disable this " + "to reduce log size when using high-resolution multimodal data." + }, + ) num_completions_to_print: int | None = field( default=None, metadata={"help": "Number of completions to print with `rich`. If `None`, all completions are logged."}, diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 4e62310e7c..25d4d83535 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -474,6 +474,7 @@ def __init__( self._total_train_tokens = 0 self._current_train_step_time = 0.0 self.log_completions = args.log_completions + self.log_multimodal = args.log_multimodal self.log_unique_prompts = args.log_unique_prompts self.num_completions_to_print = args.num_completions_to_print # Keep logs sized to the generation batch to record only outputs from the latest model update. @@ -1338,7 +1339,7 @@ def _generate_and_score_completions( self._metrics[mode][name].append(global_mean) self._pending_metrics.clear() - if images is not None: + if images is not None and self.log_multimodal: self._logs["images"].extend(gather_object(images)) output = {