Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions dimos/protocol/pubsub/impl/shmpubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from dimos.protocol.pubsub.encoders import LCMEncoderMixin, PickleEncoderMixin
from dimos.protocol.pubsub.impl.lcmpubsub import Topic
from dimos.protocol.pubsub.shm.ipc_factory import CpuShmChannel
from dimos.protocol.pubsub.shm.ipc_factory import CpuShmChannel, FrameChannel
from dimos.protocol.pubsub.spec import PubSub
from dimos.utils.logging_config import setup_logger

Expand Down Expand Up @@ -61,6 +61,16 @@ class SharedMemoryPubSubBase(PubSub[str, Any]):
- drop initial empty frame; synchronous local delivery; echo suppression
"""

# Frame-channel implementation backing each topic. Streaming keeps the
# default double-buffered CpuShmChannel (latest-wins); subclasses that need
# every message delivered (e.g. ShmRPC) override this with CpuShmQueue.
_channel_class: type[FrameChannel] = CpuShmChannel

# Extra keyword args passed to _channel_class(...) at construction (empty for
# the streaming channel; ShmRPC sets e.g. {"slots": N}). Folded into the
# segment name so distinct layouts never mmap over one another.
_channel_kwargs: dict[str, Any] = {}

# Per-topic state
# TODO: implement "is_cuda" below capacity, above cp
class _TopicState:
Expand Down Expand Up @@ -236,13 +246,27 @@ def _ensure_topic(self, topic: str) -> _TopicState:
cap = int(self.config.default_capacity)

def _names_for_topic(topic: str, capacity: int) -> tuple[str, str]:
# Python's SharedMemory requires names without a leading '/'
# Use shorter digest to avoid macOS shared memory name length limits
h = hashlib.blake2b(f"{topic}:{capacity}".encode(), digest_size=8).hexdigest()
# Python's SharedMemory requires names without a leading '/'.
# Fold the channel class name AND its layout kwargs (e.g. slots)
# into the hash so two pubsubs using different layouts (streaming
# CpuShmChannel vs RPC CpuShmQueue, or two ring sizes) on the same
# topic get distinct segments rather than mmapping incompatible
# layouts over each other.
# Use a short digest to avoid macOS shared memory name length limits.
layout = f"{self._channel_class.__name__}:{sorted(self._channel_kwargs.items())}"
h = hashlib.blake2b(
f"{topic}:{capacity}:{layout}".encode(), digest_size=8
).hexdigest()
return f"psm_{h}_data", f"psm_{h}_ctrl"
Comment thread
Dreamsorcerer marked this conversation as resolved.

data_name, ctrl_name = _names_for_topic(topic, cap)
ch = CpuShmChannel((cap + 20,), np.uint8, data_name=data_name, ctrl_name=ctrl_name)
ch = self._channel_class(
(cap + 20,),
np.uint8,
data_name=data_name,
ctrl_name=ctrl_name,
**self._channel_kwargs,
)
st = SharedMemoryPubSubBase._TopicState(ch, cap, None)
self._topics[topic] = st
return st
Expand Down
274 changes: 259 additions & 15 deletions dimos/protocol/pubsub/shm/ipc_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@
from multiprocessing import resource_tracker
from multiprocessing.shared_memory import SharedMemory
import os
import threading
import time
from typing import Any

import numpy as np
from numpy.typing import DTypeLike, NDArray

from dimos.utils.logging_config import setup_logger

_UNLINK_ON_GC = os.getenv("DIMOS_IPC_UNLINK_ON_GC", "0").lower() not in ("0", "false", "no")

logger = setup_logger()


def _unregister(shm: SharedMemory) -> SharedMemory:
"""Remove a SharedMemory segment from the resource tracker.
Expand Down Expand Up @@ -55,11 +62,26 @@ def _open_shm_with_retry(name: str) -> SharedMemory:


class FrameChannel(ABC):
"""Single-slot 'freshest frame' IPC channel with a tiny control block.
- Double-buffered to avoid torn reads.
- Descriptor is JSON-safe; attach() reconstructs in another process.
"""Shared-memory IPC channel carrying frames behind a tiny control block.

Implementations range from a single-slot double-buffered 'freshest frame'
channel (CpuShmChannel) to a multi-slot ring buffer for reliable delivery
(CpuShmQueue). Descriptor is JSON-safe; attach() reconstructs in another
process.
"""

@abstractmethod
def __init__(
self,
shape: tuple[int, ...],
dtype: DTypeLike = np.uint8,
*,
data_name: str | None = None,
ctrl_name: str | None = None,
) -> None:
"""Create (or attach by name) the channel's shared-memory segments."""
...

@property
@abstractmethod
def device(self) -> str: # "cpu" or "cuda"
Expand Down Expand Up @@ -120,6 +142,20 @@ def _safe_unlink(name: str) -> None:
pass


