Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,13 +524,23 @@ def valid_model_name(name):
f"model should be one of {available_models()} or path to a model checkpoint"
)

def valid_output_dir(path):
try:
os.makedirs(path, exist_ok=True)
except Exception as e:
raise argparse.ArgumentTypeError(f"Cannot create output directory {path}: {e}")

if not os.access(path, os.W_OK | os.X_OK):
raise argparse.ArgumentTypeError(f"Lack write/execute permission for output directory: {path}")
return path

# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="turbo", type=valid_model_name, help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_dir", "-o", type=valid_output_dir, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")

Expand Down Expand Up @@ -572,7 +582,6 @@ def valid_model_name(name):
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True)

if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
Expand All @@ -592,9 +601,10 @@ def valid_model_name(name):

from . import load_model

writer = get_writer(output_format, output_dir)

model = load_model(model_name, device=device, download_root=model_dir)

writer = get_writer(output_format, output_dir)
word_options = [
"highlight_words",
"max_line_count",
Expand Down
6 changes: 6 additions & 0 deletions whisper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ class ResultWriter:
extension: str

def __init__(self, output_dir: str):
if not os.path.exists(output_dir):
raise FileNotFoundError(f"Output directory does not exist: {output_dir}")
if not os.path.isdir(output_dir):
raise NotADirectoryError(f"Output path is not a directory: {output_dir}")
if not os.access(output_dir, os.W_OK | os.X_OK):
raise PermissionError(f"Lack write/execute permission for output directory: {output_dir}")
self.output_dir = output_dir

def __call__(
Expand Down