diff --git a/.github/workflows/test_pr_and_main.yml b/.github/workflows/test_pr_and_main.yml index b0804b50d..e6d67a4c6 100644 --- a/.github/workflows/test_pr_and_main.yml +++ b/.github/workflows/test_pr_and_main.yml @@ -1064,7 +1064,9 @@ jobs: mpisppy/tests/test_extensions.py \ mpisppy/tests/test_jensens.py \ mpisppy/tests/test_proper_bundler.py \ - mpisppy/tests/test_incumbent_writing.py + mpisppy/tests/test_incumbent_writing.py \ + mpisppy/tests/test_buffer_inspect.py \ + mpisppy/tests/test_comm_lor_check.py - name: Upload coverage data if: always() diff --git a/doc/designs/async_buffer_sanity_design.md b/doc/designs/async_buffer_sanity_design.md new file mode 100644 index 000000000..31baeff83 --- /dev/null +++ b/doc/designs/async_buffer_sanity_design.md @@ -0,0 +1,282 @@ +# Async buffer sanity inspector + +Status: sketch implemented on branch `async-buffer-sanity-design`; design open +for review. + +Motivating incident: an xhat spoke received a `SHUTDOWN` signal that was not +sent by `Hub.send_terminate`, suggesting one of the MPI-RMA buffer regions +was being corrupted. The library had no runtime sanity check on buffer +contents; `_validate_recv_field` at `mpisppy/cylinders/spcommunicator.py:307` +only checks layout at registration time. This design adds a passive set of +content checks that can be called manually during debugging or wired in via +an environment variable on hot paths. + +--- + +## 0. Goals + +1. Inspect any `SendArray` or `RecvArray` after-the-fact for content + sanity, given an optional caller-supplied context. +2. Make producers (`hub.py`, `spwindow.py`, send paths) untouchable; + inspection is purely a consumer concern. +3. Provide a single entry point that works for both send and receive + buffers and for every `Field`. +4. Catch the SHUTDOWN-stomp signature first, with a framework that + extends to other fields. + +Non-goals: + +- Cross-cylinder agreement checks (the existing `synchronize=True` path + inside `get_receive_buffer` already does this, and we are not paying + another barrier+allreduce). +- Module-level state that remembers per-buffer history. Tempting for + detecting trailing-slot oscillation, but it hides state in the + inspector and grows unbounded as buffers come and go. +- Modifying any producer to publish extra metadata. + +--- + +## 1. Invariants we lean on + +These come from reading `communicator_array` at +`mpisppy/cylinders/spcommunicator.py:33-58` and the `FieldArray` hierarchy +at lines 61-148. + +1. `communicator_array` allocates a padded MPI memory region and + initializes the *entire* region (data + id + padding) to NaN, then + writes 0.0 into the trailing logical slot (the write_id). +2. The hub's send path only ever writes into the logical `value_array()` + region and bumps the trailing write_id via + `SendArray._next_write_id()` (line 122). +3. RMA `get` copies the full padded buffer over. +4. `RecvArray._pull_id()` (line 142) records the last id whose + corresponding payload was accepted by `get_receive_buffer`. + +Consequences the inspector relies on: + +- The padding region (`window_array()[logical_len:padded_len]`) must + remain NaN. A finite value anywhere in padding is a write that ran + past the field's logical length. +- A data slot may be NaN *only* when write_id == 0 (no publish has + happened). Once write_id >= 1, NaN in data is a corrupted publish. +- `buf.array()[-1]` (the just-arrived trailing slot) must be >= + `buf.id()` (the last accepted id). A regression means the trailing + slot was rewritten to a smaller value. +- For a `SendArray`, the trailing slot must equal `buf.id()` between + publishes. + +--- + +## 2. API + +Module: `mpisppy/debug_utils/buffer_inspect.py`. + +```python +@dataclass +class InspectContext: + expected_write_id: Optional[int] = None + last_write_id: Optional[int] = None # caller-tracked baseline + nonant_count: Optional[int] = None + nonant_lower: Optional[np.ndarray] = None + nonant_upper: Optional[np.ndarray] = None + spbase: Optional[Any] = None # duck-typed; fallback source + + def get_nonant_count(self) -> Optional[int]: ... + +@dataclass +class Report: + ok: bool = True + findings: list[str] + severity: str = "warn" # "warn" | "error" + dump: Optional[str] # populated when verbose=True + +def inspect_buffer(buf, fld: Field, ctx: Optional[InspectContext] = None, + *, send: bool = False, verbose: bool = False) -> Report +``` + +`InspectContext` fields are all optional; the inspector silently skips +any check whose context is missing. `spbase` is looked up via +`getattr(spbase, "nonant_length", None)`, so a partial mock that exposes +just `nonant_length` works. Explicit fields on the context take +precedence over `spbase`-derived ones. + +`Report.severity` ladders from warn to error. Caller decides what to do +with a non-OK report — log, raise, or set `_is_new = False`. The +inspector never raises. + +`verbose=True` populates `Report.dump` with a small text block +containing `buf.array()`, `buf.window_array()`, `buf.id()`, +`logical_len`, `padded_len`, and the field name. Cheap for small +fields; for large fields the caller decides. + +--- + +## 3. Generic checks (run for every field) + +`_check_generic` (called first): + +| Check | Severity | Condition | +|---|---|---| +| trailing slot finite | error | `not np.isfinite(buf.array()[-1])` | +| trailing slot is integer-valued | error | `abs(raw - round(raw)) > 1e-9` | +| trailing slot non-negative | error | `write_id < 0` | +| send: trailing slot == `buf.id()` | error | when `send=True` | +| recv: trailing slot >= `buf.id()` | error | when `send=False` | +| recv: trailing slot >= `ctx.last_write_id` | error | when supplied | +| trailing slot == `ctx.expected_write_id` | warn | when supplied | +| no inf in data | error | always | +| no NaN in data when write_id >= 1 | error | publish should have overwritten initial NaN | +| padding region all-NaN | error | `not np.all(np.isnan(window_array()[logical_len:padded_len]))` | + +--- + +## 4. Per-field registry + +`CHECKERS: dict[Field, Callable]`. Entries currently implemented: + +- `SHUTDOWN`: only two states are legitimate -- (a) data[0] == 1.0 + with write_id >= 1 (post `Hub.send_terminate`), or (b) data[0] is + NaN with write_id == 0 (the initial state from `communicator_array`). + `send_terminate` is the only producer and writes nothing but 1.0, + so any other value (0.0, fractional, negative, +/-inf, NaN with + write_id >= 1) can only come from a stomp, an RMA race, or a + producer bug. The checker rejects all such values as an error. +- `NONANT`: length must be a positive multiple of `ctx.get_nonant_count()` + (the publisher's local nonant length is `nonant_count * len(local_scenarios)`, + so multi-scenario hubs publish wider buffers); data in + `[ctx.nonant_lower, ctx.nonant_upper]` componentwise (when supplied + and the buffer is single-scenario wide). +- `NONANT_LOWER_BOUNDS` / `NONANT_UPPER_BOUNDS`: length check; if the + caller passes the *other* bound (via `ctx.nonant_upper` / + `ctx.nonant_lower`), check componentwise consistency. +- `OBJECTIVE_INNER_BOUND` / `OBJECTIVE_OUTER_BOUND`: length == 1. +- `BEST_XHAT`: length >= `ctx.get_nonant_count()` (the buffer also + carries per-scenario costs); the nonant prefix is bounds-checked when + bounds are supplied. + +Fields without an entry (`DUALS`, `RELAXED_NONANT`, +`CROSS_SCENARIO_CUT`, `CROSS_SCENARIO_COST`, `EXPECTED_REDUCED_COST`, +`SCENARIO_REDUCED_COST`, `RECENT_XHATS`, `BEST_OBJECTIVE_BOUNDS`, +`WHOLE`) get generic checks only. They can be filled in as needs arise. + +Adding a checker: write a function with signature +`(buf, report, ctx) -> None`, register it in `CHECKERS`. No producer +changes required. + +--- + +## 5. Wiring + +Two modes, both opt-in. + +**Manual.** A developer or test imports `inspect_buffer` and pokes it at +a buffer of interest. Example uses already in +`mpisppy/debug_utils/buffer_inspect.py` smoke tests: + +```python +from mpisppy.debug_utils.buffer_inspect import inspect_buffer, InspectContext +rep = inspect_buffer(shutdown_buf, Field.SHUTDOWN, send=False, verbose=True) +if not rep.ok: + print(rep) +``` + +**CLI-gated check at the shutdown moment.** `_BoundSpoke.got_kill_signal` +in `mpisppy/cylinders/spoke.py:24-30` is the most likely place to catch +the motivating bug. A new flag `inspect_buffers_on_shutdown` is added +in `Config.popular_args` and propagated through `cfg_vanilla.shared_options` +into `opt.options`. The hook runs the inspector only when the kill +fires *and* the flag is set: + +```python +fired = bool(shutdown_buf[0] == 1.0) +if fired and self.opt.options.get("inspect_buffers_on_shutdown"): + self._inspect_buffers_on_shutdown(shutdown_buf) +``` + +`_inspect_buffers_on_shutdown` sweeps every registered receive and send +buffer through `inspect_buffer`, not just SHUTDOWN. SHUTDOWN goes first +and verbose (the diagnostic dump lands in the warning); the rest run +non-verbose. `InspectContext(spbase=self.opt)` is threaded through so the +per-field checkers that need nonant length pick it up via the spbase +fallback. The sweep is what gives us real-buffer false-positive coverage +for every checker (not just SHUTDOWN) once the smoke runs. + +A failed inspection emits a `RuntimeWarning` (not a `print`) so the signal +is filterable, capturable in tests via `warnings.catch_warnings(record=True)`, +and escalatable to a hard error via `python -W error::RuntimeWarning:mpisppy.cylinders.spoke`. +We do not `raise` here: `got_kill_signal` runs during the collective shutdown +path, and aborting on one rank would leave peers blocked on the next barrier; +the inspector's job at this site is to observe the suspect shutdown, not +abort it. Hot-path call sites added later (e.g. `update_nonants`) may +choose to raise on `rep.severity == "error"`. + +When the flag is unset (default), the inspector is not called. We fire +at the moment of detection rather than every poll because a spurious +shutdown is most diagnostic when the buffer state has just arrived and +not yet been overwritten by later activity. + +**Smoke coverage.** `mpisppy/tests/straight_tests.py` runs the existing +multi-stage Aircond cylinder invocation (PH hub + lagranger + fwph + +xhatshuffle on 4 ranks) with `--inspect-buffers-on-shutdown` and +`python -W error::RuntimeWarning:mpisppy.cylinders.spoke`. The shutdown +sweep visits every buffer used by that run -- SHUTDOWN, NONANT, +OBJECTIVE_INNER_BOUND, OBJECTIVE_OUTER_BOUND, BEST_XHAT, plus any +others these cylinders register -- so a healthy run is also a +no-false-positives guard for the corresponding checkers. A regression +that produces a warning trips the escalation and the subprocess exits +non-zero. `NONANT_LOWER_BOUNDS` / `NONANT_UPPER_BOUNDS` remain uncovered +by smoke because the cylinders in this run don't produce them; a separate +smoke with a reduced-costs or nonant-bounds spoke would close that gap. + +The unit-level test `TestSpokeGotKillSignalWarning` in +`mpisppy/tests/test_buffer_inspect.py` drives `Spoke.got_kill_signal` on +a stub: with a hand-stomped SHUTDOWN buffer (warning fires); with a legit +SHUTDOWN (silent); with the flag off (sweep not invoked); and with a +multi-buffer sweep that mixes healthy NONANT recv/send buffers with a +stomped OBJECTIVE_INNER_BOUND recv (exactly one warning, naming the bad +field). + +Other hot paths (`update_nonants`, `sync_bounds`, etc.) can be wired +the same way later. They are intentionally not wired in this round so +that the env-gated overhead surface stays small while we shake the +inspector out. + +--- + +## 6. Detecting the motivating SHUTDOWN bug + +Suspected signature: an xhat spoke sees `shutdown_buf[0] == 1.0` while +the hub has not called `send_terminate`. At least one of the following +inspector findings should fire under that scenario: + +1. `SHUTDOWN data[0]==1.0 but write_id==0` — the hub never published. +2. `recv write_id N < buf.id() M` — the trailing slot regressed after + a previously-accepted read. +3. `padding region modified: K non-NaN slot(s) at offsets ...` — an + adjacent field's write overran into the SHUTDOWN region's padding. + +If none of the three fire and the spurious shutdown still occurs, the +write_id slot and data slot are both consistent with a legitimate +shutdown; the next hypothesis would shift away from "stomp" toward a +producer that publishes shutdown out of band. + +--- + +## 7. Future work + +- Fill `DUALS`, `RECENT_XHATS`, `CROSS_SCENARIO_*` checkers as they + become useful. +- Add a `synchronize=True` variant of the SHUTDOWN read in a debug mode + to cross-check write_id agreement (already a one-line change on + `spoke.py:29`). +- An EF-side inspector for the EF solve path is a separate effort; the + buffer layout differs. + +--- + +## 8. Files + +- `mpisppy/debug_utils/__init__.py` — package marker. +- `mpisppy/debug_utils/buffer_inspect.py` — the inspector. +- `mpisppy/cylinders/spoke.py` — env-gated hook in `got_kill_signal`. +- `doc/designs/async_buffer_sanity_design.md` — this document. diff --git a/doc/src/debug_utils.rst b/doc/src/debug_utils.rst new file mode 100644 index 000000000..fd07403b9 --- /dev/null +++ b/doc/src/debug_utils.rst @@ -0,0 +1,127 @@ +.. _debug_utils: + +Buffer Sanity Inspector +======================= + +The ``mpisppy.debug_utils`` package provides a passive content-check +utility for the MPI-RMA send and receive buffers used by the hub-and- +spoke system. It is intended as a debugging aid when you suspect that +a buffer is being written to from somewhere unexpected (for example, +when a spoke sees a shutdown signal that the hub did not send). + +The inspector does *not* modify producer code, does *not* introduce +new MPI traffic, and is no-cost when not invoked. + +When to Use This +---------------- + +- **A spoke is acting on data it should not have received.** A spurious + ``SHUTDOWN`` is the canonical example, but the same idea applies to + any field — nonants outside their bounds, a write_id that went + backwards, NaN data on a buffer the hub claims to have published, etc. +- **A new field/cylinder is being introduced** and you want a cheap + invariant check during development. +- **Reproducing intermittent buffer-content bugs** where adding a print + in the hot path is too noisy unless gated. + +What the Inspector Checks +------------------------- + +Generic checks (run for every field): + +- Trailing ``write_id`` slot is a finite, non-negative, integer-valued + double. +- Send buffers: the trailing slot equals ``buf.id()``. +- Receive buffers: the trailing slot is not less than ``buf.id()`` + (the last id that ``get_receive_buffer`` accepted). An optional + ``ctx.last_write_id`` provides an additional, stricter baseline. +- Data region: no ``inf`` values; no ``NaN`` values once + ``write_id >= 1``. +- Padding region (between ``logical_len`` and ``padded_len``) remains + ``NaN`` — its canonical state from ``communicator_array``. A finite + value anywhere in padding is a write that ran past the field's + logical length. + +Per-``Field`` checks: + +- ``SHUTDOWN``: data[0] in ``{0.0, 1.0}``; if ``1.0`` then + ``write_id >= 1`` (only ``Hub.send_terminate`` writes ``1.0`` and it + bumps the id). +- ``NONANT``: length equals ``ctx.get_nonant_count()``; data within + ``[ctx.nonant_lower, ctx.nonant_upper]`` componentwise. +- ``NONANT_LOWER_BOUNDS`` / ``NONANT_UPPER_BOUNDS``: length check; + consistency with the counterpart bound when supplied via ``ctx``. +- ``OBJECTIVE_INNER_BOUND`` / ``OBJECTIVE_OUTER_BOUND``: length 1. +- ``BEST_XHAT``: length at least ``ctx.get_nonant_count()``; nonant + prefix within bounds when supplied. + +Manual Use +---------- + +.. code-block:: python + + from mpisppy.debug_utils import inspect_buffer, InspectContext + from mpisppy.cylinders.spwindow import Field + + ctx = InspectContext(nonant_count=spbase.nonant_length) + rep = inspect_buffer(some_recv_buf, Field.NONANT, ctx, verbose=True) + if not rep.ok: + print(rep) + +``Report`` is a small dataclass with ``ok``, ``findings`` (list of +strings), ``severity`` (``"warn"`` or ``"error"``), and an optional +``dump`` populated when ``verbose=True``. The inspector never raises; +the caller decides whether to log, raise, or treat the read as stale. + +Command-Line Trigger at Cylinder Shutdown +----------------------------------------- + +The ``--inspect-buffers-on-shutdown`` flag, exposed through the +standard ``Config`` system (``popular_args``), causes each spoke to +run the inspector on its ``SHUTDOWN`` receive buffer **at the moment a +shutdown is detected** (inside ``got_kill_signal``, only when the +signal fires — not on every poll). Findings, with rank info, are +printed when the report is not ok: + +.. code-block:: bash + + mpiexec -np N python my_driver.py --inspect-buffers-on-shutdown + +When the flag is unset (the default), the inspector is never called +and the shutdown-poll cost is unchanged. + +Choice of trigger point: a spurious ``SHUTDOWN`` is most diagnostic at +the moment of detection — the relevant buffer state has just arrived +and has not yet been overwritten by later activity. The check fires +once per spoke per cylinder shutdown, regardless of whether the +signal was legitimate; legitimate shutdowns produce an empty +findings list and print nothing. + +Extending: Adding a Field Checker +--------------------------------- + +Each per-field check is a function with the signature +``(buf, report, ctx) -> None`` that appends findings to ``report`` +when invariants are violated. Register it in the +``CHECKERS`` dict in ``mpisppy/debug_utils/buffer_inspect.py``: + +.. code-block:: python + + def _check_my_field(buf, report, ctx): + data = buf.value_array() + if len(data) != some_expected_length: + report.add(f"MY_FIELD wrong length: {len(data)}", severity="error") + + CHECKERS[Field.MY_FIELD] = _check_my_field + +Producers are intentionally left untouched; any context the checker +needs (lengths, bounds, scenario tree info) is passed in via +``InspectContext``. + +See Also +-------- + +The internal design document is at +``doc/designs/async_buffer_sanity_design.md``, including the +invariants the inspector relies on and explicit non-goals +(cross-cylinder consensus, module-level history state). diff --git a/doc/src/index.rst b/doc/src/index.rst index 33c47fea3..7387f98ef 100644 --- a/doc/src/index.rst +++ b/doc/src/index.rst @@ -71,6 +71,7 @@ MPI is used. drivers.rst internals.rst + debug_utils.rst contributors.rst amalgamator.rst code_coverage.rst diff --git a/mpisppy/cylinders/spoke.py b/mpisppy/cylinders/spoke.py index 1c10db59f..b070f1dd4 100644 --- a/mpisppy/cylinders/spoke.py +++ b/mpisppy/cylinders/spoke.py @@ -15,6 +15,7 @@ from mpisppy.cylinders.spcommunicator import SPCommunicator, SendCircularBuffer from mpisppy.cylinders.spwindow import Field +from mpisppy.debug_utils.buffer_inspect import InspectContext, inspect_buffer from mpisppy.utils import sputils @@ -29,7 +30,47 @@ def got_kill_signal(self): """ shutdown_buf = self.receive_buffers[self._make_key(Field.SHUTDOWN, 0)] self.get_receive_buffer(shutdown_buf, Field.SHUTDOWN, 0, synchronize=False) - return self.allreduce_or(shutdown_buf[0] == 1.0) + fired = bool(shutdown_buf[0] == 1.0) + if fired and self.opt.options.get("inspect_buffers_on_shutdown"): + self._inspect_buffers_on_shutdown(shutdown_buf) + return self.allreduce_or(fired) + + def _inspect_buffers_on_shutdown(self, shutdown_buf): + """Sweep every registered receive and send buffer through the + inspector and emit a RuntimeWarning for each non-OK report. + + The SHUTDOWN recv buffer is inspected first and verbose so the + diagnostic dump is in the warning text. All other buffers are + inspected non-verbose. InspectContext.spbase is set so checkers + that need nonant length pick it up via the fallback. + """ + ctx = InspectContext(spbase=self.opt) + shutdown_key = self._make_key(Field.SHUTDOWN, 0) + + rep = inspect_buffer(shutdown_buf, Field.SHUTDOWN, ctx, + send=False, verbose=True) + self._warn_if_buffer_bad(rep, Field.SHUTDOWN, direction="recv") + + for key, buf in self.receive_buffers.items(): + if key == shutdown_key: + continue + fld, _origin = self._split_key(key) + rep = inspect_buffer(buf, fld, ctx, send=False, verbose=False) + self._warn_if_buffer_bad(rep, fld, direction="recv") + + for fld, buf in self.send_buffers.items(): + rep = inspect_buffer(buf, fld, ctx, send=True, verbose=False) + self._warn_if_buffer_bad(rep, fld, direction="send") + + def _warn_if_buffer_bad(self, rep, fld, *, direction): + if rep.ok: + return + warnings.warn( + f"[buffer_inspect] field={fld.name} ({direction}) " + f"{self.cylinder_rank=} {self.strata_rank=} {self.global_rank=}\n{rep}", + RuntimeWarning, + stacklevel=3, + ) def is_converged(self, screen_trace=False): """ Alias for got_kill_signal; useful for algorithms working as both diff --git a/mpisppy/debug_utils/__init__.py b/mpisppy/debug_utils/__init__.py new file mode 100644 index 000000000..10352999e --- /dev/null +++ b/mpisppy/debug_utils/__init__.py @@ -0,0 +1,17 @@ +############################################################################### +# mpi-sppy: MPI-based Stochastic Programming in PYthon +# +# Copyright (c) 2024, Lawrence Livermore National Security, LLC, Alliance for +# Sustainable Energy, LLC, The Regents of the University of California, et al. +# All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for +# full copyright and license information. +############################################################################### + +from mpisppy.debug_utils.buffer_inspect import ( + InspectContext, + Report, + inspect_buffer, +) +from mpisppy.debug_utils.comm_lor_check import comm_lor_check + +__all__ = ["InspectContext", "Report", "inspect_buffer", "comm_lor_check"] diff --git a/mpisppy/debug_utils/buffer_inspect.py b/mpisppy/debug_utils/buffer_inspect.py new file mode 100644 index 000000000..a51f1bfca --- /dev/null +++ b/mpisppy/debug_utils/buffer_inspect.py @@ -0,0 +1,401 @@ +############################################################################### +# mpi-sppy: MPI-based Stochastic Programming in PYthon +# +# Copyright (c) 2024, Lawrence Livermore National Security, LLC, Alliance for +# Sustainable Energy, LLC, The Regents of the University of California, et al. +# All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for +# full copyright and license information. +############################################################################### +"""Sanity-check utilities for send and receive buffers. + +Producers of the buffers (hub and spoke code) are intentionally not modified. +Callers that want to validate a buffer instantiate an InspectContext (any +field optional) and call inspect_buffer(buf, field, ctx). + +See doc/designs/async_buffer_sanity_design.md (in progress) for the rationale +behind each check. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Optional + +import numpy as np + +from mpisppy.cylinders.spwindow import Field + + +# ---- public dataclasses ----------------------------------------------------- + + +@dataclass +class InspectContext: + """Caller-supplied context; every field optional. + + Inspector skips any check whose required context is None. + + spbase, when supplied, is queried via getattr for fallback values + (e.g. nonant_count -> spbase.nonant_length). Explicit fields on the + context take precedence over spbase-derived ones. + """ + expected_write_id: Optional[int] = None + last_write_id: Optional[int] = None + nonant_count: Optional[int] = None + nonant_lower: Optional[np.ndarray] = None + nonant_upper: Optional[np.ndarray] = None + spbase: Optional[Any] = None + + def get_nonant_count(self) -> Optional[int]: + if self.nonant_count is not None: + return self.nonant_count + return getattr(self.spbase, "nonant_length", None) + + +@dataclass +class Report: + ok: bool = True + findings: list[str] = field(default_factory=list) + severity: str = "warn" # "warn" or "error" + dump: Optional[str] = None + + def add(self, msg: str, severity: str = "warn") -> None: + self.ok = False + self.findings.append(msg) + if severity == "error": + self.severity = "error" + + def __str__(self) -> str: + head = "OK" if self.ok else f"FAIL ({self.severity})" + body = "\n".join(f" - {f}" for f in self.findings) or " (no findings)" + out = f"[buffer_inspect] {head}\n{body}" + if self.dump: + out += f"\n dump:\n{self.dump}" + return out + + +# ---- main entry point ------------------------------------------------------- + + +_INT_TOL = 1e-9 + + +def inspect_buffer(buf, + fld: Field, + ctx: Optional[InspectContext] = None, + *, + send: bool = False, + verbose: bool = False) -> Report: + """Inspect a SendArray or RecvArray for sanity. + + Args: + buf: a FieldArray (SendArray or RecvArray) from spcommunicator. + fld: the Field this buffer carries. + ctx: optional caller context (expected_write_id, bounds, ...). + send: True for SendArray, False for RecvArray. + verbose: include a raw dump in the returned Report. + + Returns: + Report with ok / findings / severity / optional dump. + """ + if ctx is None: + ctx = InspectContext() + report = Report() + + _check_generic(buf, report, ctx, send=send) + checker = CHECKERS.get(fld) + if checker is not None: + checker(buf, report, ctx) + + if verbose: + report.dump = _format_dump(buf, fld) + + return report + + +# ---- generic checks --------------------------------------------------------- + + +def _check_generic(buf, report: Report, ctx: InspectContext, *, send: bool) -> None: + logical = buf.array() + raw_id = logical[-1] + write_id = _check_write_id_slot(raw_id, report) + + _check_padding_is_nan(buf, report) + + if write_id is None: + # write_id slot itself was malformed; data-NaN check can't reason + # about whether NaNs are "expected initial state" or corruption. + return + + _check_data_nan_consistency(buf, write_id, report) + + if send: + if write_id != buf.id(): + report.add( + f"send buffer write_id slot ({write_id}) != buf.id() ({buf.id()})", + severity="error", + ) + else: + # buf.id() == last write_id that get_receive_buffer accepted via + # _pull_id; the trailing slot must never go below it. + if write_id < buf.id(): + report.add( + f"recv write_id {write_id} < buf.id() {buf.id()} " + "(trailing slot went backwards since last accepted read)", + severity="error", + ) + if ctx.last_write_id is not None and write_id < ctx.last_write_id: + report.add( + f"recv buffer write_id ({write_id}) regressed below " + f"ctx.last_write_id ({ctx.last_write_id})", + severity="error", + ) + + if ctx.expected_write_id is not None and write_id != ctx.expected_write_id: + report.add( + f"write_id {write_id} != ctx.expected_write_id {ctx.expected_write_id}" + ) + + +def _check_write_id_slot(raw_id: float, report: Report) -> Optional[int]: + if not np.isfinite(raw_id): + report.add(f"write_id slot is non-finite: {raw_id!r}", severity="error") + return None + rounded = round(raw_id) + if abs(raw_id - rounded) > _INT_TOL: + report.add( + f"write_id slot {raw_id!r} is not integer-valued", + severity="error", + ) + return None + if rounded < 0: + report.add(f"write_id slot {rounded} is negative", severity="error") + return None + return int(rounded) + + +def _check_padding_is_nan(buf, report: Report) -> None: + # communicator_array initializes the entire window to NaN and only + # rewrites the logical view. Padding that is no longer NaN means + # something has written past the field's logical length. + full = buf.window_array() + logical_len = buf.logical_len() + padded_len = buf.padded_len() + if padded_len == logical_len: + return + pad = full[logical_len:padded_len] + if not np.all(np.isnan(pad)): + bad = np.where(~np.isnan(pad))[0] + report.add( + f"padding region modified: {len(bad)} non-NaN slot(s) " + f"at offsets {bad.tolist()[:8]}{'...' if len(bad) > 8 else ''}", + severity="error", + ) + + +def _check_data_nan_consistency(buf, write_id: int, report: Report) -> None: + data = buf.value_array() + has_nan = bool(np.any(np.isnan(data))) + has_inf = bool(np.any(np.isinf(data))) + if write_id >= 1 and has_nan: + report.add( + "data contains NaN but write_id >= 1 (publish should have " + "overwritten initial NaN)", + severity="error", + ) + if has_inf: + report.add("data contains inf", severity="error") + + +# ---- field-specific checkers ----------------------------------------------- + + +def _check_shutdown(buf, report: Report, ctx: InspectContext) -> None: + # Only two legitimate states exist: + # - NaN with write_id == 0: initial state from communicator_array, + # no publish has happened yet. + # - 1.0 with write_id >= 1: Hub.send_terminate has fired. + # No producer ever writes 0.0 or any other value, so anything else + # is a stomp, an RMA race, or a producer bug. + data = buf.value_array() + val = data[0] + raw_id = buf.array()[-1] + write_id = int(round(raw_id)) if np.isfinite(raw_id) else None + + if write_id == 0 and np.isnan(val): + return + + if val != 1.0: + report.add( + f"SHUTDOWN data[0]={val!r}; only 1.0 is ever published " + "(or NaN with write_id==0 for the initial state)", + severity="error", + ) + if val == 1.0 and write_id is not None and write_id < 1: + report.add( + f"SHUTDOWN data[0]==1.0 but write_id=={write_id}; " + "only send_terminate writes 1.0 and it bumps the id", + severity="error", + ) + + +def _check_nonant(buf, report: Report, ctx: InspectContext) -> None: + # NONANT buffers are sized to the *publisher's* sum of per-scenario + # nonant counts across that publisher's local scenarios: + # len(data) == nonant_count * len(publisher.local_scenarios). + # When the publisher holds many scenarios (e.g. all 24 leaves of a + # multi-stage Aircond on one hub rank), len(data) >> nonant_count. + # The buffer's logical length is already enforced at registration + # by _validate_recv_field, so here we only insist that the size + # be a positive multiple of nonant_count, and that the componentwise + # bounds compare runs only when the buffer happens to be a single + # scenario wide. + data = buf.value_array() + n = ctx.get_nonant_count() + if n is not None and n > 0 and (len(data) == 0 or len(data) % n != 0): + report.add( + f"NONANT data length {len(data)} is not a positive multiple " + f"of nonant_count {n}", + severity="error", + ) + lo, hi = ctx.nonant_lower, ctx.nonant_upper + raw_id = buf.array()[-1] + write_id = int(round(raw_id)) if np.isfinite(raw_id) else 0 + # Bounds compare only makes sense once data has been published. + if write_id < 1: + return + if lo is not None and len(lo) == len(data): + bad = np.where(data < lo)[0] + if bad.size: + report.add( + f"NONANT below lower bound at {bad.size} index(es); " + f"first: idx {int(bad[0])} value {float(data[bad[0]])!r} " + f"< lo {float(lo[bad[0]])!r}" + ) + if hi is not None and len(hi) == len(data): + bad = np.where(data > hi)[0] + if bad.size: + report.add( + f"NONANT above upper bound at {bad.size} index(es); " + f"first: idx {int(bad[0])} value {float(data[bad[0]])!r} " + f"> hi {float(hi[bad[0]])!r}" + ) + + +def _check_lower_bounds(buf, report: Report, ctx: InspectContext) -> None: + _check_bound_pair(buf, report, ctx, is_lower=True) + + +def _check_upper_bounds(buf, report: Report, ctx: InspectContext) -> None: + _check_bound_pair(buf, report, ctx, is_lower=False) + + +def _check_bound_pair(buf, report: Report, ctx: InspectContext, + *, is_lower: bool) -> None: + data = buf.value_array() + # Length: if caller knows nonant_count, the bound buffer should match. + # (Conservative: total_number_nonants is the real length, but typical + # two-stage problems set both equal, and a strict total-nonant ctx + # field would have to be threaded; lean on nonant_count for now.) + n = ctx.get_nonant_count() + if n is not None and len(data) != n: + report.add( + f"{'LOWER' if is_lower else 'UPPER'}_BOUNDS data length " + f"{len(data)} != expected {n}", + severity="error", + ) + raw_id = buf.array()[-1] + write_id = int(round(raw_id)) if np.isfinite(raw_id) else 0 + if write_id < 1: + return + # If the caller passed the *other* bound, check componentwise consistency. + other = ctx.nonant_upper if is_lower else ctx.nonant_lower + if other is None or len(other) != len(data): + return + if is_lower: + bad = np.where(data > other)[0] + msg_dir = "lower > upper" + else: + bad = np.where(data < other)[0] + msg_dir = "upper < lower" + if bad.size: + report.add( + f"{msg_dir} at {bad.size} index(es); first: idx {int(bad[0])} " + f"this={float(data[bad[0]])!r} other={float(other[bad[0]])!r}", + severity="error", + ) + + +def _check_objective_scalar(buf, report: Report, ctx: InspectContext) -> None: + data = buf.value_array() + if len(data) != 1: + report.add( + f"objective bound data length {len(data)} != 1", + severity="error", + ) + # Finiteness already handled by _check_data_nan_consistency once + # write_id >= 1; nothing else generic to assert about a scalar bound. + + +def _check_best_xhat(buf, report: Report, ctx: InspectContext) -> None: + # BEST_XHAT = [nonant values..., per-scenario costs...]; we don't + # have scenario count in ctx, so length is only loosely constrained. + data = buf.value_array() + n = ctx.get_nonant_count() + if n is not None and len(data) < n: + report.add( + f"BEST_XHAT data length {len(data)} < nonant_count {n}", + severity="error", + ) + return + raw_id = buf.array()[-1] + write_id = int(round(raw_id)) if np.isfinite(raw_id) else 0 + if write_id < 1 or n is None: + return + xhat = data[:n] + lo, hi = ctx.nonant_lower, ctx.nonant_upper + if lo is not None and len(lo) == n: + bad = np.where(xhat < lo)[0] + if bad.size: + report.add( + f"BEST_XHAT nonant portion below lower bound at " + f"{bad.size} index(es); first: idx {int(bad[0])} " + f"value {float(xhat[bad[0]])!r} < lo {float(lo[bad[0]])!r}" + ) + if hi is not None and len(hi) == n: + bad = np.where(xhat > hi)[0] + if bad.size: + report.add( + f"BEST_XHAT nonant portion above upper bound at " + f"{bad.size} index(es); first: idx {int(bad[0])} " + f"value {float(xhat[bad[0]])!r} > hi {float(hi[bad[0]])!r}" + ) + + +CHECKERS: dict[Field, Callable[[Any, Report, InspectContext], None]] = { + Field.SHUTDOWN: _check_shutdown, + Field.NONANT: _check_nonant, + Field.NONANT_LOWER_BOUNDS: _check_lower_bounds, + Field.NONANT_UPPER_BOUNDS: _check_upper_bounds, + Field.OBJECTIVE_INNER_BOUND: _check_objective_scalar, + Field.OBJECTIVE_OUTER_BOUND: _check_objective_scalar, + Field.BEST_XHAT: _check_best_xhat, +} + + +# ---- dump helper ------------------------------------------------------------ + + +def _format_dump(buf, fld: Field) -> str: + logical = buf.array() + full = buf.window_array() + return ( + f" field = {fld.name} ({int(fld)})\n" + f" buf.id() = {buf.id()}\n" + f" data_len = {buf.data_len()}\n" + f" logical_len= {buf.logical_len()}\n" + f" padded_len = {buf.padded_len()}\n" + f" logical = {logical.tolist()}\n" + f" padded = {full.tolist()}" + ) diff --git a/mpisppy/debug_utils/comm_lor_check.py b/mpisppy/debug_utils/comm_lor_check.py new file mode 100644 index 000000000..96efafd62 --- /dev/null +++ b/mpisppy/debug_utils/comm_lor_check.py @@ -0,0 +1,53 @@ +############################################################################### +# mpi-sppy: MPI-based Stochastic Programming in PYthon +# +# Copyright (c) 2024, Lawrence Livermore National Security, LLC, Alliance for +# Sustainable Energy, LLC, The Regents of the University of California, et al. +# All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for +# full copyright and license information. +############################################################################### +"""Collective sanity check for an MPI communicator. + +Each rank contributes uint8 0 and the values are combined with MPI.LOR. Since +every contribution is 0, the reduced value is always 0 in normal operation, +so the call is effectively a barrier that announces itself. If MPI hangs or +the comm is broken, the missing STOP line is itself diagnostic. + +The non-zero branch is a defensive guard: if any rank somehow sees a non-zero +reduced value, that rank prints the value, the comm name, and a stack trace. +That branch cannot be exercised by the test (the contribution is hard-coded +to 0), so only the success path is covered. +""" + +import traceback + +import numpy as np + +import mpisppy.MPI as MPI + + +def comm_lor_check(comm, name): + """Collective LOR sanity check on `comm`. + + Args: + comm: an MPI communicator (or the single-rank mock from mpisppy.MPI). + name (str): human-readable label for the comm, used in messages. + """ + rank = comm.Get_rank() + if rank == 0: + print(f"[comm_lor_check] START comm={name!r}", flush=True) + + local = np.zeros(1, dtype=np.uint8) + reduced = np.zeros(1, dtype=np.uint8) + comm.Allreduce(local, reduced, op=MPI.LOR) + + if reduced[0] != 0: + print( + f"[comm_lor_check] NONZERO value={int(reduced[0])} " + f"comm={name!r} rank={rank}", + flush=True, + ) + traceback.print_stack() + + if rank == 0: + print(f"[comm_lor_check] STOP comm={name!r}", flush=True) diff --git a/mpisppy/tests/straight_tests.py b/mpisppy/tests/straight_tests.py index 355b691ac..92e2722c2 100644 --- a/mpisppy/tests/straight_tests.py +++ b/mpisppy/tests/straight_tests.py @@ -99,8 +99,14 @@ def _doone(cmdstr: str) -> bool: # Use the exact interpreter running this test (important for CI/conda/venv) pyexe = shlex.quote(sys.executable) +# -W: escalate any RuntimeWarning from the buffer inspector to a hard error, +# so a healthy run with --inspect-buffers-on-shutdown is also a buffer health +# check. (Doubles as a general multi-stage cylinder health check.) +# Must come before python_args -- when python_args is "-m coverage run ...", +# -W after it gets parsed by coverage instead of by Python. +warning_filter = "-W error::RuntimeWarning:mpisppy.cylinders.spoke" cmdstr = ( - f"mpiexec -np 4 {pyexe} {python_args} -m mpi4py {shlex.quote(fpath)} " + f"mpiexec -np 4 {pyexe} {warning_filter} {python_args} -m mpi4py {shlex.quote(fpath)} " f"--bundles-per-rank=0 --max-iterations=5 --default-rho=1 " f"--solver-name={shlex.quote(solver_name)} " f'--branching-factors "4 3 2" ' @@ -108,7 +114,8 @@ def _doone(cmdstr: str) -> bool: f"--rel-gap 0.01 --mu-dev 0 --sigma-dev 40 " f"--max-solver-threads 2 --start-seed 0 " f"--lagranger --lagranger-rho-rescale-factors-json {shlex.quote(jpath)} " - f"--fwph --fwph-save-file {shlex.quote(fwphSaveFile)} --xhatshuffle" + f"--fwph --fwph-save-file {shlex.quote(fwphSaveFile)} --xhatshuffle " + f"--inspect-buffers-on-shutdown" ) ok = _doone(cmdstr) diff --git a/mpisppy/tests/test_buffer_inspect.py b/mpisppy/tests/test_buffer_inspect.py new file mode 100644 index 000000000..c765068c5 --- /dev/null +++ b/mpisppy/tests/test_buffer_inspect.py @@ -0,0 +1,563 @@ +############################################################################### +# mpi-sppy: MPI-based Stochastic Programming in PYthon +# +# Copyright (c) 2024, Lawrence Livermore National Security, LLC, Alliance for +# Sustainable Energy, LLC, The Regents of the University of California, et al. +# All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for +# full copyright and license information. +############################################################################### +"""Tests for mpisppy.debug_utils.buffer_inspect.""" + +import types +import unittest +import warnings +from types import SimpleNamespace + +import numpy as np + +from mpisppy.cylinders.spcommunicator import RecvArray, SendArray +from mpisppy.cylinders.spoke import Spoke +from mpisppy.cylinders.spwindow import Field +from mpisppy.debug_utils import InspectContext, Report, inspect_buffer + + +class _FakeSP: + """Duck-typed SPBase substitute for tests that exercise ctx.spbase.""" + def __init__(self, nonant_length): + self.nonant_length = nonant_length + + +# Helper to publish a SendArray cleanly: set values then bump the id. +def _publish(buf, values): + for i, v in enumerate(values): + buf[i] = v + buf._next_write_id() + + +# ---- generic checks --------------------------------------------------------- + + +class TestGenericChecks(unittest.TestCase): + + def test_fresh_send_buffer_passes(self): + buf = SendArray(5) + rep = inspect_buffer(buf, Field.NONANT, + ctx=InspectContext(nonant_count=5), send=True) + self.assertTrue(rep.ok, msg=str(rep)) + + def test_fresh_recv_buffer_passes(self): + buf = RecvArray(5) + rep = inspect_buffer(buf, Field.NONANT, + ctx=InspectContext(nonant_count=5), send=False) + self.assertTrue(rep.ok, msg=str(rep)) + + def test_padding_overrun_detected(self): + buf = RecvArray(2) + buf._full_array[5] = 42.0 # write into padding region + rep = inspect_buffer(buf, Field.NONANT, send=False) + self.assertFalse(rep.ok) + self.assertTrue(any("padding region modified" in f for f in rep.findings)) + + def test_write_id_must_be_integer_valued(self): + buf = RecvArray(1) + buf._array[-1] = 3.5 + rep = inspect_buffer(buf, Field.NONANT, send=False) + self.assertFalse(rep.ok) + self.assertTrue(any("not integer-valued" in f for f in rep.findings)) + + def test_write_id_must_be_finite(self): + buf = RecvArray(1) + buf._array[-1] = np.inf + rep = inspect_buffer(buf, Field.NONANT, send=False) + self.assertFalse(rep.ok) + self.assertTrue(any("non-finite" in f for f in rep.findings)) + + def test_write_id_must_be_non_negative(self): + buf = RecvArray(1) + buf._array[-1] = -2.0 + rep = inspect_buffer(buf, Field.NONANT, send=False) + self.assertFalse(rep.ok) + self.assertTrue(any("negative" in f for f in rep.findings)) + + def test_send_id_mismatch_detected(self): + buf = SendArray(3) + _publish(buf, [1.0, 2.0, 3.0]) + # Tamper with the trailing slot post-publish + buf._array[-1] = 99.0 + rep = inspect_buffer(buf, Field.NONANT, send=True) + self.assertFalse(rep.ok) + self.assertTrue(any("!= buf.id()" in f for f in rep.findings)) + + def test_recv_write_id_regression_detected(self): + buf = RecvArray(1) + buf._id = 5 + buf._array[0] = 0.0 + buf._array[-1] = 2.0 + rep = inspect_buffer(buf, Field.SHUTDOWN, send=False) + self.assertFalse(rep.ok) + self.assertTrue(any("went backwards" in f for f in rep.findings)) + + def test_recv_last_write_id_baseline(self): + buf = RecvArray(1) + buf._array[0] = 0.0 + buf._array[-1] = 2.0 + ctx = InspectContext(last_write_id=10) + rep = inspect_buffer(buf, Field.SHUTDOWN, ctx=ctx, send=False) + self.assertFalse(rep.ok) + self.assertTrue(any("ctx.last_write_id" in f for f in rep.findings)) + + def test_expected_write_id_mismatch(self): + buf = SendArray(1) + _publish(buf, [0.0]) + ctx = InspectContext(expected_write_id=99) + rep = inspect_buffer(buf, Field.SHUTDOWN, ctx=ctx, send=True) + self.assertFalse(rep.ok) + self.assertTrue(any("expected_write_id" in f for f in rep.findings)) + + def test_nan_in_data_after_publish_is_finding(self): + buf = SendArray(3) + buf._next_write_id() # publish without setting values: data stays NaN + rep = inspect_buffer(buf, Field.NONANT, + ctx=InspectContext(nonant_count=3), send=True) + self.assertFalse(rep.ok) + self.assertTrue(any("contains NaN" in f for f in rep.findings)) + + def test_inf_in_data_is_finding(self): + buf = SendArray(3) + buf[0] = 1.0 + buf[1] = np.inf + buf[2] = 3.0 + buf._next_write_id() + rep = inspect_buffer(buf, Field.NONANT, + ctx=InspectContext(nonant_count=3), send=True) + self.assertFalse(rep.ok) + self.assertTrue(any("contains inf" in f for f in rep.findings)) + + +# ---- per-field checks ------------------------------------------------------- + + +class TestShutdownChecks(unittest.TestCase): + + def test_legit_shutdown_passes(self): + buf = SendArray(1) + _publish(buf, [1.0]) + rep = inspect_buffer(buf, Field.SHUTDOWN, send=True) + self.assertTrue(rep.ok, msg=str(rep)) + + def test_stomped_shutdown_caught(self): + # data=1.0 but write_id stayed at 0: the suspected stomp signature + buf = RecvArray(1) + buf._array[0] = 1.0 + rep = inspect_buffer(buf, Field.SHUTDOWN, send=False) + self.assertFalse(rep.ok) + self.assertTrue(any("write_id==0" in f for f in rep.findings)) + + def test_shutdown_value_not_in_set(self): + buf = SendArray(1) + buf[0] = 0.5 + buf._next_write_id() + rep = inspect_buffer(buf, Field.SHUTDOWN, send=True) + self.assertFalse(rep.ok) + self.assertTrue(any("only 1.0 is ever published" in f for f in rep.findings)) + + def test_shutdown_zero_is_rejected(self): + # No producer ever writes 0.0 -- it would only appear via a stomp, + # an RMA race, or a producer bug. Must be flagged. + buf = SendArray(1) + buf[0] = 0.0 + buf._next_write_id() + rep = inspect_buffer(buf, Field.SHUTDOWN, send=True) + self.assertFalse(rep.ok) + # Match on the stable part of the message; the numeric repr varies + # between numpy 1.x (0.0) and 2.x (np.float64(0.0)). + self.assertTrue(any("only 1.0 is ever published" in f for f in rep.findings)) + + def test_initial_state_passes(self): + # data=NaN, write_id=0: canonical initial state; allowed + buf = RecvArray(1) + rep = inspect_buffer(buf, Field.SHUTDOWN, send=False) + self.assertTrue(rep.ok, msg=str(rep)) + + +class TestNonantChecks(unittest.TestCase): + + def test_length_mismatch_via_explicit_count(self): + buf = SendArray(5) + _publish(buf, [0.0] * 5) + rep = inspect_buffer(buf, Field.NONANT, + ctx=InspectContext(nonant_count=7), send=True) + self.assertFalse(rep.ok) + self.assertTrue(any("data length 5" in f for f in rep.findings)) + + def test_length_mismatch_via_spbase_fallback(self): + buf = SendArray(5) + _publish(buf, [0.0] * 5) + ctx = InspectContext(spbase=_FakeSP(nonant_length=7)) + rep = inspect_buffer(buf, Field.NONANT, ctx=ctx, send=True) + self.assertFalse(rep.ok) + self.assertTrue(any("data length 5" in f for f in rep.findings)) + + def test_explicit_count_takes_precedence_over_spbase(self): + buf = SendArray(5) + _publish(buf, [0.0] * 5) + ctx = InspectContext(nonant_count=5, spbase=_FakeSP(nonant_length=7)) + rep = inspect_buffer(buf, Field.NONANT, ctx=ctx, send=True) + self.assertTrue(rep.ok, msg=str(rep)) + + def test_multi_scenario_buffer_passes(self): + # Publisher with K scenarios publishes a NONANT buffer of length + # nonant_count * K. The checker must accept any positive multiple. + buf = SendArray(24) # e.g., 4 scenarios * 6 nonants + _publish(buf, [0.0] * 24) + rep = inspect_buffer(buf, Field.NONANT, + ctx=InspectContext(nonant_count=6), send=True) + self.assertTrue(rep.ok, msg=str(rep)) + + def test_non_multiple_length_caught(self): + buf = SendArray(10) # 10 is not a multiple of 6 + _publish(buf, [0.0] * 10) + rep = inspect_buffer(buf, Field.NONANT, + ctx=InspectContext(nonant_count=6), send=True) + self.assertFalse(rep.ok) + self.assertTrue(any("not a positive multiple" in f for f in rep.findings)) + + def test_out_of_bounds_componentwise(self): + buf = SendArray(4) + _publish(buf, [0.5, 7.0, 2.0, -1.0]) + lo = np.array([0.0, 0.0, 0.0, 0.0]) + hi = np.array([5.0, 5.0, 5.0, 5.0]) + rep = inspect_buffer(buf, Field.NONANT, + ctx=InspectContext(nonant_count=4, + nonant_lower=lo, + nonant_upper=hi), + send=True) + self.assertFalse(rep.ok) + self.assertTrue(any("below lower bound" in f for f in rep.findings)) + self.assertTrue(any("above upper bound" in f for f in rep.findings)) + + def test_unpublished_nonant_skips_bounds_check(self): + # Fresh SendArray: write_id=0, data=NaN. Bounds compare must NOT fire. + buf = SendArray(3) + lo = np.array([0.0, 0.0, 0.0]) + hi = np.array([1.0, 1.0, 1.0]) + rep = inspect_buffer(buf, Field.NONANT, + ctx=InspectContext(nonant_count=3, + nonant_lower=lo, + nonant_upper=hi), + send=True) + self.assertTrue(rep.ok, msg=str(rep)) + + +class TestBoundsBufferChecks(unittest.TestCase): + + def test_lower_above_upper(self): + buf = SendArray(3) + _publish(buf, [0.0, 2.0, 5.0]) # lowers + upper = np.array([1.0, 1.0, 6.0]) + rep = inspect_buffer(buf, Field.NONANT_LOWER_BOUNDS, + ctx=InspectContext(nonant_count=3, + nonant_upper=upper), + send=True) + self.assertFalse(rep.ok) + self.assertTrue(any("lower > upper" in f for f in rep.findings)) + + def test_upper_below_lower(self): + buf = SendArray(3) + _publish(buf, [10.0, 0.5, 6.0]) # uppers + lower = np.array([0.0, 1.0, 5.0]) + rep = inspect_buffer(buf, Field.NONANT_UPPER_BOUNDS, + ctx=InspectContext(nonant_count=3, + nonant_lower=lower), + send=True) + self.assertFalse(rep.ok) + self.assertTrue(any("upper < lower" in f for f in rep.findings)) + + +class TestObjectiveChecks(unittest.TestCase): + + def test_inner_bound_length_1_ok(self): + buf = SendArray(1) + _publish(buf, [42.5]) + rep = inspect_buffer(buf, Field.OBJECTIVE_INNER_BOUND, send=True) + self.assertTrue(rep.ok, msg=str(rep)) + + def test_inner_bound_wrong_length_caught(self): + buf = SendArray(3) + _publish(buf, [1.0, 2.0, 3.0]) + rep = inspect_buffer(buf, Field.OBJECTIVE_INNER_BOUND, send=True) + self.assertFalse(rep.ok) + self.assertTrue(any("data length 3 != 1" in f for f in rep.findings)) + + +class TestBestXhatChecks(unittest.TestCase): + + def test_length_too_short(self): + buf = SendArray(2) + _publish(buf, [0.0, 1.0]) + rep = inspect_buffer(buf, Field.BEST_XHAT, + ctx=InspectContext(nonant_count=5), send=True) + self.assertFalse(rep.ok) + self.assertTrue(any("BEST_XHAT data length 2" in f for f in rep.findings)) + + def test_nonant_prefix_out_of_bounds(self): + buf = SendArray(6) + # First 3 are nonants, last 3 are per-scenario costs + _publish(buf, [0.5, 10.0, 2.0, 100.0, 200.0, 300.0]) + lo = np.array([0.0, 0.0, 0.0]) + hi = np.array([5.0, 5.0, 5.0]) + rep = inspect_buffer(buf, Field.BEST_XHAT, + ctx=InspectContext(nonant_count=3, + nonant_lower=lo, + nonant_upper=hi), + send=True) + self.assertFalse(rep.ok) + self.assertTrue(any("BEST_XHAT nonant portion above upper bound" in f + for f in rep.findings)) + + +# ---- report and verbose ----------------------------------------------------- + + +class TestReportAndVerbose(unittest.TestCase): + + def test_verbose_populates_dump(self): + buf = SendArray(1) + _publish(buf, [1.0]) + rep = inspect_buffer(buf, Field.SHUTDOWN, send=True, verbose=True) + self.assertIsNotNone(rep.dump) + self.assertIn("SHUTDOWN", rep.dump) + self.assertIn("logical_len", rep.dump) + self.assertIn("padded_len", rep.dump) + + def test_non_verbose_dump_is_none(self): + buf = SendArray(1) + _publish(buf, [1.0]) + rep = inspect_buffer(buf, Field.SHUTDOWN, send=True, verbose=False) + self.assertIsNone(rep.dump) + + def test_str_round_trip(self): + buf = RecvArray(1) + buf._array[0] = 1.0 + rep = inspect_buffer(buf, Field.SHUTDOWN, send=False, verbose=True) + s = str(rep) + self.assertIn("FAIL", s) + self.assertIn("SHUTDOWN", s) + + def test_report_severity_ladders_to_error(self): + r = Report() + r.add("warn 1", severity="warn") + self.assertEqual(r.severity, "warn") + r.add("err 1", severity="error") + self.assertEqual(r.severity, "error") + # subsequent warn doesn't downgrade + r.add("warn 2", severity="warn") + self.assertEqual(r.severity, "error") + + +# ---- config flag wiring ----------------------------------------------------- + + +class TestConfigFlagWiring(unittest.TestCase): + """The CLI flag must register in Config and propagate to shared_options.""" + + def test_default_false(self): + from mpisppy.utils.config import Config + cfg = Config() + cfg.popular_args() + self.assertFalse(cfg.inspect_buffers_on_shutdown) + + def test_propagated_to_shared_options(self): + from mpisppy.utils.config import Config + import mpisppy.utils.cfg_vanilla as vanilla + cfg = Config() + cfg.popular_args() + cfg.solver_name = "gurobi" + cfg.inspect_buffers_on_shutdown = True + opts = vanilla.shared_options(cfg, is_hub=False) + self.assertTrue(opts.get("inspect_buffers_on_shutdown")) + + def test_default_propagates_as_false(self): + from mpisppy.utils.config import Config + import mpisppy.utils.cfg_vanilla as vanilla + cfg = Config() + cfg.popular_args() + cfg.solver_name = "gurobi" + opts = vanilla.shared_options(cfg, is_hub=True) + self.assertFalse(opts.get("inspect_buffers_on_shutdown")) + + +# ---- integration with Spoke.got_kill_signal -------------------------------- + + +def _make_spoke_stub(shutdown_buf, *, inspect_on=True, + extra_recv=None, send=None, nonant_length=None): + """Build a duck-typed stub sufficient to drive Spoke.got_kill_signal. + + extra_recv: optional dict mapping (Field, origin) -> RecvArray that the + sweep will visit in addition to the SHUTDOWN entry. + send: optional dict mapping Field -> SendArray for the send-side sweep. + nonant_length: if set, exposed via stub.opt.nonant_length so checkers + that fall back to spbase pick it up. + """ + stub = SimpleNamespace() + recv = {(Field.SHUTDOWN, 0): shutdown_buf} + if extra_recv: + recv.update(extra_recv) + stub.receive_buffers = recv + stub.send_buffers = dict(send) if send else {} + stub._make_key = lambda field, origin: (field, origin) + stub._split_key = lambda key: key + # The real method copies from the MPI window into the buffer; for the + # stub the buffer is already populated, so this is a no-op. + stub.get_receive_buffer = lambda buf, field, origin, synchronize=True: True + stub.opt = SimpleNamespace( + options={"inspect_buffers_on_shutdown": inspect_on}, + ) + if nonant_length is not None: + stub.opt.nonant_length = nonant_length + stub.cylinder_rank = 0 + stub.strata_rank = 1 + stub.global_rank = 1 + stub.allreduce_or = lambda v: v + # Bind the sweep helpers from the real class onto the stub so that + # got_kill_signal can call self._inspect_buffers_on_shutdown(...). + stub._inspect_buffers_on_shutdown = types.MethodType( + Spoke._inspect_buffers_on_shutdown, stub) + stub._warn_if_buffer_bad = types.MethodType( + Spoke._warn_if_buffer_bad, stub) + return stub + + +class TestSpokeGotKillSignalWarning(unittest.TestCase): + """The print -> warnings.warn switch in Spoke.got_kill_signal.""" + + def test_stomped_shutdown_emits_runtime_warning(self): + # data=1.0 but write_id stayed at 0: the suspected stomp signature + buf = RecvArray(1) + buf._array[0] = 1.0 + stub = _make_spoke_stub(buf, inspect_on=True) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", RuntimeWarning) + fired = Spoke.got_kill_signal(stub) + self.assertTrue(fired) + runtime_warnings = [w for w in caught if issubclass(w.category, RuntimeWarning)] + self.assertEqual(len(runtime_warnings), 1, msg=[str(w.message) for w in caught]) + self.assertIn("buffer_inspect", str(runtime_warnings[0].message)) + + def test_legit_shutdown_emits_no_warning(self): + # Properly published shutdown signal: data=1.0 with write_id>=1 + buf = RecvArray(1) + buf._array[0] = 1.0 + buf._array[-1] = 1.0 # write_id slot + buf._id = 1 + stub = _make_spoke_stub(buf, inspect_on=True) + with warnings.catch_warnings(): + warnings.simplefilter("error", RuntimeWarning) + fired = Spoke.got_kill_signal(stub) + self.assertTrue(fired) + + def test_flag_off_suppresses_warning_on_stomped_buffer(self): + # Inspector must not run when the flag is off, even with a bad buffer. + buf = RecvArray(1) + buf._array[0] = 1.0 + stub = _make_spoke_stub(buf, inspect_on=False) + with warnings.catch_warnings(): + warnings.simplefilter("error", RuntimeWarning) + fired = Spoke.got_kill_signal(stub) + self.assertTrue(fired) + + def test_sweep_inspects_every_buffer(self): + # SHUTDOWN legit + healthy NONANT recv + healthy NONANT send: no warnings. + # Then add a stomped OBJECTIVE_INNER_BOUND recv: exactly one warning, + # naming that field. + good_shutdown = RecvArray(1) + good_shutdown._array[0] = 1.0 + good_shutdown._array[-1] = 1.0 + good_shutdown._id = 1 + + good_nonant_recv = RecvArray(3) + for i, v in enumerate([0.0, 1.0, 2.0]): + good_nonant_recv._array[i] = v + good_nonant_recv._array[-1] = 1.0 + good_nonant_recv._id = 1 + + good_nonant_send = SendArray(3) + _publish(good_nonant_send, [0.0, 1.0, 2.0]) + + bad_inner = RecvArray(1) + bad_inner._array[0] = 0.5 + bad_inner._full_array[3] = 7.0 # write into padding region + + extra_recv = { + (Field.NONANT, 1): good_nonant_recv, + (Field.OBJECTIVE_INNER_BOUND, 1): bad_inner, + } + send = {Field.NONANT: good_nonant_send} + stub = _make_spoke_stub( + good_shutdown, + inspect_on=True, + extra_recv=extra_recv, + send=send, + nonant_length=3, + ) + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always", RuntimeWarning) + fired = Spoke.got_kill_signal(stub) + self.assertTrue(fired) + runtime_warnings = [w for w in caught if issubclass(w.category, RuntimeWarning)] + self.assertEqual( + len(runtime_warnings), 1, + msg=[str(w.message) for w in runtime_warnings], + ) + msg = str(runtime_warnings[0].message) + self.assertIn("OBJECTIVE_INNER_BOUND", msg) + self.assertIn("(recv)", msg) + + def test_sweep_healthy_run_emits_no_warnings(self): + # Multiple healthy buffers (the false-positive regression guard). + good_shutdown = RecvArray(1) + good_shutdown._array[0] = 1.0 + good_shutdown._array[-1] = 1.0 + good_shutdown._id = 1 + + good_nonant_recv = RecvArray(3) + for i, v in enumerate([0.0, 1.0, 2.0]): + good_nonant_recv._array[i] = v + good_nonant_recv._array[-1] = 1.0 + good_nonant_recv._id = 1 + + good_inner_recv = RecvArray(1) + good_inner_recv._array[0] = 12.5 + good_inner_recv._array[-1] = 1.0 + good_inner_recv._id = 1 + + good_nonant_send = SendArray(3) + _publish(good_nonant_send, [0.0, 1.0, 2.0]) + + good_outer_send = SendArray(1) + _publish(good_outer_send, [10.0]) + + stub = _make_spoke_stub( + good_shutdown, + inspect_on=True, + extra_recv={ + (Field.NONANT, 1): good_nonant_recv, + (Field.OBJECTIVE_INNER_BOUND, 1): good_inner_recv, + }, + send={ + Field.NONANT: good_nonant_send, + Field.OBJECTIVE_OUTER_BOUND: good_outer_send, + }, + nonant_length=3, + ) + + with warnings.catch_warnings(): + warnings.simplefilter("error", RuntimeWarning) + fired = Spoke.got_kill_signal(stub) + self.assertTrue(fired) + + +if __name__ == "__main__": + unittest.main() diff --git a/mpisppy/tests/test_comm_lor_check.py b/mpisppy/tests/test_comm_lor_check.py new file mode 100644 index 000000000..084aa1606 --- /dev/null +++ b/mpisppy/tests/test_comm_lor_check.py @@ -0,0 +1,81 @@ +############################################################################### +# mpi-sppy: MPI-based Stochastic Programming in PYthon +# +# Copyright (c) 2024, Lawrence Livermore National Security, LLC, Alliance for +# Sustainable Energy, LLC, The Regents of the University of California, et al. +# All rights reserved. Please see the files COPYRIGHT.md and LICENSE.md for +# full copyright and license information. +############################################################################### +"""Tests for mpisppy.debug_utils.comm_lor_check. + +The all-zero path is exercised with the real (mock) MPI.COMM_WORLD. The +non-zero path is exercised with a fake comm whose Allreduce writes 1 into +the recv buffer, so the failure-print code path runs without changing the +function under test. +""" + +import contextlib +import io +import unittest + +from mpisppy import MPI +from mpisppy.debug_utils import comm_lor_check + + +class _FakeNonzeroComm: + """Stand-in for an MPI comm that forces the LOR result to 1. + + Get_rank() returns 0 so START/STOP also print, letting the test confirm + both the announcement and the failure-print run in the same call. + """ + + def Get_rank(self): + return 0 + + def Allreduce(self, sendbuf, recvbuf, op=None): + recvbuf[0] = 1 + + +class TestCommLorCheck(unittest.TestCase): + + def test_prints_start_and_stop_with_comm_name(self): + # Single-rank MPI.COMM_WORLD works for both mpi4py and the mock shim. + buf = io.StringIO() + with contextlib.redirect_stdout(buf): + comm_lor_check(MPI.COMM_WORLD, "world") + out = buf.getvalue() + self.assertIn("START", out) + self.assertIn("STOP", out) + self.assertIn("'world'", out) + + def test_no_nonzero_announcement_and_no_stack_dump(self): + out_buf = io.StringIO() + err_buf = io.StringIO() + with contextlib.redirect_stdout(out_buf), \ + contextlib.redirect_stderr(err_buf): + comm_lor_check(MPI.COMM_WORLD, "world") + # NONZERO header goes to stdout; stack trace goes to stderr. + self.assertNotIn("NONZERO", out_buf.getvalue()) + self.assertNotIn('File "', err_buf.getvalue()) + + def test_nonzero_path_prints_value_name_and_stack_dump(self): + out_buf = io.StringIO() + err_buf = io.StringIO() + with contextlib.redirect_stdout(out_buf), \ + contextlib.redirect_stderr(err_buf): + comm_lor_check(_FakeNonzeroComm(), "broken") + out = out_buf.getvalue() + err = err_buf.getvalue() + self.assertIn("NONZERO", out) + self.assertIn("value=1", out) + self.assertIn("'broken'", out) + self.assertIn("rank=0", out) + # traceback.print_stack writes 'File "..."' frames to stderr. + self.assertIn('File "', err) + # START/STOP still announced on this rank. + self.assertIn("START", out) + self.assertIn("STOP", out) + + +if __name__ == "__main__": + unittest.main() diff --git a/mpisppy/utils/cfg_vanilla.py b/mpisppy/utils/cfg_vanilla.py index 4b53e056d..19f3cd2ad 100644 --- a/mpisppy/utils/cfg_vanilla.py +++ b/mpisppy/utils/cfg_vanilla.py @@ -82,6 +82,7 @@ def shared_options(cfg, is_hub=False): # Optional initial xhat candidate file (.npy); None disables. # Consumed by XhatInnerBoundBase._try_file_xhat. "xhat_from_file" : cfg.get("xhat_from_file", None), + "inspect_buffers_on_shutdown" : cfg.get("inspect_buffers_on_shutdown", False), # Optional filename prefix; if set, _BoundSpoke.update_if_improving # writes a first-stage solution snapshot on each new best incumbent. "incumbent_on_improvement_filename_prefix" : cfg.get( @@ -91,7 +92,7 @@ def shared_options(cfg, is_hub=False): # axis 2: any CLI flags below override file entries at the same # predicate. Load and apply it first so the rest of the axis-2 # chain can overlay on top. - if _hasit(cfg, "solver_options_file"): + if cfg.get("solver_options_file"): # treats None and "" as unset file_data = sputils.load_solver_options_file(cfg.solver_options_file) shoptions["iter0_solver_options"].update(file_data["default"]) shoptions["iter0_solver_options"].update(file_data["iter0"]) @@ -175,7 +176,7 @@ def apply_solver_specs(name, spoke, cfg): options["iterk_solver_options"].update(spoke_file_blocks["iterk"]) options["solver_options_layers"].extend( sputils.options_file_section_to_layers(spoke_file_blocks)) - if _hasit(cfg, name+"_solver_options_file"): + if cfg.get(name+"_solver_options_file"): # treats None and "" as unset spoke_file_data = sputils.load_solver_options_file( cfg.get(name+"_solver_options_file")) # Per-spoke files only consume their own predicates; the diff --git a/mpisppy/utils/config.py b/mpisppy/utils/config.py index 0c040fb68..d3235701e 100644 --- a/mpisppy/utils/config.py +++ b/mpisppy/utils/config.py @@ -230,6 +230,13 @@ def popular_args(self): domain=bool, default=False) + self.add_to_config("inspect_buffers_on_shutdown", + description="When a spoke detects a shutdown signal, run " + "mpisppy.debug_utils.buffer_inspect on the SHUTDOWN receive " + "buffer and print any findings. Off by default.", + domain=bool, + default=False) + self.add_to_config("warmstart_subproblems", description="Warmstart subproblems from prior solution.", domain=bool,