Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 111 additions & 9 deletions src/electrai/entrypoints/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import os
from pathlib import Path
from types import SimpleNamespace
Expand All @@ -11,6 +12,48 @@

from electrai.lightning import LightningGenerator

logger = logging.getLogger(__name__)


def _resolve_checkpoint(cfg) -> Path:
"""Find the best available checkpoint from config.

Resolution order:
1. cfg.ckpt_file — explicit path to a specific .ckpt file
2. cfg.ckpt_path itself, if it points to a .ckpt file
3. cfg.ckpt_path / "last.ckpt"
4. cfg.ckpt_path / "best.ckpt"
5. Latest ckpt_*.ckpt in cfg.ckpt_path (highest epoch by lexicographic sort)
"""
ckpt_file = getattr(cfg, "ckpt_file", None)
if ckpt_file is not None:
ckpt = Path(ckpt_file)
if ckpt.exists():
return ckpt
raise FileNotFoundError(f"Checkpoint not found: {ckpt}")

ckpt_path = Path(getattr(cfg, "ckpt_path", "./checkpoints"))

# If ckpt_path is itself a file, use it directly
if ckpt_path.is_file():
return ckpt_path

for name in ("last.ckpt", "best.ckpt"):
candidate = ckpt_path / name
if candidate.exists():
return candidate

# Glob for ckpt_*.ckpt and pick the latest epoch by lexicographic sort
candidates = sorted(ckpt_path.glob("ckpt_*.ckpt"))
if candidates:
return candidates[-1]

raise FileNotFoundError(
f"No checkpoint found in {ckpt_path}. "
"Set ckpt_file to an explicit path, or ensure ckpt_path contains "
"last.ckpt, best.ckpt, or ckpt_*.ckpt files."
)


def test(args):
# -----------------------------
Expand All @@ -30,12 +73,22 @@ def test(args):
# Model (LightningModule handles architecture + loss + optimizer)
# -----------------------------
lit_model = LightningGenerator(cfg)
lit_model.test_cfg = SimpleNamespace(log_dir=cfg.log_dir, out_dir=cfg.out_dir)

# -----------------------------
# Callback
# W&B (optional)
# -----------------------------
ckpt_path = Path(getattr(cfg, "ckpt_path", "./checkpoints"))
wandb_mode = getattr(cfg, "wandb_mode", "disabled").lower()
os.environ["WANDB_MODE"] = wandb_mode
if wandb_mode != "disabled":
from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(
project=getattr(cfg, "wb_pname", "electrai"),
entity=getattr(cfg, "entity", None),
config=vars(cfg),
)
else:
wandb_logger = None

