Fix the deprecated torch.cuda.amp module#21
Conversation
There was a problem hiding this comment.
Pull Request Overview
This PR modernizes the GradScaler import by migrating from the deprecated torch.cuda.amp.GradScaler to the newer torch.amp.GradScaler API, which is device-agnostic. The changes ensure backward compatibility by wrapping the new API with a partial function that defaults to CUDA device when neither NPU nor MLU is available.
- Updated GradScaler import from
torch.cuda.amptotorch.ampin production code and tests - Added a device-specific wrapper using
functools.partialto maintain CUDA as default device
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| mmengine/optim/optimizer/amp_optimizer_wrapper.py | Updated GradScaler import to use torch.amp and added partial wrapper for CUDA device specification |
| tests/test_optim/test_optimizer/test_optimizer_wrapper.py | Updated test imports to use the new torch.amp.GradScaler API |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Documentation build overview
Show files changed (2 files in total): 📝 2 modified | ➕ 0 added | ➖ 0 deleted
|
|
If you fix the linting error I will merge. |
|
@lauriebax Lint fixed. And, as described in open-mmlab#1676:
I suggest to update |
This sub-PR is related to open-mmlab#1665
Brief
According to PyTorch:
This includes two related replacement:
amp_optimizer_wrappertest_optimizer_wrapperPyTest Result After this PR
pytest tests/test_optim/test_optimizer/test_optimizer_wrapper.py