Skip to content

fix(async-grpo): honor model init dtype#5416

Open
kdubovikov wants to merge 2 commits intohuggingface:mainfrom
kdubovikov:async-dtype-fix
Open

fix(async-grpo): honor model init dtype#5416
kdubovikov wants to merge 2 commits intohuggingface:mainfrom
kdubovikov:async-dtype-fix

Conversation

@kdubovikov
Copy link
Copy Markdown
Contributor

@kdubovikov kdubovikov commented Mar 31, 2026

What does this PR do?

This PR fixes AsyncGRPOTrainer's string-model load path so it no longer hardcodes dtype=torch.float32.

Before this change, AsyncGRPOTrainer always instantiated models from a model name with fp32, even when the trainer config was using bf16/fp16 defaults. That made async GRPO behave inconsistently with other TRL trainers and could increase memory usage enough to trigger avoidable OOMs on larger runs.

This PR brings async GRPO in line with the existing TRL model-init pattern by:

  • adding model_init_kwargs to AsyncGRPOConfig
  • removing the hardcoded fp32 override from AsyncGRPOTrainer
  • defaulting string-model loading to dtype="auto"
  • validating and normalizing string dtypes like "bfloat16" into torch.dtype
  • preserving the existing device_map=None default unless explicitly overridden

It also adds regression tests covering:

  • the default dtype="auto" path
  • explicit dtype conversion from config
  • invalid dtype rejection

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

AI writing disclosure

We welcome the use of AI tools to help with contributions. For transparency and to help us improve our review process, please indicate the level of AI involvement in this PR.

  • No AI usage: the PR was written entirely by a human.
  • AI-assisted: some parts were suggested or improved by AI, but the PR was written and reviewed by a human.
  • AI-generated: the PR was mostly or fully generated by an AI tool.

Note

Medium Risk
Changes AsyncGRPOTrainer model loading defaults and dtype normalization, which can alter numerical precision and memory usage across runs. Risk is mitigated by added regression tests but could affect users relying on the previous fp32-forced behavior.

Overview
Fixes AsyncGRPOTrainer string-model loading to stop hardcoding fp32 and instead honor configurable model init settings.

Adds model_init_kwargs to AsyncGRPOConfig and updates trainer initialization to default dtype to "auto", normalize string dtypes (e.g. "bfloat16"torch.bfloat16), validate invalid values, and preserve device_map=None unless overridden.

Introduces new tests asserting the default dtype="auto" behavior, dtype string conversion, and invalid dtype rejection via ValueError.

Written by Cursor Bugbot for commit 4262968. This will update automatically on new commits. Configure here.

Copy link
Copy Markdown

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Fix All in Cursor

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

raise ValueError(
"Invalid `dtype` passed to `AsyncGRPOConfig`. Expected either 'auto' or a string "
f"representing a `torch.dtype` (e.g., 'float32'), but got {dtype}."
) from exc
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing isinstance check after getattr allows non-dtype values

Low Severity

After getattr(torch, dtype) resolves a string, the code only catches AttributeError but never verifies the result is actually a torch.dtype. A string like "tensor" or "cuda" would successfully resolve via getattr (returning torch.tensor function or torch.cuda module) without raising AttributeError, and the non-dtype value would be silently forwarded to from_pretrained. The CPO and BCO trainers guard against this with an explicit not isinstance(dtype, torch.dtype) check after the getattr call.

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant