diff --git a/system/athena/athenad.py b/system/athena/athenad.py index de09c2da7637ef..36f1d848c2062f 100755 --- a/system/athena/athenad.py +++ b/system/athena/athenad.py @@ -28,7 +28,7 @@ create_connection) import cereal.messaging as messaging -from cereal import log +from cereal import car, log from cereal.services import SERVICE_LIST from openpilot.common.api import Api, get_key_pair from openpilot.common.utils import CallbackReader, get_upload_stream @@ -44,6 +44,7 @@ ATHENA_HOST = os.getenv('ATHENA_HOST', 'wss://athena.comma.ai') HANDLER_THREADS = int(os.getenv('HANDLER_THREADS', "4")) LOCAL_PORT_WHITELIST = {22, } # SSH +WEBRTCD_PORT = 5001 LOG_ATTR_NAME = 'user.upload' LOG_ATTR_VALUE_MAX_UNIX_TIME = int.to_bytes(2147483647, 4, sys.byteorder) @@ -536,6 +537,16 @@ def getSshAuthorizedKeys() -> str: def getGithubUsername() -> str: return cast(str, Params().get("GithubUsername") or "") + +@dispatcher.add_method +def getNotCar() -> bool: + cp_bytes = Params().get("CarParamsPersistent") + if cp_bytes is not None: + with car.CarParams.from_bytes(cp_bytes) as CP: + return CP.notCar + return False + + @dispatcher.add_method def getSimInfo(): return HARDWARE.get_sim_info() @@ -557,6 +568,35 @@ def getNetworks(): return HARDWARE.get_networks() +@dispatcher.add_method +def startStream(sdp: str) -> dict: + from openpilot.system.webrtc.webrtcd import StreamRequestBody + bridge_services_in = [] + + # get live car params to avoid stale notCar edge case + cp_bytes = Params().get("CarParams") + if cp_bytes is not None: + with car.CarParams.from_bytes(cp_bytes) as CP: + if CP.notCar: + bridge_services_in.append("testJoystick") + + body = StreamRequestBody(sdp, ["driver"], bridge_services_in, ["carState"]) + try: + resp = requests.post(f"http://localhost:{WEBRTCD_PORT}/stream", + json=asdict(body), timeout=10) + if not resp.ok: + try: + error_body = resp.json() + raise Exception(error_body.get("message", f"webrtcd returned {resp.status_code}")) + except ValueError: + resp.raise_for_status() + return resp.json() + except requests.ConnectTimeout as e: + raise Exception("webrtc took too long to respond. is it on?") from e + except requests.ConnectionError as e: + raise Exception("webrtc is not running. turn on comma body ignition.") from e + + @dispatcher.add_method def takeSnapshot() -> str | dict[str, str] | None: from openpilot.system.camerad.snapshot import jpeg_write, snapshot diff --git a/system/webrtc/webrtcd.py b/system/webrtc/webrtcd.py index d2c90cafb5b2e6..7a51d45034e42b 100755 --- a/system/webrtc/webrtcd.py +++ b/system/webrtc/webrtcd.py @@ -2,6 +2,7 @@ import argparse import asyncio +import contextlib import json import uuid import logging @@ -83,11 +84,16 @@ def start(self): assert self.task is None self.task = asyncio.create_task(self.run()) - def stop(self): - if self.task is None or self.task.done(): + async def stop(self): + if self.task is None: return - self.task.cancel() + task = self.task self.task = None + if task.done(): + return + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task async def run(self): from aiortc.exceptions import InvalidStateError @@ -145,24 +151,29 @@ def __init__(self, sdp: str, cameras: list[str], incoming_services: list[str], o self.outgoing_bridge_runner = CerealProxyRunner(self.outgoing_bridge) self.run_task: asyncio.Task | None = None + self._cleanup_lock = asyncio.Lock() + self._cleanup_done = False self.logger = logging.getLogger("webrtcd") - self.logger.info("New stream session (%s), cameras %s, incoming services %s, outgoing services %s", - self.identifier, cameras, incoming_services, outgoing_services) + self.logger.info( + "New stream session (%s), cameras %s, incoming services %s, outgoing services %s", + self.identifier, cameras, incoming_services, outgoing_services, + ) def start(self): self.run_task = asyncio.create_task(self.run()) - def stop(self): - if self.run_task.done(): - return - self.run_task.cancel() + async def stop(self): + if self.run_task is not None and not self.run_task.done() and self.run_task is not asyncio.current_task(): + self.run_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self.run_task self.run_task = None - asyncio.run(self.post_run_cleanup()) + await self.post_run_cleanup() async def get_answer(self): return await self.stream.start() - async def message_handler(self, message: bytes): + def message_handler(self, message: bytes): assert self.incoming_bridge is not None try: self.incoming_bridge.send(message) @@ -183,16 +194,21 @@ async def run(self): self.logger.info("Stream session (%s) connected", self.identifier) await self.stream.wait_for_disconnection() - await self.post_run_cleanup() self.logger.info("Stream session (%s) ended", self.identifier) except Exception: self.logger.exception("Stream session failure") + finally: + await self.post_run_cleanup() async def post_run_cleanup(self): - await self.stream.stop() - if self.outgoing_bridge is not None: - self.outgoing_bridge_runner.stop() + async with self._cleanup_lock: + if self._cleanup_done: + return + self._cleanup_done = True + if self.outgoing_bridge_runner is not None: + await self.outgoing_bridge_runner.stop() + await self.stream.stop() @dataclass @@ -208,11 +224,33 @@ async def get_stream(request: 'web.Request'): raw_body = await request.json() body = StreamRequestBody(**raw_body) - session = StreamSession(body.sdp, body.cameras, body.bridge_services_in, body.bridge_services_out, debug_mode) - answer = await session.get_answer() - session.start() + async with request.app['stream_lock']: + # Fully disconnect any other active stream before starting the replacement. + for sid, s in list(stream_dict.items()): + if s.run_task and not s.run_task.done(): + try: + ch = s.stream.get_messaging_channel() + ch.send(json.dumps({"type": "connectionReplaced", "data": "Another device has connected, closing this session."})) + except Exception: + pass + await s.stop() + del stream_dict[sid] + + session = StreamSession(body.sdp, body.cameras, body.bridge_services_in, body.bridge_services_out, debug_mode) + try: + answer = await session.get_answer() + except ValueError as e: + await session.stop() + raise web.HTTPBadRequest( + text=json.dumps({"error": "invalid_sdp", "message": str(e)}), + content_type="application/json", + ) from e + except Exception: + await session.stop() + raise + session.start() - stream_dict[session.identifier] = session + stream_dict[session.identifier] = session return web.json_response({"sdp": answer.sdp, "type": answer.type}) @@ -224,6 +262,7 @@ async def get_schema(request: 'web.Request'): schema_dict = {s: generate_field(log.Event.schema.fields[s]) for s in services} return web.json_response(schema_dict) + async def post_notify(request: 'web.Request'): try: payload = await request.json() @@ -239,9 +278,10 @@ async def post_notify(request: 'web.Request'): return web.Response(status=200, text="OK") + async def on_shutdown(app: 'web.Application'): for session in app['streams'].values(): - session.stop() + await session.stop() del app['streams'] @@ -254,6 +294,7 @@ def webrtcd_thread(host: str, port: int, debug: bool): app = web.Application() app['streams'] = dict() + app['stream_lock'] = asyncio.Lock() app['debug'] = debug app.on_shutdown.append(on_shutdown) app.router.add_post("/stream", get_stream)