Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mmengine/optim/optimizer/amp_optimizer_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from typing import Union
from functools import partial

import torch
import torch.nn as nn
Expand All @@ -17,7 +18,8 @@
elif is_mlu_available():
from torch.mlu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler
GradScaler = partial(GradScaler, device='cuda')


@OPTIM_WRAPPERS.register_module()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_optim/test_optimizer/test_optimizer_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch.distributed as torch_dist
import torch.nn as nn
from parameterized import parameterized
from torch.cuda.amp import GradScaler
from torch.amp import GradScaler
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD, Adam, Optimizer

Expand Down
Loading