Skip to content
42 changes: 41 additions & 1 deletion system/athena/athenad.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
create_connection)

import cereal.messaging as messaging
from cereal import log
from cereal import car, log
from cereal.services import SERVICE_LIST
from openpilot.common.api import Api, get_key_pair
from openpilot.common.utils import CallbackReader, get_upload_stream
Expand All @@ -44,6 +44,7 @@
ATHENA_HOST = os.getenv('ATHENA_HOST', 'wss://athena.comma.ai')
HANDLER_THREADS = int(os.getenv('HANDLER_THREADS', "4"))
LOCAL_PORT_WHITELIST = {22, } # SSH
WEBRTCD_PORT = 5001

LOG_ATTR_NAME = 'user.upload'
LOG_ATTR_VALUE_MAX_UNIX_TIME = int.to_bytes(2147483647, 4, sys.byteorder)
Expand Down Expand Up @@ -536,6 +537,16 @@ def getSshAuthorizedKeys() -> str:
def getGithubUsername() -> str:
return cast(str, Params().get("GithubUsername") or "")


@dispatcher.add_method
def getNotCar() -> bool:
Comment thread
stefpi marked this conversation as resolved.
Comment on lines +541 to +542

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

side note: we should probably come up with a better name than notCar... at the very least the negative is confusing

cp_bytes = Params().get("CarParamsPersistent")
if cp_bytes is not None:
with car.CarParams.from_bytes(cp_bytes) as CP:
return CP.notCar
return False


@dispatcher.add_method
def getSimInfo():
return HARDWARE.get_sim_info()
Expand All @@ -557,6 +568,35 @@ def getNetworks():
return HARDWARE.get_networks()


@dispatcher.add_method
def startStream(sdp: str) -> dict:
from openpilot.system.webrtc.webrtcd import StreamRequestBody
bridge_services_in = []

# get live car params to avoid stale notCar edge case
cp_bytes = Params().get("CarParams")
if cp_bytes is not None:
with car.CarParams.from_bytes(cp_bytes) as CP:
if CP.notCar:
bridge_services_in.append("testJoystick")

body = StreamRequestBody(sdp, ["driver"], bridge_services_in, ["carState"])
try:
resp = requests.post(f"http://localhost:{WEBRTCD_PORT}/stream",
json=asdict(body), timeout=10)
if not resp.ok:
try:
error_body = resp.json()
raise Exception(error_body.get("message", f"webrtcd returned {resp.status_code}"))
except ValueError:
resp.raise_for_status()
return resp.json()
except requests.ConnectTimeout as e:
raise Exception("webrtc took too long to respond. is it on?") from e
except requests.ConnectionError as e:
raise Exception("webrtc is not running. turn on comma body ignition.") from e


@dispatcher.add_method
def takeSnapshot() -> str | dict[str, str] | None:
from openpilot.system.camerad.snapshot import jpeg_write, snapshot
Expand Down
81 changes: 61 additions & 20 deletions system/webrtc/webrtcd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import argparse
import asyncio
import contextlib
import json
import uuid
import logging
Expand Down Expand Up @@ -83,11 +84,16 @@ def start(self):
assert self.task is None
self.task = asyncio.create_task(self.run())

def stop(self):
if self.task is None or self.task.done():
async def stop(self):
if self.task is None:
return
self.task.cancel()
task = self.task
self.task = None
if task.done():
return
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task

async def run(self):
from aiortc.exceptions import InvalidStateError
Expand Down Expand Up @@ -145,24 +151,29 @@ def __init__(self, sdp: str, cameras: list[str], incoming_services: list[str], o
self.outgoing_bridge_runner = CerealProxyRunner(self.outgoing_bridge)

self.run_task: asyncio.Task | None = None
self._cleanup_lock = asyncio.Lock()
self._cleanup_done = False
self.logger = logging.getLogger("webrtcd")
self.logger.info("New stream session (%s), cameras %s, incoming services %s, outgoing services %s",
self.identifier, cameras, incoming_services, outgoing_services)
self.logger.info(
"New stream session (%s), cameras %s, incoming services %s, outgoing services %s",
self.identifier, cameras, incoming_services, outgoing_services,
)

def start(self):
self.run_task = asyncio.create_task(self.run())

def stop(self):
if self.run_task.done():
return
self.run_task.cancel()
async def stop(self):
if self.run_task is not None and not self.run_task.done() and self.run_task is not asyncio.current_task():
self.run_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self.run_task
self.run_task = None
asyncio.run(self.post_run_cleanup())
await self.post_run_cleanup()

async def get_answer(self):
return await self.stream.start()

async def message_handler(self, message: bytes):
def message_handler(self, message: bytes):
assert self.incoming_bridge is not None
try:
self.incoming_bridge.send(message)
Expand All @@ -183,16 +194,21 @@ async def run(self):
self.logger.info("Stream session (%s) connected", self.identifier)

await self.stream.wait_for_disconnection()
await self.post_run_cleanup()

self.logger.info("Stream session (%s) ended", self.identifier)
except Exception:
self.logger.exception("Stream session failure")
finally:
await self.post_run_cleanup()

async def post_run_cleanup(self):
await self.stream.stop()
if self.outgoing_bridge is not None:
self.outgoing_bridge_runner.stop()
async with self._cleanup_lock:
if self._cleanup_done:
return
self._cleanup_done = True
if self.outgoing_bridge_runner is not None:
await self.outgoing_bridge_runner.stop()
await self.stream.stop()


@dataclass
Expand All @@ -208,11 +224,33 @@ async def get_stream(request: 'web.Request'):
raw_body = await request.json()
body = StreamRequestBody(**raw_body)

session = StreamSession(body.sdp, body.cameras, body.bridge_services_in, body.bridge_services_out, debug_mode)
answer = await session.get_answer()
session.start()
async with request.app['stream_lock']:
# Fully disconnect any other active stream before starting the replacement.
for sid, s in list(stream_dict.items()):
if s.run_task and not s.run_task.done():
try:
ch = s.stream.get_messaging_channel()
ch.send(json.dumps({"type": "connectionReplaced", "data": "Another device has connected, closing this session."}))
except Exception:
pass
await s.stop()
del stream_dict[sid]

session = StreamSession(body.sdp, body.cameras, body.bridge_services_in, body.bridge_services_out, debug_mode)
try:
answer = await session.get_answer()
except ValueError as e:
await session.stop()
raise web.HTTPBadRequest(
text=json.dumps({"error": "invalid_sdp", "message": str(e)}),
content_type="application/json",
) from e
except Exception:
await session.stop()
raise
session.start()

stream_dict[session.identifier] = session
stream_dict[session.identifier] = session

return web.json_response({"sdp": answer.sdp, "type": answer.type})

Expand All @@ -224,6 +262,7 @@ async def get_schema(request: 'web.Request'):
schema_dict = {s: generate_field(log.Event.schema.fields[s]) for s in services}
return web.json_response(schema_dict)


async def post_notify(request: 'web.Request'):
try:
payload = await request.json()
Expand All @@ -239,9 +278,10 @@ async def post_notify(request: 'web.Request'):

return web.Response(status=200, text="OK")


async def on_shutdown(app: 'web.Application'):
for session in app['streams'].values():
session.stop()
await session.stop()
del app['streams']


Expand All @@ -254,6 +294,7 @@ def webrtcd_thread(host: str, port: int, debug: bool):
app = web.Application()

app['streams'] = dict()
app['stream_lock'] = asyncio.Lock()
app['debug'] = debug
app.on_shutdown.append(on_shutdown)
app.router.add_post("/stream", get_stream)
Expand Down
Loading