fix(async-grpo): honor model init dtype#5416
fix(async-grpo): honor model init dtype#5416kdubovikov wants to merge 2 commits intohuggingface:mainfrom
Conversation
There was a problem hiding this comment.
Cursor Bugbot has reviewed your changes and found 1 potential issue.
Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.
| raise ValueError( | ||
| "Invalid `dtype` passed to `AsyncGRPOConfig`. Expected either 'auto' or a string " | ||
| f"representing a `torch.dtype` (e.g., 'float32'), but got {dtype}." | ||
| ) from exc |
There was a problem hiding this comment.
Missing isinstance check after getattr allows non-dtype values
Low Severity
After getattr(torch, dtype) resolves a string, the code only catches AttributeError but never verifies the result is actually a torch.dtype. A string like "tensor" or "cuda" would successfully resolve via getattr (returning torch.tensor function or torch.cuda module) without raising AttributeError, and the non-dtype value would be silently forwarded to from_pretrained. The CPO and BCO trainers guard against this with an explicit not isinstance(dtype, torch.dtype) check after the getattr call.


What does this PR do?
This PR fixes
AsyncGRPOTrainer's string-model load path so it no longer hardcodesdtype=torch.float32.Before this change,
AsyncGRPOTraineralways instantiated models from a model name with fp32, even when the trainer config was using bf16/fp16 defaults. That made async GRPO behave inconsistently with other TRL trainers and could increase memory usage enough to trigger avoidable OOMs on larger runs.This PR brings async GRPO in line with the existing TRL model-init pattern by:
model_init_kwargstoAsyncGRPOConfigAsyncGRPOTrainerdtype="auto""bfloat16"intotorch.dtypedevice_map=Nonedefault unless explicitly overriddenIt also adds regression tests covering:
dtype="auto"pathBefore submitting
AI writing disclosure
We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.
Note
Medium Risk
Changes
AsyncGRPOTrainermodel loading defaults and dtype normalization, which can alter numerical precision and memory usage across runs. Risk is mitigated by added regression tests but could affect users relying on the previous fp32-forced behavior.Overview
Fixes
AsyncGRPOTrainerstring-model loading to stop hardcoding fp32 and instead honor configurable model init settings.Adds
model_init_kwargstoAsyncGRPOConfigand updates trainer initialization to defaultdtypeto"auto", normalize string dtypes (e.g."bfloat16"�torch.bfloat16), validate invalid values, and preservedevice_map=Noneunless overridden.Introduces new tests asserting the default
dtype="auto"behavior, dtype string conversion, and invalid dtype rejection viaValueError.Written by Cursor Bugbot for commit 4262968. This will update automatically on new commits. Configure here.