Skip to content
8 changes: 7 additions & 1 deletion livekit-agents/livekit/agents/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,13 @@ def push_frame(self, frame: rtc.AudioFrame) -> None:
self._input_ch.send_nowait(frame)

def flush(self) -> None:
"""Mark the end of the current segment"""
"""Mark the end of the current segment.

Implementations MUST treat this as a hard segment boundary: drop any accumulated
speech/silence state so the next pushed frame starts a fresh segment. Used by the
pipeline to recover from out-of-band end-of-turn signals (e.g. STT EOS) without
tearing down and recreating the stream.
"""
self._check_input_not_ended()
self._check_not_closed()
self._input_ch.send_nowait(self._FlushSentinel())
Expand Down
22 changes: 19 additions & 3 deletions livekit-agents/livekit/agents/voice/audio_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..telemetry import trace_types, tracer
from ..types import NOT_GIVEN, NotGivenOr
from ..utils import aio, is_given
from ..vad import VADStream
from . import io
from ._utils import _set_participant_attributes
from .endpointing import BaseEndpointing
Expand Down Expand Up @@ -181,6 +182,7 @@ def __init__(

self._stt_pipeline: _STTPipeline | None = None
self._vad_ch: aio.Chan[rtc.AudioFrame] | None = None
self._vad_stream: VADStream | None = None

self._tasks: set[asyncio.Task[Any]] = set()

Expand Down Expand Up @@ -624,6 +626,7 @@ def update_stt(self, stt: io.STTNode | None, *, pipeline: _STTPipeline | None =
def update_vad(self, vad: vad.VAD | None) -> None:
self._vad = vad
if vad:
self._vad_stream = None
self._vad_ch = aio.Chan[rtc.AudioFrame]()
self._vad_atask = asyncio.create_task(
self._vad_task(vad, self._vad_ch, self._vad_atask)
Expand All @@ -634,6 +637,7 @@ def update_vad(self, vad: vad.VAD | None) -> None:
self._tasks.add(task)
self._vad_atask = None
self._vad_ch = None
self._vad_stream = None

self._interruption_enabled = (
self._interruption_detection is not None and self._vad is not None
Expand Down Expand Up @@ -984,11 +988,20 @@ async def _on_stt_event(self, ev: stt.SpeechEvent) -> None:
# reset VAD so that incorrect end of turn from STT can be corrected by VAD interruption
# if user is still speaking (an immediate VAD SOS will interrupt the agent)
if self._vad:
if self._speaking:
if self._vad_speech_started:
if self._vad_stream is not None:
self._vad_stream.flush()
else:
self.update_vad(self._vad)

logger.warning(
"stt end of speech received while user is speaking, resetting vad"
"stt end of speech received while vad is still in a speech segment, "
"flushing vad",
extra={
"vad_speech_start_time": self._speech_start_time,
"flushed": self._vad_stream is not None,
},
)
self.update_vad(self._vad)

self._speaking = False
self._user_turn_committed = True
Expand Down Expand Up @@ -1290,6 +1303,7 @@ async def _vad_task(
await aio.cancel_and_wait(task)

stream = vad.stream()
self._vad_stream = stream

@utils.log_exceptions(logger=logger)
async def _forward() -> None:
Expand All @@ -1304,6 +1318,8 @@ async def _forward() -> None:
finally:
await aio.cancel_and_wait(forward_task)
await stream.aclose()
if self._vad_stream is stream:
self._vad_stream = None

# reset the speaking state to prevent stuck user speaking state during handoff
if self._speaking:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def window_size_samples(self) -> int:
def context_size(self) -> int:
return self._context_size

def reset(self) -> None:
self._context.fill(0)
self._rnn_state.fill(0)
self._input_buffer.fill(0)

def __call__(self, x: np.ndarray) -> float:
self._input_buffer[:, : self._context_size] = self._context
self._input_buffer[:, self._context_size :] = x
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,51 @@ async def _main_task(self) -> None:

extra_inference_time = 0.0

def _reset_state() -> None:
nonlocal speech_buffer_index
nonlocal pub_speaking, pub_speech_duration, pub_silence_duration
nonlocal pub_current_sample, pub_timestamp
nonlocal speech_threshold_duration, silence_threshold_duration
nonlocal input_frames, inference_frames, resampler
nonlocal input_copy_remaining_fract, extra_inference_time

self._model.reset()
self._exp_filter = utils.ExpFilter(alpha=0.35)

speech_buffer_index = 0
self._speech_buffer_max_reached = False
if self._speech_buffer is not None:
self._speech_buffer.fill(0)

pub_speaking = False
pub_speech_duration = 0.0
pub_silence_duration = 0.0
pub_current_sample = 0
pub_timestamp = 0.0
speech_threshold_duration = 0.0
silence_threshold_duration = 0.0

input_frames = []
inference_frames = []
input_copy_remaining_fract = 0.0
extra_inference_time = 0.0

if self._input_sample_rate and self._input_sample_rate != self._opts.sample_rate:
resampler = rtc.AudioResampler(
input_rate=self._input_sample_rate,
output_rate=self._opts.sample_rate,
quality=rtc.AudioResamplerQuality.QUICK,
)
else:
resampler = None

async for input_frame in self._input_ch:
if isinstance(input_frame, self._FlushSentinel):
_reset_state()
continue

if not isinstance(input_frame, rtc.AudioFrame):
continue # ignore flush sentinel for now
continue

if not self._input_sample_rate:
self._input_sample_rate = input_frame.sample_rate
Expand Down
8 changes: 7 additions & 1 deletion tests/fake_vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import asyncio
import time

from livekit import rtc
from livekit.agents.vad import VAD, VADCapabilities, VADEvent, VADEventType, VADStream

from .fake_stt import FakeUserSpeech
Expand Down Expand Up @@ -41,7 +42,12 @@ async def _main_task(self) -> None:
if not self._vad._fake_user_speeches:
return

await self._input_ch.recv()
async for input_frame in self._input_ch:
if isinstance(input_frame, rtc.AudioFrame):
break
else:
return

start_time = time.perf_counter()

def current_time() -> float:
Expand Down
53 changes: 52 additions & 1 deletion tests/test_agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import logging
import time
from unittest.mock import MagicMock, Mock
from unittest.mock import MagicMock, Mock, patch

import pytest

Expand Down Expand Up @@ -756,6 +756,57 @@ async def test_backchannel_boundary_suppresses_start_boundary_backchannel() -> N
await _close_test_session(session)


async def _make_stt_eos_recognition() -> AudioRecognition:
return AudioRecognition(
create_session(FakeActions()),
hooks=_TestRecognitionHooks(),
endpointing=BaseEndpointing(min_delay=0.0, max_delay=0.0),
stt=None,
vad=None,
interruption_detection=None,
turn_detection="stt",
)


async def test_stt_eos_resets_active_vad_stream_without_restarting_vad() -> None:
recognition = await _make_stt_eos_recognition()
recognition._speaking = True
recognition._vad_speech_started = True
recognition._vad = MagicMock()
resettable_stream = MagicMock()
recognition._vad_stream = resettable_stream

try:
with patch.object(recognition, "update_vad") as update_vad:
await recognition._on_stt_event(SpeechEvent(type=SpeechEventType.END_OF_SPEECH))

resettable_stream.flush.assert_called_once_with()
update_vad.assert_not_called()
assert recognition._vad_stream is resettable_stream
finally:
if recognition._end_of_turn_task is not None:
await aio.cancel_and_wait(recognition._end_of_turn_task)
await _close_test_session(recognition._session)


async def test_stt_eos_falls_back_to_update_vad_when_no_active_stream() -> None:
recognition = await _make_stt_eos_recognition()
recognition._speaking = True
recognition._vad_speech_started = True
recognition._vad = MagicMock()
recognition._vad_stream = None

try:
with patch.object(recognition, "update_vad") as update_vad:
await recognition._on_stt_event(SpeechEvent(type=SpeechEventType.END_OF_SPEECH))

update_vad.assert_called_once_with(recognition._vad)
finally:
if recognition._end_of_turn_task is not None:
await aio.cancel_and_wait(recognition._end_of_turn_task)
await _close_test_session(recognition._session)


async def test_backchannel_boundary_releases_end_boundary_transcript() -> None:
actions = FakeActions()
session = create_session(
Expand Down
59 changes: 59 additions & 0 deletions tests/test_vad.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

import pytest

from livekit.agents import vad
Expand Down Expand Up @@ -60,6 +62,63 @@ async def test_chunks_vad(sample_rate) -> None:
f.write(utils.make_wav_file(inference_frames))


async def _drain_speech_segment(
stream: vad.VADStream, frames: list, *, timeout: float = 30.0
) -> tuple[vad.VADEvent, vad.VADEvent]:
"""Push *frames* until both START_OF_SPEECH and END_OF_SPEECH have fired."""

done = asyncio.Event()

async def _pump() -> None:
for frame in frames:
if done.is_set():
return
stream.push_frame(frame)
await asyncio.sleep(0)

async def _consume() -> tuple[vad.VADEvent, vad.VADEvent]:
sos_event: vad.VADEvent | None = None
async for ev in stream:
if ev.type == vad.VADEventType.START_OF_SPEECH and sos_event is None:
sos_event = ev
elif ev.type == vad.VADEventType.END_OF_SPEECH and sos_event is not None:
return sos_event, ev

raise AssertionError("stream ended before END_OF_SPEECH")

pump_task = asyncio.create_task(_pump())
try:
return await asyncio.wait_for(_consume(), timeout=timeout)
finally:
done.set()
pump_task.cancel()
try:
await pump_task
except asyncio.CancelledError:
pass


async def test_reset_recovers_full_speech_segment() -> None:
"""Real speech audio should still produce a complete SOS + EOS cycle after reset."""

frames, *_ = await utils.make_test_speech(chunk_duration_ms=10, sample_rate=16000)
assert len(frames) > 1, "frames aren't chunked"

stream = VAD.stream()
try:
first_sos, first_eos = await _drain_speech_segment(stream, frames)
assert first_sos.type == vad.VADEventType.START_OF_SPEECH
assert first_eos.type == vad.VADEventType.END_OF_SPEECH

stream.flush()

second_sos, second_eos = await _drain_speech_segment(stream, frames)
assert second_sos.type == vad.VADEventType.START_OF_SPEECH
assert second_eos.type == vad.VADEventType.END_OF_SPEECH
finally:
await stream.aclose()


@pytest.mark.parametrize("sample_rate", SAMPLE_RATES)
async def test_file_vad(sample_rate):
frames, *_ = await utils.make_test_speech(sample_rate=sample_rate)
Expand Down
Loading