diff --git a/pyod/models/gaal_base.py b/pyod/models/gaal_base.py index a0610be3..5af2e70e 100644 --- a/pyod/models/gaal_base.py +++ b/pyod/models/gaal_base.py @@ -8,12 +8,10 @@ try: import torch -except ImportError: - print('please install torch first') - -import torch -import torch.nn as nn -import torch.nn.functional as F + import torch.nn as nn + import torch.nn.functional as F +except ImportError as e: + raise ImportError('PyTorch is required for GAAL models. Please install it with `pip install torch`.') from e def create_discriminator(latent_size, data_size):