def _create_or_open(name: str, size: int) -> tuple[SharedMemory, bool]:
"""Create a named SHM segment (owner) or attach to an existing one (reader)."""
try:
# Owner: leave registered because unlink() will unregister, and
# the tracker serves as safety net if the process crashes.
shm = SharedMemory(create=True, size=size, name=name)
owner = True
except FileExistsError:
# Reader: unregister because we only close(), never unlink().
shm = _unregister(SharedMemory(name=name))
owner = False
return shm, owner


class CpuShmChannel(FrameChannel):
def __init__( # type: ignore[no-untyped-def]
self,
Expand All @@ -133,18 +169,6 @@ def __init__( # type: ignore[no-untyped-def]
self._dtype = np.dtype(dtype)
self._nbytes = int(self._dtype.itemsize * np.prod(self._shape))

def _create_or_open(name: str, size: int): # type: ignore[no-untyped-def]
try:
# Owner: leave registered because unlink() will unregister, and
# the tracker serves as safety net if the process crashes.
shm = SharedMemory(create=True, size=size, name=name)
owner = True
except FileExistsError:
# Reader: unregister because we only close(), never unlink().
shm = _unregister(SharedMemory(name=name))
owner = False
return shm, owner

if data_name is None or ctrl_name is None:
# Fallback: random names (old behavior) -> always owner
self._shm_data = SharedMemory(create=True, size=2 * self._nbytes)
Expand Down Expand Up @@ -290,6 +314,226 @@ def close(self) -> None:
pass


class CpuShmQueue(FrameChannel):
"""Multi-slot ring-buffer SHM channel for reliable delivery under load.

Delivery is reliable for a single writer *instance* (publishes are serialised
by ``_pub_lock``) feeding readers that keep up: every message lands in its own
slot with a monotonic seq. Two limits remain by design:

- ``_pub_lock`` is per-instance, so multiple writer instances/processes on the
*same* named segment are not mutually excluded and can race the seq counter.
The pubsub uses one writer instance per topic.
- A reader outpaced by more than ``slots`` messages loses the oldest (logged).
Size ``slots`` for the expected burst; a multi-reader ring cannot apply
backpressure.
"""

_HEADER_FIELDS = 3 # (seq, ts, length) per slot, all int64
_CTRL_SLOTS = 2 # (producer_seq, last_ts)
# Ring depth. Sized so a reader polling at ~1ms keeps up with realistic RPC
# bursts without overflow.
_DEFAULT_SLOTS = 256

def __init__(
self,
shape: tuple[int, ...],
dtype: DTypeLike = np.uint8,
*,
data_name: str | None = None,
ctrl_name: str | None = None,
slots: int = _DEFAULT_SLOTS,
) -> None:
self._shape = tuple(shape)
self._dtype = np.dtype(dtype)
self._frame_nbytes = int(self._dtype.itemsize * np.prod(self._shape))
self._slots = int(slots)
self._pub_lock = threading.Lock()

data_size = self._header_bytes + self._slots * self._frame_nbytes
ctrl_size = self._CTRL_SLOTS * 8

if data_name is None or ctrl_name is None:
self._shm_data = SharedMemory(create=True, size=data_size)
self._shm_ctrl = SharedMemory(create=True, size=ctrl_size)
self._own_data = self._own_ctrl = True
else:
# Track ownership per segment: under a first-attach race two processes
# can split ownership (one creates data, the other ctrl), and each must
# unlink the segment it created or it leaks.
self._shm_data, self._own_data = _create_or_open(data_name, data_size)
self._shm_ctrl, self._own_ctrl = _create_or_open(ctrl_name, ctrl_size)

self._ctrl: NDArray[np.int64] = np.ndarray(
(self._CTRL_SLOTS,), dtype=np.int64, buffer=self._shm_ctrl.buf
)
if self._own_ctrl:
self._ctrl[:] = 0

self._finalizer_data = (
weakref.finalize(self, _safe_unlink, self._shm_data.name)
if (_UNLINK_ON_GC and self._own_data)
else None
)
self._finalizer_ctrl = (
weakref.finalize(self, _safe_unlink, self._shm_ctrl.name)
if (_UNLINK_ON_GC and self._own_ctrl)
else None
)

@property
def _header_bytes(self) -> int:
return self._slots * self._HEADER_FIELDS * 8

def _map(self) -> tuple[NDArray[np.int64], NDArray[np.uint8]]:
"""Build fresh header/payload views over the data segment.

Views are transient (garbage-collected when the caller returns) so
``close()`` never trips over exported buffer pointers.
"""
buf = self._shm_data.buf
headers: NDArray[np.int64] = np.ndarray(
(self._slots, self._HEADER_FIELDS), dtype=np.int64, buffer=buf
)
payloads: NDArray[np.uint8] = np.ndarray(
(self._slots, self._frame_nbytes),
dtype=np.uint8,
buffer=buf,
offset=self._header_bytes,
)
return headers, payloads

@property
def device(self) -> str:
return "cpu"

@property
def shape(self) -> tuple[int, ...]:
return self._shape

@property
def dtype(self) -> np.dtype[Any]:
return self._dtype

def publish(self, frame: NDArray[Any], length: int | None = None) -> None:
assert isinstance(frame, np.ndarray)
assert frame.shape == self._shape and frame.dtype == self._dtype
src = np.frombuffer(np.ascontiguousarray(frame), dtype=np.uint8)
n = self._frame_nbytes if length is None else min(int(length), self._frame_nbytes)
headers, payloads = self._map()
with self._pub_lock:
seq = int(self._ctrl[0]) + 1
ts = int(time.time_ns())
slot = seq % self._slots
# Invalidate the slot before overwriting it: a reader still holding
# the previous occupant's seq sees 0 on its post-copy re-check and
# discards the torn payload instead of returning a mix of two
# messages. Then payload, metadata, and the new seq last, so a reader
# that observes the seq is guaranteed a fully-written slot.
headers[slot, 0] = 0
payloads[slot, :n] = src[:n]
headers[slot, 1] = ts
headers[slot, 2] = n
headers[slot, 0] = seq
# Publish globally: ts before seq (readers key off _ctrl[0]).
self._ctrl[1] = ts
self._ctrl[0] = seq
Comment thread
Dreamsorcerer marked this conversation as resolved.

def _read_slot(
self, headers: NDArray[np.int64], payloads: NDArray[np.uint8], want: int
) -> tuple[int, NDArray[np.uint8]] | None:
"""Return (ts, payload_copy) if the slot still holds ``want``, else None."""
slot = want % self._slots
for _ in range(3):
if int(headers[slot, 0]) != want:
return None # not yet written, or already overwritten
ts = int(headers[slot, 1])
n = int(headers[slot, 2])
if n < 0 or n > self._frame_nbytes:
continue # torn header; retry
payload = np.array(payloads[slot, :n], copy=True)
if int(headers[slot, 0]) == want:
return ts, payload # seq stable across the copy -> consistent
return None

def read(
self, last_seq: int = -1, require_new: bool = True
) -> tuple[int, int, NDArray[np.uint8] | None]:
current = int(self._ctrl[0])
if current <= 0:
return last_seq, int(self._ctrl[1]), None
if require_new:
if current <= last_seq:
return last_seq, int(self._ctrl[1]), None
# Clamp to the first real seq: seqs start at 1, so the ABC's default
# last_seq=-1 must not make want=0 (a phantom seq that inflates the
# outpaced-drop count and matches the zero-initialised slot 0).
want = max(1, last_seq + 1)
oldest = max(1, current - self._slots + 1)
if want < oldest:
logger.warning(
f"CpuShmQueue reader outpaced: dropping {oldest - want} message(s) "
f"(seq {want} -> {oldest}); increase slots or poll faster"
)
want = oldest
else:
want = current # newest available
headers, payloads = self._map()
# Skip any messages overwritten between the snapshot and our read.
while want <= current:
got = self._read_slot(headers, payloads, want)
if got is not None:
ts, payload = got
return want, ts, payload
want += 1
return last_seq, int(self._ctrl[1]), None

def descriptor(self) -> dict[str, Any]:
return {
"kind": "cpu_queue",
"shape": self._shape,
"dtype": self._dtype.str,
"frame_nbytes": self._frame_nbytes,
"slots": self._slots,
"data_name": self._shm_data.name,
"ctrl_name": self._shm_ctrl.name,
}

@classmethod
def attach(cls, desc: dict[str, Any]) -> "CpuShmQueue":
obj = object.__new__(cls)
obj._shape = tuple(desc["shape"])
obj._dtype = np.dtype(desc["dtype"])
obj._frame_nbytes = int(desc["frame_nbytes"])
obj._slots = int(desc["slots"])
obj._pub_lock = threading.Lock()
data_name = desc["data_name"]
ctrl_name = desc["ctrl_name"]
try:
obj._shm_data = _open_shm_with_retry(data_name)
obj._shm_ctrl = _open_shm_with_retry(ctrl_name)
except FileNotFoundError as e:
raise FileNotFoundError(
f"CPU IPC queue attach failed: control/data SHM not found "
f"(ctrl='{ctrl_name}', data='{data_name}'). "
f"Ensure the writer is running on the same host and the channel is alive."
) from e
obj._ctrl = np.ndarray((cls._CTRL_SLOTS,), dtype=np.int64, buffer=obj._shm_ctrl.buf)
obj._own_data = obj._own_ctrl = False
obj._finalizer_data = obj._finalizer_ctrl = None
return obj

def close(self) -> None:
self._shm_ctrl.close()
self._shm_data.close()
# Unlink each segment we created; a reader that created neither just drops
# its handles.
if self._own_ctrl:
_safe_unlink(self._shm_ctrl.name)
if self._own_data:
_safe_unlink(self._shm_data.name)
Comment thread
Dreamsorcerer marked this conversation as resolved.


class CPU_IPC_Factory:
"""Creates/attaches CPU shared-memory channels."""

Expand Down
Loading
Loading