From c05044db87754b230a488e7e9a34cef97053b7c9 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Fri, 3 Jul 2026 16:21:23 +0100 Subject: [PATCH 1/2] Add a queue to shmrpc --- dimos/protocol/pubsub/impl/shmpubsub.py | 23 +- dimos/protocol/pubsub/shm/ipc_factory.py | 251 ++++++++++++++++-- dimos/protocol/pubsub/shm/test_ipc_factory.py | 164 ++++++++++++ dimos/protocol/rpc/pubsubrpc.py | 3 + dimos/protocol/rpc/test_spec.py | 10 +- 5 files changed, 424 insertions(+), 27 deletions(-) create mode 100644 dimos/protocol/pubsub/shm/test_ipc_factory.py diff --git a/dimos/protocol/pubsub/impl/shmpubsub.py b/dimos/protocol/pubsub/impl/shmpubsub.py index 883afcdcc0..ec66e226fe 100644 --- a/dimos/protocol/pubsub/impl/shmpubsub.py +++ b/dimos/protocol/pubsub/impl/shmpubsub.py @@ -31,7 +31,7 @@ from dimos.protocol.pubsub.encoders import LCMEncoderMixin, PickleEncoderMixin from dimos.protocol.pubsub.impl.lcmpubsub import Topic -from dimos.protocol.pubsub.shm.ipc_factory import CpuShmChannel +from dimos.protocol.pubsub.shm.ipc_factory import CpuShmChannel, FrameChannel from dimos.protocol.pubsub.spec import PubSub from dimos.utils.logging_config import setup_logger @@ -61,6 +61,11 @@ class SharedMemoryPubSubBase(PubSub[str, Any]): - drop initial empty frame; synchronous local delivery; echo suppression """ + # Frame-channel implementation backing each topic. Streaming keeps the + # default double-buffered CpuShmChannel (latest-wins); subclasses that need + # every message delivered (e.g. ShmRPC) override this with CpuShmQueue. + _channel_class: type[FrameChannel] = CpuShmChannel + # Per-topic state # TODO: implement "is_cuda" below capacity, above cp class _TopicState: @@ -236,13 +241,21 @@ def _ensure_topic(self, topic: str) -> _TopicState: cap = int(self.config.default_capacity) def _names_for_topic(topic: str, capacity: int) -> tuple[str, str]: - # Python's SharedMemory requires names without a leading '/' - # Use shorter digest to avoid macOS shared memory name length limits - h = hashlib.blake2b(f"{topic}:{capacity}".encode(), digest_size=8).hexdigest() + # Python's SharedMemory requires names without a leading '/'. + # Fold the channel class name into the hash so two pubsubs using + # different layouts (streaming CpuShmChannel vs RPC CpuShmQueue) + # on the same topic get distinct segments rather than mmapping + # incompatible layouts over each other. + # Use a short digest to avoid macOS shared memory name length limits. + h = hashlib.blake2b( + f"{topic}:{capacity}:{self._channel_class.__name__}".encode(), digest_size=8 + ).hexdigest() return f"psm_{h}_data", f"psm_{h}_ctrl" data_name, ctrl_name = _names_for_topic(topic, cap) - ch = CpuShmChannel((cap + 20,), np.uint8, data_name=data_name, ctrl_name=ctrl_name) + ch = self._channel_class( + (cap + 20,), np.uint8, data_name=data_name, ctrl_name=ctrl_name + ) st = SharedMemoryPubSubBase._TopicState(ch, cap, None) self._topics[topic] = st return st diff --git a/dimos/protocol/pubsub/shm/ipc_factory.py b/dimos/protocol/pubsub/shm/ipc_factory.py index 39e3275c37..4fe666edf9 100644 --- a/dimos/protocol/pubsub/shm/ipc_factory.py +++ b/dimos/protocol/pubsub/shm/ipc_factory.py @@ -18,12 +18,19 @@ from multiprocessing import resource_tracker from multiprocessing.shared_memory import SharedMemory import os +import threading import time +from typing import Any import numpy as np +from numpy.typing import DTypeLike, NDArray + +from dimos.utils.logging_config import setup_logger _UNLINK_ON_GC = os.getenv("DIMOS_IPC_UNLINK_ON_GC", "0").lower() not in ("0", "false", "no") +logger = setup_logger() + def _unregister(shm: SharedMemory) -> SharedMemory: """Remove a SharedMemory segment from the resource tracker. @@ -55,11 +62,26 @@ def _open_shm_with_retry(name: str) -> SharedMemory: class FrameChannel(ABC): - """Single-slot 'freshest frame' IPC channel with a tiny control block. - - Double-buffered to avoid torn reads. - - Descriptor is JSON-safe; attach() reconstructs in another process. + """Shared-memory IPC channel carrying frames behind a tiny control block. + + Implementations range from a single-slot double-buffered 'freshest frame' + channel (CpuShmChannel) to a multi-slot ring buffer for reliable delivery + (CpuShmQueue). Descriptor is JSON-safe; attach() reconstructs in another + process. """ + @abstractmethod + def __init__( + self, + shape: tuple[int, ...], + dtype: DTypeLike = np.uint8, + *, + data_name: str | None = None, + ctrl_name: str | None = None, + ) -> None: + """Create (or attach by name) the channel's shared-memory segments.""" + ... + @property @abstractmethod def device(self) -> str: # "cpu" or "cuda" @@ -120,6 +142,20 @@ def _safe_unlink(name: str) -> None: pass +def _create_or_open(name: str, size: int) -> tuple[SharedMemory, bool]: + """Create a named SHM segment (owner) or attach to an existing one (reader).""" + try: + # Owner: leave registered because unlink() will unregister, and + # the tracker serves as safety net if the process crashes. + shm = SharedMemory(create=True, size=size, name=name) + owner = True + except FileExistsError: + # Reader: unregister because we only close(), never unlink(). + shm = _unregister(SharedMemory(name=name)) + owner = False + return shm, owner + + class CpuShmChannel(FrameChannel): def __init__( # type: ignore[no-untyped-def] self, @@ -133,18 +169,6 @@ def __init__( # type: ignore[no-untyped-def] self._dtype = np.dtype(dtype) self._nbytes = int(self._dtype.itemsize * np.prod(self._shape)) - def _create_or_open(name: str, size: int): # type: ignore[no-untyped-def] - try: - # Owner: leave registered because unlink() will unregister, and - # the tracker serves as safety net if the process crashes. - shm = SharedMemory(create=True, size=size, name=name) - owner = True - except FileExistsError: - # Reader: unregister because we only close(), never unlink(). - shm = _unregister(SharedMemory(name=name)) - owner = False - return shm, owner - if data_name is None or ctrl_name is None: # Fallback: random names (old behavior) -> always owner self._shm_data = SharedMemory(create=True, size=2 * self._nbytes) @@ -290,6 +314,203 @@ def close(self) -> None: pass +class CpuShmQueue(FrameChannel): + """Multi-slot ring-buffer SHM channel for reliable delivery under load.""" + + _HEADER_FIELDS = 3 # (seq, ts, length) per slot, all int64 + _CTRL_SLOTS = 2 # (producer_seq, last_ts) + # Ring depth. Sized so a reader polling at ~1ms keeps up with realistic RPC + # bursts without overflow. + _DEFAULT_SLOTS = 256 + + def __init__( + self, + shape: tuple[int, ...], + dtype: DTypeLike = np.uint8, + *, + data_name: str | None = None, + ctrl_name: str | None = None, + slots: int = _DEFAULT_SLOTS, + ) -> None: + self._shape = tuple(shape) + self._dtype = np.dtype(dtype) + self._frame_nbytes = int(self._dtype.itemsize * np.prod(self._shape)) + self._slots = int(slots) + self._pub_lock = threading.Lock() + + data_size = self._header_bytes + self._slots * self._frame_nbytes + ctrl_size = self._CTRL_SLOTS * 8 + + if data_name is None or ctrl_name is None: + self._shm_data = SharedMemory(create=True, size=data_size) + self._shm_ctrl = SharedMemory(create=True, size=ctrl_size) + self._is_owner = True + else: + self._shm_data, own_d = _create_or_open(data_name, data_size) + self._shm_ctrl, own_c = _create_or_open(ctrl_name, ctrl_size) + self._is_owner = own_d and own_c + + self._ctrl: NDArray[np.int64] = np.ndarray( + (self._CTRL_SLOTS,), dtype=np.int64, buffer=self._shm_ctrl.buf + ) + if self._is_owner: + self._ctrl[:] = 0 + + self._finalizer_data = ( + weakref.finalize(self, _safe_unlink, self._shm_data.name) + if (_UNLINK_ON_GC and self._is_owner) + else None + ) + self._finalizer_ctrl = ( + weakref.finalize(self, _safe_unlink, self._shm_ctrl.name) + if (_UNLINK_ON_GC and self._is_owner) + else None + ) + + @property + def _header_bytes(self) -> int: + return self._slots * self._HEADER_FIELDS * 8 + + def _map(self) -> tuple[NDArray[np.int64], NDArray[np.uint8]]: + """Build fresh header/payload views over the data segment. + + Views are transient (garbage-collected when the caller returns) so + ``close()`` never trips over exported buffer pointers. + """ + buf = self._shm_data.buf + headers: NDArray[np.int64] = np.ndarray( + (self._slots, self._HEADER_FIELDS), dtype=np.int64, buffer=buf + ) + payloads: NDArray[np.uint8] = np.ndarray( + (self._slots, self._frame_nbytes), + dtype=np.uint8, + buffer=buf, + offset=self._header_bytes, + ) + return headers, payloads + + @property + def device(self) -> str: + return "cpu" + + @property + def shape(self) -> tuple[int, ...]: + return self._shape + + @property + def dtype(self) -> np.dtype[Any]: + return self._dtype + + def publish(self, frame: NDArray[Any], length: int | None = None) -> None: + assert isinstance(frame, np.ndarray) + assert frame.shape == self._shape and frame.dtype == self._dtype + src = np.frombuffer(np.ascontiguousarray(frame), dtype=np.uint8) + n = self._frame_nbytes if length is None else min(int(length), self._frame_nbytes) + headers, payloads = self._map() + with self._pub_lock: + seq = int(self._ctrl[0]) + 1 + ts = int(time.time_ns()) + slot = seq % self._slots + # Payload first, metadata next, seq last: a reader that observes the + # new seq is then guaranteed to see a fully-written slot. + payloads[slot, :n] = src[:n] + headers[slot, 1] = ts + headers[slot, 2] = n + headers[slot, 0] = seq + # Publish globally: ts before seq (readers key off _ctrl[0]). + self._ctrl[1] = ts + self._ctrl[0] = seq + + def _read_slot( + self, headers: NDArray[np.int64], payloads: NDArray[np.uint8], want: int + ) -> tuple[int, NDArray[np.uint8]] | None: + """Return (ts, payload_copy) if the slot still holds ``want``, else None.""" + slot = want % self._slots + for _ in range(3): + if int(headers[slot, 0]) != want: + return None # not yet written, or already overwritten + ts = int(headers[slot, 1]) + n = int(headers[slot, 2]) + if n < 0 or n > self._frame_nbytes: + continue # torn header; retry + payload = np.array(payloads[slot, :n], copy=True) + if int(headers[slot, 0]) == want: + return ts, payload # seq stable across the copy -> consistent + return None + + def read( + self, last_seq: int = -1, require_new: bool = True + ) -> tuple[int, int, NDArray[np.uint8] | None]: + current = int(self._ctrl[0]) + if current <= 0: + return last_seq, int(self._ctrl[1]), None + if require_new: + if current <= last_seq: + return last_seq, int(self._ctrl[1]), None + want = last_seq + 1 + oldest = max(1, current - self._slots + 1) + if want < oldest: + logger.warning( + f"CpuShmQueue reader outpaced: dropping {oldest - want} message(s) " + f"(seq {want} -> {oldest}); increase slots or poll faster" + ) + want = oldest + else: + want = current # newest available + headers, payloads = self._map() + # Skip any messages overwritten between the snapshot and our read. + while want <= current: + got = self._read_slot(headers, payloads, want) + if got is not None: + ts, payload = got + return want, ts, payload + want += 1 + return last_seq, int(self._ctrl[1]), None + + def descriptor(self) -> dict[str, Any]: + return { + "kind": "cpu_queue", + "shape": self._shape, + "dtype": self._dtype.str, + "frame_nbytes": self._frame_nbytes, + "slots": self._slots, + "data_name": self._shm_data.name, + "ctrl_name": self._shm_ctrl.name, + } + + @classmethod + def attach(cls, desc: dict[str, Any]) -> "CpuShmQueue": + obj = object.__new__(cls) + obj._shape = tuple(desc["shape"]) + obj._dtype = np.dtype(desc["dtype"]) + obj._frame_nbytes = int(desc["frame_nbytes"]) + obj._slots = int(desc["slots"]) + obj._pub_lock = threading.Lock() + data_name = desc["data_name"] + ctrl_name = desc["ctrl_name"] + try: + obj._shm_data = _open_shm_with_retry(data_name) + obj._shm_ctrl = _open_shm_with_retry(ctrl_name) + except FileNotFoundError as e: + raise FileNotFoundError( + f"CPU IPC queue attach failed: control/data SHM not found " + f"(ctrl='{ctrl_name}', data='{data_name}'). " + f"Ensure the writer is running on the same host and the channel is alive." + ) from e + obj._ctrl = np.ndarray((cls._CTRL_SLOTS,), dtype=np.int64, buffer=obj._shm_ctrl.buf) + obj._is_owner = False + obj._finalizer_data = obj._finalizer_ctrl = None + return obj + + def close(self) -> None: + self._shm_ctrl.close() + self._shm_data.close() + if self._is_owner: + # Owner unlinks the segments; readers just drop their handle. + _safe_unlink(self._shm_ctrl.name) + _safe_unlink(self._shm_data.name) + + class CPU_IPC_Factory: """Creates/attaches CPU shared-memory channels.""" diff --git a/dimos/protocol/pubsub/shm/test_ipc_factory.py b/dimos/protocol/pubsub/shm/test_ipc_factory.py new file mode 100644 index 0000000000..3c05804fda --- /dev/null +++ b/dimos/protocol/pubsub/shm/test_ipc_factory.py @@ -0,0 +1,164 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the ring-buffer SHM channel (``CpuShmQueue``).""" + +import threading +import uuid + +import numpy as np + +from dimos.protocol.pubsub.shm.ipc_factory import CpuShmQueue + +CAP = 64 + + +def _frame(payload: bytes) -> np.ndarray: + frame = np.zeros((CAP,), dtype=np.uint8) + frame[: len(payload)] = np.frombuffer(payload, dtype=np.uint8) + return frame + + +def _publish(ch: CpuShmQueue, payload: bytes) -> None: + ch.publish(_frame(payload), length=len(payload)) + + +def _drain(ch: CpuShmQueue, last_seq: int = 0) -> tuple[list[tuple[int, bytes]], int]: + """Read every available message, returning [(seq, payload_bytes), ...].""" + out: list[tuple[int, bytes]] = [] + while True: + seq, _ts, view = ch.read(last_seq=last_seq, require_new=True) + if view is None: + break + last_seq = seq + out.append((seq, view.tobytes())) + return out, last_seq + + +def test_single_message() -> None: + ch = CpuShmQueue((CAP,), np.uint8, slots=8) + try: + _publish(ch, b"hello") + got, _ = _drain(ch) + assert got == [(1, b"hello")] + finally: + ch.close() + + +def test_empty_channel_returns_none() -> None: + ch = CpuShmQueue((CAP,), np.uint8, slots=8) + try: + _seq, _ts, view = ch.read(last_seq=0, require_new=True) + assert view is None + finally: + ch.close() + + +def test_sequential_exact_once() -> None: + """Every message published within one ring is delivered exactly once, in order.""" + ch = CpuShmQueue((CAP,), np.uint8, slots=8) + try: + for i in range(8): + _publish(ch, f"m{i}".encode()) + got, _ = _drain(ch) + assert [seq for seq, _ in got] == list(range(1, 9)) + assert [payload for _, payload in got] == [f"m{i}".encode() for i in range(8)] + finally: + ch.close() + + +def test_wraparound_when_reader_keeps_up() -> None: + """A reader that drains after each publish loses nothing across many wraps.""" + ch = CpuShmQueue((CAP,), np.uint8, slots=4) + try: + received: list[bytes] = [] + last = 0 + for i in range(12): # 3x the ring size + _publish(ch, f"m{i}".encode()) + got, last = _drain(ch, last) + received += [payload for _, payload in got] + assert received == [f"m{i}".encode() for i in range(12)] + finally: + ch.close() + + +def test_reader_outpaced_drops_oldest() -> None: + """When the ring overflows before a read, only the newest `slots` survive.""" + slots = 4 + ch = CpuShmQueue((CAP,), np.uint8, slots=slots) + try: + for i in range(2 * slots): # publish 8 into 4 slots without reading + _publish(ch, f"m{i}".encode()) + got, _ = _drain(ch) + # The oldest `slots` messages were overwritten; loss is visible as a + # sequence gap: the first delivered seq is slots+1, not 1. + assert [seq for seq, _ in got] == list(range(slots + 1, 2 * slots + 1)) + assert [payload for _, payload in got] == [ + f"m{i}".encode() for i in range(slots, 2 * slots) + ] + finally: + ch.close() + + +def test_concurrent_publishers_no_loss() -> None: + """`slots` threads each publishing once fill the ring with no loss or dupes.""" + slots = 32 + ch = CpuShmQueue((CAP,), np.uint8, slots=slots) + try: + threads = [ + threading.Thread(target=_publish, args=(ch, f"m{i}".encode())) for i in range(slots) + ] + for t in threads: + t.start() + for t in threads: + t.join() + + got, _ = _drain(ch) + assert len(got) == slots # exact-once: no message lost + assert sorted(seq for seq, _ in got) == list(range(1, slots + 1)) # unique seqs + assert {payload for _, payload in got} == {f"m{i}".encode() for i in range(slots)} + finally: + ch.close() + + +def test_two_instances_share_named_segment() -> None: + """Writer and reader in separate instances (as ShmRPC uses them) agree.""" + tag = uuid.uuid4().hex[:12] + data_name, ctrl_name = f"tq_{tag}_data", f"tq_{tag}_ctrl" + writer = CpuShmQueue((CAP,), np.uint8, data_name=data_name, ctrl_name=ctrl_name, slots=8) + reader = CpuShmQueue((CAP,), np.uint8, data_name=data_name, ctrl_name=ctrl_name, slots=8) + try: + for i in range(5): + _publish(writer, f"m{i}".encode()) + got, _ = _drain(reader) + assert [payload for _, payload in got] == [f"m{i}".encode() for i in range(5)] + finally: + reader.close() + writer.close() + + +def test_attach_roundtrip() -> None: + """A channel attached from a descriptor reads what the owner published.""" + writer = CpuShmQueue((CAP,), np.uint8, slots=8) + try: + reader = CpuShmQueue.attach(writer.descriptor()) + try: + _publish(writer, b"abc") + _publish(writer, b"defg") + got, _ = _drain(reader) + assert [payload for _, payload in got] == [b"abc", b"defg"] + finally: + reader.close() + finally: + writer.close() diff --git a/dimos/protocol/rpc/pubsubrpc.py b/dimos/protocol/rpc/pubsubrpc.py index abfd521666..56ceade436 100644 --- a/dimos/protocol/rpc/pubsubrpc.py +++ b/dimos/protocol/rpc/pubsubrpc.py @@ -30,6 +30,7 @@ from dimos.constants import LCM_MAX_CHANNEL_NAME_LENGTH from dimos.protocol.pubsub.impl.lcmpubsub import PickleLCM, Topic from dimos.protocol.pubsub.impl.shmpubsub import PickleSharedMemory +from dimos.protocol.pubsub.shm.ipc_factory import CpuShmQueue from dimos.protocol.pubsub.spec import PubSub from dimos.protocol.rpc.rpc_utils import deserialize_exception, serialize_exception from dimos.protocol.rpc.spec import DEFAULT_RPC_TIMEOUT, DEFAULT_RPC_TIMEOUTS, Args, RPCSpec @@ -322,6 +323,8 @@ def topicgen(self, name: str, req_or_res: bool) -> Topic: class ShmRPC(PubSubRPCMixin[str, Any], PickleSharedMemory): + _channel_class = CpuShmQueue + def __init__( self, rpc_timeouts: dict[str, float] | None = None, diff --git a/dimos/protocol/rpc/test_spec.py b/dimos/protocol/rpc/test_spec.py index f403ab1370..732cc4d2a7 100644 --- a/dimos/protocol/rpc/test_spec.py +++ b/dimos/protocol/rpc/test_spec.py @@ -147,9 +147,10 @@ def test_basic_sync_call(rpc_context, impl_name: str) -> None: unsub() -async def test_async_call() -> None: +@pytest.mark.parametrize("rpc_context, impl_name", testdata) +async def test_async_call(rpc_context, impl_name: str) -> None: """Test asynchronous RPC calls.""" - with lcm_rpc_context() as (server, client): + with rpc_context() as (server, client): # Serve the function unsub = server.serve_rpc(add_function, "add_async") @@ -354,11 +355,6 @@ def test_multiple_services(rpc_context, impl_name: str) -> None: @pytest.mark.skipif_macos_bug def test_concurrent_calls(rpc_context, impl_name: str) -> None: """Test making multiple concurrent RPC calls.""" - # Skip for SharedMemory - double-buffered architecture can't handle concurrent bursts - # The channel only holds 2 frames, so 1000 rapid concurrent responses overwrite each other - if impl_name == "shm": - pytest.skip("SharedMemory uses double-buffering; can't handle 1000 concurrent responses") - with rpc_context() as (server, client): # Serve a function that we'll call concurrently unsub = server.serve_rpc(add_function, "concurrent_add") From fec3887bc6e9165441a776b3aa08dc7e5a1a15e8 Mon Sep 17 00:00:00 2001 From: Sam Bull Date: Fri, 3 Jul 2026 18:27:21 +0100 Subject: [PATCH 2/2] Minor fixes --- dimos/protocol/pubsub/impl/shmpubsub.py | 23 ++++++--- dimos/protocol/pubsub/shm/ipc_factory.py | 51 ++++++++++++++----- dimos/protocol/pubsub/shm/test_ipc_factory.py | 7 ++- dimos/protocol/rpc/pubsubrpc.py | 16 +++++- 4 files changed, 75 insertions(+), 22 deletions(-) diff --git a/dimos/protocol/pubsub/impl/shmpubsub.py b/dimos/protocol/pubsub/impl/shmpubsub.py index ec66e226fe..e0863d5d4e 100644 --- a/dimos/protocol/pubsub/impl/shmpubsub.py +++ b/dimos/protocol/pubsub/impl/shmpubsub.py @@ -66,6 +66,11 @@ class SharedMemoryPubSubBase(PubSub[str, Any]): # every message delivered (e.g. ShmRPC) override this with CpuShmQueue. _channel_class: type[FrameChannel] = CpuShmChannel + # Extra keyword args passed to _channel_class(...) at construction (empty for + # the streaming channel; ShmRPC sets e.g. {"slots": N}). Folded into the + # segment name so distinct layouts never mmap over one another. + _channel_kwargs: dict[str, Any] = {} + # Per-topic state # TODO: implement "is_cuda" below capacity, above cp class _TopicState: @@ -242,19 +247,25 @@ def _ensure_topic(self, topic: str) -> _TopicState: def _names_for_topic(topic: str, capacity: int) -> tuple[str, str]: # Python's SharedMemory requires names without a leading '/'. - # Fold the channel class name into the hash so two pubsubs using - # different layouts (streaming CpuShmChannel vs RPC CpuShmQueue) - # on the same topic get distinct segments rather than mmapping - # incompatible layouts over each other. + # Fold the channel class name AND its layout kwargs (e.g. slots) + # into the hash so two pubsubs using different layouts (streaming + # CpuShmChannel vs RPC CpuShmQueue, or two ring sizes) on the same + # topic get distinct segments rather than mmapping incompatible + # layouts over each other. # Use a short digest to avoid macOS shared memory name length limits. + layout = f"{self._channel_class.__name__}:{sorted(self._channel_kwargs.items())}" h = hashlib.blake2b( - f"{topic}:{capacity}:{self._channel_class.__name__}".encode(), digest_size=8 + f"{topic}:{capacity}:{layout}".encode(), digest_size=8 ).hexdigest() return f"psm_{h}_data", f"psm_{h}_ctrl" data_name, ctrl_name = _names_for_topic(topic, cap) ch = self._channel_class( - (cap + 20,), np.uint8, data_name=data_name, ctrl_name=ctrl_name + (cap + 20,), + np.uint8, + data_name=data_name, + ctrl_name=ctrl_name, + **self._channel_kwargs, ) st = SharedMemoryPubSubBase._TopicState(ch, cap, None) self._topics[topic] = st diff --git a/dimos/protocol/pubsub/shm/ipc_factory.py b/dimos/protocol/pubsub/shm/ipc_factory.py index 4fe666edf9..9e8da675c8 100644 --- a/dimos/protocol/pubsub/shm/ipc_factory.py +++ b/dimos/protocol/pubsub/shm/ipc_factory.py @@ -315,7 +315,19 @@ def close(self) -> None: class CpuShmQueue(FrameChannel): - """Multi-slot ring-buffer SHM channel for reliable delivery under load.""" + """Multi-slot ring-buffer SHM channel for reliable delivery under load. + + Delivery is reliable for a single writer *instance* (publishes are serialised + by ``_pub_lock``) feeding readers that keep up: every message lands in its own + slot with a monotonic seq. Two limits remain by design: + + - ``_pub_lock`` is per-instance, so multiple writer instances/processes on the + *same* named segment are not mutually excluded and can race the seq counter. + The pubsub uses one writer instance per topic. + - A reader outpaced by more than ``slots`` messages loses the oldest (logged). + Size ``slots`` for the expected burst; a multi-reader ring cannot apply + backpressure. + """ _HEADER_FIELDS = 3 # (seq, ts, length) per slot, all int64 _CTRL_SLOTS = 2 # (producer_seq, last_ts) @@ -344,26 +356,28 @@ def __init__( if data_name is None or ctrl_name is None: self._shm_data = SharedMemory(create=True, size=data_size) self._shm_ctrl = SharedMemory(create=True, size=ctrl_size) - self._is_owner = True + self._own_data = self._own_ctrl = True else: - self._shm_data, own_d = _create_or_open(data_name, data_size) - self._shm_ctrl, own_c = _create_or_open(ctrl_name, ctrl_size) - self._is_owner = own_d and own_c + # Track ownership per segment: under a first-attach race two processes + # can split ownership (one creates data, the other ctrl), and each must + # unlink the segment it created or it leaks. + self._shm_data, self._own_data = _create_or_open(data_name, data_size) + self._shm_ctrl, self._own_ctrl = _create_or_open(ctrl_name, ctrl_size) self._ctrl: NDArray[np.int64] = np.ndarray( (self._CTRL_SLOTS,), dtype=np.int64, buffer=self._shm_ctrl.buf ) - if self._is_owner: + if self._own_ctrl: self._ctrl[:] = 0 self._finalizer_data = ( weakref.finalize(self, _safe_unlink, self._shm_data.name) - if (_UNLINK_ON_GC and self._is_owner) + if (_UNLINK_ON_GC and self._own_data) else None ) self._finalizer_ctrl = ( weakref.finalize(self, _safe_unlink, self._shm_ctrl.name) - if (_UNLINK_ON_GC and self._is_owner) + if (_UNLINK_ON_GC and self._own_ctrl) else None ) @@ -411,8 +425,12 @@ def publish(self, frame: NDArray[Any], length: int | None = None) -> None: seq = int(self._ctrl[0]) + 1 ts = int(time.time_ns()) slot = seq % self._slots - # Payload first, metadata next, seq last: a reader that observes the - # new seq is then guaranteed to see a fully-written slot. + # Invalidate the slot before overwriting it: a reader still holding + # the previous occupant's seq sees 0 on its post-copy re-check and + # discards the torn payload instead of returning a mix of two + # messages. Then payload, metadata, and the new seq last, so a reader + # that observes the seq is guaranteed a fully-written slot. + headers[slot, 0] = 0 payloads[slot, :n] = src[:n] headers[slot, 1] = ts headers[slot, 2] = n @@ -447,7 +465,10 @@ def read( if require_new: if current <= last_seq: return last_seq, int(self._ctrl[1]), None - want = last_seq + 1 + # Clamp to the first real seq: seqs start at 1, so the ABC's default + # last_seq=-1 must not make want=0 (a phantom seq that inflates the + # outpaced-drop count and matches the zero-initialised slot 0). + want = max(1, last_seq + 1) oldest = max(1, current - self._slots + 1) if want < oldest: logger.warning( @@ -498,16 +519,18 @@ def attach(cls, desc: dict[str, Any]) -> "CpuShmQueue": f"Ensure the writer is running on the same host and the channel is alive." ) from e obj._ctrl = np.ndarray((cls._CTRL_SLOTS,), dtype=np.int64, buffer=obj._shm_ctrl.buf) - obj._is_owner = False + obj._own_data = obj._own_ctrl = False obj._finalizer_data = obj._finalizer_ctrl = None return obj def close(self) -> None: self._shm_ctrl.close() self._shm_data.close() - if self._is_owner: - # Owner unlinks the segments; readers just drop their handle. + # Unlink each segment we created; a reader that created neither just drops + # its handles. + if self._own_ctrl: _safe_unlink(self._shm_ctrl.name) + if self._own_data: _safe_unlink(self._shm_data.name) diff --git a/dimos/protocol/pubsub/shm/test_ipc_factory.py b/dimos/protocol/pubsub/shm/test_ipc_factory.py index 3c05804fda..e90e661306 100644 --- a/dimos/protocol/pubsub/shm/test_ipc_factory.py +++ b/dimos/protocol/pubsub/shm/test_ipc_factory.py @@ -112,7 +112,12 @@ def test_reader_outpaced_drops_oldest() -> None: def test_concurrent_publishers_no_loss() -> None: - """`slots` threads each publishing once fill the ring with no loss or dupes.""" + """Threads sharing ONE instance publish concurrently with no loss or dupes. + + This exercises the per-instance ``_pub_lock`` (single-writer-instance thread + safety). It does not cover multiple writer *instances* over one segment, which + are not cross-process serialised -- see the ``CpuShmQueue`` class docstring. + """ slots = 32 ch = CpuShmQueue((CAP,), np.uint8, slots=slots) try: diff --git a/dimos/protocol/rpc/pubsubrpc.py b/dimos/protocol/rpc/pubsubrpc.py index 56ceade436..821ca98470 100644 --- a/dimos/protocol/rpc/pubsubrpc.py +++ b/dimos/protocol/rpc/pubsubrpc.py @@ -322,8 +322,17 @@ def topicgen(self, name: str, req_or_res: bool) -> Topic: return Topic(topic=topic) +# RPC messages are small control-plane payloads (pickled dicts), so the ring uses +# a modest per-slot capacity and a deep slot count to absorb concurrent bursts. +# Segment bytes = slots * (capacity + 20) ~= 16 MiB here; inheriting the 3.6 MB +# streaming frame would make each segment ~900 MiB. +_SHM_RPC_CAPACITY = 64 * 1024 +_SHM_RPC_SLOTS = 256 + + class ShmRPC(PubSubRPCMixin[str, Any], PickleSharedMemory): _channel_class = CpuShmQueue + _channel_kwargs = {"slots": _SHM_RPC_SLOTS} def __init__( self, @@ -334,7 +343,12 @@ def __init__( ) -> None: if rpc_timeouts is None: rpc_timeouts = dict(DEFAULT_RPC_TIMEOUTS) - PickleSharedMemory.__init__(self, prefer=prefer, **kwargs) + # Only the pubsub base consumes default_capacity; pop it so it doesn't + # also flow into PubSubRPCMixin.__init__ via **kwargs. + default_capacity = kwargs.pop("default_capacity", _SHM_RPC_CAPACITY) + PickleSharedMemory.__init__( + self, prefer=prefer, default_capacity=default_capacity, **kwargs + ) PubSubRPCMixin.__init__( self, rpc_timeouts=rpc_timeouts, default_rpc_timeout=default_rpc_timeout, **kwargs )