diff --git a/README_POC.md b/README_POC.md new file mode 100644 index 000000000..69b7334cd --- /dev/null +++ b/README_POC.md @@ -0,0 +1,99 @@ +# PoC: Portable Network State Serialization for Brian2 + +**GSoC 2026 — Serialization/Deserialization for Brian Simulator models, results, and input data** + +--- + +## What this does + +This PoC implements the **core mechanism** of the project: serializing a +Brian2 Network state to a portable `.brian` ZIP archive and restoring it +with bit-identical numerical fidelity — without using Python pickle. + +### The problem + +`CPPStandaloneDevice.network_store()` and `network_restore()` both raise +`NotImplementedError`. The existing `RuntimeDevice` checkpoint path uses +`pickle`, which is Python-version–specific, not human-readable, and +blocks interoperability with NWB/NEO/ML frameworks. + +### The approach + +Two new functions in `brian2/serialization/brian_format.py`: + +| Function | What it does | +|---|---| +| `serialize_network_state(net, path)` | Calls `net._full_state()`, converts arrays + spike queues to a ZIP archive | +| `restore_network_state(net, path)` | Reads the archive and calls `obj._restore_from_full_state()` on each object | + +**Archive layout** (`.brian` = ZIP): +``` +checkpoint.brian +├── metadata.json # format_version, brian2_version, t, +│ # per-variable shapes / dtypes / SI dimension tuples +├── arrays.npz # all ArrayVariable values (compressed NumPy binary) +└── spikequeues.json # SynapticPathway in-flight spike state + # (Cython SpikeQueue offset + spike_lists) +``` + +**Why this unlocks the full project:** + +- For `CPPStandaloneDevice`: after a standalone run, arrays come from + `results/` (read via `get_array_filename(var)`) instead of + live NumPy arrays — the ZIP format is device-agnostic. +- For `BrianExporter` (brian2tools): `model.json` + `arrays.npz` become + the interchange format for structural + state export. +- For `BrianImporter`: the JSON metadata is the reconstruction spec — + object names, variable shapes, and SI dimension tuples (7-element list + of SI exponents from `Dimension._dims`). + +--- + +## Files + +``` +brian2/ +└── serialization/ + ├── __init__.py # exports serialize_network_state, restore_network_state + └── brian_format.py # core implementation (~170 lines) + +examples/ +└── serialization/ + └── brian_format_demo.py # end-to-end round-trip demo +``` + +--- + +## Run the demo + +```bash +# from the repo root, with the brian2 dev environment active +conda activate brian2 # or: source .venv/bin/activate +python examples/serialization/brian_format_demo.py +``` + +Expected output: +``` +Reference: t= 10. ms spikes=0 v[0,-1]=0.037158 + +Checkpoint: written to /tmp/....brian (9.8 KB) + Archive files : ['metadata.json', 'arrays.npz', 'spikequeues.json'] + Serialized t : 5.0 ms + Objects : ['neurons', 'voltage', 'synapses_pre', 'spikes', 'synapses', 'defaultclock'] + +After restore: net.t=5. ms G.v[0]=0.060653 + +Round-trip check: + State variable max error : 0.00e+00 + Spike indices match : True + Spike times max error : 0.00e+00 + +PASS — serialize → restore → run produces bit-identical results. +``` + +The demo: +1. Runs a LIF network (NeuronGroup + Synapses with per-neuron delay + StateMonitor + SpikeMonitor) for 5 ms. +2. Serializes to a `.brian` archive. +3. Resets to t=0 using the existing pickle store/restore. +4. Restores from the `.brian` archive to t=5 ms. +5. Runs 5 ms more and verifies the traces are **numerically identical** to a reference run that ran continuously to 10 ms. diff --git a/brian2/serialization/__init__.py b/brian2/serialization/__init__.py new file mode 100644 index 000000000..b2e3792cd --- /dev/null +++ b/brian2/serialization/__init__.py @@ -0,0 +1,16 @@ +""" +brian2.serialization +~~~~~~~~~~~~~~~~~~~~ + +Portable serialization of Brian2 Network state. + +Provides :func:`serialize_network_state` and :func:`restore_network_state` +which store / load network state as a ``.brian`` ZIP archive (JSON + npz) +instead of Python pickle. + +See :mod:`brian2.serialization.brian_format` for the full specification. +""" + +from .brian_format import restore_network_state, serialize_network_state + +__all__ = ["serialize_network_state", "restore_network_state"] diff --git a/brian2/serialization/brian_format.py b/brian2/serialization/brian_format.py new file mode 100644 index 000000000..792b71a5e --- /dev/null +++ b/brian2/serialization/brian_format.py @@ -0,0 +1,248 @@ +""" +brian2.serialization.brian_format +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Serialize and restore Brian2 Network state as a portable ``.brian`` ZIP +archive (JSON metadata + NumPy ``.npz`` arrays) instead of Python pickle. + +This module demonstrates the core mechanism for: + +1. **Fixing** ``CPPStandaloneDevice.network_store / network_restore`` + (currently raise ``NotImplementedError``): after a standalone run, the + same ZIP format is populated by reading binary files from ``results/`` + via ``get_array_filename(var)`` rather than from live NumPy arrays. + +2. **Replacing** pickle in ``RuntimeDevice`` store/restore with an + interoperable format that is not tied to a Python version or platform. + +Archive layout +-------------- +:: + + checkpoint.brian (ZIP, deflate) + ├── metadata.json # format_version, brian2_version, t, per-object + │ # variable shapes / dtypes / SI dimension tuples + ├── arrays.npz # flat dict "ObjName__varname" → ndarray + └── spikequeues.json # SynapticPathway _spikequeue states + # (offset + spike_lists, JSON-serialisable) + +Why not pickle? +--------------- +Pickle is Python-version–specific, not human-readable, and blocks +interoperability with tools such as NWB, NEO, or ML frameworks. A ZIP of +JSON + npz is readable in any language, diff-able in git, and extensible +without breaking existing archives. +""" + +import io +import json +import zipfile + +import numpy as np + +import brian2 + +_METADATA_FILE = "metadata.json" +_ARRAYS_FILE = "arrays.npz" +_QUEUES_FILE = "spikequeues.json" +_FORMAT_VERSION = "1" + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def serialize_network_state(net, filepath): + """ + Serialize the state of *net* to a ``.brian`` ZIP archive at *filepath*. + + Parameters + ---------- + net : brian2.Network + The network whose state to serialize. Must have been run (or at + least initialized) so that all ``ArrayVariable`` objects hold + concrete values. + filepath : str or path-like + Destination path, e.g. ``"checkpoint.brian"``. + + Notes + ----- + Internally this calls ``net._full_state()``, the same method used by + the existing ``Network.store()`` / ``Network.restore()`` pickle path, + so fidelity guarantees are identical. + + For ``CPPStandaloneDevice`` the same JSON + npz layout is used, but + arrays are read from ``results/`` via + ``CPPStandaloneDevice.get_array_filename(var)`` rather than from live + NumPy arrays — the format is device-agnostic by design. + """ + # net._full_state() walks all objects via _get_all_objects() and calls + # obj._full_state() on each VariableOwner, plus every unique Clock. + # Returns: {obj_name: {var_name: (values_copy, size)}, "0_t": float} + raw_state = net._full_state() + + t = float(raw_state.pop("0_t")) + + metadata = { + "format_version": _FORMAT_VERSION, + "brian2_version": brian2.__version__, + "t": t, + "objects": {}, + "spikequeue_objects": [], + } + + arrays = {} # flat key "ObjName__varname" → ndarray + spikequeues = {} # obj_name → encoded queue state + + for obj_name, obj_state in raw_state.items(): + obj_meta = {} + + for var_name, value in obj_state.items(): + if var_name == "_spikequeue": + # SynapticPathway stores queue as (offset, list_of_lists) + # or None when the pathway has no in-flight spikes yet. + encoded = _encode_spikequeue(value) + if encoded is not None: + spikequeues[obj_name] = encoded + metadata["spikequeue_objects"].append(obj_name) + continue + + arr, size = value + arr = np.asarray(arr) + key = f"{obj_name}__{var_name}" + arrays[key] = arr + # size is int for 1-D arrays and a tuple for 2-D arrays + # (e.g. StateMonitor records shape (timesteps, N_neurons)). + size_json = list(size) if isinstance(size, tuple) else int(size) + obj_meta[var_name] = { + "size": size_json, + "dtype": str(arr.dtype), + "shape": list(arr.shape), + # Preserve SI dimension as a 7-tuple of exponents so the + # archive can be read without Brian2 installed (e.g. in + # analysis scripts or by NWB/NEO converters). + "dim": _dim_of_array_key(net, obj_name, var_name), + } + + metadata["objects"][obj_name] = obj_meta + + # Pack into ZIP: JSON text + compressed npz blob + npz_buf = io.BytesIO() + np.savez_compressed(npz_buf, **arrays) + + with zipfile.ZipFile(filepath, "w", compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr(_METADATA_FILE, json.dumps(metadata, indent=2)) + zf.writestr(_ARRAYS_FILE, npz_buf.getvalue()) + zf.writestr(_QUEUES_FILE, json.dumps(spikequeues, indent=2)) + + +def restore_network_state(net, filepath): + """ + Restore the state of *net* from a ``.brian`` ZIP archive at *filepath*. + + The network's objects must already exist and carry the **same names** + as when the archive was created — the same contract as the existing + ``Network.store() / Network.restore()`` mechanism. + + Parameters + ---------- + net : brian2.Network + filepath : str or path-like + """ + with zipfile.ZipFile(filepath, "r") as zf: + metadata = json.loads(zf.read(_METADATA_FILE)) + npz_data = np.load( + io.BytesIO(zf.read(_ARRAYS_FILE)), allow_pickle=False + ) + spikequeues = json.loads(zf.read(_QUEUES_FILE)) + + # Restore the network simulation time directly (mirrors Network.restore) + net.t_ = float(metadata["t"]) + + # Rebuild the state dict that _restore_from_full_state expects: + # {var_name: (ndarray, size)} plus "_spikequeue" for pathways + reconstructed = {} + for obj_name, obj_meta in metadata["objects"].items(): + obj_state = {} + for var_name, var_info in obj_meta.items(): + key = f"{obj_name}__{var_name}" + raw_size = var_info["size"] + # Restore original type: list → tuple, int stays int + size = tuple(raw_size) if isinstance(raw_size, list) else raw_size + obj_state[var_name] = (npz_data[key], size) + if obj_name in spikequeues: + obj_state["_spikequeue"] = _decode_spikequeue(spikequeues[obj_name]) + reconstructed[obj_name] = obj_state + + # Walk all live network objects (same traversal as Network._full_state) + from brian2.core.network import _get_all_objects + + all_objects = _get_all_objects(net.objects) + clocks = {obj.clock for obj in all_objects} + + for obj in list(all_objects) + list(clocks): + if obj.name in reconstructed: + obj._restore_from_full_state(reconstructed[obj.name]) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _encode_spikequeue(queue_state): + """ + Convert a Cython ``SpikeQueue._full_state()`` return value to a + JSON-serializable dict. + + Cython format: ``(offset: int, spike_lists: list[list[int]])`` + """ + if queue_state is None: + return None + offset, spike_lists = queue_state + return { + "offset": int(offset), + "spike_lists": [[int(s) for s in slot] for slot in spike_lists], + } + + +def _decode_spikequeue(encoded): + """ + Reconstruct the Cython ``SpikeQueue`` 2-tuple from a JSON-loaded dict. + + Compatible with ``SynapticPathway._restore_from_full_state`` which + calls ``self.queue._restore_from_full_state(converted_queue_state)``. + """ + if encoded is None: + return None + return (encoded["offset"], [list(slot) for slot in encoded["spike_lists"]]) + + +def _dim_of_array_key(net, obj_name, var_name): + """ + Look up the SI dimension of *var_name* on the object named *obj_name* + and return it as a plain Python list (7 SI exponents). + + Falls back to ``None`` if the variable or its dimension cannot be + resolved — this is metadata only and does not affect restore fidelity. + """ + from brian2.core.network import _get_all_objects + from brian2.core.clocks import Clock + + all_objects = list(_get_all_objects(net.objects)) + clocks = list({obj.clock for obj in all_objects}) + for obj in all_objects + clocks: + if obj.name != obj_name: + continue + if not hasattr(obj, "variables"): + return None + var = obj.variables.get(var_name) + if var is None: + return None + dim = getattr(var, "dim", None) + if dim is None: + return None + try: + return list(dim._dims) + except AttributeError: + return None + return None diff --git a/examples/serialization/brian_format_demo.py b/examples/serialization/brian_format_demo.py new file mode 100644 index 000000000..e5a83b9da --- /dev/null +++ b/examples/serialization/brian_format_demo.py @@ -0,0 +1,151 @@ +""" +Portable Network State Serialization — PoC Demo +================================================ + +This script demonstrates the core mechanism of the GSoC project: +serializing a Brian2 Network state to a .brian ZIP archive (JSON + npz) +and restoring it with exact numerical fidelity. + +What it shows +------------- +1. A LIF network (NeuronGroup + Synapses + monitors) runs for 5 ms. +2. The network state is serialized to ``checkpoint.brian`` — a ZIP containing + ``metadata.json`` (object/variable descriptors with SI dimensions) and + ``arrays.npz`` (all raw array data, no pickle). +3. The network is reset to t=0 using the existing pickle-based store/restore. +4. The .brian checkpoint is loaded: all arrays and spike queues are restored + to t=5 ms without re-running the simulation. +5. The network runs another 5 ms from the restored state. +6. The final traces are compared against a reference that ran continuously + to 10 ms — they must be numerically identical. + +Why this matters +---------------- +- ``CPPStandaloneDevice.network_store/restore`` both raise NotImplementedError. + The .brian format is the fix: after a C++ standalone run, arrays are read + from ``results/`` (via ``get_array_filename``) and packed into + the same ZIP layout — no code-gen changes required. +- Pickle is Python-version–specific. The .brian format is readable from any + language and diffable in git. + +Run +--- + cd /path/to/brian2 + python examples/serialization/brian_format_demo.py +""" + +import json +import os +import tempfile +import zipfile + +import numpy as np +from brian2 import * + +from brian2.serialization import restore_network_state, serialize_network_state + + +TAU = 10 * ms # module-level so it's in scope when net.run() resolves namespaces + + +def build_network(name_suffix=""): + """Return a small deterministic LIF network ready to run.""" + G = NeuronGroup( + 20, + "dv/dt = -v / TAU : 1", + threshold="v > 0.9", + reset="v = 0", + method="exact", + name=f"neurons{name_suffix}", + ) + G.v = "0.1 + 0.05 * i / N" # deterministic, no rand() + + S = Synapses( + G, G, + "w : 1", + on_pre="v += w", + name=f"synapses{name_suffix}", + ) + S.connect(j="i") # 1:1 deterministic connectivity + S.w = 0.05 + S.delay = "i * 0.1 * ms" + + V = StateMonitor(G, "v", record=True, name=f"voltage{name_suffix}") + M = SpikeMonitor(G, name=f"spikes{name_suffix}") + return G, S, V, M + + +# ── Reference run (no checkpointing) ──────────────────────────────────────── + +start_scope() +G_ref, S_ref, V_ref, M_ref = build_network("_ref") +net_ref = Network(G_ref, S_ref, V_ref, M_ref) +net_ref.run(10 * ms) + +ref_v = V_ref.v[:, :].copy() +ref_spike_i = M_ref.i[:].copy() +ref_spike_t = M_ref.t_[:].copy() + +print(f"Reference: t={net_ref.t!s:>10s} " + f"spikes={len(ref_spike_i)} " + f"v[0,-1]={ref_v[0, -1]:.6f}") + + +# ── Checkpoint run (serialize at 5 ms, restore, continue) ─────────────────── + +start_scope() +G, S, V, M = build_network() +net = Network(G, S, V, M) + +# Store initial state with the existing pickle mechanism so we can reset later +net.store("initial") + +net.run(5 * ms) + +# --- serialize to .brian archive ------------------------------------------- +with tempfile.NamedTemporaryFile(suffix=".brian", delete=False) as f: + ckpt_path = f.name + +serialize_network_state(net, ckpt_path) + +archive_size_kb = os.path.getsize(ckpt_path) / 1024 +print(f"\nCheckpoint: written to {ckpt_path} ({archive_size_kb:.1f} KB)") + +# Verify archive structure +with zipfile.ZipFile(ckpt_path) as zf: + names = zf.namelist() + meta = json.loads(zf.read("metadata.json")) + +print(f" Archive files : {names}") +print(f" Serialized t : {meta['t'] * 1000:.1f} ms") +print(f" Objects : {list(meta['objects'].keys())}") + +# --- reset to t=0, then restore from .brian archive ------------------------ +net.restore("initial") # existing pickle mechanism → t=0 +restore_network_state(net, ckpt_path) # .brian archive → t=5 ms + +print(f"\nAfter restore: net.t={net.t!s} G.v[0]={G.v[0]:.6f}") + +# --- run from restored state ----------------------------------------------- +net.run(5 * ms) + +# ── Verify round-trip fidelity ─────────────────────────────────────────────── + +v_err = np.max(np.abs(ref_v - V.v[:, :])) +spike_i_ok = np.array_equal(ref_spike_i, M.i[:]) +spike_t_err = ( + np.max(np.abs(ref_spike_t - M.t_[:])) if len(ref_spike_t) else 0.0 +) + +print(f"\nRound-trip check:") +print(f" State variable max error : {v_err:.2e}") +print(f" Spike indices match : {spike_i_ok}") +print(f" Spike times max error : {spike_t_err:.2e}") + +assert v_err < 1e-12, f"State variable mismatch: {v_err}" +assert spike_i_ok, "Spike index mismatch" +assert spike_t_err < 1e-12, f"Spike time mismatch: {spike_t_err}" + +print("\nPASS — serialize → restore → run produces bit-identical results.") + +os.unlink(ckpt_path)