diff --git a/whisper/__init__.py b/whisper/__init__.py index f284ec045..decc67463 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -103,7 +103,7 @@ def available_models() -> List[str]: def load_model( name: str, device: Optional[Union[str, torch.device]] = None, - download_root: str = None, + download_root: Optional[str] = None, in_memory: bool = False, ) -> Whisper: """ diff --git a/whisper/timing.py b/whisper/timing.py index 2340000bc..6379f0cd2 100644 --- a/whisper/timing.py +++ b/whisper/timing.py @@ -27,9 +27,9 @@ def median_filter(x: torch.Tensor, filter_width: int): # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D x = x[None, None, :] - assert ( - filter_width > 0 and filter_width % 2 == 1 - ), "`filter_width` should be an odd number" + assert filter_width > 0 and filter_width % 2 == 1, ( + "`filter_width` should be an odd number" + ) result = None x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect") @@ -211,7 +211,7 @@ def find_alignment( weights = (weights - mean) / std weights = median_filter(weights, medfilt_width) - matrix = weights.mean(axis=0) + matrix = weights.mean(dim=0) matrix = matrix[len(tokenizer.sot_sequence) : -1] text_indices, time_indices = dtw(-matrix) diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 0a4cc3623..bdca217d4 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -561,7 +561,7 @@ def valid_model_name(name): parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line") parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment") parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment") - parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS") + parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supersedes MKL_NUM_THREADS/OMP_NUM_THREADS") parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file") parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected") # fmt: on