# -----------------------------
# Trainer
Expand All @@ -55,23 +108,72 @@ def test(args):
world_size = int(os.environ.get("WORLD_SIZE", local_world_size))
num_nodes = max(1, world_size // local_world_size)
trainer = Trainer(
logger=None,
logger=wandb_logger,
callbacks=None,
accelerator="gpu" if torch.cuda.is_available() else "cpu",
devices="auto",
num_nodes=num_nodes,
precision=cfg.precision,
precision=getattr(cfg, "model_precision", getattr(cfg, "precision", 32)),
)

lit_model.test_cfg = SimpleNamespace(
log_dir=log_dir, out_dir=out_dir, tmp_dir=tmp_dir, save_pred=cfg.save_pred
)

# -----------------------------
# Train
# Resolve checkpoint and run test
# -----------------------------
ckpt = ckpt_path / "last.ckpt"
if not ckpt.exists():
raise FileNotFoundError(f"Checkpoint not found: {ckpt}")
ckpt = _resolve_checkpoint(cfg)
logger.info("Using checkpoint: %s", ckpt)

trainer.test(model=lit_model, datamodule=datamodule, ckpt_path=ckpt)

# -----------------------------
# Post-test analysis
# -----------------------------
metrics_csv = log_dir / "metrics.csv"
if metrics_csv.exists():
from electrai.scripts.analyze.summarize import plot_distribution, summarize

summary_text = summarize(metrics_csv, output_dir=log_dir)
logger.info("\n%s", summary_text)
plot_distribution(metrics_csv, output_dir=log_dir)

if wandb_logger is not None:
from electrai.scripts.analyze.summarize import log_to_wandb

log_to_wandb(metrics_csv, output_dir=log_dir)

# Optional: saturation analysis (always possible with enriched CSV)
analyze_cfg = getattr(cfg, "analyze", None)
run_analysis = analyze_cfg is None or getattr(analyze_cfg, "enabled", True)

if run_analysis:
from electrai.scripts.analyze.analyze_saturation import analyze_metrics

saturation_dir = log_dir / "saturation"
saturation_dir.mkdir(exist_ok=True, parents=True)
try:
analyze_metrics(metrics_csv, saturation_dir)
except (KeyError, ValueError) as e:
logger.warning("Saturation analysis skipped: %s", e)

# Tail analysis requires metadata CSV
metadata_path = (
getattr(analyze_cfg, "metadata", None) if analyze_cfg else None
)
if metadata_path is not None:
from electrai.scripts.analyze.analyze_tail import main as tail_main

tail_dir = log_dir / "tail"
tail_dir.mkdir(exist_ok=True, parents=True)
tail_main(
[
"--metrics",
str(metrics_csv),
"--metadata",
str(metadata_path),
"--output-dir",
str(tail_dir),
]
)
45 changes: 38 additions & 7 deletions src/electrai/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def test_step(self, batch):

self.log("test_loss", loss, prog_bar=True, sync_dist=True)

# Per-sample statistics over spatial dims (keep batch dim)
spatial_dims = tuple(range(1, preds.ndim)) # all dims except batch
y_cpu = (
torch.cat([t.unsqueeze(0) for t in y]).detach().cpu()
if isinstance(y, list)
Expand All @@ -122,15 +124,19 @@ def test_step(self, batch):
"target": y_cpu,
"index": indices,
"nmae": loss.detach().cpu(),
"duration": elapsed,
"max_pred": preds.amax(dim=spatial_dims).detach().cpu(),
"max_target": y_cpu.amax(dim=spatial_dims),
"mean_pred": preds.mean(dim=spatial_dims).detach().cpu(),
"mean_target": y_cpu.mean(dim=spatial_dims),
"num_electrons": y_cpu.sum(dim=spatial_dims),
"batch_duration_ms": elapsed,
}
if self.save_pred:
out["pred"] = preds.detach().cpu()
return out

def on_test_batch_end(self, outputs, _batch, batch_idx):
indices = outputs["index"]
nmae = outputs["nmae"]

if self.save_pred:
preds = outputs["pred"]
Expand All @@ -141,14 +147,36 @@ def on_test_batch_end(self, outputs, _batch, batch_idx):
preds[i].squeeze(0).cpu().numpy(),
)

if isinstance(nmae, torch.Tensor) and nmae.ndim == 0:
nmae = nmae.unsqueeze(0)
# Ensure scalar tensors are iterable (batch_size=1 produces 0-d tensors)
per_sample_keys = (
"max_pred",
"max_target",
"mean_pred",
"mean_target",
"num_electrons",
)
for key in per_sample_keys:
val = outputs[key]
if isinstance(val, torch.Tensor) and val.ndim == 0:
outputs[key] = val.unsqueeze(0)

n_samples = len(indices)
avg_duration_ms = outputs["batch_duration_ms"] / n_samples
# nmae is the batch-averaged scalar; broadcast it to every sample row
nmae_val = outputs["nmae"].item()

tmp_csv = (
self.tmp_dir / f"metrics_rank_{self.global_rank}_batch_{batch_idx}.csv"
)
with tmp_csv.open("w") as f:
for idx, n in zip(indices, nmae, strict=True):
f.write(f"rank_{self.global_rank},{idx},{n.item()}\n")
for i, idx in enumerate(indices):
f.write(
f"rank_{self.global_rank},{idx},"
f"{nmae_val},"
f"{outputs['max_pred'][i].item()},{outputs['max_target'][i].item()},"
f"{outputs['mean_pred'][i].item()},{outputs['mean_target'][i].item()},"
f"{outputs['num_electrons'][i].item()},{avg_duration_ms}\n"
)

def on_test_epoch_end(self):
is_dist = dist.is_available() and dist.is_initialized()
Expand Down Expand Up @@ -183,7 +211,10 @@ def on_test_epoch_end(self):
)

with final_csv.open("w") as f_out:
f_out.write("rank,index,nmae\n")
f_out.write(
"rank,index,nmae,max_pred,max_target,"
"mean_pred,mean_target,num_electrons,avg_duration_ms\n"
)
for tmp_csv in all_tmp_csvs:
with tmp_csv.open() as f_in:
for line in f_in:
Expand Down
Empty file.
Empty file.
Loading
Loading