diff --git a/.python-version b/.python-version
new file mode 100644
index 0000000..2c07333
--- /dev/null
+++ b/.python-version
@@ -0,0 +1 @@
+3.11
diff --git a/Dockerfile b/Dockerfile
index af7400f..304e9ff 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,11 +1,25 @@
-# Uses GPU enabled torch
-# FROM public.ecr.aws/artefacts/go2:mujoco
-# Uses CPU only torch
-FROM public.ecr.aws/artefacts/go2:mujoco-cputorch
+FROM nvidia/cuda:12.8.1-base-ubuntu22.04
+
+COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
+
+ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility,graphics
+ENV MUJOCO_GL=egl
+SHELL ["/bin/bash", "-c"]
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ libgl1 libglfw3 libgles2 libegl1 libglib2.0-0 \
+ python3-pip python3-dev ffmpeg git curl cmake build-essential \
+ && rm -rf /var/lib/apt/lists/*
+
WORKDIR /ws
+COPY pyproject.toml uv.lock .python-version ./
+# Use cpu-torch (for smaller image size)
+RUN uv lock --index pytorch=https://download.pytorch.org/whl/cpu
+# TODO: Should we move optional dependencies (like rerun-sdk) to extras to decrease the image size?
+RUN uv sync
COPY artefacts.yaml go2_wtw_demo.py utils.py ./
COPY tests/ tests/
COPY resources/ resources/
-CMD artefacts run $ARTEFACTS_JOB_NAME
+CMD uv run artefacts run $ARTEFACTS_JOB_NAME
\ No newline at end of file
diff --git a/Dockerfile.base b/Dockerfile.base
deleted file mode 100644
index 908e0ed..0000000
--- a/Dockerfile.base
+++ /dev/null
@@ -1,60 +0,0 @@
-FROM nvidia/cuda:12.8.1-base-ubuntu22.04
-
-ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility,graphics
-# MuJoCo needs a display context for passive viewer; use EGL for offscreen
-ENV MUJOCO_GL=egl
-
-SHELL ["/bin/bash", "-c"]
-
-# System deps for MuJoCo rendering (EGL/offscreen) and video encoding
-RUN apt-get update && apt-get install -y --no-install-recommends \
- libgl1 \
- libglfw3 \
- libgles2 \
- libegl1 \
- libglib2.0-0 \
- python3-pip \
- python3-dev \
- ffmpeg \
- git \
- curl \
- cmake \
- build-essential \
- && rm -rf /var/lib/apt/lists/*
-
-# CycloneDDS (required by unitree_sdk2_python)
-RUN git clone https://github.com/eclipse-cyclonedds/cyclonedds -b releases/0.10.x --depth 1 /opt/cyclonedds \
- && mkdir /opt/cyclonedds/build /opt/cyclonedds/install \
- && cd /opt/cyclonedds/build \
- && cmake .. -DCMAKE_INSTALL_PREFIX=/opt/cyclonedds/install \
- && cmake --build . --target install
-
-ENV CYCLONEDDS_HOME=/opt/cyclonedds/install
-
-WORKDIR /ws/src
-
-RUN git clone -b high-level-direct --depth 1 https://github.com/art-e-fact/unitree_mujoco.git
-RUN git clone --depth=1 https://github.com/unitreerobotics/unitree_sdk2_python.git
-
-RUN pip install --no-cache-dir ./unitree_sdk2_python
-
-WORKDIR /ws
-
-COPY scripts/fetch_wtw_checkpoints.sh scripts/
-RUN ./scripts/fetch_wtw_checkpoints.sh
-
-# Use this for cpu-torch (smaller image size)
-RUN pip install --no-cache-dir \
- --extra-index-url https://download.pytorch.org/whl/cpu \
- "mujoco>=3.5.0" \
- "numpy>=2.0.0" \
- "pytest>=8.0.0" \
- torch \
- mediapy \
- artefacts-cli \
- pygame \
- opencv-python
-
-# Use this for gpu-torch (CUDA drivers, large image size)
-# COPY requirements.txt .
-# RUN pip install --no-cache-dir -r requirements.txt
diff --git a/README.md b/README.md
index 9d86123..e9c79c8 100644
--- a/README.md
+++ b/README.md
@@ -1,61 +1,42 @@
# Go2 MuJoCo Walk Demo (Artefacts)
-Unitree Go2 walking demo using MuJoCo and the Walk-These-Ways pretrained policy.
+Unitree Go2 walking demo using MuJoCo.
The demo runs on ubuntu, but also MacOS natively without vms or containerization.
A flat scene is added to the `unitree_mujoco` package (see copy stage) to provide a simple demo for the [Go2](https://www.unitree.com/go2) to move.
-With thanks to [Teddy Liao](https://github.com/Teddy-Liao) for the pretrained policy in the [walk-these-ways-go2](https://github.com/Teddy-Liao/walk-these-ways-go2) repository.
## Setup
### Prerequisites
-You will need cyclonedds installed if you do not have it already. We suggest to install outside of this project.
+This project uses `uv` for Python environment management. For installation instructions, see https://docs.astral.sh/uv/getting-started/installation/
+
+
+## Initialize the virtual environment
```bash
-cd ~
-git clone https://github.com/eclipse-cyclonedds/cyclonedds -b releases/0.10.x
-cd cyclonedds && mkdir build install && cd build
-cmake .. -DCMAKE_INSTALL_PREFIX=../install
-cmake --build . --target install
+uv sync
+source .venv/bin/activate
```
-### Project Setup
+## Run
+#### Run follow-on-rails demo with procedurally generated railroad:
```bash
-cd ~/go2-mujoco-artefacts
-# Create virtual environment
-python3 -m venv venv
-source venv/bin/activate
-
-# Clone required repos
-mkdir -p src
-git clone --depth=1 -b high-level-direct https://github.com/art-e-fact/unitree_mujoco.git src/unitree_mujoco
-git clone --depth=1 https://github.com/unitreerobotics/unitree_sdk2_python.git src/unitree_sdk2_python
-# Install the sdk
-cd src/unitree_sdk2_python
-export CYCLONEDDS_HOME=~/cyclonedds/install
-pip install -e .
-
-# Back to root of repository
-cd ../..
-# Download WTW policy checkpoints (3 files only, no full repo clone)
-bash scripts/fetch_wtw_checkpoints.sh
-
-# Install dependencies
-pip install -r requirements.txt
+python go2_rails_demo.py --rerun --heightmap-nav --seed 123
```
-## Run (Linux)
+The robot should walk on the rails following the target object.
+
+_Note: run `python go2_rails_demo.py --help` for more options._
+
+#### Run the robot with WTW policy in a test environment:
```bash
-source venv/bin/activate # if not already
python go2_wtw_demo.py
```
-## Run (MacOS)
-```bash
-source venv/bin/activate # if not already
-mjpython go2_wtw_demo.py
-```
-The robot will walk forward and turn left in a loop.
+
+_Note: run `python go2_wtw_demo.py --help` for more options._
+
+
### Run test with artefacts
diff --git a/go2_rails_demo.py b/go2_rails_demo.py
new file mode 100644
index 0000000..24f4f6f
--- /dev/null
+++ b/go2_rails_demo.py
@@ -0,0 +1,534 @@
+#!/usr/bin/env python3
+"""
+Go2 Rail-Following Demo
+
+Generates a procedural rail network and drives the Go2 along a chosen road
+using a simple pursuit controller via the high-level SportClient API.
+
+Usage:
+ python go2_rails_demo.py # MuJoCo viewer (UI)
+ python go2_rails_demo.py --headless # no display (CI / testing)
+
+"""
+
+import sys
+import os
+import math
+import time
+import threading
+import subprocess
+from utils import get_python_executable, sim_sleep, FrontCameraRecorder
+
+_HERE = os.path.dirname(os.path.abspath(__file__))
+
+
+def _drain(proc, events):
+ """Print subprocess stdout; set events[i] when events[i][0] marker appears."""
+ for line in proc.stdout:
+ print(f" [sim] {line.rstrip()}")
+ for marker, event in events:
+ if marker and marker in line:
+ event.set()
+
+
+def _stop(procs):
+ for proc in reversed(procs):
+ if proc.poll() is None:
+ proc.terminate()
+ try:
+ proc.wait(timeout=5)
+ except subprocess.TimeoutExpired:
+ proc.kill()
+
+
+def _build_waypoints(road):
+ """Extract (x, y, heading_rad) waypoints from a road."""
+ return [(x, y, math.radians(h)) for x, y, h in road]
+
+
+def _build_path(waypoints):
+ """Return cumulative arc-length distances and (x,y) arrays for waypoints."""
+ xs = [w[0] for w in waypoints]
+ ys = [w[1] for w in waypoints]
+ dists = [0.0]
+ for i in range(1, len(xs)):
+ dx = xs[i] - xs[i - 1]
+ dy = ys[i] - ys[i - 1]
+ dists.append(dists[-1] + math.sqrt(dx * dx + dy * dy))
+ return dists, xs, ys
+
+
+def _sample_path(dists, xs, ys, s):
+ """Interpolate (x, y) at arc-length distance s along the path."""
+ s = max(0.0, min(s, dists[-1]))
+ for i in range(1, len(dists)):
+ if dists[i] >= s:
+ t = (s - dists[i - 1]) / (dists[i] - dists[i - 1]) if dists[i] > dists[i - 1] else 0.0
+ return xs[i - 1] + t * (xs[i] - xs[i - 1]), ys[i - 1] + t * (ys[i] - ys[i - 1])
+ return xs[-1], ys[-1]
+
+
+def _path_tangent(dists, xs, ys, s):
+ """Return the tangent angle (rad) of the path at arc-length s."""
+ s = max(0.0, min(s, dists[-1]))
+ for i in range(1, len(dists)):
+ if dists[i] >= s:
+ return math.atan2(ys[i] - ys[i - 1], xs[i] - xs[i - 1])
+ return math.atan2(ys[-1] - ys[-2], xs[-1] - xs[-2])
+
+
+def _closest_path_s(dists, xs, ys, px, py):
+ """Return the arc-length s of the point on the path closest to (px, py)."""
+ best_s, best_d2 = 0.0, float("inf")
+ for i in range(1, len(dists)):
+ # Project (px,py) onto segment [i-1, i]
+ ax, ay = xs[i - 1], ys[i - 1]
+ bx, by = xs[i], ys[i]
+ dx, dy = bx - ax, by - ay
+ seg_len2 = dx * dx + dy * dy
+ if seg_len2 < 1e-12:
+ continue
+ t = max(0.0, min(1.0, ((px - ax) * dx + (py - ay) * dy) / seg_len2))
+ cx, cy = ax + t * dx, ay + t * dy
+ d2 = (px - cx) ** 2 + (py - cy) ** 2
+ if d2 < best_d2:
+ best_d2 = d2
+ best_s = dists[i - 1] + t * (dists[i] - dists[i - 1])
+ return best_s
+
+
+def _analyze_rails(data, width, height, resolution, forward_yaw, radius=1.0):
+ """Detect rail heading and lateral offset from a heightmap.
+
+ Args:
+ data: flat float32 array (width*height), EMPTY=1e9 for missing cells.
+ width, height, resolution: grid parameters.
+ forward_yaw: approximate forward direction (rad, world frame) to
+ disambiguate the ±180° PCA ambiguity.
+ radius: only consider cells within this distance (m) from the robot.
+
+ Returns:
+ (rail_heading_world, lateral_offset) or None if insufficient data.
+ rail_heading_world: heading of the rails in world frame (rad).
+ lateral_offset: signed distance from robot to rail midline (m);
+ positive = robot is to the left of center.
+ """
+ import numpy as np
+
+ valid = data[data < 1e8]
+ if len(valid) < 20:
+ return None
+ # Rails are the highest features — keep cells above the 90th percentile
+ thresh = np.percentile(valid, 99)
+ rail_mask = (data >= thresh) & (data < 1e8)
+ if rail_mask.sum() < 10:
+ return None
+
+ # Cell positions relative to grid center (= robot).
+ # The grid is axis-aligned with the world frame, so these offsets
+ # are world-frame relative displacements from the robot.
+ idx = np.argwhere(rail_mask.reshape(height, width)) # (N, 2) as [iy, ix]
+ cx = (idx[:, 1] - width / 2.0 + 0.5) * resolution # world-x offset
+ cy = (idx[:, 0] - height / 2.0 + 0.5) * resolution # world-y offset
+
+ # Only keep cells within `radius` of the robot to ignore distant turns.
+ near = cx**2 + cy**2 <= radius**2
+ cx, cy = cx[near], cy[near]
+ if len(cx) < 10:
+ return None
+
+ # Split into two rail clusters using the gap in perpendicular projection.
+ perp_fwd = -np.sin(forward_yaw) * cx + np.cos(forward_yaw) * cy
+ order = np.argsort(perp_fwd)
+ sorted_perp = perp_fwd[order]
+ gaps = np.diff(sorted_perp)
+ split = np.argmax(gaps)
+ if gaps[split] < 2 * resolution or min(split + 1, len(cx) - split - 1) < 3:
+ return None # no clear two-rail separation
+
+ # Recenter each rail cluster to remove inter-rail lateral spread.
+ mask_a, mask_b = order[:split + 1], order[split + 1:]
+ ca = np.column_stack([cx[mask_a] - cx[mask_a].mean(),
+ cy[mask_a] - cy[mask_a].mean()])
+ cb = np.column_stack([cx[mask_b] - cx[mask_b].mean(),
+ cy[mask_b] - cy[mask_b].mean()])
+ pts = np.vstack([ca, cb])
+
+ # PCA on recentered points: principal axis = rail direction
+ cov = np.cov(pts[:, 0], pts[:, 1]) # 2×2
+ eigvals, eigvecs = np.linalg.eigh(cov)
+ principal = eigvecs[:, 1] # largest eigenvalue
+ rail_heading = np.arctan2(principal[1], principal[0])
+
+ # Disambiguate ±180°: pick the direction closer to forward_yaw
+ diff = (rail_heading - forward_yaw + np.pi) % (2 * np.pi) - np.pi
+ if abs(diff) > np.pi / 2:
+ rail_heading += np.pi
+ rail_heading = (rail_heading + np.pi) % (2 * np.pi) - np.pi
+
+ # Lateral offset: average of both rail centroids projected onto perp axis.
+ perp_x, perp_y = -np.sin(rail_heading), np.cos(rail_heading)
+ mid_x = (cx[mask_a].mean() + cx[mask_b].mean()) / 2
+ mid_y = (cy[mask_a].mean() + cy[mask_b].mean()) / 2
+ lateral_offset = -(mid_x * perp_x + mid_y * perp_y)
+
+ return rail_heading, lateral_offset, np.column_stack([cx, cy])
+
+
+def main():
+ import argparse
+ import numpy as np
+ from unitree_sdk2py.core.channel import ChannelFactoryInitialize
+ from unitree_sdk2py.go2.sport.sport_client import SportClient
+ from unitree_mujoco import config
+ from rail_gen import RailwayScene, TerrainSpec
+
+ parser = argparse.ArgumentParser(description="Go2 Rail-Following Demo")
+ parser.add_argument("--seed", type=int, default=None, help="Random seed")
+ parser.add_argument("--n-roads", type=int, default=1, help="Number of rail roads")
+ parser.add_argument("--interface", default=config.INTERFACE, help="Network interface")
+ parser.add_argument("--domain", type=int, default=0, help="DDS domain ID")
+ parser.add_argument("--headless", action="store_true", help="No viewer (CI / testing)")
+ parser.add_argument("--telemetry", metavar="PATH", default=None,
+ help="Write simulation state (qpos/qvel) as JSONL to PATH")
+ parser.add_argument("--record", metavar="PATH", default=None,
+ help="Save spectator-view recording (passed to sport_mujoco.py)")
+ parser.add_argument("--record-front", metavar="PATH", default=None,
+ help="Save front-camera recording to PATH")
+ parser.add_argument("--no-heightmap", action="store_false", dest="heightmap", default=True,
+ help="Disable HeightMap_ DDS publishing in the sim")
+ parser.add_argument("--heightmap-debug", action="store_true",
+ help="Visualise height map rays in the viewer")
+ parser.add_argument("--v-forward", type=float, default=1.0, help="Forward velocity (m/s)")
+ parser.add_argument("--yaw-gain", type=float, default=1.0, help="Heading alignment gain")
+ parser.add_argument("--lateral-gain", type=float, default=1.0, help="Lateral centering gain")
+ parser.add_argument("--target-speed", type=float, default=0.5, help="Target speed along path (m/s)")
+ parser.add_argument("--target-lead", type=float, default=1.0, help="Initial target lead distance (m)")
+ parser.add_argument("--no-terrain", action="store_false", dest="terrain", default=True, help="Disable terrain heightfield")
+ parser.add_argument("--teleop", action="store_true", help="Control robot with gamepad instead of auto")
+ parser.add_argument("--rerun", action="store_true", help="Stream data to Rerun viewer")
+ parser.add_argument("--heightmap-nav", action="store_true",
+ help="Steer using heightmap rail detection instead of path geometry")
+ parser.add_argument("--policy", choices=["wtw", "rsl_rl"], default="rsl_rl",
+ help="Locomotion policy (default: rsl_rl)")
+ args = parser.parse_args()
+
+ if args.heightmap_nav:
+ args.heightmap = True
+
+ # --- Generate rail scene ---
+ rng = np.random.default_rng(args.seed)
+ terrain = TerrainSpec() if args.terrain else None
+ scene = RailwayScene.build(rng, n_roads=args.n_roads, terrain=terrain)
+
+ waypoints = _build_waypoints(scene.net.roads[0])
+ start_pos = waypoints[0] if waypoints else None
+ print(f"Rail scene: {len(scene.net.roads)} roads, following road 0 "
+ f"({len(waypoints)} waypoints)")
+
+ scene_path = scene.save_mujoco_scene(_HERE, start_pos=start_pos)
+ print(f"Saved MuJoCo scene to {scene_path}")
+
+ if args.rerun:
+ import rerun as rr
+ rr.init("go2_rails_demo", spawn=True)
+ scene.log_rerun()
+
+ env = {**os.environ, "PYTHONUNBUFFERED": "1"}
+ procs = []
+ recorder = None
+
+ try:
+ # --- sport_mujoco.py ---
+ _sport_mujoco = os.path.join(os.path.dirname(sys.executable), "sport-mujoco")
+ sim_cmd = [get_python_executable(), _sport_mujoco,
+ "--interface", args.interface, "--domain", str(args.domain),
+ "--scene", scene_path, "--keyframe", "rail_start",
+ "--policy", args.policy]
+ if args.headless:
+ sim_cmd.append("--headless")
+ if args.record:
+ sim_cmd += ["--record", os.path.abspath(args.record)]
+ if args.telemetry:
+ sim_cmd += ["--telemetry", os.path.abspath(args.telemetry)]
+ if args.heightmap:
+ sim_cmd.append("--heightmap")
+ if args.heightmap_debug:
+ sim_cmd.append("--heightmap-debug")
+ sim_cmd.append("--uwb")
+
+ sim_proc = subprocess.Popen(
+ sim_cmd,
+ stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
+ text=True, env=env, start_new_session=True,
+ )
+ sim_ready = threading.Event()
+ sim_standing = threading.Event()
+ threading.Thread(target=_drain, args=(sim_proc,
+ [("Serving sport RPC", sim_ready),
+ ("Standing complete.", sim_standing)]),
+ daemon=True).start()
+ procs.append(sim_proc)
+
+ assert sim_ready.wait(timeout=60), "sport_mujoco did not load in time"
+ if not args.headless:
+ time.sleep(1.0)
+ assert sim_standing.wait(timeout=20), "sport_mujoco did not reach standing pose"
+ time.sleep(1.5) # WTW pre-warming
+
+ # --- SportClient ---
+ ChannelFactoryInitialize(args.domain, args.interface)
+ client = SportClient()
+ client.SetTimeout(10.0)
+ client.Init()
+
+ telemetry_path = os.path.abspath(args.telemetry) if args.telemetry else None
+ sleep = (lambda dt: sim_sleep(dt, telemetry_path)) if telemetry_path else time.sleep
+
+ if args.record_front:
+ recorder = FrontCameraRecorder(args.record_front)
+ recorder.start()
+
+ print(f"\n=== Go2 Rail-Following Demo ===")
+ print(f"v_forward={args.v_forward} yaw_gain={args.yaw_gain} "
+ f"target_speed={args.target_speed}")
+ print("=" * 50 + "\n")
+
+ # --- Gamepad (teleop mode) ------------------------------------------
+ joy = None
+ if args.teleop:
+ import pygame
+ pygame.init()
+ if pygame.joystick.get_count() == 0:
+ print("ERROR: --teleop requested but no gamepad found")
+ return
+ joy = pygame.joystick.Joystick(0)
+ joy.init()
+ print(f"[teleop] Using gamepad: {joy.get_name()}")
+
+ # --- Set up UWB subscriber to get robot pose ---
+ from unitree_sdk2py.core.channel import ChannelSubscriber, ChannelPublisher
+ from unitree_sdk2py.idl.unitree_go.msg.dds_ import UwbState_
+ from unitree_sdk2py.idl.geometry_msgs.msg.dds_ import Pose_, Point_, Quaternion_
+
+ marker_pub = ChannelPublisher("rt/uwb_tag_pose", Pose_)
+ marker_pub.Init()
+
+ uwb = {"az": 0.0, "pitch": 0.0, "dist": 0.0, "yaw": 0.0, "ready": False}
+ uwb_lock = threading.Lock()
+
+ def _on_uwb(msg):
+ with uwb_lock:
+ uwb["az"] = msg.orientation_est
+ uwb["pitch"] = msg.pitch_est
+ uwb["dist"] = msg.distance_est
+ uwb["yaw"] = msg.base_yaw
+ uwb["ready"] = True
+
+ uwb_sub = ChannelSubscriber("rt/uwbstate", UwbState_)
+ uwb_sub.Init(_on_uwb, 10)
+
+ # Publish initial marker so UWB has a target to measure
+ path_dists, path_xs, path_ys = _build_path(waypoints)
+ path_s = args.target_lead # current arc-length position of target
+ tx, ty = _sample_path(path_dists, path_xs, path_ys, path_s)
+ marker_pub.Write(Pose_(Point_(tx, ty, 0.9), Quaternion_(0, 0, 0, 1)))
+
+ # Wait for first UWB update
+ for _ in range(100):
+ with uwb_lock:
+ if uwb["ready"]:
+ break
+ time.sleep(0.05)
+
+ # --- Heightmap subscriber (for --rerun and/or --heightmap-nav) ---
+ hm_sub = None
+ hm_data = None # latest heightmap data array
+ hm_msg = None # latest HeightMap_ message
+ if args.rerun or args.heightmap_nav:
+ from unitree_sdk2py.idl.unitree_go.msg.dds_ import HeightMap_
+ hm_sub = ChannelSubscriber("rt/utlidar/height_map_array", HeightMap_)
+ hm_sub.Init()
+
+ def _poll_heightmap():
+ """Read latest HeightMap_, update hm_data/hm_msg, optionally log to Rerun."""
+ nonlocal hm_data, hm_msg
+ if hm_sub is None:
+ return
+ msg = hm_sub.Read()
+ if msg is None:
+ return
+ hm_msg = msg
+ hm_data = np.array(msg.data, dtype=np.float32)
+ if args.rerun:
+ mask = hm_data < 1e8
+ if mask.any():
+ ix = np.arange(len(hm_data)) % msg.width
+ iy = np.arange(len(hm_data)) // msg.width
+ rr.set_time("sim_time", timestamp=msg.stamp)
+ rr.log("heightmap", rr.Points3D(
+ np.column_stack([
+ msg.origin[0] + ix[mask] * msg.resolution,
+ msg.origin[1] + iy[mask] * msg.resolution,
+ hm_data[mask],
+ ]),
+ radii=0.015,
+ ))
+
+ # --- Control loop ---
+ DEADZONE = 0.1
+ dt = 0.1
+ step_count = 0
+ MIN_DIST = 1.0
+
+ if joy:
+ # --- Teleop mode: gamepad controls the robot directly ---
+ print("[teleop] Gamepad active. Left stick = move, right stick X = rotate.")
+ import pygame
+ while True:
+ pygame.event.pump()
+ jlx = joy.get_axis(0) # left stick X → lateral
+ jly = -joy.get_axis(1) # left stick Y → forward (inverted)
+ jrx = joy.get_axis(3) # right stick X → yaw
+
+ vx = jly * args.v_forward if abs(jly) > DEADZONE else 0.0
+ vy = -jlx * 0.3 if abs(jlx) > DEADZONE else 0.0
+ vyaw = -jrx * 2.5 if abs(jrx) > DEADZONE else 0.0
+
+ client.Move(vx, vy, vyaw)
+ sleep(dt)
+ _poll_heightmap()
+ step_count += 1
+ elif args.heightmap_nav:
+ # --- Heightmap-nav mode: steer using rail detection ---
+ while path_s < path_dists[-1]:
+ path_s += args.target_speed * dt
+ tx, ty = _sample_path(path_dists, path_xs, path_ys, path_s)
+ marker_pub.Write(Pose_(Point_(tx, ty, 0.9), Quaternion_(0, 0, 0, 1)))
+
+ with uwb_lock:
+ tag_dist = uwb["dist"]
+ ryaw = uwb["yaw"]
+ az = uwb["az"]
+
+ _poll_heightmap()
+
+ # Approximate forward direction in world frame
+ forward_yaw = ryaw + az
+
+ result = None
+ if hm_data is not None and hm_msg is not None:
+ result = _analyze_rails(hm_data, hm_msg.width, hm_msg.height,
+ hm_msg.resolution, forward_yaw)
+
+ if result is None:
+ client.StopMove()
+ if step_count % 20 == 0:
+ print(" [heightmap-nav] WARNING: no rails detected, stopping")
+ sleep(dt)
+ step_count += 1
+ continue
+
+ rail_heading, lateral_err, rail_xy = result
+ heading_err = (rail_heading - ryaw + math.pi) % (2 * math.pi) - math.pi
+
+ vx = 0.0 if tag_dist < MIN_DIST else args.v_forward * min(1.0, (tag_dist - MIN_DIST) / MIN_DIST)
+ vy = max(-0.6, min(0.6, -args.lateral_gain * lateral_err))
+ vyaw = max(-2.5, min(2.5, args.yaw_gain * heading_err))
+
+
+ client.Move(vx, vy, vyaw)
+
+ if args.rerun:
+ # Robot world position = grid center
+ rx = hm_msg.origin[0] + hm_msg.width * hm_msg.resolution / 2
+ ry = hm_msg.origin[1] + hm_msg.height * hm_msg.resolution / 2
+ rz = 0.35
+ L = 2.0 # arrow length
+ rr.log("rail_analysis/forward_yaw", rr.Arrows3D(
+ origins=[[rx, ry, rz]],
+ vectors=[[L * math.cos(forward_yaw), L * math.sin(forward_yaw), 0]],
+ colors=[[0, 200, 0]],
+ ))
+ rr.log("rail_analysis/rail_heading", rr.Arrows3D(
+ origins=[[rx, ry, rz]],
+ vectors=[[L * math.cos(rail_heading), L * math.sin(rail_heading), 0]],
+ colors=[[255, 255, 0]],
+ ))
+ rr.log("rail_analysis/robot_yaw", rr.Arrows3D(
+ origins=[[rx, ry, rz]],
+ vectors=[[L * math.cos(ryaw), L * math.sin(ryaw), 0]],
+ colors=[[100, 100, 255]],
+ ))
+ # Rail cells (world coords)
+ rr.log("rail_analysis/cells", rr.Points3D(
+ np.column_stack([rx + rail_xy[:, 0], ry + rail_xy[:, 1],
+ np.full(len(rail_xy), rz)]),
+ colors=[[255, 50, 50]], radii=0.03,
+ ))
+ # Lateral offset line (robot → track center)
+ px, py = -math.sin(rail_heading), math.cos(rail_heading)
+ rr.log("rail_analysis/lateral", rr.LineStrips3D(
+ [[[rx, ry, rz],
+ [rx - lateral_err * px, ry - lateral_err * py, rz]]],
+ colors=[[255, 128, 0]], radii=0.02,
+ ))
+
+ sleep(dt)
+ step_count += 1
+
+ if step_count % 50 == 0:
+ print(f" step {step_count}: rail_hdg={math.degrees(rail_heading):.1f}° "
+ f"lat={lateral_err:+.3f}m hdg_err={math.degrees(heading_err):+.1f}°")
+ else:
+ # --- Path-based auto mode: follow path with UWB-based control ---
+ while path_s < path_dists[-1]:
+ path_s += args.target_speed * dt
+ tx, ty = _sample_path(path_dists, path_xs, path_ys, path_s)
+ marker_pub.Write(Pose_(Point_(tx, ty, 0.9), Quaternion_(0, 0, 0, 1)))
+
+ with uwb_lock:
+ tag_dist = uwb["dist"]
+ ryaw = uwb["yaw"]
+
+ az_r = uwb["az"]
+ cos_y, sin_y = math.cos(ryaw), math.sin(ryaw)
+ lx = tag_dist * math.cos(az_r)
+ ly = tag_dist * math.sin(az_r)
+ rx = tx - (cos_y * lx - sin_y * ly)
+ ry = ty - (sin_y * lx + cos_y * ly)
+
+ robot_s = _closest_path_s(path_dists, path_xs, path_ys, rx, ry)
+ tangent = _path_tangent(path_dists, path_xs, path_ys, robot_s)
+ cx, cy = _sample_path(path_dists, path_xs, path_ys, robot_s)
+
+ dx, dy = rx - cx, ry - cy
+ lateral_err = -math.sin(tangent) * dx + math.cos(tangent) * dy
+ heading_err = (tangent - ryaw + math.pi) % (2 * math.pi) - math.pi
+
+ vx = 0.0 if tag_dist < MIN_DIST else args.v_forward * min(1.0, (tag_dist - MIN_DIST) / MIN_DIST)
+ vy = max(-0.3, min(0.3, -args.lateral_gain * lateral_err))
+ vyaw = max(-2.5, min(2.5, args.yaw_gain * heading_err))
+
+ client.Move(vx, vy, vyaw)
+ sleep(dt)
+ _poll_heightmap()
+ step_count += 1
+
+ client.StopMove()
+ print(f"\nReached end of road after {step_count} steps.")
+
+ except KeyboardInterrupt:
+ print("\nInterrupted.")
+ finally:
+ if recorder is not None:
+ recorder.stop()
+ _stop(procs)
+ # Clean up temp scene file
+ if os.path.exists(scene_path):
+ os.remove(scene_path)
+ print(f"Cleaned up {scene_path}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/go2_wtw_demo.py b/go2_wtw_demo.py
index 1b890aa..b9e41c8 100644
--- a/go2_wtw_demo.py
+++ b/go2_wtw_demo.py
@@ -16,14 +16,7 @@
import time
import threading
import subprocess
-from utils import get_python_executable, sim_sleep, last_sim_time, FrontCameraRecorder
-
-_HERE = os.path.dirname(os.path.abspath(__file__))
-_SIM_DIR = os.path.join(_HERE, "src", "unitree_mujoco", "simulate_python")
-_SDK_DIR = os.path.join(_HERE, "src", "unitree_sdk2_python")
-
-sys.path.insert(0, _SDK_DIR)
-sys.path.insert(0, _SIM_DIR)
+from utils import get_python_executable, sim_sleep, FrontCameraRecorder
def _drain(proc, events):
@@ -52,7 +45,7 @@ def main():
import argparse
from unitree_sdk2py.core.channel import ChannelFactoryInitialize
from unitree_sdk2py.go2.sport.sport_client import SportClient
- import config
+ from unitree_mujoco import config
parser = argparse.ArgumentParser(description="Go2 Walk-These-Ways Demo")
parser.add_argument("--cycles", type=int, default=1, help="Number of square-path cycles")
@@ -67,18 +60,23 @@ def main():
help="Save spectator-view recording (passed to sport_mujoco.py)")
parser.add_argument("--record-front", metavar="PATH", default=None,
help="Save front-camera recording to PATH (e.g. front.mp4)")
+ parser.add_argument("--heightmap", action="store_true",
+ help="Enable HeightMap_ DDS publishing in the sim")
+ parser.add_argument("--heightmap-debug", action="store_true",
+ help="Visualise height map rays in the viewer (implies --heightmap)")
parser.add_argument("--v-forward", type=float, default=0.4, help="Forward velocity (m/s)")
parser.add_argument("--v-lateral", type=float, default=0.0, help="Lateral velocity (m/s)")
parser.add_argument("--rotation-speed", type=float, default=2.5, help="Rotation speed (rad/s)")
args = parser.parse_args()
- env = {**os.environ, "PYTHONUNBUFFERED": "1", "PYTHONPATH": _SDK_DIR}
+ env = {**os.environ, "PYTHONUNBUFFERED": "1"}
procs = []
recorder = None
try:
# --- sport_mujoco.py: unified sim + WTW + RPC server in one process ---
- sim_cmd = [get_python_executable(), "-u", os.path.join(_SIM_DIR, "sport_mujoco.py"),
+ _sport_mujoco = os.path.join(os.path.dirname(sys.executable), "sport-mujoco")
+ sim_cmd = [get_python_executable(), _sport_mujoco,
"--interface", args.interface, "--domain", str(args.domain)]
if args.headless:
sim_cmd.append("--headless")
@@ -88,9 +86,13 @@ def main():
sim_cmd += ["--record", os.path.abspath(args.record)]
if args.telemetry:
sim_cmd += ["--telemetry", os.path.abspath(args.telemetry)]
+ if args.heightmap:
+ sim_cmd.append("--heightmap")
+ if args.heightmap_debug:
+ sim_cmd.append("--heightmap-debug")
sim_proc = subprocess.Popen(
- sim_cmd, cwd=_SIM_DIR,
+ sim_cmd,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
text=True, env=env,
)
@@ -114,6 +116,28 @@ def main():
client.SetTimeout(10.0)
client.Init()
+ # --- HeightMap subscriber (verification) ---
+ if args.heightmap:
+ from unitree_sdk2py.core.channel import ChannelSubscriber
+ from unitree_sdk2py.idl.unitree_go.msg.dds_ import HeightMap_
+
+ def _on_heightmap(msg):
+ import numpy as _np
+ arr = _np.array(msg.data, dtype=_np.float32)
+ filled = arr[arr < 1.0e9]
+ if len(filled) > 0:
+ print(f"[heightmap] t={msg.stamp:.2f} {msg.width}x{msg.height} "
+ f"origin=({msg.origin[0]:.2f},{msg.origin[1]:.2f}) "
+ f"cells={len(filled)}/{len(arr)} "
+ f"h=[{filled.min():.3f}, {filled.max():.3f}] "
+ f"avg={filled.mean():.3f} std={filled.std():.3f} "
+ f"median={_np.median(filled):.3f} "
+ f"p95={_np.percentile(filled, 95):.3f} "
+ f"p99={_np.percentile(filled, 99):.3f}")
+
+ hmap_sub = ChannelSubscriber("rt/utlidar/height_map_array", HeightMap_)
+ hmap_sub.Init(_on_heightmap, 10)
+
telemetry_path = os.path.abspath(args.telemetry) if args.telemetry else None
sleep = (lambda dt: sim_sleep(dt, telemetry_path)) if telemetry_path else time.sleep
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..8ec69a4
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,36 @@
+[project]
+name = "go2-mujoco-artefacts"
+version = "0.1.0"
+description = "Add your description here"
+readme = "README.md"
+requires-python = ">=3.11"
+dependencies = [
+ "artefacts-cli>=0.15.0",
+ "mediapy>=1.2.6",
+ "mujoco>=3.6.0",
+ "numpy>=2.4.3",
+ "opencv-python>=4.13.0.92",
+ "opensimplex>=0.4.5.1",
+ "perlin-noise==1.14",
+ "pygame>=2.6.1",
+ "pyglm>=2.8.3",
+ "pytest>=9.0.2",
+ "rerun-sdk>=0.31.1",
+ "torch>=2.11.0",
+ "unitree-mujoco",
+ "unitree-sdk2py",
+]
+
+[tool.uv.sources]
+# Pending PR in the official repo: https://github.com/unitreerobotics/unitree_sdk2_python/pull/107
+unitree-sdk2py = { git = "https://github.com/carlosdp/unitree_sdk2_python.git", branch = "proper-python-modules" }
+
+# NOTE: Update the revision (commit hash) if necessary before committing
+unitree-mujoco = { git = "https://github.com/art-e-fact/unitree_mujoco.git", rev = "0bd9111e4f3e1ee83e779f7c6de5c6b609414d2b" }
+# For development, install the local clone in editable mode:
+# unitree-mujoco = { path = "src/unitree_mujoco", editable = true }
+
+
+[tool.ty.environment]
+python = ".venv"
+
diff --git a/rail_gen.py b/rail_gen.py
new file mode 100644
index 0000000..a925d6c
--- /dev/null
+++ b/rail_gen.py
@@ -0,0 +1,859 @@
+from dataclasses import dataclass, field
+import math
+import os
+import numpy as np
+from pyglm import glm
+from xml.etree.ElementTree import Element, SubElement, tostring, indent
+
+
+@dataclass
+class RailSpec:
+ """Cross-section specification for a rail type.
+
+ Attributes:
+ name: Human-readable rail standard name.
+ profile: Cross-section vertices as (x_mm, y_mm) pairs.
+ gauge: Track gauge in meters.
+ """
+
+ name: str
+ profile: list[tuple[float, float]]
+ gauge: float
+
+
+# Data from: https://www.jfe-steel.co.jp/products/katakou/rail/rail_a.html
+def _make_jis60kg() -> RailSpec:
+ B, C, G, F, E, D, A = 145.0, 65.0, 16.5, 30.1, 94.9, 49.0, 174.0
+ return RailSpec(
+ name="JIS 60kg N",
+ gauge=1.067,
+ profile=[
+ (-B/2, 0), (-B/2, F*0.3), # left base up
+ (-G/2-8, F), (-G/2, F), # left foot taper
+ (-G/2, F+E), (-C/2, F+E), # left web & head
+ (-C/2, A), ( C/2, A), # head top
+ ( C/2, F+E), ( G/2, F+E), # right head & web
+ ( G/2, F), ( G/2+8, F), # right foot taper
+ ( B/2, F*0.3), ( B/2, 0), # right base down
+ ], # closed: (B/2,0) → (-B/2,0) = base bottom
+ ) # fmt: skip
+
+
+JIS_60KG = _make_jis60kg()
+
+
+@dataclass
+class TerrainSpec:
+ """Configuration for terrain heightfield generation."""
+
+ margin: float = 3.0 # extra extent beyond network bounds (m)
+ resolution: float = 0.2 # grid cell size (m)
+ base_depth: float = 0.15 # depth below rail level (m)
+ slope_width: float = 1.2 # lateral distance of slope from rail to base (m)
+ flat_radius: float = 0.8 # half-width of flat zone around centerline (m)
+ noise_amplitude: float = 0.2 # noise height std-dev (m)
+ noise_octaves: int = 20 # Perlin noise octaves (higher = more detail)
+
+
+# --- Geometry helpers ---
+
+
+def _obb_overlap_2d(cx1, cy1, a1, cx2, cy2, a2, hx, hy):
+ """Check if two 2D oriented boxes overlap using the Separating Axis Theorem.
+
+ Both boxes share the same half-extents (*hx*, *hy*).
+ """
+ cos1, sin1 = math.cos(a1), math.sin(a1)
+ cos2, sin2 = math.cos(a2), math.sin(a2)
+ # The 4 separating axes: 2 edge normals per box
+ axes = [(cos1, sin1), (-sin1, cos1), (cos2, sin2), (-sin2, cos2)]
+ dx, dy = cx2 - cx1, cy2 - cy1
+ for ax, ay in axes:
+ # Project the center-to-center vector onto the axis
+ d = abs(ax * dx + ay * dy)
+ # Sum of half-extent projections for both boxes
+ r1 = hx * abs(ax * cos1 + ay * sin1) + hy * abs(-ax * sin1 + ay * cos1)
+ r2 = hx * abs(ax * cos2 + ay * sin2) + hy * abs(-ax * sin2 + ay * cos2)
+ if d > r1 + r2:
+ return False # separating axis found
+ return True # no separating axis → overlap
+
+
+def _perlin_noise_2d(nrow, ncol, octaves, seed):
+ """Generate 2D Perlin noise grid using the perlin-noise package."""
+ from perlin_noise import PerlinNoise
+
+ noise = PerlinNoise(octaves=octaves, seed=seed)
+ return np.array(
+ [[noise([r / nrow, c / ncol]) for c in range(ncol)] for r in range(nrow)]
+ )
+
+
+# --- Network ---
+
+
+class RailNetwork:
+ """A network of railroad tracks stored as polylines with heading.
+
+ Each road is a densely-sampled sequence of ``(x, y, heading_deg)`` tuples
+ produced by :class:`RailNetworkBuilder`. Call :meth:`sample_string` to get
+ poses compatible with the mesh extrusion pipeline.
+
+ Attributes:
+ spec: Rail cross-section specification.
+ roads: List of roads, each a list of ``(x, y, heading_deg)`` tuples.
+ sleeper_spacing: Distance between sleepers along the track (m).
+ sleeper_size: Sleeper dimensions as ``(length, width, height)`` in meters.
+ """
+
+ def __init__(
+ self,
+ spec: RailSpec = JIS_60KG,
+ roads: list | None = None,
+ sleeper_spacing: float = 0.6,
+ sleeper_size: tuple[float, float, float] = (0.2, 1.4, 0.05),
+ ):
+ self.spec = spec
+ self.roads: list[list[tuple[float, float, float]]] = roads or []
+ self.sleeper_spacing = sleeper_spacing
+ self.sleeper_size = sleeper_size
+
+ def sample_string(self, si: int, offset: float = 0.0, resolution: float = 0.1):
+ """Subsample road *si* at *resolution* spacing with lateral *offset*.
+
+ Returns:
+ List of ``(glm.vec3, glm.quat)`` poses along the road.
+ """
+ road = self.roads[si]
+ if len(road) < 2:
+ return []
+
+ pts = np.array([(x, y) for x, y, _ in road])
+ headings = np.array([h for _, _, h in road])
+ diffs = np.diff(pts, axis=0)
+ seg_lens = np.linalg.norm(diffs, axis=1)
+ cumlen = np.concatenate(([0.0], np.cumsum(seg_lens)))
+ total = cumlen[-1]
+ if total < 1e-9:
+ return []
+
+ results = []
+ for d in np.linspace(0, total, max(2, int(math.ceil(total / resolution)) + 1)):
+ idx = int(
+ np.clip(
+ np.searchsorted(cumlen, d, side="right") - 1, 0, len(seg_lens) - 1
+ )
+ )
+ frac = (d - cumlen[idx]) / max(seg_lens[idx], 1e-12)
+
+ x = pts[idx, 0] + frac * diffs[idx, 0]
+ y = pts[idx, 1] + frac * diffs[idx, 1]
+ j = min(idx + 1, len(headings) - 1)
+ h = headings[idx] + frac * (headings[j] - headings[idx])
+
+ h_rad = math.radians(h)
+ x -= offset * math.sin(h_rad)
+ y += offset * math.cos(h_rad)
+
+ results.append(
+ (
+ glm.vec3(x, y, 0.0),
+ glm.angleAxis(h_rad, glm.vec3(0, 0, 1)),
+ )
+ )
+ return results
+
+ def sample_sleepers(
+ self, rng: np.random.Generator | None = None
+ ) -> list[tuple[glm.vec3, glm.quat]]:
+ """Sample sleeper poses across all roads, eliminating overlaps.
+
+ Sleepers from all roads are pooled. Pairs whose oriented bounding
+ boxes overlap are found and one is randomly discarded, repeated
+ until no overlaps remain.
+ """
+ sl, sw, _ = self.sleeper_size
+ hx, hy = sl / 2, sw / 2 # half-extents along local axes
+
+ poses = []
+ for si in range(len(self.roads)):
+ poses.extend(self.sample_string(si, resolution=self.sleeper_spacing))
+ if len(poses) < 2:
+ return poses
+
+ rng = rng or np.random.default_rng()
+
+ # Extract center (x, y) and heading for OBB tests
+ cx = np.array([p.x for p, _ in poses])
+ cy = np.array([p.y for p, _ in poses])
+ angles = np.array(
+ [
+ math.atan2(2 * (q.w * q.z + q.x * q.y), 1 - 2 * (q.y * q.y + q.z * q.z))
+ for _, q in poses
+ ]
+ )
+
+ def _find_overlap(cx, cy, angles):
+ """Return index pair (i, j) of first OBB overlap, or None."""
+ n = len(cx)
+ # Quick squared-distance pre-filter (diagonal of box as radius)
+ r = math.hypot(hx, hy)
+ dx = cx[:, None] - cx[None, :]
+ dy = cy[:, None] - cy[None, :]
+ dist2 = dx * dx + dy * dy
+ np.fill_diagonal(dist2, np.inf)
+ candidates = np.argwhere(dist2 < (2 * r) ** 2)
+
+ for idx in range(len(candidates)):
+ i, j = int(candidates[idx, 0]), int(candidates[idx, 1])
+ if i >= j:
+ continue
+ # 2D SAT with 4 axes (2 edge normals per box)
+ if _obb_overlap_2d(
+ cx[i], cy[i], angles[i], cx[j], cy[j], angles[j], hx, hy
+ ):
+ return i, j
+ return None
+
+ while len(poses) >= 2:
+ pair = _find_overlap(cx, cy, angles)
+ if pair is None:
+ break
+ drop = rng.choice(pair)
+ poses.pop(drop)
+ cx = np.delete(cx, drop)
+ cy = np.delete(cy, drop)
+ angles = np.delete(angles, drop)
+
+ return poses
+
+ def generate_terrain(
+ self, spec: TerrainSpec, rng: np.random.Generator | None = None
+ ):
+ """Generate terrain elevation grid: flat under rails, sloping to base depth."""
+ all_pts = np.vstack([np.array([(x, y) for x, y, _ in r]) for r in self.roads])
+ xmin, ymin = all_pts.min(axis=0) - spec.margin
+ xmax, ymax = all_pts.max(axis=0) + spec.margin
+ nrow = int(math.ceil((ymax - ymin) / spec.resolution)) + 1
+ ncol = int(math.ceil((xmax - xmin) / spec.resolution)) + 1
+ xs = np.linspace(xmin, xmax, ncol)
+ ys = np.linspace(ymin, ymax, nrow)
+ gx, gy = np.meshgrid(xs, ys)
+
+ # Min distance to any road centerline (per row for memory efficiency)
+ dist = np.empty((nrow, ncol))
+ for r in range(nrow):
+ row_xy = np.column_stack([gx[r], gy[r]])
+ d2 = ((row_xy[:, None, :] - all_pts[None, :, :]) ** 2).sum(axis=2)
+ dist[r] = np.sqrt(d2.min(axis=1))
+
+ # Smoothstep profile: flat near rails → slope → base
+ t = np.clip((dist - spec.flat_radius) / spec.slope_width, 0, 1)
+ t = t * t * (3 - 2 * t)
+ elevation = -spec.base_depth * t
+
+ # Add noise (masked by t so it's zero under the rails)
+ rng = rng or np.random.default_rng()
+ seed = int(rng.integers(0, 2**31))
+ elevation += (
+ _perlin_noise_2d(nrow, ncol, spec.noise_octaves, seed)
+ * spec.noise_amplitude
+ * t
+ )
+
+ return elevation, (xmin, xmax, ymin, ymax)
+
+
+@dataclass
+class RailNetworkBuilder:
+ """Builds a :class:`RailNetwork` by growing roads step-by-step.
+
+ Each road is a polyline whose heading evolves via a smoothly-varying
+ curvature that randomly changes target. New roads branch off random points
+ on existing roads and are discarded if they collide after a diverge zone.
+
+ Attributes:
+ step_size: Distance between consecutive road points (m).
+ max_turn: Maximum curvature magnitude (deg/step).
+ heading_change_speed: Rate at which curvature approaches its target (deg/m).
+ change_turn_prob: Probability of picking a new target curvature each step.
+ clearance: Minimum distance between non-adjacent roads (m).
+ diverge_steps: Steps at a branch start exempt from collision checks.
+ branch_margin: Minimum distance from the end of a road for branch points (m).
+ min_road_length: Minimum road length (m).
+ max_road_length: Maximum road length (m).
+ rail_spec: Rail specification for the output network.
+ sleeper_spacing: Distance between sleepers along the track (m).
+ sleeper_size: Sleeper dimensions as ``(length, width, height)`` in meters.
+ """
+
+ step_size: float = 0.1
+ max_turn: float = 2.0
+ heading_change_speed: float = 3.0
+ change_turn_prob: float = 0.01
+ clearance: float = 2.0
+ diverge_steps: int = 50
+ branch_margin: float = 3.0
+ min_road_length: float = 10.0
+ max_road_length: float = 30.0
+ rail_spec: RailSpec = field(default_factory=lambda: JIS_60KG)
+ sleeper_spacing: float = 0.6
+ sleeper_size: tuple[float, float, float] = (0.2, 1.4, 0.07)
+
+ def _grow_road(self, rng, start, n_steps):
+ """Grow a single road from *start* ``(x, y, heading_deg)`` for *n_steps*."""
+ x, y, h = start
+ curvature = 0.0
+ target = 0.0
+ road = [(x, y, h)]
+ max_change = self.heading_change_speed * self.step_size
+
+ for _ in range(n_steps):
+ if rng.random() < self.change_turn_prob:
+ target = float(rng.uniform(-self.max_turn, self.max_turn))
+ curvature += max(-max_change, min(max_change, target - curvature))
+ h += curvature
+ x += self.step_size * math.cos(math.radians(h))
+ y += self.step_size * math.sin(math.radians(h))
+ road.append((x, y, h))
+ return road
+
+ def _intersects(self, new_road, roads, skip_start=0):
+ """Check whether *new_road* collides with *roads* after *skip_start* points."""
+ if not roads:
+ return False
+ all_pts = np.vstack([np.array([(x, y) for x, y, _ in r]) for r in roads])
+ cl2 = self.clearance**2
+ for x, y, _ in new_road[skip_start:]:
+ if np.min(np.sum((all_pts - [x, y]) ** 2, axis=1)) < cl2:
+ return True
+ return False
+
+ def build(
+ self, rng: np.random.Generator, n_roads: int = 5, max_trials: int = 1000
+ ) -> RailNetwork:
+ """Build a network of *n_roads* roads.
+
+ The first road starts at the origin. Subsequent roads branch off a
+ random point on an existing road. Branches that collide with other
+ roads (outside the diverge zone) are discarded and retried.
+ """
+ roads: list[list[tuple[float, float, float]]] = []
+
+ for ri in range(n_roads):
+ for _ in range(max_trials):
+ length = float(rng.uniform(self.min_road_length, self.max_road_length))
+ n_steps = int(length / self.step_size)
+
+ if not roads:
+ start = (0.0, 0.0, float(rng.uniform(0, 360)))
+ skip = 0
+ else:
+ parent = roads[int(rng.integers(len(roads)))]
+ margin_steps = int(self.branch_margin / self.step_size)
+ max_idx = max(0, len(parent) - 1 - margin_steps)
+ start = parent[int(rng.integers(max_idx + 1))]
+ skip = self.diverge_steps
+
+ road = self._grow_road(rng, start, n_steps)
+ if not self._intersects(road, roads, skip_start=skip):
+ roads.append(road)
+ break
+ else:
+ print(f"Warning: could not place road {ri} after {max_trials} trials")
+
+ # Center the network at the origin
+ if roads:
+ all_pts = np.vstack([np.array([(x, y) for x, y, _ in r]) for r in roads])
+ center = (all_pts.min(axis=0) + all_pts.max(axis=0)) / 2
+ roads = [
+ [(x - center[0], y - center[1], h) for x, y, h in r] for r in roads
+ ]
+
+ return RailNetwork(
+ spec=self.rail_spec,
+ roads=roads,
+ sleeper_spacing=self.sleeper_spacing,
+ sleeper_size=self.sleeper_size,
+ )
+
+
+# --- Mesh ---
+
+
+@dataclass
+class MeshData:
+ """Triangle mesh with per-vertex normals.
+
+ Attributes:
+ vertices: Vertex positions, shape ``(V, 3)``.
+ faces: Triangle indices, shape ``(F, 3)``.
+ normals: Per-vertex normals, shape ``(V, 3)``.
+ """
+
+ vertices: np.ndarray
+ faces: np.ndarray
+ normals: np.ndarray
+
+
+def _compute_vertex_normals(verts, faces):
+ normals = np.zeros_like(verts)
+ v0, v1, v2 = verts[faces[:, 0]], verts[faces[:, 1]], verts[faces[:, 2]]
+ face_normals = np.cross(v1 - v0, v2 - v0)
+ for i in range(3):
+ np.add.at(normals, faces[:, i], face_normals)
+ lens = np.linalg.norm(normals, axis=1, keepdims=True)
+ return normals / np.where(lens < 1e-12, 1.0, lens)
+
+
+def _extrude_profile(profile_mm, samples, z_offset: float = 0.0):
+ """Extrude a 2D cross-section along sampled curve frames."""
+ n_prof = len(profile_mm)
+ n_samp = len(samples)
+ profile = np.array(profile_mm, dtype=float) * 0.001
+
+ verts = np.zeros((n_samp * n_prof, 3))
+ for si, (pos, quat) in enumerate(samples):
+ left = quat * glm.vec3(0, 1, 0)
+ for pi, (px, py) in enumerate(profile):
+ v = pos + float(px) * left + float(py + z_offset) * glm.vec3(0, 0, 1)
+ verts[si * n_prof + pi] = [v.x, v.y, v.z]
+
+ # Append centroid vertices for end caps (non-convex profile needs centroid fan)
+ sc_idx = n_samp * n_prof
+ ec_idx = n_samp * n_prof + 1
+ verts = np.vstack([
+ verts,
+ verts[:n_prof].mean(axis=0, keepdims=True),
+ verts[(n_samp - 1) * n_prof:].mean(axis=0, keepdims=True),
+ ])
+
+ faces = []
+ for si in range(n_samp - 1):
+ for pi in range(n_prof):
+ pn = (pi + 1) % n_prof
+ a, b = si * n_prof + pi, si * n_prof + pn
+ c, d = (si + 1) * n_prof + pn, (si + 1) * n_prof + pi
+ faces.extend([[a, d, c], [a, c, b]])
+
+ # Start cap: [sc, pi, pn] → normal in -forward direction
+ for pi in range(n_prof):
+ pn = (pi + 1) % n_prof
+ faces.append([sc_idx, pi, pn])
+
+ # End cap: reversed winding → normal in +forward direction
+ base = (n_samp - 1) * n_prof
+ for pi in range(n_prof):
+ pn = (pi + 1) % n_prof
+ faces.append([ec_idx, base + pn, base + pi])
+
+ faces = np.array(faces, dtype=int)
+ return MeshData(verts, faces, _compute_vertex_normals(verts, faces))
+
+
+# --- Output ---
+
+
+def generate_mujoco_xml(
+ net: RailNetwork, resolution: float = 0.2, terrain: TerrainSpec | None = None
+) -> str:
+ """Generate MuJoCo MJCF XML with inline rail meshes and optional terrain hfield."""
+ spec = net.spec
+ root = Element("mujoco", model="rail_network")
+ asset = SubElement(root, "asset")
+ SubElement(
+ asset,
+ "material",
+ name="mat_rail",
+ rgba="0.55 0.55 0.6 1",
+ specular="0.8",
+ shininess="0.9",
+ )
+ SubElement(
+ asset,
+ "material",
+ name="mat_sleeper",
+ rgba="0.4 0.28 0.18 1",
+ specular="0.2",
+ shininess="0.1",
+ )
+ worldbody = SubElement(root, "worldbody")
+ half_g = spec.gauge / 2.0
+ _sl, _sw, sh = net.sleeper_size
+ rail_z = sh / 2 # rails start at top of (half-buried) sleepers
+
+ # Rail cross-section half-extents for box colliders (mm → m)
+ rail_hw = spec.profile[0][0] * 0.001 # half base width (negative, take abs)
+ rail_hw = abs(rail_hw)
+ rail_hh = max(py for _, py in spec.profile) * 0.001 / 2 # half height
+
+ for si in range(len(net.roads)):
+ for tag, off in [("L", half_g), ("R", -half_g)]:
+ samples = net.sample_string(si, offset=off, resolution=resolution)
+ if len(samples) < 2:
+ continue
+ mesh = _extrude_profile(spec.profile, samples, z_offset=rail_z)
+ name = f"rail_s{si}_{tag}"
+ SubElement(
+ asset,
+ "mesh",
+ name=name,
+ vertex=" ".join(f"{v:.6f}" for v in mesh.vertices.ravel()),
+ face=" ".join(str(i) for i in mesh.faces.ravel()),
+ )
+ # Visual only — convex-hull collision won't work for long rails
+ SubElement(
+ worldbody,
+ "geom",
+ name=name,
+ type="mesh",
+ mesh=name,
+ material="mat_rail",
+ contype="0",
+ conaffinity="0",
+ )
+ # Box colliders along the rail for accurate collision
+ for ji in range(len(samples) - 1):
+ p0, _ = samples[ji]
+ p1, _ = samples[ji + 1]
+ mx = (p0.x + p1.x) / 2
+ my = (p0.y + p1.y) / 2
+ mz = rail_z + rail_hh
+ dx, dy = p1.x - p0.x, p1.y - p0.y
+ half_len = math.sqrt(dx * dx + dy * dy) / 2
+ yaw = math.atan2(dy, dx)
+ SubElement(
+ worldbody,
+ "geom",
+ name=f"{name}_col{ji}",
+ type="box",
+ size=f"{half_len:.4f} {rail_hw:.4f} {rail_hh:.4f}",
+ pos=f"{mx:.4f} {my:.4f} {mz:.4f}",
+ euler=f"0 0 {yaw:.6f}",
+ contype="1",
+ conaffinity="1",
+ group="3",
+ )
+
+ # Sleepers (deduplicated across all roads) — sunk halfway into ground
+ sl, sw, _sh = net.sleeper_size
+ for ti, (pos, quat) in enumerate(net.sample_sleepers()):
+ fwd = quat * glm.vec3(1, 0, 0)
+ yaw = math.atan2(float(fwd.y), float(fwd.x))
+ SubElement(
+ worldbody,
+ "geom",
+ name=f"sleeper_{ti}",
+ type="box",
+ size=f"{sl / 2:.4f} {sw / 2:.4f} {sh / 2:.4f}",
+ pos=f"{pos.x:.4f} {pos.y:.4f} {0:.4f}",
+ euler=f"0 0 {yaw:.6f}",
+ material="mat_sleeper",
+ contype="1",
+ conaffinity="1",
+ )
+
+ # Terrain heightfield
+ if terrain is not None:
+ elevation, (xmin, xmax, ymin, ymax) = net.generate_terrain(terrain)
+ nrow, ncol = elevation.shape
+ e_min, e_max = float(elevation.min()), float(elevation.max())
+ e_range = max(e_max - e_min, 1e-6)
+ rx, ry = (xmax - xmin) / 2, (ymax - ymin) / 2
+ cx, cy = (xmin + xmax) / 2, (ymin + ymax) / 2
+ SubElement(
+ asset,
+ "material",
+ name="mat_terrain",
+ rgba="0.45 0.38 0.28 1",
+ specular="0.1",
+ shininess="0.1",
+ )
+ SubElement(
+ asset,
+ "hfield",
+ name="terrain",
+ nrow=str(nrow),
+ ncol=str(ncol),
+ size=f"{rx:.4f} {ry:.4f} {e_range:.4f} 0.5",
+ elevation=" ".join(f"{v:.4f}" for v in elevation[::-1].ravel()),
+ )
+ SubElement(
+ worldbody,
+ "geom",
+ name="terrain",
+ type="hfield",
+ hfield="terrain",
+ pos=f"{cx:.4f} {cy:.4f} {e_min:.4f}",
+ material="mat_terrain",
+ contype="1",
+ conaffinity="1",
+ )
+
+ indent(root, space=" ")
+ return tostring(root, encoding="unicode")
+
+
+def log_network(net: RailNetwork, terrain: TerrainSpec | None = None):
+ """Log the rail network to Rerun: centerlines, meshes, and optional terrain."""
+ import rerun as rr
+
+ spec = net.spec
+ half_g = spec.gauge / 2.0
+ sl, sw, sh = net.sleeper_size
+
+ for si in range(len(net.roads)):
+ samples = net.sample_string(si, resolution=0.2)
+ if samples:
+ pts = np.array([[p.x, p.y, p.z] for p, _ in samples])
+ rr.log(
+ f"network/{si}/center",
+ rr.LineStrips3D([pts], colors=[80, 80, 80], radii=0.01),
+ static=True,
+ )
+
+ for tag, off in [("L", half_g), ("R", -half_g)]:
+ samples = net.sample_string(si, offset=off, resolution=0.2)
+ if len(samples) < 2:
+ continue
+ mesh = _extrude_profile(spec.profile, samples, z_offset=sh / 2)
+ rr.log(
+ f"network/{si}/mesh_{tag}",
+ rr.Mesh3D(
+ vertex_positions=mesh.vertices,
+ triangle_indices=mesh.faces,
+ vertex_normals=mesh.normals,
+ vertex_colors=[140, 140, 155],
+ ),
+ static=True,
+ )
+
+ # Sleepers (deduplicated across all roads) — sunk halfway into ground
+ sleepers = net.sample_sleepers()
+ if sleepers:
+ centers = []
+ half_sizes = []
+ rotations = []
+ for pos, quat in sleepers:
+ fwd = quat * glm.vec3(1, 0, 0)
+ yaw = math.atan2(float(fwd.y), float(fwd.x))
+ centers.append([pos.x, pos.y, 0.0])
+ half_sizes.append([sl / 2, sw / 2, sh / 2])
+ q = glm.angleAxis(yaw, glm.vec3(0, 0, 1))
+ rotations.append(rr.Quaternion(xyzw=[q.x, q.y, q.z, q.w]))
+ rr.log(
+ "network/sleepers",
+ rr.Boxes3D(
+ centers=centers,
+ half_sizes=half_sizes,
+ colors=[[100, 70, 45]],
+ quaternions=rotations,
+ fill_mode="solid",
+ ),
+ static=True,
+ )
+
+ # Terrain mesh
+ if terrain is not None:
+ elevation, (xmin, xmax, ymin, ymax) = net.generate_terrain(terrain)
+ nrow, ncol = elevation.shape
+ xs = np.linspace(xmin, xmax, ncol)
+ ys = np.linspace(ymin, ymax, nrow)
+ gx, gy = np.meshgrid(xs, ys)
+ verts = np.column_stack([gx.ravel(), gy.ravel(), elevation.ravel()])
+ i = (np.arange(nrow - 1)[:, None] * ncol + np.arange(ncol - 1)[None, :]).ravel()
+ faces = np.vstack(
+ [
+ np.column_stack([i, i + ncol, i + 1]),
+ np.column_stack([i + 1, i + ncol, i + ncol + 1]),
+ ]
+ )
+ rr.log(
+ "network/terrain",
+ rr.Mesh3D(
+ vertex_positions=verts,
+ triangle_indices=faces,
+ vertex_normals=_compute_vertex_normals(verts, faces),
+ vertex_colors=[120, 100, 80],
+ ),
+ static=True,
+ )
+
+
+class RailwayScene:
+ """A complete rail scene: network + terrain, with MuJoCo and Rerun output."""
+
+ def __init__(self, net: RailNetwork, terrain: TerrainSpec | None = None):
+ self.net = net
+ self.terrain = terrain
+
+ _TERRAIN_DEFAULT = object() # sentinel: distinguish 'not passed' from None
+
+ @classmethod
+ def build(
+ cls,
+ rng: np.random.Generator,
+ n_roads: int = 5,
+ terrain: TerrainSpec | None | object = _TERRAIN_DEFAULT,
+ **builder_kwargs,
+ ) -> "RailwayScene":
+ if terrain is cls._TERRAIN_DEFAULT:
+ terrain = TerrainSpec()
+ net = RailNetworkBuilder(**builder_kwargs).build(rng, n_roads=n_roads)
+ return cls(net, terrain)
+
+ def log_rerun(self):
+ log_network(self.net, terrain=self.terrain)
+
+ def save_mujoco_scene(
+ self, project_root: str, start_pos: tuple[float, float, float] | None = None
+ ) -> str:
+ """Write a complete MuJoCo scene XML to a temp file and return its path.
+
+ The file includes the Go2 robot, lighting, and all rail/terrain
+ geometry. The caller is responsible for deleting the file when done.
+
+ Args:
+ project_root: Absolute path to the project root (for go2.xml include).
+ start_pos: Optional (x, y, heading_rad) to place the robot.
+ """
+ import tempfile
+ from xml.etree.ElementTree import parse as ET_parse
+ from io import StringIO
+
+ go2_xml = os.path.join(project_root, "resources", "go2.xml")
+
+ # Build the base scene (mirroring resources/scene_flat.xml)
+ root = Element("mujoco", model="go2 rail scene")
+ SubElement(root, "include", file=go2_xml)
+ SubElement(root, "statistic", center="0 0 0.1", extent="0.8")
+
+ vis = SubElement(root, "visual")
+ SubElement(
+ vis,
+ "headlight",
+ diffuse="0.6 0.6 0.6",
+ ambient="0.3 0.3 0.3",
+ specular="0 0 0",
+ )
+ SubElement(vis, "rgba", haze="0.15 0.25 0.35 1")
+ SubElement(
+ vis,
+ "global",
+ azimuth="-130",
+ elevation="-20",
+ offwidth="1280",
+ offheight="720",
+ )
+ SubElement(vis, "map", zfar="200")
+
+ asset = SubElement(root, "asset")
+ SubElement(
+ asset,
+ "texture",
+ type="skybox",
+ builtin="gradient",
+ rgb1="0.3 0.5 0.7",
+ rgb2="0 0 0",
+ width="512",
+ height="3072",
+ )
+ SubElement(
+ asset,
+ "texture",
+ type="2d",
+ name="groundplane",
+ builtin="checker",
+ mark="edge",
+ rgb1="0.2 0.3 0.4",
+ rgb2="0.1 0.2 0.3",
+ markrgb="0.8 0.8 0.8",
+ width="300",
+ height="300",
+ )
+ SubElement(
+ asset,
+ "material",
+ name="groundplane",
+ texture="groundplane",
+ texuniform="true",
+ texrepeat="5 5",
+ reflectance="0.2",
+ )
+
+ worldbody = SubElement(root, "worldbody")
+ SubElement(worldbody, "light", pos="0 0 1.5", dir="0 0 -1", directional="true")
+ SubElement(
+ worldbody,
+ "camera",
+ name="spectator",
+ pos="0 -3 1.5",
+ xyaxes="1 0 0 0 0.447 0.894",
+ )
+
+ # Override the robot start position if requested
+ if start_pos is not None:
+ sx, sy, syaw = start_pos
+ cw, sw_ = math.cos(syaw / 2), math.sin(syaw / 2)
+ qpos = (
+ f"{sx} {sy} 0.27 {cw} 0 0 {sw_} "
+ "0 0.9 -1.8 0 0.9 -1.8 0 0.9 -1.8 0 0.9 -1.8"
+ )
+ ctrl = "0 0.9 -1.8 0 0.9 -1.8 0 0.9 -1.8 0 0.9 -1.8"
+ kf = SubElement(root, "keyframe")
+ SubElement(kf, "key", name="rail_start", qpos=qpos, ctrl=ctrl)
+
+ # Human-marker mocap body (visible + collidable for lidar)
+ marker = SubElement(worldbody, "body", name="uwb_tag", mocap="true")
+ SubElement(
+ marker,
+ "geom",
+ type="cylinder",
+ size="0.2 0.9",
+ rgba="1.0 0.5 0.0 0.5",
+ contype="0",
+ conaffinity="0",
+ )
+
+ # Add a flat floor if no terrain heightfield
+ if self.terrain is None:
+ SubElement(
+ worldbody,
+ "geom",
+ name="floor",
+ type="plane",
+ size="0 0 0.05",
+ material="groundplane",
+ )
+
+ # Merge rail/sleeper/terrain from generate_mujoco_xml()
+ rail_xml = generate_mujoco_xml(self.net, terrain=self.terrain)
+ rail_root = ET_parse(StringIO(rail_xml)).getroot()
+ rail_asset = rail_root.find("asset")
+ rail_wb = rail_root.find("worldbody")
+ if rail_asset is not None:
+ for child in rail_asset:
+ asset.append(child)
+ if rail_wb is not None:
+ for child in rail_wb:
+ worldbody.append(child)
+
+ indent(root, space=" ")
+ xml_str = tostring(root, encoding="unicode")
+
+ fd, path = tempfile.mkstemp(suffix=".xml", prefix="rail_scene_")
+ os.write(fd, xml_str.encode())
+ os.close(fd)
+ return path
+
+
+if __name__ == "__main__":
+ import rerun as rr
+
+ rng = np.random.default_rng()
+ rr.init("rail_network", spawn=True)
+ builder = RailNetworkBuilder()
+ net = builder.build(rng, n_roads=5)
+ terrain = TerrainSpec()
+ print(f"Roads: {len(net.roads)}, points: {sum(len(r) for r in net.roads)}")
+ log_network(net, terrain=terrain)
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index d4d286c..0000000
--- a/requirements.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-mujoco>=3.5.0
-numpy>=2.0.0
-pytest>=8.0.0
-torch
-mediapy
-artefacts-cli
-pygame
-opencv-python
diff --git a/resources/go2.xml b/resources/go2.xml
new file mode 100644
index 0000000..aa99ad0
--- /dev/null
+++ b/resources/go2.xml
@@ -0,0 +1,291 @@
+