diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0c687af12e..fd276b671f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1025,7 +1025,7 @@ def _get_per_token_logps_and_entropies( token_type_ids=None, mm_token_type_ids=None, image_position_ids=None, - ) -> dict[str, torch.Tensor | None]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: """Compute log-probs and (optionally) entropies for each token.""" batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak all_logps = [] diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 19e9fc8dea..48ba37628f 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -681,7 +681,7 @@ def _get_per_token_logps_and_entropies( token_type_ids=None, mm_token_type_ids=None, image_position_ids=None, - ) -> dict[str, torch.Tensor | None]: + ) -> tuple[torch.Tensor, torch.Tensor | None]: """Compute log-probs and (optionally) entropies for each token.""" batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak all_logps = []