diff --git a/mmengine/optim/optimizer/amp_optimizer_wrapper.py b/mmengine/optim/optimizer/amp_optimizer_wrapper.py index 4f3323f2cc..b60f3d09b8 100644 --- a/mmengine/optim/optimizer/amp_optimizer_wrapper.py +++ b/mmengine/optim/optimizer/amp_optimizer_wrapper.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from contextlib import contextmanager +from functools import partial from typing import Union import torch @@ -17,7 +18,8 @@ elif is_mlu_available(): from torch.mlu.amp import GradScaler else: - from torch.cuda.amp import GradScaler + from torch.amp import GradScaler + GradScaler = partial(GradScaler, device='cuda') @OPTIM_WRAPPERS.register_module() diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index ef1db241dd..ba4ca77d11 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -8,7 +8,7 @@ import torch.distributed as torch_dist import torch.nn as nn from parameterized import parameterized -from torch.cuda.amp import GradScaler +from torch.amp import GradScaler from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import SGD, Adam, Optimizer