From a1ab537cfb01c378f95f5726aee3aba6f674305f Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 3 Jul 2026 00:32:11 +0300 Subject: [PATCH 01/10] *lio publishes pointclouds in IMU frame, no mention of body --- dimos/control/blueprints/mobile.py | 2 ++ .../sensors/lidar/fastlio2/cpp/flake.lock | 8 ++++---- .../sensors/lidar/fastlio2/cpp/flake.nix | 5 +++-- .../sensors/lidar/fastlio2/cpp/main.cpp | 4 +--- .../hardware/sensors/lidar/fastlio2/module.py | 9 +++------ .../sensors/lidar/pointlio/cpp/flake.lock | 20 +++++++++---------- .../sensors/lidar/pointlio/cpp/flake.nix | 4 +++- .../sensors/lidar/pointlio/cpp/main.cpp | 4 +--- .../hardware/sensors/lidar/pointlio/module.py | 9 +++------ .../robot/diy/alfred/blueprints/alfred_nav.py | 1 + .../navigation/unitree_go2_nav_3d.py | 10 +++++----- 11 files changed, 36 insertions(+), 40 deletions(-) diff --git a/dimos/control/blueprints/mobile.py b/dimos/control/blueprints/mobile.py index 710526b926..703fe5d6a8 100644 --- a/dimos/control/blueprints/mobile.py +++ b/dimos/control/blueprints/mobile.py @@ -140,6 +140,8 @@ def _flowbase_twist_base( "publish_free_paths": False, }, simple_planner={ + # FastLio2 publishes odom -> mid360_link (no separate body frame). + "body_frame": "mid360_link", "cell_size": 0.2, "obstacle_height_threshold": 0.15, "inflation_radius": 0.3, # FlowBase footprint smaller than G1's 0.5 diff --git a/dimos/hardware/sensors/lidar/fastlio2/cpp/flake.lock b/dimos/hardware/sensors/lidar/fastlio2/cpp/flake.lock index 34bf2f67bf..9435c81456 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/cpp/flake.lock +++ b/dimos/hardware/sensors/lidar/fastlio2/cpp/flake.lock @@ -37,16 +37,16 @@ "fast-lio": { "flake": false, "locked": { - "lastModified": 1781776629, - "narHash": "sha256-Ik3OwjSUZza/C545iPC4G/fzfJfFsdIo2GvplgN45hA=", + "lastModified": 1783023362, + "narHash": "sha256-h+cK9PIA8dQKEaWendfxHz1kqfEc5sqaaPRG0ma1SRc=", "owner": "dimensionalOS", "repo": "dimos-module-fastlio2", - "rev": "a32c9f599940a94595aa72868e2e4ab436a44b75", + "rev": "7841812385d4ee79ed981b90e9a411933b293754", "type": "github" }, "original": { "owner": "dimensionalOS", - "ref": "jeff/feat/fastlio-body-cloud", + "ref": "ivan/fix/body-cloud-imu-extrinsic", "repo": "dimos-module-fastlio2", "type": "github" } diff --git a/dimos/hardware/sensors/lidar/fastlio2/cpp/flake.nix b/dimos/hardware/sensors/lidar/fastlio2/cpp/flake.nix index c7c4319440..15312f3f9f 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/cpp/flake.nix +++ b/dimos/hardware/sensors/lidar/fastlio2/cpp/flake.nix @@ -13,8 +13,9 @@ flake = false; }; fast-lio = { - # v0.3.0-quiet-logs + get_body_cloud() (sensor-frame cloud). - url = "github:dimensionalOS/dimos-module-fastlio2?ref=jeff/feat/fastlio-body-cloud"; + # get_body_cloud()/get_body_cloud_down() with the IMU<-lidar extrinsic + # applied (PR #1); retarget to jeff/feat/fastlio-body-cloud once merged. + url = "github:dimensionalOS/dimos-module-fastlio2?ref=ivan/fix/body-cloud-imu-extrinsic"; flake = false; }; lcm-extended = { diff --git a/dimos/hardware/sensors/lidar/fastlio2/cpp/main.cpp b/dimos/hardware/sensors/lidar/fastlio2/cpp/main.cpp index 94294ec312..0c2207dbc1 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/cpp/main.cpp +++ b/dimos/hardware/sensors/lidar/fastlio2/cpp/main.cpp @@ -59,7 +59,6 @@ static FastLio* g_fastlio = nullptr; static std::string g_lidar_topic; static std::string g_odometry_topic; static std::string g_frame_id; // required via --frame_id -static std::string g_child_frame_id; // required via --child_frame_id static std::string g_sensor_frame_id; // required via --sensor_frame_id static float g_frequency = 10.0f; @@ -154,7 +153,7 @@ static void publish_odometry(const custom_messages::Odometry& odom, double times nav_msgs::Odometry msg; msg.header = make_header(g_frame_id, timestamp); - msg.child_frame_id = g_child_frame_id; + msg.child_frame_id = g_sensor_frame_id; msg.pose.pose.position.x = odom.pose.pose.position.x; msg.pose.pose.position.y = odom.pose.pose.position.y; @@ -399,7 +398,6 @@ int main(int argc, char** argv) { std::string lidar_ip = mod.arg("lidar_ip", "192.168.1.155"); g_frequency = mod.arg_float("frequency", 10.0f); g_frame_id = mod.arg_required("frame_id"); - g_child_frame_id = mod.arg_required("child_frame_id"); g_sensor_frame_id = mod.arg_required("sensor_frame_id"); float pointcloud_freq = mod.arg_float("pointcloud_freq", 5.0f); float odom_freq = mod.arg_float("odom_freq", 50.0f); diff --git a/dimos/hardware/sensors/lidar/fastlio2/module.py b/dimos/hardware/sensors/lidar/fastlio2/module.py index 6f0694ad4a..21c8363023 100644 --- a/dimos/hardware/sensors/lidar/fastlio2/module.py +++ b/dimos/hardware/sensors/lidar/fastlio2/module.py @@ -52,7 +52,7 @@ from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.navigation.cmu_nav.frames import FRAME_BODY, FRAME_ODOM +from dimos.navigation.cmu_nav.frames import FRAME_ODOM from dimos.spec import perception # Human-readable enums; the C++ binary maps these strings to FAST-LIO's int codes. @@ -71,12 +71,9 @@ class FastLio2Config(NativeModuleConfig): lidar_ip: str | None = Field(default_factory=lambda: os.environ.get("DIMOS_FASTLIO_LIDAR_IP")) frequency: float = 10.0 - # Odometry is published as frame_id (fixed) -> child_frame_id (moving body), + # Odometry is published as frame_id (fixed) -> sensor_frame_id (moving sensor), # and also broadcast on TF. The point cloud is stamped with sensor_frame_id - # (the lidar's own frame — get_body_cloud is the undistorted scan, not yet - # transformed into the body frame). frame_id: str = FRAME_ODOM - child_frame_id: str = FRAME_BODY sensor_frame_id: str = "mid360_link" # FAST-LIO internal processing rates @@ -150,7 +147,7 @@ def _on_odom_for_tf(self, msg: Odometry) -> None: self.tf.publish( Transform( frame_id=self.frame_id, - child_frame_id=self.config.child_frame_id, + child_frame_id=self.config.sensor_frame_id, translation=Vector3( msg.pose.position.x, msg.pose.position.y, diff --git a/dimos/hardware/sensors/lidar/pointlio/cpp/flake.lock b/dimos/hardware/sensors/lidar/pointlio/cpp/flake.lock index 3cb06c2284..6588a595e7 100644 --- a/dimos/hardware/sensors/lidar/pointlio/cpp/flake.lock +++ b/dimos/hardware/sensors/lidar/pointlio/cpp/flake.lock @@ -37,18 +37,18 @@ "fast-lio": { "flake": false, "locked": { - "lastModified": 1781782101, - "narHash": "sha256-2phOAdagFal8BTBEKxEbl3LDSx/7SNGVTFu0zYEXB1g=", - "owner": "dimensionalOS", - "repo": "dimos-module-fastlio2", - "rev": "288e357e5457723c1cce4d4060f76ed7f85b10d4", - "type": "github" + "lastModified": 1783023293, + "narHash": "sha256-isV6Jn3ACmpRFzUvm65cknT4B8bvswox93BBzEUHgUw=", + "ref": "main", + "rev": "82ef3a327347e2866e981bd95c8bece8b72903cf", + "revCount": 76, + "type": "git", + "url": "ssh://git@github.com/dimensionalOS/dimos-module-pointlio" }, "original": { - "owner": "dimensionalOS", - "ref": "pointlio", - "repo": "dimos-module-fastlio2", - "type": "github" + "ref": "main", + "type": "git", + "url": "ssh://git@github.com/dimensionalOS/dimos-module-pointlio" } }, "flake-utils": { diff --git a/dimos/hardware/sensors/lidar/pointlio/cpp/flake.nix b/dimos/hardware/sensors/lidar/pointlio/cpp/flake.nix index 0ef30ba768..04b26b88fb 100644 --- a/dimos/hardware/sensors/lidar/pointlio/cpp/flake.nix +++ b/dimos/hardware/sensors/lidar/pointlio/cpp/flake.nix @@ -13,7 +13,9 @@ flake = false; }; fast-lio = { - url = "github:dimensionalOS/dimos-module-fastlio2/pointlio"; + # Point-LIO fork (split out of dimos-module-fastlio2's pointlio branch). + # Repo is org-internal for now, hence git+ssh instead of github:. + url = "git+ssh://git@github.com/dimensionalOS/dimos-module-pointlio?ref=main"; flake = false; }; lcm-extended = { diff --git a/dimos/hardware/sensors/lidar/pointlio/cpp/main.cpp b/dimos/hardware/sensors/lidar/pointlio/cpp/main.cpp index 9c7bcb3714..91c5c53e0c 100644 --- a/dimos/hardware/sensors/lidar/pointlio/cpp/main.cpp +++ b/dimos/hardware/sensors/lidar/pointlio/cpp/main.cpp @@ -77,7 +77,6 @@ static std::vector parse_doubles(const std::string& csv) { static std::string g_lidar_topic; static std::string g_odometry_topic; static std::string g_frame_id; // required via --frame_id -static std::string g_child_frame_id; // required via --child_frame_id static std::string g_sensor_frame_id; // required via --sensor_frame_id static float g_frequency = 10.0f; @@ -158,7 +157,7 @@ static void publish_odometry(const custom_messages::Odometry& odom, double times nav_msgs::Odometry msg; msg.header = make_header(g_frame_id, timestamp); - msg.child_frame_id = g_child_frame_id; + msg.child_frame_id = g_sensor_frame_id; // Pose in the SLAM/sensor frame. msg.pose.pose.position.x = odom.pose.pose.position.x; @@ -382,7 +381,6 @@ int main(int argc, char** argv) { std::string lidar_ip = mod.arg("lidar_ip", "192.168.1.155"); g_frequency = mod.arg_float("frequency", 10.0f); g_frame_id = mod.arg_required("frame_id"); - g_child_frame_id = mod.arg_required("child_frame_id"); g_sensor_frame_id = mod.arg_required("sensor_frame_id"); float pointcloud_freq = mod.arg_float("pointcloud_freq", 5.0f); float odom_freq = mod.arg_float("odom_freq", 50.0f); diff --git a/dimos/hardware/sensors/lidar/pointlio/module.py b/dimos/hardware/sensors/lidar/pointlio/module.py index 9db92e9f6d..352768254b 100644 --- a/dimos/hardware/sensors/lidar/pointlio/module.py +++ b/dimos/hardware/sensors/lidar/pointlio/module.py @@ -58,7 +58,7 @@ from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 -from dimos.navigation.cmu_nav.frames import FRAME_BODY, FRAME_ODOM +from dimos.navigation.cmu_nav.frames import FRAME_ODOM from dimos.spec import perception # Human-readable enums; the C++ binary (main.cpp) maps these strings to @@ -81,12 +81,9 @@ class PointLioConfig(NativeModuleConfig): lidar_ip: str | None = Field(default_factory=lambda: os.environ.get("DIMOS_POINTLIO_LIDAR_IP")) frequency: float = 10.0 - # Odometry is published as frame_id (fixed) -> child_frame_id (moving body), + # Odometry is published as frame_id (fixed) -> sensor_frame_id (moving sensor), # and also broadcast on TF. The point cloud is stamped with sensor_frame_id - # (the lidar's own frame — get_body_cloud is the undistorted scan, not yet - # transformed into the body frame). frame_id: str = FRAME_ODOM - child_frame_id: str = FRAME_BODY sensor_frame_id: str = "mid360_link" # Point-LIO internal processing rates (Hz) @@ -186,7 +183,7 @@ def _on_odom_for_tf(self, msg: Odometry) -> None: self.tf.publish( Transform( frame_id=self.frame_id, - child_frame_id=self.config.child_frame_id, + child_frame_id=self.config.sensor_frame_id, translation=Vector3( msg.pose.position.x, msg.pose.position.y, diff --git a/dimos/robot/diy/alfred/blueprints/alfred_nav.py b/dimos/robot/diy/alfred/blueprints/alfred_nav.py index 16530ffadf..f7ec7f5485 100644 --- a/dimos/robot/diy/alfred/blueprints/alfred_nav.py +++ b/dimos/robot/diy/alfred/blueprints/alfred_nav.py @@ -40,6 +40,7 @@ "publish_free_paths": False, }, simple_planner={ + "body_frame": "mid360_link", "cell_size": 0.2, "obstacle_height_threshold": 0.15, "inflation_radius": 0.3, diff --git a/dimos/robot/unitree/go2/blueprints/navigation/unitree_go2_nav_3d.py b/dimos/robot/unitree/go2/blueprints/navigation/unitree_go2_nav_3d.py index b18782043c..ef71c385ba 100644 --- a/dimos/robot/unitree/go2/blueprints/navigation/unitree_go2_nav_3d.py +++ b/dimos/robot/unitree/go2/blueprints/navigation/unitree_go2_nav_3d.py @@ -47,11 +47,11 @@ def _render_path(msg: Any) -> Any: def _static_robot_body(rr: Any) -> list[Any]: - """Go2-shaped box on pointlio's body frame, counter-rotated for the lidar pitch.""" + """Go2-shaped box on pointlio's sensor frame, counter-rotated for the lidar pitch.""" return [ rr.Boxes3D(half_sizes=[0.35, 0.155, 0.2], colors=[(0, 255, 127)]), rr.Transform3D( - parent_frame="tf#/body", + parent_frame="tf#/mid360_link", rotation=rr.RotationAxisAngle(axis=(0, 1, 0), degrees=-45.0), ), ] @@ -66,9 +66,9 @@ def _static_robot_body(rr: Any) -> list[Any]: }, "memory_limit": "256MB", # base_link tf comes from the go2 internal odometry, which is not the map - # frame. Anchor the robot box to pointlio's body frame instead and hide the + # frame. Anchor the robot box to pointlio's sensor frame instead and hide the # camera frustum that rides base_link. - "static": {"world/tf/body": _static_robot_body}, + "static": {"world/tf/mid360_link": _static_robot_body}, "visual_override": { **rerun_config["visual_override"], "world/global_map": _render_global_map, @@ -91,7 +91,7 @@ def _static_robot_body(rr: Any) -> list[Any]: (GO2Connection, "odom", "odom_go2"), ] ), - PointLio.blueprint(child_frame_id="body"), + PointLio.blueprint(), RayTracingVoxelMap.blueprint( voxel_size=voxel_size, emit_every=1, From 73f9e678475521cc4c66ed55a0ad9a7b37fc0592 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 3 Jul 2026 00:50:33 +0300 Subject: [PATCH 02/10] go2 3d nav odom fix --- .../unitree/go2/blueprints/navigation/unitree_go2_nav_3d.py | 4 +++- dimos/robot/unitree/go2/connection.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/dimos/robot/unitree/go2/blueprints/navigation/unitree_go2_nav_3d.py b/dimos/robot/unitree/go2/blueprints/navigation/unitree_go2_nav_3d.py index ef71c385ba..b072c4d725 100644 --- a/dimos/robot/unitree/go2/blueprints/navigation/unitree_go2_nav_3d.py +++ b/dimos/robot/unitree/go2/blueprints/navigation/unitree_go2_nav_3d.py @@ -85,7 +85,9 @@ def _static_robot_body(rr: Any) -> list[Any]: unitree_go2_nav_3d = autoconnect( vis_module(viewer_backend=global_config.viewer, rerun_config=_nav_rerun_config), # "mcf" for stair traversal - GO2Connection.blueprint(lidar=False, camera=False, motion_mode="mcf").remappings( + GO2Connection.blueprint( + lidar=False, camera=False, motion_mode="mcf", odom_frame_id="go2_odom" + ).remappings( [ (GO2Connection, "lidar", "lidar_l1"), (GO2Connection, "odom", "odom_go2"), diff --git a/dimos/robot/unitree/go2/connection.py b/dimos/robot/unitree/go2/connection.py index 777fe5605a..b2f74bf327 100644 --- a/dimos/robot/unitree/go2/connection.py +++ b/dimos/robot/unitree/go2/connection.py @@ -70,6 +70,9 @@ class ConnectionConfig(ModuleConfig): motion_mode: str | None = None # Per-device AES-128 key (Go2 fw >=1.1.15); defaults from GlobalConfig. aes_128_key: str | None = Field(default_factory=lambda m: m["g"].unitree_aes_128_key) + # TF parent frame of the internal odometry (odom_frame_id -> base_link). + # Rename (e.g. "go2_odom") when another odom source owns the tree root + odom_frame_id: str = "world" class Go2ConnectionProtocol(Protocol): @@ -323,6 +326,7 @@ def _odom_to_tf(cls, odom: PoseStamped) -> list[Transform]: ] def _publish_tf(self, msg: PoseStamped) -> None: + msg.frame_id = self.config.odom_frame_id transforms = self._odom_to_tf(msg) self.tf.publish(*transforms) if self.odom.transport: From e57ce74ff3f54c242e7c7a58288dd9decff9db0c Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 3 Jul 2026 03:25:22 +0300 Subject: [PATCH 03/10] StreamTF: tf service over a recorded memory2 stream Unify replay tf with the live service: StreamTF(MultiTBuffer, TFSpec) mirrors PubSubTF, pulling windows from a recorded tf stream on demand instead of receiving pushed messages. Lookups span buffer_size backward (or an explicit time_tolerance) plus forward_tolerance ahead; a cache miss prefetches cache_span past the query window and evicts everything first, so chronological replay costs one db query per cache_span. - TFLookup protocol (read side) + mypy conformance checks - get_pose hoisted from PubSubTF to TFSpec - MultiTBuffer: None tolerance resolves to buffer_size explicitly - map global: registration via tf stream (never Observation poses), --frame auto-detects world/map/odom via probe lookups, fail-fast when the cloud frame can't be resolved - tf lookup tests consolidated into a grid: live MultiTBuffer vs StreamTF over memory/sqlite stores --- dimos/mapping/utils/cli/map.py | 180 +++++++++++----- dimos/memory2/tf.py | 137 ++++++++++++ dimos/protocol/tf/test_tf.py | 380 +++++++++++++++++---------------- dimos/protocol/tf/tf.py | 71 ++++-- docs/usage/transforms.md | 2 +- 5 files changed, 511 insertions(+), 259 deletions(-) create mode 100644 dimos/memory2/tf.py diff --git a/dimos/mapping/utils/cli/map.py b/dimos/mapping/utils/cli/map.py index cae3e04114..91540afaf4 100644 --- a/dimos/mapping/utils/cli/map.py +++ b/dimos/mapping/utils/cli/map.py @@ -18,7 +18,7 @@ import math from pathlib import Path import subprocess -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import rerun as rr import rerun.blueprint as rrb @@ -31,6 +31,7 @@ from dimos.mapping.loop_closure.pgo import PoseGraph from dimos.memory2.stream import Stream from dimos.memory2.type.observation import Observation + from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 @@ -40,6 +41,24 @@ # labels never overlap the boxes. MARKER_STEM = 1.0 +# Sentinel a register resolver returns when a frame has no usable registration +# source and must be dropped (None means "already in the world frame"). +_DROP_FRAME = cast("Transform", object()) + +# Conventional world frames tried in order when --frame isn't given. +_WORLD_FRAMES = ("world", "map", "odom") + + +def _detect_world(tf_buf: Any, cloud_frame: str, ts: float) -> str | None: + """Pick the first conventional world frame that resolves the cloud frame via tf.""" + if cloud_frame in _WORLD_FRAMES: + return cloud_frame + if tf_buf is not None: + for cand in _WORLD_FRAMES: + if tf_buf.get(cand, cloud_frame, time_point=ts) is not None: + return cand + return None + def _log_markers( prefix: str, @@ -97,34 +116,21 @@ def _accumulate( block_count: int, device: str, graph: PoseGraph | None = None, - world_frame: bool = True, + register: Callable[[Observation[Any]], Transform | None] | None = None, carve_columns: bool = False, progress_cb: Callable[[Observation[Any]], None] | None = None, ) -> PointCloud2 | None: """Accumulate a voxel map from `obs_iter`, optionally PGO-correcting each frame. - By default the clouds are assumed already world-registered (the go2/fastlio - path) — only the PGO correction is applied, if any. Set ``world_frame=False`` - (the ``--use-tf`` path) when each frame's cloud is in the sensor/body frame - and must be registered into the world via its per-frame pose. + ``register`` maps each observation to the transform lifting its cloud into + the world frame: ``None`` when the cloud is already world-registered, or + ``_DROP_FRAME`` when it has no registration source (the frame is skipped). + With ``register=None`` all clouds are assumed world-registered. Returns the final ``PointCloud2`` (or ``None`` if the input was empty). Disposal of the underlying ``VoxelGrid`` is handled by ``VoxelMapTransformer``. """ from dimos.mapping.voxels import VoxelMapTransformer - from dimos.msgs.geometry_msgs.Quaternion import Quaternion - from dimos.msgs.geometry_msgs.Transform import Transform - from dimos.msgs.geometry_msgs.Vector3 import Vector3 - - def _pose_tf(obs: Observation[Any]) -> Transform: - pose = obs.pose - assert pose is not None - return Transform( - translation=Vector3(pose.position.x, pose.position.y, pose.position.z), - rotation=Quaternion( - pose.orientation.x, pose.orientation.y, pose.orientation.z, pose.orientation.w - ), - ) def prepared() -> Iterable[Observation[PointCloud2]]: for obs in obs_iter: @@ -132,14 +138,14 @@ def prepared() -> Iterable[Observation[PointCloud2]]: progress_cb(obs) if len(obs.data) == 0: continue - # body->world via the per-frame pose, unless the clouds are already - # world-registered (go2 default). graph adds the PGO correction on top - # (correction ∘ pose), applied after the pose. + # sensor->world via `register`, unless the clouds are already + # world-registered. graph adds the PGO correction on top + # (correction ∘ tf), applied after the registration. tf: Transform | None = None - if not world_frame: - if obs.pose is None: + if register is not None: + tf = register(obs) + if tf is _DROP_FRAME: continue - tf = _pose_tf(obs) if graph is not None: if obs.pose_tuple is None: continue @@ -328,11 +334,19 @@ def main( None, "--out", help="Output .rrd path (default: ./.rrd)" ), no_gui: bool = typer.Option(False, "--no-gui", help="Write the .rrd but don't launch rerun"), - use_tf: bool = typer.Option( - False, - "--use-tf", - help="Clouds are in the sensor/body frame; register each by its per-frame pose. " - "By default clouds are assumed already world-registered (e.g. go2/fastlio).", + frame: str | None = typer.Option( + None, + "--frame", + help="World frame to register clouds into. Default: auto-detect — the " + "first of 'world', 'map', 'odom' that resolves the cloud frame via the " + "dataset's tf stream. Clouds whose frame_id differs from it are " + "registered via tf; clouds already in it pass through verbatim.", + ), + tf_tolerance: float | None = typer.Option( + None, + "--tf-tolerance", + help="Max |Δts| (s) for tf lookups; default unlimited (nearest message), " + "which also serves static/rarely-published transforms", ), carve: bool = typer.Option( False, @@ -412,44 +426,102 @@ def main( total = lidar.count() - # Spatial dedup: bucket frames by 3D cell using the raw pose, keep the - # latest per cell. Shared by raw and PGO rebuilds. Doesn't touch obs.data - # so it stays cheap (no pointcloud loading). With pgo_tol<=0 the bucketing - # is disabled and every posed frame is kept (keyed by index). - seen: dict[Any, Observation[Any]] = {} - for i, obs in enumerate(lidar): + # Register clouds into the world frame via the dataset's tf stream. Clouds + # already stamped with the world frame pass through verbatim; sensor-frame + # clouds with no tf lookup are dropped. Stored per-frame poses are never + # used for registration — only as trajectory metadata (dedup/path) when + # the tf stream can't provide a position. + from dimos.memory2.tf import StreamTF + + tf_buf = StreamTF.from_store(store) + # Streams are homogeneous: read the cloud frame from the first observation. + first_obs = next(iter(lidar), None) + cloud_frame: str | None = first_obs.data.frame_id if first_obs is not None else None + + world = frame + if world is None and first_obs is not None and cloud_frame is not None: + world = _detect_world(tf_buf, cloud_frame, first_obs.ts) + if world is None: + frames = tf_buf.get_frames() if tf_buf is not None else set() + known = ", ".join(sorted(frames)) or "dataset has no tf stream" + raise typer.BadParameter( + f"none of {', '.join(_WORLD_FRAMES)} resolves {cloud_frame!r} clouds; " + f"pass --frame (tf frames: {known})", + param_hint="--frame", + ) + if world is None: + world = "world" # empty lidar stream; the frame is moot + + if first_obs is not None and cloud_frame is not None and cloud_frame != world: + # Fail fast when registration is impossible: probe the first cloud's + # timestamp (unbounded tolerance — "possible at all", not "in range"). + probe = ( + tf_buf.get(world, cloud_frame, time_point=first_obs.ts) if tf_buf is not None else None + ) + if probe is None: + frames = tf_buf.get_frames() if tf_buf is not None else set() + known = ", ".join(sorted(frames)) or "dataset has no tf stream" + raise typer.BadParameter( + f"cannot register {cloud_frame!r} clouds into {world!r} (tf frames: {known})", + param_hint="--frame", + ) + print(f"registering clouds {world!r} ← {cloud_frame!r} via tf") + elif cloud_frame is not None: + print(f"clouds already in world frame {world!r}; accumulating verbatim") + + def _register(obs: Observation[Any]) -> Transform | None: + cf = obs.data.frame_id + if cf == world: + return None + if tf_buf is None: + return _DROP_FRAME + tf = tf_buf.get(world, cf, time_point=obs.ts, time_tolerance=tf_tolerance) + return tf if tf is not None else _DROP_FRAME + + def _position(obs: Observation[Any]) -> tuple[float, float, float] | None: + """Trajectory position for dedup/path: tf lookup, else the stored pose.""" + if tf_buf is not None and cloud_frame is not None and cloud_frame != world: + tf = tf_buf.get(world, cloud_frame, time_point=obs.ts, time_tolerance=tf_tolerance) + if tf is not None: + return (tf.translation.x, tf.translation.y, tf.translation.z) pose = obs.pose - if pose is None: - continue # Reject placeholder poses: zero translation OR uninitialized rotation. # Same condition as pgo_keyframes so dedup and PGO see the same frames. - if pose.position.is_zero() or pose.orientation.is_zero(): + if pose is not None and not (pose.position.is_zero() or pose.orientation.is_zero()): + return (pose.position.x, pose.position.y, pose.position.z) + return None + + # Spatial dedup: bucket frames by 3D cell using the trajectory position, + # keep the latest per cell. Shared by raw and PGO rebuilds. Doesn't touch + # obs.data so it stays cheap (no pointcloud loading). With pgo_tol<=0 the + # bucketing is disabled and every positioned frame is kept (keyed by index). + seen: dict[Any, tuple[Observation[Any], tuple[float, float, float]]] = {} + for i, obs in enumerate(lidar): + pos = _position(obs) + if pos is None: continue if pgo_tol > 0: - t = pose.position # math.floor so negative coords bucket consistently; int() truncates # toward zero and silently folds -0.5 and 0.5 into the same cell. key: Any = ( - math.floor(t.x / pgo_tol), - math.floor(t.y / pgo_tol), - math.floor(t.z / pgo_tol), + math.floor(pos[0] / pgo_tol), + math.floor(pos[1] / pgo_tol), + math.floor(pos[2] / pgo_tol), ) else: key = i - seen[key] = obs + seen[key] = (obs, pos) n_kept = len(seen) pct = 100 * n_kept / total if total else 0 if pgo_tol > 0: print(f"dedup: kept [{n_kept}/{total}] frames ({pct:.1f}%) at tol={pgo_tol}m") else: - print(f"dedup: disabled, kept all [{n_kept}/{total}] posed frames") + print(f"dedup: disabled, kept all [{n_kept}/{total}] positioned frames") # Dict insertion order = lidar iteration order = chronological. - # `seen` only contains entries with non-None poses (filtered above). - path: list[tuple[float, float, float]] = [ - (p[0], p[1], p[2]) for obs in seen.values() if (p := obs.pose_tuple) is not None - ] + kept = [obs for obs, _ in seen.values()] + path: list[tuple[float, float, float]] = [pos for _, pos in seen.values()] pgo_map = None pgo_path: list[tuple[float, float, float]] = [] @@ -465,12 +537,12 @@ def main( ] pgo_map = _accumulate( - seen.values(), + kept, voxel=voxel, block_count=block_count, device=device, graph=graph, - world_frame=not use_tf, + register=_register, carve_columns=carve, progress_cb=progress(n_kept, "pgo pass 2 (rebuilding)"), ) @@ -484,18 +556,18 @@ def main( block_count=block_count, device=device, graph=graph, - world_frame=not use_tf, + register=_register, carve_columns=carve, progress_cb=progress(total, "full pgo (rebuilding)"), ) # Raw map: same dedup'd frames, no PGO correction. global_map = _accumulate( - seen.values(), + kept, voxel=voxel, block_count=block_count, device=device, - world_frame=not use_tf, + register=_register, carve_columns=carve, progress_cb=progress(n_kept, "reconstructing global map"), ) diff --git a/dimos/memory2/tf.py b/dimos/memory2/tf.py new file mode 100644 index 0000000000..e5a1d59227 --- /dev/null +++ b/dimos/memory2/tf.py @@ -0,0 +1,137 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TF service backed by a recorded ``tf`` stream.""" + +from __future__ import annotations + +import math +from typing import TYPE_CHECKING, Any, cast + +from dimos.memory2.stream import Stream +from dimos.msgs.tf2_msgs.TFMessage import TFMessage +from dimos.protocol.tf.tf import MultiTBuffer, TFConfig, TFSpec + +if TYPE_CHECKING: + from dimos.msgs.geometry_msgs.Transform import Transform + from dimos.protocol.tf.tf import TFLookup + + +class StreamTFConfig(TFConfig): + stream: Stream[TFMessage] | None = ( + None # Required field but needs default for config inheritance + ) + # Prefetch span (s) cached past a missed query window, so chronological + # replay costs one db query per cache_span of progress. Also the cache + # size bound: a miss evicts everything before re-caching. + cache_span: float = 300.0 + + +class StreamTF(MultiTBuffer, TFSpec): + """A tf service whose backend is a recorded memory2 ``tf`` stream. + + The read-side mirror of :class:`~dimos.protocol.tf.tf.PubSubTF`: the same + :class:`MultiTBuffer` cache and lookup API, but ingestion pulls windows + from the stream on demand instead of receiving pushed messages. + + Lookups reach as far as they would against the live service: with no + explicit ``time_tolerance`` a query window spans ``buffer_size`` seconds + backward (what a live buffer would still hold) plus ``forward_tolerance`` + ahead — the recorded-time analog of the live wall-clock wait for future + transforms, which is why lookups here never block. A cache miss fetches + the window plus ``cache_span`` beyond it in one query. The cache pins the + underlying :class:`MultiTBuffer` to infinite retention: insert-time + pruning would silently delete data the cache still claims to hold, so + eviction is explicit (miss → full clear → re-cache). + """ + + config: StreamTFConfig + + def __init__(self, stream: Stream[TFMessage] | None = None, **kwargs: Any) -> None: + if stream is not None: + kwargs["stream"] = stream + TFSpec.__init__(self, **kwargs) + MultiTBuffer.__init__(self, buffer_size=math.inf) + + if self.config.stream is None: + raise ValueError("Stream configuration is missing") + self.stream = self.config.stream + + self._covered: tuple[float, float] | None = None + + @classmethod + def from_store(cls, store: Any, stream: str = "tf") -> StreamTF | None: + if stream not in store.list_streams(): + return None + return cls(store.stream(stream, TFMessage)) + + def publish(self, *args: Transform) -> None: + raise NotImplementedError("StreamTF is a read-only replay service.") + + def publish_static(self, *args: Transform) -> None: + raise NotImplementedError("StreamTF is a read-only replay service.") + + def _load(self, lo: float, hi: float) -> None: + # at() windows are boundary-inclusive; from/to_timestamp are strict and + # would skip messages stamped exactly at the stream's first timestamp. + for obs in self.stream.at((lo + hi) / 2, (hi - lo) / 2): + self.receive_transform(*obs.data.transforms) + + def _ensure(self, lo: float, hi: float) -> None: + """Serve ``[lo, hi]`` from the cache, else re-cache ``[lo, hi + cache_span]``. + + The prefetch past ``hi`` makes chronological replay cost one db query + per ``cache_span`` of progress. A miss evicts everything first — a full + clear (not partial pruning) keeps ``_covered`` truthful. + """ + if self._covered is not None: + clo, chi = self._covered + if clo <= lo and hi <= chi: + return + with self._cv: + self.buffers.clear() + self._load(lo, hi + self.config.cache_span) + self._covered = (lo, hi + self.config.cache_span) + + def get( + self, + parent_frame: str, + child_frame: str, + time_point: float | None = None, + time_tolerance: float | None = None, + *, + forward_tolerance: float = 0.0, + ) -> Transform | None: + tp = time_point + if tp is None: + last = next(iter(self.stream.order_by("ts", desc=True).limit(1)), None) + tp = last.ts if last is not None else None + if tp is not None: + back = time_tolerance if time_tolerance is not None else self.config.buffer_size + fwd = time_tolerance if time_tolerance is not None else forward_tolerance + self._ensure(tp - back, tp + fwd) + # The recorded-time lookahead above stands in for the live wall-clock + # wait, so the base lookup must never block. + return super().get( + parent_frame, + child_frame, + time_point, + time_tolerance, + forward_tolerance=0.0, + ) + + +if TYPE_CHECKING: + # mypy conformance check: StreamTF satisfies the read-side tf protocol. + _lookup_impl: TFLookup = cast("StreamTF", None) diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index 281e99c0d7..24951be375 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -20,11 +20,15 @@ import pytest +from dimos.memory2.store.memory import MemoryStore +from dimos.memory2.store.sqlite import SqliteStore +from dimos.memory2.tf import StreamTF from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.protocol.tf.tf import TF, MultiTBuffer, TBuffer +from dimos.msgs.tf2_msgs.TFMessage import TFMessage +from dimos.protocol.tf.tf import TF, MultiTBuffer, TBuffer, TFLookup # from https://foxglove.dev/blog/understanding-ros-transforms @@ -310,68 +314,6 @@ def test_graph(self) -> None: print(ttbuffer.graph()) - def test_get_latest_transform(self) -> None: - ttbuffer = MultiTBuffer() - - # Add multiple transforms - for i in range(3): - transform = Transform( - translation=Vector3(float(i), 0.0, 0.0), - frame_id="world", - child_frame_id="robot", - ts=time.time() + i * 0.1, - ) - ttbuffer.receive_transform(transform) - time.sleep(0.01) - - # Get latest transform - latest = ttbuffer.get("world", "robot") - assert latest is not None - assert latest.translation.x == 2.0 - - def test_get_transform_at_time(self) -> None: - ttbuffer = MultiTBuffer() - base_time = time.time() - - # Add transforms at known times - for i in range(5): - transform = Transform( - translation=Vector3(float(i), 0.0, 0.0), - frame_id="world", - child_frame_id="robot", - ts=base_time + i * 0.5, - ) - ttbuffer.receive_transform(transform) - - # Get transform closest to middle time - middle_time = base_time + 1.25 # Should be closest to i=2 (t=1.0) or i=3 (t=1.5) - result = ttbuffer.get("world", "robot", time_point=middle_time) - assert result is not None - # At t=1.25, it's equidistant from i=2 (t=1.0) and i=3 (t=1.5) - # The implementation picks the later one when equidistant - assert result.translation.x == 3.0 - - def test_time_tolerance(self) -> None: - ttbuffer = MultiTBuffer() - base_time = time.time() - - # Add single transform - transform = Transform( - translation=Vector3(1.0, 0.0, 0.0), - frame_id="world", - child_frame_id="robot", - ts=base_time, - ) - ttbuffer.receive_transform(transform) - - # Within tolerance - result = ttbuffer.get("world", "robot", time_point=base_time + 0.1, time_tolerance=0.2) - assert result is not None - - # Outside tolerance - result = ttbuffer.get("world", "robot", time_point=base_time + 0.5, time_tolerance=0.1) - assert result is None - def test_forward_tolerance_returns_when_buffer_fills(self) -> None: ttbuffer = MultiTBuffer() base_time = time.time() @@ -465,39 +407,6 @@ def publish_after_delay() -> None: assert result.translation.x == 1.0 assert result.translation.y == 2.0 - def test_same_frame_returns_identity(self) -> None: - ttbuffer = MultiTBuffer() - - # Empty buffer: same-frame lookup still returns identity - result = ttbuffer.get("base_link", "base_link") - assert result is not None - assert result.frame_id == "base_link" - assert result.child_frame_id == "base_link" - assert result.translation.x == 0.0 - assert result.translation.y == 0.0 - assert result.translation.z == 0.0 - assert result.rotation.x == 0.0 - assert result.rotation.y == 0.0 - assert result.rotation.z == 0.0 - assert result.rotation.w == 1.0 - - # Same behavior when the frame happens to exist in the buffer - ttbuffer.receive_transform( - Transform(frame_id="world", child_frame_id="base_link", ts=time.time()) - ) - result = ttbuffer.get("world", "world") - assert result is not None - assert result.frame_id == "world" - assert result.child_frame_id == "world" - assert result.rotation.w == 1.0 - - def test_nonexistent_frame_pair(self) -> None: - ttbuffer = MultiTBuffer() - - # Try to get transform for non-existent frame pair - result = ttbuffer.get("foo", "bar") - assert result is None - def test_get_transform_search_direct(self) -> None: ttbuffer = MultiTBuffer() base_time = time.time() @@ -705,103 +614,210 @@ def test_string_representations(self) -> None: assert "TBuffer(world -> robot2, 1 msgs" in ttbuffer_str assert "TBuffer(robot1 -> sensor, 1 msgs" in ttbuffer_str - def test_get_with_transform_chain_composition(self) -> None: - ttbuffer = MultiTBuffer() - base_time = time.time() - # Create transform chain: world -> robot -> sensor - # world -> robot: translate by (1, 0, 0) - transform1 = Transform( - translation=Vector3(1.0, 0.0, 0.0), - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), # Identity - frame_id="world", - child_frame_id="robot", - ts=base_time, - ) - - # robot -> sensor: translate by (0, 2, 0) and rotate 90 degrees around Z - import math +# --- Grid tests: every get() scenario below runs against the live service and +# --- against StreamTF replaying a recorded tf stream (memory- and sqlite-backed). - # 90 degrees around Z: quaternion (0, 0, sin(45°), cos(45°)) - transform2 = Transform( - translation=Vector3(0.0, 2.0, 0.0), - rotation=Quaternion(0.0, 0.0, math.sin(math.pi / 4), math.cos(math.pi / 4)), - frame_id="robot", - child_frame_id="sensor", - ts=base_time, - ) +T0 = 1_700_000_000.0 - ttbuffer.receive_transform(transform1, transform2) - # Get composed transform from world to sensor - result = ttbuffer.get("world", "sensor") - assert result is not None - - # The composed transform should: - # 1. Apply world->robot translation: (1, 0, 0) - # 2. Apply robot->sensor translation in robot frame: (0, 2, 0) - # Total translation: (1, 2, 0) - assert abs(result.translation.x - 1.0) < 1e-6 - assert abs(result.translation.y - 2.0) < 1e-6 - assert abs(result.translation.z - 0.0) < 1e-6 - - # Rotation should be 90 degrees around Z (same as transform2) - assert abs(result.rotation.x - 0.0) < 1e-6 - assert abs(result.rotation.y - 0.0) < 1e-6 - assert abs(result.rotation.z - math.sin(math.pi / 4)) < 1e-6 - assert abs(result.rotation.w - math.cos(math.pi / 4)) < 1e-6 - - # Frame IDs should be correct - assert result.frame_id == "world" - assert result.child_frame_id == "sensor" +def _t(parent: str, child: str, x: float, ts: float) -> Transform: + return Transform( + frame_id=parent, + child_frame_id=child, + translation=Vector3(x, 0.0, 0.0), + ts=ts, + ) - def test_get_with_longer_transform_chain(self) -> None: - ttbuffer = MultiTBuffer() - base_time = time.time() - # Create longer chain: world -> base -> arm -> hand - # Each adds a translation along different axes - transforms = [ +@pytest.fixture(params=["live", "stream_memory", "stream_sqlite"]) +def make_tf(request, tmp_path): # type: ignore[no-untyped-def] + """Builder fixture: feed it transforms, get back a TFLookup over them.""" + stores = [] + + def build(*transforms: Transform) -> TFLookup: + if request.param == "live": + buf = MultiTBuffer() + buf.receive_transform(*transforms) + return buf + store = ( + MemoryStore() + if request.param == "stream_memory" + else SqliteStore(path=str(tmp_path / "tf.db")) + ) + stores.append(store) + stream = store.stream("tf", TFMessage) + for t in transforms: + stream.append(TFMessage(t), ts=t.ts, pose=None) + return StreamTF(store.stream("tf", TFMessage)) + + yield build + for store in stores: + store.stop() + + +class TestLookupGrid: + """get() scenarios that must answer identically live and over a recording.""" + + def test_latest(self, make_tf) -> None: # type: ignore[no-untyped-def] + tf = make_tf(*(_t("world", "robot", float(i), T0 + i * 0.1) for i in range(3))) + got = tf.get("world", "robot") + assert got is not None + assert got.translation.x == 2.0 + + def test_nearest_in_time(self, make_tf) -> None: # type: ignore[no-untyped-def] + tf = make_tf(*(_t("world", "robot", float(i), T0 + i * 0.5) for i in range(5))) + got = tf.get("world", "robot", time_point=T0 + 1.25) + assert got is not None + # Equidistant between i=2 (t=1.0) and i=3 (t=1.5) — the later one wins. + assert got.translation.x == 3.0 + + def test_inverse(self, make_tf) -> None: # type: ignore[no-untyped-def] + tf = make_tf(_t("world", "robot", 5.0, T0)) + got = tf.get("robot", "world", time_point=T0) + assert got is not None + assert got.translation.x == pytest.approx(-5.0) + + def test_time_tolerance(self, make_tf) -> None: # type: ignore[no-untyped-def] + tf = make_tf(_t("world", "robot", 1.0, T0)) + assert tf.get("world", "robot", time_point=T0 + 0.1, time_tolerance=0.2) is not None + assert tf.get("world", "robot", time_point=T0 + 0.5, time_tolerance=0.1) is None + + def test_same_frame_identity(self, make_tf) -> None: # type: ignore[no-untyped-def] + tf = make_tf(_t("world", "robot", 1.0, T0)) + got = tf.get("world", "world", time_point=T0) + assert got is not None + assert got.frame_id == "world" + assert got.child_frame_id == "world" + assert got.translation.x == 0.0 + assert got.rotation.w == 1.0 + + def test_unknown_frame(self, make_tf) -> None: # type: ignore[no-untyped-def] + tf = make_tf(_t("world", "robot", 1.0, T0)) + assert tf.get("foo", "bar", time_point=T0) is None + + def test_chain_composition(self, make_tf) -> None: # type: ignore[no-untyped-def] + # world -> robot: translate (1, 0, 0); robot -> sensor: translate + # (0, 2, 0) and rotate 90° around Z. + tf = make_tf( Transform( - translation=Vector3(1.0, 0.0, 0.0), # Move 1 along X + translation=Vector3(1.0, 0.0, 0.0), rotation=Quaternion(0.0, 0.0, 0.0, 1.0), frame_id="world", - child_frame_id="base", - ts=base_time, - ), - Transform( - translation=Vector3(0.0, 2.0, 0.0), # Move 2 along Y - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="base", - child_frame_id="arm", - ts=base_time, + child_frame_id="robot", + ts=T0, ), Transform( - translation=Vector3(0.0, 0.0, 3.0), # Move 3 along Z - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="arm", - child_frame_id="hand", - ts=base_time, + translation=Vector3(0.0, 2.0, 0.0), + rotation=Quaternion(0.0, 0.0, math.sin(math.pi / 4), math.cos(math.pi / 4)), + frame_id="robot", + child_frame_id="sensor", + ts=T0, ), - ] - - for t in transforms: - ttbuffer.receive_transform(t) - - # Get composed transform from world to hand - result = ttbuffer.get("world", "hand") + ) + result = tf.get("world", "sensor", time_point=T0) assert result is not None - - # Total translation should be sum of all: (1, 2, 3) - assert abs(result.translation.x - 1.0) < 1e-6 - assert abs(result.translation.y - 2.0) < 1e-6 - assert abs(result.translation.z - 3.0) < 1e-6 - - # Rotation should still be identity (all rotations were identity) - assert abs(result.rotation.x - 0.0) < 1e-6 - assert abs(result.rotation.y - 0.0) < 1e-6 - assert abs(result.rotation.z - 0.0) < 1e-6 - assert abs(result.rotation.w - 1.0) < 1e-6 - + assert result.translation.x == pytest.approx(1.0) + assert result.translation.y == pytest.approx(2.0) + assert result.translation.z == pytest.approx(0.0) + assert result.rotation.z == pytest.approx(math.sin(math.pi / 4)) + assert result.rotation.w == pytest.approx(math.cos(math.pi / 4)) assert result.frame_id == "world" - assert result.child_frame_id == "hand" + assert result.child_frame_id == "sensor" + + def test_chain_with_sparse_static_edge(self, make_tf) -> None: # type: ignore[no-untyped-def] + # A static edge published once at startup composes with dynamic data + # arriving ten seconds later (default lookup reach = buffer_size). + tf = make_tf( + _t("world", "map", 100.0, T0), + *(_t("map", "base", i / 10, T0 + i / 10) for i in range(100)), + ) + got = tf.get("world", "base", time_point=T0 + 9.0) + assert got is not None + assert got.translation.x == pytest.approx(109.0) + + def test_conforms_to_tf_lookup(self, make_tf) -> None: # type: ignore[no-untyped-def] + assert isinstance(make_tf(_t("world", "robot", 1.0, T0)), TFLookup) + + +class TestStreamTF: + """Replay-specific surface: construction, bounded caching, and lookahead.""" + + @pytest.fixture + def store(self): # type: ignore[no-untyped-def] + store = MemoryStore() + stream = store.stream("tf", TFMessage) + # Startup burst: a static stamped exactly at the stream start (regression: + # strict `ts >` range queries used to drop it), then dynamic map→base. + stream.append(TFMessage(_t("world", "map", 100.0, T0)), ts=T0, pose=None) + for i in range(100): + ts = T0 + i / 10 + stream.append(TFMessage(_t("map", "base", i / 10, ts)), ts=ts, pose=None) + yield store + store.stop() + + def test_from_store(self, store) -> None: # type: ignore[no-untyped-def] + assert StreamTF.from_store(store) is not None + assert StreamTF.from_store(store, "nope") is None + + def test_missing_stream(self) -> None: + with pytest.raises(ValueError, match="Stream configuration"): + StreamTF() + + def test_empty_stream(self) -> None: + store = MemoryStore() + store.stream("tf", TFMessage) + try: + tf = StreamTF(store.stream("tf", TFMessage)) + assert tf.get("world", "base", time_point=T0) is None + finally: + store.stop() + + def test_get_frames(self, store) -> None: # type: ignore[no-untyped-def] + tf = StreamTF(store.stream("tf", TFMessage)) + # Nothing is cached until a lookup pulls a window in. + assert tf.get_frames() == set() + tf.get("map", "base", time_point=T0 + 5.0) + assert tf.get_frames() == {"world", "map", "base"} + + def test_read_only(self, store) -> None: # type: ignore[no-untyped-def] + tf = StreamTF(store.stream("tf", TFMessage)) + with pytest.raises(NotImplementedError): + tf.publish(_t("map", "base", 0.0, T0)) + + def test_get_pose(self, store) -> None: # type: ignore[no-untyped-def] + tf = StreamTF(store.stream("tf", TFMessage)) + pose = tf.get_pose("map", "base", time_point=T0 + 5.0) + assert pose is not None + assert pose.position.x == pytest.approx(5.0) + + def test_cache_prefetch_and_eviction(self) -> None: + # 40 s of data with a small cache_span: a miss caches the query window + # plus cache_span ahead, follow-ups inside it are pure cache hits, and + # the first query past it evicts everything and re-caches. + store = MemoryStore() + stream = store.stream("tf", TFMessage) + for i in range(400): + ts = T0 + i / 10 + stream.append(TFMessage(_t("map", "base", i / 10, ts)), ts=ts, pose=None) + try: + tf = StreamTF(store.stream("tf", TFMessage), cache_span=2.0) + got = tf.get("map", "base", time_point=T0 + 35.0, time_tolerance=0.5) + assert got is not None + assert got.translation.x == pytest.approx(35.0) + covered = tf._covered + assert covered == pytest.approx((T0 + 34.5, T0 + 37.5)) + # Inside the prefetched span: served from cache, no re-query. + got = tf.get("map", "base", time_point=T0 + 37.0, time_tolerance=0.5) + assert got is not None + assert got.translation.x == pytest.approx(37.0) + assert tf._covered == covered + # Past the span: evict and re-cache around the new query. + got = tf.get("map", "base", time_point=T0 + 39.0, time_tolerance=0.5) + assert got is not None + assert got.translation.x == pytest.approx(39.0) + assert tf._covered == pytest.approx((T0 + 38.5, T0 + 41.5)) + # Bounded: the buffer holds the re-cached span, not the stream. + assert len(tf.buffers[("map", "base")]) < 100 + finally: + store.stop() diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index eb3d72b470..f976ead6c6 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -20,6 +20,7 @@ from functools import reduce import threading import time +from typing import TYPE_CHECKING, Protocol, runtime_checkable from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Transform import Transform @@ -34,6 +35,24 @@ logger = setup_logger() +@runtime_checkable +class TFLookup(Protocol): + """Read side of a tf service: resolve ``parent ← child`` at a time point. + + Satisfied by the live services (:class:`MultiTBuffer`, :class:`PubSubTF`) + and by replay backends like ``dimos.memory2.tf.StreamTF``. Code that only + queries transforms should accept this instead of a concrete service. + """ + + def get( + self, + parent_frame: str, + child_frame: str, + time_point: float | None = None, + time_tolerance: float | None = None, + ) -> Transform | None: ... + + # generic configuration for transform service class TFConfig(BaseConfig): buffer_size: float = 10.0 # seconds @@ -64,6 +83,26 @@ def get( forward_tolerance: float = 0.0, ) -> Transform | None: ... + def get_pose( + self, + parent_frame: str, + child_frame: str, + time_point: float | None = None, + time_tolerance: float | None = None, + *, + forward_tolerance: float = 0.0, + ) -> PoseStamped | None: + tf = self.get( + parent_frame, + child_frame, + time_point, + time_tolerance, + forward_tolerance=forward_tolerance, + ) + if not tf: + return None + return tf.to_pose() + def receive_transform(self, *args: Transform) -> None: ... def receive_tfmessage(self, msg: TFMessage) -> None: @@ -159,16 +198,20 @@ def get_transform( ts=time_point if time_point is not None else time.time(), ) + # No explicit tolerance means "anything still buffered" — the buffer + # holds at most buffer_size seconds, so that is the effective reach. + tolerance = time_tolerance if time_tolerance is not None else self.buffer_size + with self._cv: # Check forward direction key = (parent_frame, child_frame) if key in self.buffers: - return self.buffers[key].get(time_point, time_tolerance) # type: ignore[arg-type] + return self.buffers[key].get(time_point, tolerance) # Check reverse direction and return inverse reverse_key = (child_frame, parent_frame) if reverse_key in self.buffers: - transform = self.buffers[reverse_key].get(time_point, time_tolerance) # type: ignore[arg-type] + transform = self.buffers[reverse_key].get(time_point, tolerance) return transform.inverse() if transform else None return None @@ -388,26 +431,6 @@ def get( forward_tolerance=forward_tolerance, ) - def get_pose( - self, - parent_frame: str, - child_frame: str, - time_point: float | None = None, - time_tolerance: float | None = None, - *, - forward_tolerance: float = 0.0, - ) -> PoseStamped | None: - tf = self.get( - parent_frame, - child_frame, - time_point, - time_tolerance, - forward_tolerance=forward_tolerance, - ) - if not tf: - return None - return tf.to_pose() - def receive_msg(self, msg: TFMessage, topic: Topic) -> None: self.receive_tfmessage(msg) @@ -423,3 +446,7 @@ class LCMTF(PubSubTF): TF = LCMTF + +if TYPE_CHECKING: + # mypy conformance checks: the live services satisfy the read-side protocol. + _lookup_impls: tuple[type[TFLookup], ...] = (MultiTBuffer, PubSubTF) diff --git a/docs/usage/transforms.md b/docs/usage/transforms.md index df7f70274f..596a6b5b4b 100644 --- a/docs/usage/transforms.md +++ b/docs/usage/transforms.md @@ -200,7 +200,7 @@ Every module has access to `self.tf`, a transform service that: - **Looks up** transforms between any two frames - **Buffers** historical transforms for temporal queries -The TF service is implemented in [`tf.py`](/dimos/protocol/tf/tf.py) and is lazily initialized on first access. +The TF service is implemented in [`protocol/tf/tf.py`](/dimos/protocol/tf/tf.py) and is lazily initialized on first access. ### Multi-Module Transform Example From 90c1f2d32d59bc627f0a51c9d164df5aa274a6fd Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 3 Jul 2026 03:34:27 +0300 Subject: [PATCH 04/10] comments cleanup --- dimos/memory2/tf.py | 42 +++++++----------------------------------- 1 file changed, 7 insertions(+), 35 deletions(-) diff --git a/dimos/memory2/tf.py b/dimos/memory2/tf.py index e5a1d59227..7c8dae0d5d 100644 --- a/dimos/memory2/tf.py +++ b/dimos/memory2/tf.py @@ -32,30 +32,10 @@ class StreamTFConfig(TFConfig): stream: Stream[TFMessage] | None = ( None # Required field but needs default for config inheritance ) - # Prefetch span (s) cached past a missed query window, so chronological - # replay costs one db query per cache_span of progress. Also the cache - # size bound: a miss evicts everything before re-caching. cache_span: float = 300.0 class StreamTF(MultiTBuffer, TFSpec): - """A tf service whose backend is a recorded memory2 ``tf`` stream. - - The read-side mirror of :class:`~dimos.protocol.tf.tf.PubSubTF`: the same - :class:`MultiTBuffer` cache and lookup API, but ingestion pulls windows - from the stream on demand instead of receiving pushed messages. - - Lookups reach as far as they would against the live service: with no - explicit ``time_tolerance`` a query window spans ``buffer_size`` seconds - backward (what a live buffer would still hold) plus ``forward_tolerance`` - ahead — the recorded-time analog of the live wall-clock wait for future - transforms, which is why lookups here never block. A cache miss fetches - the window plus ``cache_span`` beyond it in one query. The cache pins the - underlying :class:`MultiTBuffer` to infinite retention: insert-time - pruning would silently delete data the cache still claims to hold, so - eviction is explicit (miss → full clear → re-cache). - """ - config: StreamTFConfig def __init__(self, stream: Stream[TFMessage] | None = None, **kwargs: Any) -> None: @@ -83,18 +63,12 @@ def publish_static(self, *args: Transform) -> None: raise NotImplementedError("StreamTF is a read-only replay service.") def _load(self, lo: float, hi: float) -> None: - # at() windows are boundary-inclusive; from/to_timestamp are strict and - # would skip messages stamped exactly at the stream's first timestamp. for obs in self.stream.at((lo + hi) / 2, (hi - lo) / 2): self.receive_transform(*obs.data.transforms) + self._covered = (lo, hi) def _ensure(self, lo: float, hi: float) -> None: - """Serve ``[lo, hi]`` from the cache, else re-cache ``[lo, hi + cache_span]``. - - The prefetch past ``hi`` makes chronological replay cost one db query - per ``cache_span`` of progress. A miss evicts everything first — a full - clear (not partial pruning) keeps ``_covered`` truthful. - """ + """Serve ``[lo, hi]`` from the cache, else re-cache ``[lo, hi + cache_span]``.""" if self._covered is not None: clo, chi = self._covered if clo <= lo and hi <= chi: @@ -102,7 +76,6 @@ def _ensure(self, lo: float, hi: float) -> None: with self._cv: self.buffers.clear() self._load(lo, hi + self.config.cache_span) - self._covered = (lo, hi + self.config.cache_span) def get( self, @@ -117,12 +90,11 @@ def get( if tp is None: last = next(iter(self.stream.order_by("ts", desc=True).limit(1)), None) tp = last.ts if last is not None else None - if tp is not None: - back = time_tolerance if time_tolerance is not None else self.config.buffer_size - fwd = time_tolerance if time_tolerance is not None else forward_tolerance - self._ensure(tp - back, tp + fwd) - # The recorded-time lookahead above stands in for the live wall-clock - # wait, so the base lookup must never block. + + back = time_tolerance if time_tolerance is not None else self.config.buffer_size + fwd = time_tolerance if time_tolerance is not None else forward_tolerance + self._ensure(tp - back, tp + fwd) + return super().get( parent_frame, child_frame, From 42034cbf75a72dce85cb7010d8c6c897df656b92 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 3 Jul 2026 03:51:54 +0300 Subject: [PATCH 05/10] type fix --- dimos/memory2/tf.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dimos/memory2/tf.py b/dimos/memory2/tf.py index 7c8dae0d5d..5db3ef6c8c 100644 --- a/dimos/memory2/tf.py +++ b/dimos/memory2/tf.py @@ -91,9 +91,10 @@ def get( last = next(iter(self.stream.order_by("ts", desc=True).limit(1)), None) tp = last.ts if last is not None else None - back = time_tolerance if time_tolerance is not None else self.config.buffer_size - fwd = time_tolerance if time_tolerance is not None else forward_tolerance - self._ensure(tp - back, tp + fwd) + if tp is not None: + back = time_tolerance if time_tolerance is not None else self.config.buffer_size + fwd = time_tolerance if time_tolerance is not None else forward_tolerance + self._ensure(tp - back, tp + fwd) return super().get( parent_frame, From 1ce8306df9b9eb7f85b06c00f2b577ebf9aefbc7 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 3 Jul 2026 04:01:53 +0300 Subject: [PATCH 06/10] small cleanup --- dimos/mapping/utils/cli/map.py | 40 +++++++++++++++------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/dimos/mapping/utils/cli/map.py b/dimos/mapping/utils/cli/map.py index 91540afaf4..4310a3e6eb 100644 --- a/dimos/mapping/utils/cli/map.py +++ b/dimos/mapping/utils/cli/map.py @@ -18,7 +18,7 @@ import math from pathlib import Path import subprocess -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any import rerun as rr import rerun.blueprint as rrb @@ -41,10 +41,6 @@ # labels never overlap the boxes. MARKER_STEM = 1.0 -# Sentinel a register resolver returns when a frame has no usable registration -# source and must be dropped (None means "already in the world frame"). -_DROP_FRAME = cast("Transform", object()) - # Conventional world frames tried in order when --frame isn't given. _WORLD_FRAMES = ("world", "map", "odom") @@ -123,9 +119,8 @@ def _accumulate( """Accumulate a voxel map from `obs_iter`, optionally PGO-correcting each frame. ``register`` maps each observation to the transform lifting its cloud into - the world frame: ``None`` when the cloud is already world-registered, or - ``_DROP_FRAME`` when it has no registration source (the frame is skipped). - With ``register=None`` all clouds are assumed world-registered. + the world frame; ``None`` means no transform is available and the frame is + skipped. With ``register=None`` all clouds are assumed world-registered. Returns the final ``PointCloud2`` (or ``None`` if the input was empty). Disposal of the underlying ``VoxelGrid`` is handled by ``VoxelMapTransformer``. @@ -144,7 +139,7 @@ def prepared() -> Iterable[Observation[PointCloud2]]: tf: Transform | None = None if register is not None: tf = register(obs) - if tf is _DROP_FRAME: + if tf is None: continue if graph is not None: if obs.pose_tuple is None: @@ -452,13 +447,17 @@ def main( if world is None: world = "world" # empty lidar stream; the frame is moot + # Registration: sensor-frame clouds get a per-frame tf lookup lifting them + # into the world frame (frames with no tf answer are dropped); clouds + # already stamped with the world frame accumulate verbatim (register=None). + register: Callable[[Observation[Any]], Transform | None] | None = None if first_obs is not None and cloud_frame is not None and cloud_frame != world: # Fail fast when registration is impossible: probe the first cloud's # timestamp (unbounded tolerance — "possible at all", not "in range"). probe = ( tf_buf.get(world, cloud_frame, time_point=first_obs.ts) if tf_buf is not None else None ) - if probe is None: + if tf_buf is None or probe is None: frames = tf_buf.get_frames() if tf_buf is not None else set() known = ", ".join(sorted(frames)) or "dataset has no tf stream" raise typer.BadParameter( @@ -466,18 +465,15 @@ def main( param_hint="--frame", ) print(f"registering clouds {world!r} ← {cloud_frame!r} via tf") + buf = tf_buf + + def _register(obs: Observation[Any]) -> Transform | None: + return buf.get(world, obs.data.frame_id, time_point=obs.ts, time_tolerance=tf_tolerance) + + register = _register elif cloud_frame is not None: print(f"clouds already in world frame {world!r}; accumulating verbatim") - def _register(obs: Observation[Any]) -> Transform | None: - cf = obs.data.frame_id - if cf == world: - return None - if tf_buf is None: - return _DROP_FRAME - tf = tf_buf.get(world, cf, time_point=obs.ts, time_tolerance=tf_tolerance) - return tf if tf is not None else _DROP_FRAME - def _position(obs: Observation[Any]) -> tuple[float, float, float] | None: """Trajectory position for dedup/path: tf lookup, else the stored pose.""" if tf_buf is not None and cloud_frame is not None and cloud_frame != world: @@ -542,7 +538,7 @@ def _position(obs: Observation[Any]) -> tuple[float, float, float] | None: block_count=block_count, device=device, graph=graph, - register=_register, + register=register, carve_columns=carve, progress_cb=progress(n_kept, "pgo pass 2 (rebuilding)"), ) @@ -556,7 +552,7 @@ def _position(obs: Observation[Any]) -> tuple[float, float, float] | None: block_count=block_count, device=device, graph=graph, - register=_register, + register=register, carve_columns=carve, progress_cb=progress(total, "full pgo (rebuilding)"), ) @@ -567,7 +563,7 @@ def _position(obs: Observation[Any]) -> tuple[float, float, float] | None: voxel=voxel, block_count=block_count, device=device, - register=_register, + register=register, carve_columns=carve, progress_cb=progress(n_kept, "reconstructing global map"), ) From d7bf5f79de2b60725fdfb5e03c065abbb4ece6b4 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 3 Jul 2026 04:08:56 +0300 Subject: [PATCH 07/10] mac skip --- dimos/protocol/tf/test_tf.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index 24951be375..67a1399501 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -15,6 +15,7 @@ # limitations under the License. import math +import platform import threading import time @@ -633,6 +634,12 @@ def _t(parent: str, child: str, x: float, ts: float) -> Transform: @pytest.fixture(params=["live", "stream_memory", "stream_sqlite"]) def make_tf(request, tmp_path): # type: ignore[no-untyped-def] """Builder fixture: feed it transforms, get back a TFLookup over them.""" + if request.param == "stream_sqlite" and ( + platform.machine() == "aarch64" or platform.system() == "Darwin" + ): + # Same guard as memory2/conftest.py: sqlite-vec ships a 32-bit binary + # in the aarch64 wheel and fails to load on macOS CI. + pytest.skip("sqlite-vec extension not loadable here") stores = [] def build(*transforms: Transform) -> TFLookup: From 2fabfdd149b28e21209cac96d730164c95c3cfc2 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 3 Jul 2026 04:12:16 +0300 Subject: [PATCH 08/10] mac skip --- dimos/protocol/tf/test_tf.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index 67a1399501..e011a49086 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -15,7 +15,6 @@ # limitations under the License. import math -import platform import threading import time @@ -631,15 +630,17 @@ def _t(parent: str, child: str, x: float, ts: float) -> Transform: ) -@pytest.fixture(params=["live", "stream_memory", "stream_sqlite"]) +@pytest.fixture( + params=[ + "live", + "stream_memory", + # sqlite-vec ships a 32-bit binary in the aarch64 wheel and fails to + # load on macOS CI (same guard as memory2/conftest.py). + pytest.param("stream_sqlite", marks=[pytest.mark.skipif_aarch64, pytest.mark.skipif_macos]), + ] +) def make_tf(request, tmp_path): # type: ignore[no-untyped-def] """Builder fixture: feed it transforms, get back a TFLookup over them.""" - if request.param == "stream_sqlite" and ( - platform.machine() == "aarch64" or platform.system() == "Darwin" - ): - # Same guard as memory2/conftest.py: sqlite-vec ships a 32-bit binary - # in the aarch64 wheel and fails to load on macOS CI. - pytest.skip("sqlite-vec extension not loadable here") stores = [] def build(*transforms: Transform) -> TFLookup: From b7a9ee935645b49c562803028de5181f31a2d2b0 Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 3 Jul 2026 14:21:47 +0300 Subject: [PATCH 09/10] map uses obs.pose as a backup --- dimos/mapping/utils/cli/map.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/dimos/mapping/utils/cli/map.py b/dimos/mapping/utils/cli/map.py index 4310a3e6eb..106e141797 100644 --- a/dimos/mapping/utils/cli/map.py +++ b/dimos/mapping/utils/cli/map.py @@ -473,13 +473,15 @@ def _register(obs: Observation[Any]) -> Transform | None: register = _register elif cloud_frame is not None: print(f"clouds already in world frame {world!r}; accumulating verbatim") + print("warning: trajectory positions come from stored obs.pose (old dataset)") def _position(obs: Observation[Any]) -> tuple[float, float, float] | None: - """Trajectory position for dedup/path: tf lookup, else the stored pose.""" - if tf_buf is not None and cloud_frame is not None and cloud_frame != world: - tf = tf_buf.get(world, cloud_frame, time_point=obs.ts, time_tolerance=tf_tolerance) - if tf is not None: - return (tf.translation.x, tf.translation.y, tf.translation.z) + """Trajectory position for dedup/path: registration tf, else the stored pose.""" + if register is not None: + tf = register(obs) + if tf is None: + return None + return (tf.translation.x, tf.translation.y, tf.translation.z) pose = obs.pose # Reject placeholder poses: zero translation OR uninitialized rotation. # Same condition as pgo_keyframes so dedup and PGO see the same frames. From d45c2845beae0842fe5d38b6a9c598b53c335d2c Mon Sep 17 00:00:00 2001 From: Ivan Nikolic Date: Fri, 3 Jul 2026 14:23:40 +0300 Subject: [PATCH 10/10] hold _cv across _ensure's check-clear-reload Covers the covered-range check, eviction, and reload atomically so a concurrent reader can't see a stale _covered against cleared buffers; also resets _covered before reloading so a failed _load can't leave the cache claiming coverage it evicted. --- dimos/memory2/tf.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/dimos/memory2/tf.py b/dimos/memory2/tf.py index 5db3ef6c8c..ce36173cff 100644 --- a/dimos/memory2/tf.py +++ b/dimos/memory2/tf.py @@ -69,13 +69,14 @@ def _load(self, lo: float, hi: float) -> None: def _ensure(self, lo: float, hi: float) -> None: """Serve ``[lo, hi]`` from the cache, else re-cache ``[lo, hi + cache_span]``.""" - if self._covered is not None: - clo, chi = self._covered - if clo <= lo and hi <= chi: - return - with self._cv: + with self._cv: + if self._covered is not None: + clo, chi = self._covered + if clo <= lo and hi <= chi: + return self.buffers.clear() - self._load(lo, hi + self.config.cache_span) + self._covered = None + self._load(lo, hi + self.config.cache_span) def get( self,