Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion whisper/mlx_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def transcribe(
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
carry_initial_prompt: bool = False,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
Expand Down Expand Up @@ -126,6 +127,11 @@ def transcribe(
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.

carry_initial_prompt: bool
If True, the `initial_prompt` is forcefully carried forward into the context window for all
subsequent text segments. This ensures the model does not forget prompt-engineered context
over long audio files.

decode_options: dict
Keyword arguments to construct `DecodingOptions` instances

Expand Down Expand Up @@ -293,7 +299,27 @@ def new_segment(
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype)

decode_options["prompt"] = all_tokens[prompt_reset_since:]
prompt_tokens = all_tokens[prompt_reset_since:]
if carry_initial_prompt and initial_prompt_tokens:
# Extract previous text tokens, removing the initial prompt if it's currently at the beginning
prev_tokens = (
prompt_tokens[len(initial_prompt_tokens) :]
if prompt_reset_since == 0
else prompt_tokens
)
# Calculate available space across context window
max_prompt_length = model.dims.n_text_ctx // 2 - 1
max_prev_length = max(
0, max_prompt_length - len(initial_prompt_tokens)
)
# Retain initial prompt and latest previous tokens
prompt_tokens = (
initial_prompt_tokens + prev_tokens[-max_prev_length:]
if max_prev_length > 0
else initial_prompt_tokens
)

decode_options["prompt"] = prompt_tokens
result: DecodingResult = decode_with_fallback(mel_segment)

tokens = np.array(result.tokens)
Expand Down
11 changes: 11 additions & 0 deletions whisper/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,17 @@ def test_transcribe(self):
),
)

def test_carry_initial_prompt(self):
result = mlx_whisper.transcribe(
TEST_AUDIO,
path_or_hf_repo=MLX_FP32_MODEL_PATH,
fp16=False,
initial_prompt="A test prompt.",
carry_initial_prompt=True,
)
self.assertIn("text", result)
self.assertGreater(len(result["text"]), 0)

def test_transcribe_alice(self):
audio_file = os.path.join(
os.path.expanduser("~"),
Expand Down