diff --git a/whisper/decoding.py b/whisper/decoding.py index 49485d009..0a5143cc8 100644 --- a/whisper/decoding.py +++ b/whisper/decoding.py @@ -655,9 +655,9 @@ def _get_audio_features(self, mel: Tensor): audio_features = self.model.encoder(mel) if audio_features.dtype != ( - torch.float16 if self.options.fp16 else torch.float32 + torch.float16 if self.options.fp16 else torch.float32 ): - return TypeError( + raise TypeError( f"audio_features has an incorrect dtype: {audio_features.dtype}" )