diff --git a/package.json b/package.json
new file mode 100644
index 00000000..668e1889
--- /dev/null
+++ b/package.json
@@ -0,0 +1 @@
+{"dependencies": {}}
\ No newline at end of file
diff --git a/scripts/api.sh b/scripts/api.sh
index b69f0e34..9b0b0cf4 100644
--- a/scripts/api.sh
+++ b/scripts/api.sh
@@ -1,4 +1,4 @@
#!/bin/bash
cd workbench
-uvicorn _api.main:app --host 0.0.0.0 --port 8000 --reload
\ No newline at end of file
+python -m uvicorn _api.main:app --host 0.0.0.0 --port 8000 --reload
\ No newline at end of file
diff --git a/workbench/_api/main.py b/workbench/_api/main.py
index 08755b02..8020ad93 100644
--- a/workbench/_api/main.py
+++ b/workbench/_api/main.py
@@ -1,13 +1,17 @@
-from fastapi import FastAPI
+from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import JSONResponse
import logging
import os
+import traceback
import anyio
-from .routes import lens, patch, models, logit_lens, activation_patching
+from .routes import lens, patch, models, logit_lens, activation_patching, causal_mediation
from .state import AppState
-from dotenv import load_dotenv; load_dotenv()
+from pathlib import Path
+from dotenv import load_dotenv
+load_dotenv(Path(__file__).resolve().parent / ".env")
# Configure logging
logging.basicConfig(
@@ -59,9 +63,19 @@ def fastapi_app():
app.include_router(lens, prefix="/lens")
app.include_router(logit_lens, prefix="/logit_lens")
app.include_router(activation_patching, prefix="/activation_patching")
+ app.include_router(causal_mediation, prefix="/causal_mediation", tags=["causal_mediation"])
app.include_router(patch, prefix="/patch")
app.include_router(models, prefix="/models")
+ @app.exception_handler(Exception)
+ async def global_exception_handler(request: Request, exc: Exception):
+ tb = traceback.format_exception(type(exc), exc, exc.__traceback__)
+ logging.error(f"Unhandled exception: {''.join(tb)}")
+ return JSONResponse(
+ status_code=500,
+ content={"detail": str(exc), "traceback": ''.join(tb[-3:])},
+ )
+
app.state.m = AppState()
return app
diff --git a/workbench/_api/routes/__init__.py b/workbench/_api/routes/__init__.py
index a1b6e8a6..ca438fa6 100644
--- a/workbench/_api/routes/__init__.py
+++ b/workbench/_api/routes/__init__.py
@@ -3,9 +3,17 @@
from .models import router as models
from .logit_lens import router as logit_lens
from .activation_patching import router as activation_patching
+from .causal_mediation import router as causal_mediation
from nnsight import ndif
import nnsightful
ndif.register(nnsightful)
-__all__ = ["lens", "patch", "models", "logit_lens", "activation_patching"]
\ No newline at end of file
+__all__ = [
+ "lens",
+ "patch",
+ "models",
+ "logit_lens",
+ "activation_patching",
+ "causal_mediation",
+]
\ No newline at end of file
diff --git a/workbench/_api/routes/causal_mediation.py b/workbench/_api/routes/causal_mediation.py
new file mode 100644
index 00000000..1e3f69a5
--- /dev/null
+++ b/workbench/_api/routes/causal_mediation.py
@@ -0,0 +1,214 @@
+from __future__ import annotations
+
+from typing import Any
+
+import torch
+from fastapi import APIRouter, Depends
+from pydantic import BaseModel
+
+from ..auth import require_user_email
+from ..data_models import NDIFResponse
+from ..state import AppState, get_state
+
+from nnsightful.types import LogitLensData
+
+router = APIRouter()
+
+
+class CausalMediationRequest(BaseModel):
+ model: str
+ src_prompt: str
+ tgt_prompt: str
+ src_token_pos: int
+ src_layer: int
+ tgt_token_pos: int
+ tgt_layer: int
+ topk: int = 5
+ include_entropy: bool = True
+
+
+class CausalMediationResponse(NDIFResponse):
+ """Identical shape to LogitLensResponse so the frontend can reuse the
+ existing logit-lens transform/renderer."""
+ data: LogitLensData | None = None
+
+
+def _format_lens(
+ logits: torch.Tensor,
+ tokenizer,
+ model_name: str,
+ input_tokens: list[str],
+ n_layers: int,
+ *,
+ top_k: int = 5,
+ include_entropy: bool = True,
+) -> dict[str, Any]:
+ """Inlined mirror of the local `format()` inside
+ `nnsightful.tools.logit_lens._run`. Turns a [L, T, V] logits tensor into
+ the dict shape consumed by `LogitLensData(**...)`.
+
+ Why: `LogitLensTool` doesn't override `_format`, so calling
+ `logit_lens_tool._format(...)` falls through to the abstract `Tool._format`
+ in `nnsightful/tools/_base.py` whose body is `...` — i.e. returns `None`.
+ """
+ layers = list(range(n_layers))
+ positions = list(range(len(input_tokens)))
+
+ if include_entropy:
+ log_p = torch.nn.functional.log_softmax(logits, dim=-1)
+ p = log_p.exp()
+ entropy = torch.round(-(p * log_p).sum(dim=-1), decimals=3).tolist()
+ else:
+ entropy = None
+
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+ logits.to("cpu") # free memory
+
+ _, top_indices = torch.topk(probs, k=top_k, dim=-1)
+
+ topks = [
+ [tokenizer.batch_decode(torch.tensor(pos).unsqueeze(dim=1)) for pos in layer]
+ for layer in top_indices.tolist()
+ ]
+
+ unique_indices = [
+ torch.unique(top_indices[:, pi, :].flatten(), sorted=False).tolist()
+ for pi in range(top_indices.shape[1])
+ ]
+ probs = probs.permute(1, 2, 0)
+ trajectories = [
+ {
+ tokenizer.decode(token): torch.round(probs[pos_idx][token], decimals=3).tolist()
+ for token in pos
+ }
+ for pos_idx, pos in enumerate(unique_indices)
+ ]
+
+ return {
+ "meta": {"version": 2, "timestamp": "3h", "model": model_name},
+ "layers": layers,
+ "input": input_tokens,
+ "tracked": trajectories,
+ "topk": topks,
+ "entropy": entropy,
+ "positions": positions,
+ }
+
+
+def _run_causal_mediation(
+ model,
+ src_prompt: str,
+ tgt_prompt: str,
+ src_token_pos: int,
+ src_layer: int,
+ tgt_token_pos: int,
+ tgt_layer: int,
+ *,
+ remote: bool = False,
+ backend=None,
+) -> dict[str, Any]:
+ """Capture the source residual at (src_layer, src_token_pos), patch it
+ into the target prompt at (tgt_layer, tgt_token_pos), then run a logit
+ lens over the *patched* forward pass.
+
+ Patching pattern mirrors `nnsightful.tools.activation_patching._run`:
+ a slice-assign on `model.layers_output[i]` is sufficient to register the
+ intervention and have it propagate through subsequent layers — no
+ explicit `model.layers[i].output = (...)` reassignment is needed.
+ """
+ n_layers = model.num_layers
+
+ with torch.no_grad():
+ with model.session(remote=remote, backend=backend):
+ # 1) Source pass — capture the residual at (src_layer, src_token_pos).
+ with model.trace(src_prompt):
+ src_hidden = model.layers_output[src_layer][0, src_token_pos].save()
+
+ # 2) Target pass — at tgt_layer, slice-assign the source residual
+ # into the target's tgt_token_pos. project_on_vocab at every
+ # layer gives us a logit-lens grid over the patched pass.
+ with model.trace(tgt_prompt):
+ per_layer_logits = []
+ for i in range(n_layers):
+ hs = model.layers_output[i]
+ if i == tgt_layer:
+ hs[0, tgt_token_pos][:] = src_hidden
+ per_layer_logits.append(model.project_on_vocab(hs))
+ # Stack to a single [L, T, V] tensor so backend() returns a
+ # known shape on the remote path (one saved key: "logits").
+ logits = torch.cat(per_layer_logits, dim=0).save()
+
+ if remote and backend is not None:
+ return {"job_id": backend.job_id}
+
+ return {"logits": logits}
+
+
+@router.post("/start", response_model=CausalMediationResponse)
+async def start_causal_mediation(
+ req: CausalMediationRequest,
+ state: AppState = Depends(get_state),
+ user_email: str = Depends(require_user_email),
+):
+ model = state[req.model]
+ backend = state.make_backend(model=model)
+
+ raw = _run_causal_mediation(
+ model,
+ req.src_prompt,
+ req.tgt_prompt,
+ req.src_token_pos,
+ req.src_layer,
+ req.tgt_token_pos,
+ req.tgt_layer,
+ remote=state.remote,
+ backend=backend,
+ )
+
+ if "job_id" in raw:
+ return {"job_id": raw["job_id"]}
+
+ input_tokens = [
+ str(model.tokenizer.decode(token))
+ for token in model.tokenizer.encode(req.tgt_prompt)
+ ]
+ data = _format_lens(
+ raw["logits"],
+ tokenizer=model.tokenizer,
+ model_name=req.model,
+ input_tokens=input_tokens,
+ n_layers=model.num_layers,
+ top_k=req.topk,
+ include_entropy=req.include_entropy,
+ )
+ return {"data": data}
+
+
+@router.post("/results/{job_id}", response_model=CausalMediationResponse)
+async def collect_causal_mediation(
+ job_id: str,
+ req: CausalMediationRequest,
+ state: AppState = Depends(get_state),
+ user_email: str = Depends(require_user_email),
+):
+ backend = state.make_backend(job_id=job_id)
+ results = backend()
+
+ model = state[req.model]
+ tokenizer = model.tokenizer
+ input_tokens = [
+ str(tokenizer.decode(token))
+ for token in tokenizer.encode(req.tgt_prompt)
+ ]
+
+ data = _format_lens(
+ results["logits"],
+ tokenizer=tokenizer,
+ model_name=req.model,
+ input_tokens=input_tokens,
+ n_layers=model.num_layers,
+ top_k=req.topk,
+ include_entropy=req.include_entropy,
+ )
+
+ return {"data": data}
diff --git a/workbench/_web/bun.lock b/workbench/_web/bun.lock
index 2acc8147..24cfc38f 100644
--- a/workbench/_web/bun.lock
+++ b/workbench/_web/bun.lock
@@ -1,5 +1,6 @@
{
"lockfileVersion": 1,
+ "configVersion": 0,
"workspaces": {
"": {
"name": "nextjs-shadcn",
@@ -34,6 +35,7 @@
"d3-delaunay": "^6.0.4",
"dotenv": "^17.2.1",
"drizzle-orm": "^0.44.4",
+ "edulogitlens": "github:jon-bell/edulogitlens#cm-only",
"framer-motion": "^12.23.22",
"html-to-image": "^1.11.13",
"lexical": "^0.34.0",
@@ -662,6 +664,12 @@
"@radix-ui/rect": ["@radix-ui/rect@1.1.1", "", {}, "sha512-HPwpGIzkl28mWyZqG52jiqDJ12waP11Pa1lGoiyUkIEuMLBP0oeK/C89esbXrxsky5we7dfd8U58nm0SgAWpVw=="],
+ "@react-dnd/asap": ["@react-dnd/asap@5.0.2", "", {}, "sha512-WLyfoHvxhs0V9U+GTsGilGgf2QsPl6ZZ44fnv0/b8T3nQyvzxidxsg/ZltbWssbsRDlYW8UKSQMTGotuTotZ6A=="],
+
+ "@react-dnd/invariant": ["@react-dnd/invariant@4.0.2", "", {}, "sha512-xKCTqAK/FFauOM9Ta2pswIyT3D8AQlfrYdOi/toTPEhqCuAs1v5tcJ3Y08Izh1cJ5Jchwy9SeAXmMg6zrKs2iw=="],
+
+ "@react-dnd/shallowequal": ["@react-dnd/shallowequal@4.0.2", "", {}, "sha512-/RVXdLvJxLg4QKvMoM5WlwNR9ViO9z8B/qPcc+C0Sa/teJY7QG7kJ441DwzOjMYEY7GmU4dj5EcGHIkKZiQZCA=="],
+
"@react-pdf/fns": ["@react-pdf/fns@3.1.3", "", {}, "sha512-0I7pApDr1/RLAKbizuLy/IHTEa93LSPy/bEwYniboC3Xqnp6Od8xFJKbKEzGw2wh/5zKFFwl00g4t9RwgIMc3w=="],
"@react-pdf/font": ["@react-pdf/font@4.0.6", "", { "dependencies": { "@react-pdf/pdfkit": "^5.0.0", "@react-pdf/types": "^2.10.0", "fontkit": "^2.0.2", "is-url": "^1.2.4" } }, "sha512-1RxR/hTyZcbgjESUjrMms574xuS9PLB4ovqQx6jvgdrIHXUyeUtSH6i3Szw1qVfUnA9MfaEm1FBuydQeJD39BQ=="],
@@ -1074,6 +1082,8 @@
"dlv": ["dlv@1.1.3", "", {}, "sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA=="],
+ "dnd-core": ["dnd-core@16.0.1", "", { "dependencies": { "@react-dnd/asap": "^5.0.1", "@react-dnd/invariant": "^4.0.1", "redux": "^4.2.0" } }, "sha512-HK294sl7tbw6F6IeuK16YSBUoorvHpY8RHO+9yFfaJyCDVb6n7PRcezrOEOa2SBCqiYpemh5Jx20ZcjKdFAVng=="],
+
"doctrine": ["doctrine@2.1.0", "", { "dependencies": { "esutils": "^2.0.2" } }, "sha512-35mSku4ZXK0vfCuHEDAwt55dg2jNajHZ1odvF+8SSr82EsZY4QmXfuWso8oEd8zRhVObSN18aM0CjSdoBX7zIw=="],
"dom-helpers": ["dom-helpers@5.2.1", "", { "dependencies": { "@babel/runtime": "^7.8.7", "csstype": "^3.0.2" } }, "sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA=="],
@@ -1090,6 +1100,8 @@
"easy-table": ["easy-table@1.1.0", "", { "optionalDependencies": { "wcwidth": ">=1.0.1" } }, "sha512-oq33hWOSSnl2Hoh00tZWaIPi1ievrD9aFG82/IgjlycAnW9hHx5PkJiXpxPsgEE+H7BsbVQXFVFST8TEXS6/pA=="],
+ "edulogitlens": ["edulogitlens@github:jon-bell/edulogitlens#416f56b", { "dependencies": { "lucide-react": "^0.487.0", "motion": "^12.23.24", "react-dnd": "^16.0.1", "react-dnd-html5-backend": "^16.0.1" }, "peerDependencies": { "react": "^18.0.0 || ^19.0.0", "react-dom": "^18.0.0 || ^19.0.0" } }, "jon-bell-edulogitlens-416f56b", "sha512-ISmUpe0w3MzO32GVLKWKLolch7/0awOjZKTsQH/mXG5v+ai0HTpGFnNHi9Lb3MDuPW3UDWqwpiAD60ii3oECog=="],
+
"electron-to-chromium": ["electron-to-chromium@1.5.200", "", {}, "sha512-rFCxROw7aOe4uPTfIAx+rXv9cEcGx+buAF4npnhtTqCJk5KDFRnh3+KYj7rdVh6lsFt5/aPs+Irj9rZ33WMA7w=="],
"emoji-regex": ["emoji-regex@9.2.2", "", {}, "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg=="],
@@ -1630,6 +1642,10 @@
"react": ["react@18.3.1", "", { "dependencies": { "loose-envify": "^1.1.0" } }, "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ=="],
+ "react-dnd": ["react-dnd@16.0.1", "", { "dependencies": { "@react-dnd/invariant": "^4.0.1", "@react-dnd/shallowequal": "^4.0.1", "dnd-core": "^16.0.1", "fast-deep-equal": "^3.1.3", "hoist-non-react-statics": "^3.3.2" }, "peerDependencies": { "@types/hoist-non-react-statics": ">= 3.3.1", "@types/node": ">= 12", "@types/react": ">= 16", "react": ">= 16.14" }, "optionalPeers": ["@types/hoist-non-react-statics", "@types/node", "@types/react"] }, "sha512-QeoM/i73HHu2XF9aKksIUuamHPDvRglEwdHL4jsp784BgUuWcg6mzfxT0QDdQz8Wj0qyRKx2eMg8iZtWvU4E2Q=="],
+
+ "react-dnd-html5-backend": ["react-dnd-html5-backend@16.0.1", "", { "dependencies": { "dnd-core": "^16.0.1" } }, "sha512-Wu3dw5aDJmOGw8WjH1I1/yTH+vlXEL4vmjk5p+MHxP8HuHJS1lAGeIdG/hze1AvNeXWo/JgULV87LyQOr+r5jw=="],
+
"react-dom": ["react-dom@18.3.1", "", { "dependencies": { "loose-envify": "^1.1.0", "scheduler": "^0.23.2" }, "peerDependencies": { "react": "^18.3.1" } }, "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw=="],
"react-error-boundary": ["react-error-boundary@3.1.4", "", { "dependencies": { "@babel/runtime": "^7.12.5" }, "peerDependencies": { "react": ">=16.13.1" } }, "sha512-uM9uPzZJTF6wRQORmSrvOIgt4lJ9MC1sNgEOj2XGsDTRE4kmpWxg7ENK9EWNKJRMAOY9z0MuF4yIfl6gp4sotA=="],
@@ -1660,6 +1676,8 @@
"readdirp": ["readdirp@3.6.0", "", { "dependencies": { "picomatch": "^2.2.1" } }, "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA=="],
+ "redux": ["redux@4.2.1", "", { "dependencies": { "@babel/runtime": "^7.9.2" } }, "sha512-LAUYz4lc+Do8/g7aeRa8JkyDErK6ekstQaqWQrNRW//MY1TvCEpMtpTWvlQ+FPbWCx+Xixu/6SHt5N0HR+SB4w=="],
+
"reflect.getprototypeof": ["reflect.getprototypeof@1.0.10", "", { "dependencies": { "call-bind": "^1.0.8", "define-properties": "^1.2.1", "es-abstract": "^1.23.9", "es-errors": "^1.3.0", "es-object-atoms": "^1.0.0", "get-intrinsic": "^1.2.7", "get-proto": "^1.0.1", "which-builtin-type": "^1.2.1" } }, "sha512-00o4I+DVrefhv+nX0ulyi3biSHCPDe+yLv5o/p6d/UVlirijB8E16FtfwSAi4g3tcqrQ4lRAqQSoFEZJehYEcw=="],
"regexp.prototype.flags": ["regexp.prototype.flags@1.5.4", "", { "dependencies": { "call-bind": "^1.0.8", "define-properties": "^1.2.1", "es-errors": "^1.3.0", "get-proto": "^1.0.1", "gopd": "^1.2.0", "set-function-name": "^2.0.2" } }, "sha512-dYqgNSZbDwkaJ2ceRd9ojCGjBq+mOm9LmtXnAnEGyHhN/5R7iDW2TRw3h+o/jCFxus3P2LfWIIiwowAjANm7IA=="],
@@ -2004,6 +2022,10 @@
"defaults/clone": ["clone@1.0.4", "", {}, "sha512-JQHZ2QMW6l3aH/j6xCqQThY/9OH4D/9ls34cgkUBiEeocRTU04tHfKPBsUK1PqZCUQM7GiA0IIXJSuXHI64Kbg=="],
+ "edulogitlens/lucide-react": ["lucide-react@0.487.0", "", { "peerDependencies": { "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-aKqhOQ+YmFnwq8dWgGjOuLc8V1R9/c/yOd+zDY4+ohsR2Jo05lSGc3WsstYPIzcTpeosN7LoCkLReUUITvaIvw=="],
+
+ "edulogitlens/motion": ["motion@12.38.0", "", { "dependencies": { "framer-motion": "^12.38.0", "tslib": "^2.4.0" }, "peerDependencies": { "@emotion/is-prop-valid": "*", "react": "^18.0.0 || ^19.0.0", "react-dom": "^18.0.0 || ^19.0.0" }, "optionalPeers": ["@emotion/is-prop-valid", "react", "react-dom"] }, "sha512-uYfXzeHlgThchzwz5Te47dlv5JOUC7OB4rjJ/7XTUgtBZD8CchMN8qEJ4ZVsUmTyYA44zjV0fBwsiktRuFnn+w=="],
+
"error-ex/is-arrayish": ["is-arrayish@0.2.1", "", {}, "sha512-zz06S8t0ozoDXMG+ube26zeCTNXcKIPJZJi8hBrF4idCLms4CG9QtK7qBl1boi5ODzFpjswb5JPmHCbMpjaYzg=="],
"esbuild-register/debug": ["debug@4.4.1", "", { "dependencies": { "ms": "^2.1.3" } }, "sha512-KcKCqiftBJcZr++7ykoDIEwSa3XWowTfNPo92BYxjXiyYEVrUQh2aLyhxBCwww+heortUFxEJYcRzosstTEBYQ=="],
@@ -2182,6 +2204,8 @@
"@typescript-eslint/typescript-estree/minimatch/brace-expansion": ["brace-expansion@2.0.2", "", { "dependencies": { "balanced-match": "^1.0.0" } }, "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ=="],
+ "edulogitlens/motion/framer-motion": ["framer-motion@12.38.0", "", { "dependencies": { "motion-dom": "^12.38.0", "motion-utils": "^12.36.0", "tslib": "^2.4.0" }, "peerDependencies": { "@emotion/is-prop-valid": "*", "react": "^18.0.0 || ^19.0.0", "react-dom": "^18.0.0 || ^19.0.0" }, "optionalPeers": ["@emotion/is-prop-valid", "react", "react-dom"] }, "sha512-rFYkY/pigbcswl1XQSb7q424kSTQ8q6eAC+YUsSKooHQYuLdzdHjrt6uxUC+PRAO++q5IS7+TamgIw1AphxR+g=="],
+
"form-data/mime-types/mime-db": ["mime-db@1.52.0", "", {}, "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg=="],
"glob/minimatch/brace-expansion": ["brace-expansion@2.0.2", "", { "dependencies": { "balanced-match": "^1.0.0" } }, "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ=="],
@@ -2204,6 +2228,10 @@
"@argos-ci/core/sharp/@img/sharp-wasm32/@emnapi/runtime": ["@emnapi/runtime@1.10.0", "", { "dependencies": { "tslib": "^2.4.0" } }, "sha512-ewvYlk86xUoGI0zQRNq/mC+16R1QeDlKQy21Ki3oSYXNgLb45GV1P6A0M+/s6nyCuNDqe5VpaY84BzXGwVbwFA=="],
+ "edulogitlens/motion/framer-motion/motion-dom": ["motion-dom@12.38.0", "", { "dependencies": { "motion-utils": "^12.36.0" } }, "sha512-pdkHLD8QYRp8VfiNLb8xIBJis1byQ9gPT3Jnh2jqfFtAsWUA3dEepDlsWe/xMpO8McV+VdpKVcp+E+TGJEtOoA=="],
+
+ "edulogitlens/motion/framer-motion/motion-utils": ["motion-utils@12.36.0", "", {}, "sha512-eHWisygbiwVvf6PZ1vhaHCLamvkSbPIeAYxWUuL3a2PD/TROgE7FvfHWTIH4vMl798QLfMw15nRqIaRDXTlYRg=="],
+
"rimraf/glob/path-scurry/lru-cache": ["lru-cache@11.2.2", "", {}, "sha512-F9ODfyqML2coTIsQpSkRHnLSZMtkU8Q+mSfcaIyKwy58u+8k5nvAYeiNhsyMARvzNcXJ9QfWVrcPsC9e9rAxtg=="],
}
}
diff --git a/workbench/_web/instrumentation.ts b/workbench/_web/instrumentation.ts
index ba38ab9f..258d4907 100644
--- a/workbench/_web/instrumentation.ts
+++ b/workbench/_web/instrumentation.ts
@@ -6,7 +6,7 @@ export function register() {
export const onRequestError = async (err: Error, request: NextRequest, context: NextResponse) => {
if (process.env.NEXT_RUNTIME === "nodejs") {
- const { getPostHogServer } = require("./src/lib/posthog-server");
+ const { getPostHogServer } = await import("./src/lib/posthog-server");
const posthog = await getPostHogServer();
// Only capture exception if PostHog is enabled
diff --git a/workbench/_web/next.config.js b/workbench/_web/next.config.js
index ed00fba2..5f7d1d7a 100644
--- a/workbench/_web/next.config.js
+++ b/workbench/_web/next.config.js
@@ -8,7 +8,7 @@ const __dirname = path.dirname(fileURLToPath(import.meta.url));
const nextConfig = {
reactStrictMode: true,
- transpilePackages: ["nnsightful"],
+ transpilePackages: ["nnsightful", "edulogitlens"],
turbopack: {
// Expand root so Turbopack can resolve the symlinked nnsightful package
// which lives at ../../nnsightful (outside the default _web/ root)
diff --git a/workbench/_web/package.json b/workbench/_web/package.json
index 5d2b1a74..6fffae2e 100644
--- a/workbench/_web/package.json
+++ b/workbench/_web/package.json
@@ -46,6 +46,7 @@
"d3-delaunay": "^6.0.4",
"dotenv": "^17.2.1",
"drizzle-orm": "^0.44.4",
+ "edulogitlens": "github:jon-bell/edulogitlens#cm-only",
"framer-motion": "^12.23.22",
"html-to-image": "^1.11.13",
"lexical": "^0.34.0",
diff --git a/workbench/_web/src/app/dev/logit-lens-intro/page.tsx b/workbench/_web/src/app/dev/logit-lens-intro/page.tsx
new file mode 100644
index 00000000..fbaa60c7
--- /dev/null
+++ b/workbench/_web/src/app/dev/logit-lens-intro/page.tsx
@@ -0,0 +1,104 @@
+"use client";
+
+import { LogitLensGrid } from "edulogitlens";
+import type { LogitLensData, LogitCell } from "edulogitlens";
+
+function generateMockData(): LogitLensData {
+ const tokens = [
+ "The",
+ "E",
+ "iff",
+ "el",
+ "Tower",
+ "is",
+ "in",
+ "the",
+ "city",
+ "of",
+ "Paris",
+ ",",
+ "France",
+ ".",
+ ];
+
+ const layers = Array.from({ length: 12 }, (_, i) => i);
+
+ const vocab = [
+ "t",
+ "bow",
+ "illi",
+ "Tower",
+ "el",
+ "France",
+ "Paris",
+ "tower",
+ "city",
+ "of",
+ "the",
+ "in",
+ "is",
+ "a",
+ "and",
+ "Eiff",
+ "to",
+ "built",
+ "was",
+ "meters",
+ "at",
+ "by",
+ "from",
+ "with",
+ "on",
+ "for",
+ "an",
+ "stands",
+ "tall",
+ ];
+
+ const data: LogitCell[][] = tokens.map((token) => {
+ return layers.map((_, layerIdx) => {
+ const convergence = layerIdx / layers.length;
+
+ let primaryToken = token;
+ let prob: number;
+
+ if (convergence < 0.3) {
+ primaryToken = vocab[Math.floor(Math.random() * vocab.length)];
+ prob = 0.05 + Math.random() * 0.15;
+ } else if (convergence < 0.6) {
+ primaryToken =
+ Math.random() > 0.5 ? token : vocab[Math.floor(Math.random() * vocab.length)];
+ prob = 0.2 + Math.random() * 0.3;
+ } else {
+ primaryToken = token;
+ prob = 0.5 + convergence * 0.4 + Math.random() * 0.1;
+ }
+ prob = Math.min(prob, 0.95);
+
+ const topTokens: { token: string; prob: number }[] = [{ token: primaryToken, prob }];
+ let remaining = (1 - prob) * 0.4;
+ for (let i = 0; i < 14; i++) {
+ const candidate = vocab[Math.floor(Math.random() * vocab.length)];
+ topTokens.push({ token: candidate, prob: remaining });
+ remaining *= 0.7;
+ }
+
+ return { token: primaryToken, probability: prob, topTokens };
+ });
+ });
+
+ return { tokens, layers, data };
+}
+
+const MOCK_DATA = generateMockData();
+
+export default function DevLogitLensIntroPage() {
+ return (
+
+
Logit Lens Intro — Dev Preview
+
+
+
+
+ );
+}
diff --git a/workbench/_web/src/app/workbench/[workspaceId]/activation-patching/[chartId]/components/TokenSelector.tsx b/workbench/_web/src/app/workbench/[workspaceId]/activation-patching/[chartId]/components/TokenSelector.tsx
index 7d430824..8a463489 100644
--- a/workbench/_web/src/app/workbench/[workspaceId]/activation-patching/[chartId]/components/TokenSelector.tsx
+++ b/workbench/_web/src/app/workbench/[workspaceId]/activation-patching/[chartId]/components/TokenSelector.tsx
@@ -3,7 +3,14 @@
import { useMemo } from "react";
import { X, RotateCcw } from "lucide-react";
import { cn } from "@/lib/utils";
-import Select, { MultiValue, StylesConfig, GroupBase, components } from "react-select";
+import Select, {
+ MultiValue,
+ StylesConfig,
+ GroupBase,
+ components,
+ type MultiValueProps,
+ type OptionProps,
+} from "react-select";
// Option type for react-select
interface TokenOption {
@@ -152,7 +159,7 @@ const selectStyles: StylesConfig> = {
};
// Custom MultiValue component with colored indicator
-const CustomMultiValue = (props: any) => {
+const CustomMultiValue = (props: MultiValueProps>) => {
const color = LINE_COLORS[props.data.value % LINE_COLORS.length];
return (
@@ -169,7 +176,9 @@ const CustomMultiValue = (props: any) => {
onClick={(e) => {
e.preventDefault();
e.stopPropagation();
- props.removeProps.onClick(e);
+ props.removeProps.onClick?.(
+ e as unknown as React.MouseEvent,
+ );
}}
onMouseDown={(e) => {
e.preventDefault();
@@ -184,7 +193,7 @@ const CustomMultiValue = (props: any) => {
};
// Custom Option component with badges for source/target predictions
-const CustomOption = (props: any) => {
+const CustomOption = (props: OptionProps>) => {
const tokenIndex = props.data.value;
const badge = tokenIndex === 0 ? "source pred" : tokenIndex === 1 ? "target pred" : null;
diff --git a/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/components/CMIntroArea.tsx b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/components/CMIntroArea.tsx
new file mode 100644
index 00000000..32406977
--- /dev/null
+++ b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/components/CMIntroArea.tsx
@@ -0,0 +1,460 @@
+"use client";
+
+import { useState, useEffect, useCallback, useMemo, useRef } from "react";
+import { useParams } from "next/navigation";
+import { useQuery } from "@tanstack/react-query";
+import { ModelSelector } from "@/components/ModelSelector";
+import { AlertCircle, Loader2, Play, X } from "lucide-react";
+import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
+import { Button } from "@/components/ui/button";
+import { useTour } from "@reactour/tour";
+import { CMIntroTutorial } from "@/tutorials/cmIntro";
+import { useCMIntroTutorial, hydrateCMIntroTutorial } from "@/stores/useCMIntroTutorial";
+import { getModels } from "@/lib/api/modelsApi";
+import { useWorkspace } from "@/stores/useWorkspace";
+import { encodeText } from "@/actions/tok";
+import { TokenizerLoadError } from "@/actions/errors";
+import { Token } from "@/types/models";
+import { PatchPromptSection } from "@/components/activation-patching/toolkit";
+import { toast } from "sonner";
+import { useCMIntroLogitLens, CMIntroLensResult } from "@/lib/api/cmIntroApi";
+import type { LogitLensIntroData } from "@/types/logitLensIntro";
+
+// Top-1 next-token at the final layer / final input position. Returns null
+// if the lens data is empty/malformed. Mirrors the prediction the
+// activation-patching tool surfaces in italics next to each prompt.
+function getNextTokenPrediction(data: LogitLensIntroData | undefined): string | null {
+ if (!data) return null;
+ const { layers, input, tracked, topk } = data;
+ if (!layers?.length || !input?.length || !tracked?.length || !topk?.length) return null;
+ const finalLayerIdx = layers.length - 1;
+ const lastPosIdx = input.length - 1;
+ const candidates = topk[finalLayerIdx]?.[lastPosIdx];
+ if (!candidates?.length) return null;
+ const posTracked = tracked[lastPosIdx];
+ if (!posTracked) return candidates[0];
+ let bestToken = candidates[0];
+ let bestProb = posTracked[bestToken]?.[finalLayerIdx] ?? 0;
+ for (const token of candidates) {
+ const prob = posTracked[token]?.[finalLayerIdx] ?? 0;
+ if (prob > bestProb) {
+ bestProb = prob;
+ bestToken = token;
+ }
+ }
+ return bestToken;
+}
+
+// Default model for the CM intro. 32 layers reads as a manageable heatmap;
+// index 0 in the model list is the 70B (80 layers), far too many for a primer.
+const DEFAULT_INTRO_MODEL = "meta-llama/Llama-3.1-8B";
+
+interface CMIntroAreaProps {
+ sourcePrompt: string;
+ targetPrompt: string;
+ onSourcePromptChange: (value: string) => void;
+ onTargetPromptChange: (value: string) => void;
+ onLensResult?: (result: CMIntroLensResult, runSrc: string, runTgt: string) => void;
+ lensResult?: CMIntroLensResult | null;
+ lastRunSrcPrompt?: string | null;
+ lastRunTgtPrompt?: string | null;
+}
+
+function useTutorialAutoStart() {
+ const { setSteps, setIsOpen, isOpen } = useTour();
+ const { completed, markCompleted } = useCMIntroTutorial();
+ // Auto-start fires at most once per mount; dismissing then resaving the
+ // localStorage flag prevents a popup loop (same pattern as lens-intro).
+ const autoStartedRef = useRef(false);
+
+ useEffect(() => {
+ hydrateCMIntroTutorial();
+ }, []);
+
+ useEffect(() => {
+ if (autoStartedRef.current || completed || isOpen) return;
+ if (!setSteps || !setIsOpen) return;
+ autoStartedRef.current = true;
+ const steps = CMIntroTutorial.chapters[0]?.steps ?? [];
+ setSteps(steps);
+ const id = setTimeout(() => {
+ setIsOpen(true);
+ markCompleted();
+ }, 600);
+ return () => clearTimeout(id);
+ }, [completed, isOpen, setSteps, setIsOpen, markCompleted]);
+
+ const startTutorial = () => {
+ if (!setSteps || !setIsOpen) return;
+ const steps = CMIntroTutorial.chapters[0]?.steps ?? [];
+ setSteps(steps);
+ setIsOpen(true);
+ };
+
+ return { startTutorial };
+}
+
+export default function CMIntroArea({
+ sourcePrompt,
+ targetPrompt,
+ onSourcePromptChange,
+ onTargetPromptChange,
+ onLensResult,
+ lensResult,
+ lastRunSrcPrompt,
+ lastRunTgtPrompt,
+}: CMIntroAreaProps) {
+ const { chartId } = useParams<{ chartId: string }>();
+ const { selectedModelIdx, setSelectedModelIdx } = useWorkspace();
+
+ const { data: models } = useQuery({
+ queryKey: ["models"],
+ queryFn: getModels,
+ refetchInterval: 120000,
+ });
+
+ // Default to Llama-3.1-8B once when models load, rather than leaving the
+ // workspace default at index 0 (the 70B, 80 layers). Guarded so a later
+ // manual model choice is not overridden.
+ const didDefaultModel = useRef(false);
+ useEffect(() => {
+ if (didDefaultModel.current || !models || models.length === 0) return;
+ didDefaultModel.current = true;
+ const idx = models.findIndex((m) => m.name === DEFAULT_INTRO_MODEL);
+ if (idx !== -1 && idx !== selectedModelIdx) {
+ setSelectedModelIdx(idx);
+ }
+ }, [models, selectedModelIdx, setSelectedModelIdx]);
+
+ const selectedModel = useMemo(() => {
+ if (!models || models.length === 0) return undefined;
+ return models[selectedModelIdx]?.name || models[0].name;
+ }, [models, selectedModelIdx]);
+
+ // Source prompt state
+ const [srcTokens, setSrcTokens] = useState([]);
+ const [srcEditing, setSrcEditing] = useState(true);
+ const [srcTokenizedModel, setSrcTokenizedModel] = useState(null);
+ const srcTextareaRef = useRef(null);
+ const srcTokenContainerRef = useRef(null);
+
+ // Target prompt state
+ const [tgtTokens, setTgtTokens] = useState([]);
+ const [tgtEditing, setTgtEditing] = useState(true);
+ const [tgtTokenizedModel, setTgtTokenizedModel] = useState(null);
+ const tgtTextareaRef = useRef(null);
+ const tgtTokenContainerRef = useRef(null);
+
+ // Auto-focus the source prompt textarea on first mount so new users
+ // can start typing immediately. Only if it's empty (don't steal focus
+ // when revisiting a chart that already has prompts).
+ const didFocusRef = useRef(false);
+ useEffect(() => {
+ if (didFocusRef.current) return;
+ if (!sourcePrompt) {
+ srcTextareaRef.current?.focus();
+ didFocusRef.current = true;
+ }
+ // eslint-disable-next-line react-hooks/exhaustive-deps
+ }, []);
+
+ const tokenize = useCallback(async (text: string, model: string): Promise => {
+ try {
+ return await encodeText(text, model);
+ } catch (error) {
+ if (error instanceof TokenizerLoadError) {
+ toast.error(
+ `Could not load tokenizer for ${model}. The model may be gated and require authentication.`,
+ );
+ } else {
+ toast.error("Failed to tokenize prompt.");
+ }
+ return null;
+ }
+ }, []);
+
+ // Initial tokenize on mount when a model becomes available
+ useEffect(() => {
+ if (!selectedModel) return;
+ const run = async () => {
+ if (sourcePrompt) {
+ const tokens = await tokenize(sourcePrompt, selectedModel);
+ if (tokens && tokens.length > 0) {
+ setSrcTokens(tokens);
+ setSrcTokenizedModel(selectedModel);
+ setSrcEditing(false);
+ }
+ }
+ if (targetPrompt) {
+ const tokens = await tokenize(targetPrompt, selectedModel);
+ if (tokens && tokens.length > 0) {
+ setTgtTokens(tokens);
+ setTgtTokenizedModel(selectedModel);
+ setTgtEditing(false);
+ }
+ }
+ };
+ run();
+ // Only auto-tokenize when the model first becomes available
+ // eslint-disable-next-line react-hooks/exhaustive-deps
+ }, [selectedModel]);
+
+ const handleSrcBlur = useCallback(() => {
+ setTimeout(async () => {
+ const activeElement = document.activeElement;
+ const withinTextarea = activeElement && srcTextareaRef.current?.contains(activeElement);
+ const withinToken =
+ activeElement && srcTokenContainerRef.current?.contains(activeElement);
+ if (withinTextarea || withinToken) return;
+
+ if (!sourcePrompt || !selectedModel) {
+ // Clear any stale tokens so the empty state actually reads as
+ // empty (otherwise the token display would show the previous
+ // tokenization after the user deletes the text).
+ setSrcTokens([]);
+ setSrcTokenizedModel(null);
+ setSrcEditing(true);
+ return;
+ }
+ const tokens = await tokenize(sourcePrompt, selectedModel);
+ if (tokens && tokens.length > 0) {
+ setSrcTokens(tokens);
+ setSrcTokenizedModel(selectedModel);
+ setSrcEditing(false);
+ }
+ }, 100);
+ }, [sourcePrompt, selectedModel, tokenize]);
+
+ const handleTgtBlur = useCallback(() => {
+ setTimeout(async () => {
+ const activeElement = document.activeElement;
+ const withinTextarea = activeElement && tgtTextareaRef.current?.contains(activeElement);
+ const withinToken =
+ activeElement && tgtTokenContainerRef.current?.contains(activeElement);
+ if (withinTextarea || withinToken) return;
+
+ if (!targetPrompt || !selectedModel) {
+ setTgtTokens([]);
+ setTgtTokenizedModel(null);
+ setTgtEditing(true);
+ return;
+ }
+ const tokens = await tokenize(targetPrompt, selectedModel);
+ if (tokens && tokens.length > 0) {
+ setTgtTokens(tokens);
+ setTgtTokenizedModel(selectedModel);
+ setTgtEditing(false);
+ }
+ }, 100);
+ }, [targetPrompt, selectedModel, tokenize]);
+
+ // Explicit clear handlers so the user has a one-click way to reset a
+ // prompt back to empty (deleting characters in the textarea works too,
+ // but a button is more discoverable).
+ const handleClearSrc = useCallback(() => {
+ onSourcePromptChange("");
+ setSrcTokens([]);
+ setSrcTokenizedModel(null);
+ setSrcEditing(true);
+ setTimeout(() => srcTextareaRef.current?.focus(), 0);
+ }, [onSourcePromptChange]);
+
+ const handleClearTgt = useCallback(() => {
+ onTargetPromptChange("");
+ setTgtTokens([]);
+ setTgtTokenizedModel(null);
+ setTgtEditing(true);
+ setTimeout(() => tgtTextareaRef.current?.focus(), 0);
+ }, [onTargetPromptChange]);
+
+ const configModelUnavailable =
+ srcTokenizedModel && selectedModel && srcTokenizedModel !== selectedModel
+ ? srcTokenizedModel
+ : null;
+
+ // Predicted next-token from the last lens run. Hidden when the prompt
+ // currently in the textarea no longer matches what the lens was run on.
+ const srcPrediction = useMemo(() => {
+ if (!lensResult?.source) return null;
+ if (lastRunSrcPrompt == null || sourcePrompt !== lastRunSrcPrompt) return null;
+ return getNextTokenPrediction(lensResult.source);
+ }, [lensResult, sourcePrompt, lastRunSrcPrompt]);
+
+ const tgtPrediction = useMemo(() => {
+ if (!lensResult?.target) return null;
+ if (lastRunTgtPrompt == null || targetPrompt !== lastRunTgtPrompt) return null;
+ return getNextTokenPrediction(lensResult.target);
+ }, [lensResult, targetPrompt, lastRunTgtPrompt]);
+
+ const { mutateAsync: runLogitLens, isPending: isRunning } = useCMIntroLogitLens();
+
+ // Target is optional: when blank, CM Intro runs in single-prompt mode and
+ // only computes the source lens. The widget hides the target heatmap and
+ // disables drag-and-drop patching in that mode.
+ const canRun = !!selectedModel && !!sourcePrompt.trim() && !isRunning;
+
+ const handleRun = useCallback(async () => {
+ if (!selectedModel) {
+ toast.error("Please select a model.");
+ return;
+ }
+ const src = sourcePrompt.trim();
+ const tgt = targetPrompt.trim();
+ if (!src) {
+ toast.error("Please enter a source prompt.");
+ return;
+ }
+
+ if (!chartId) {
+ toast.error("Missing chart id.");
+ return;
+ }
+
+ try {
+ const result = await runLogitLens({
+ sourcePrompt: src,
+ targetPrompt: tgt, // empty string is fine — mutation skips the call
+ model: selectedModel,
+ chartId,
+ });
+ onLensResult?.(result, src, tgt);
+ toast.success(tgt ? "Logit lens computed for both prompts." : "Logit lens computed.");
+ } catch (error) {
+ // Error toast handled by the mutation's onError.
+ }
+ }, [selectedModel, sourcePrompt, targetPrompt, runLogitLens, onLensResult, chartId]);
+
+ const { startTutorial } = useTutorialAutoStart();
+
+ return (
+
+
+
CM Intro
+
+ {configModelUnavailable && (
+
+
+
+
+
+
+ Tokens last computed with "{configModelUnavailable}".
+ Click a prompt and blur to retokenize.
+
+
+
+ )}
+
+ Tutorial
+
+
+
+
+
+
+
+ {sourcePrompt && (
+
+
+ Clear
+
+ )}
+
{}}
+ predictionToken={srcPrediction}
+ />
+
+ The prompt you want to steal state from
+ . Pick something with a clear, specific prediction — its internal
+ activations will be the source of the patch.
+
+
+
+
+ {targetPrompt && (
+
+
+ Clear
+
+ )}
+
{}}
+ predictionToken={tgtPrediction}
+ />
+
+ The prompt you want to patch into .
+ Usually similar grammar but a different answer — any change in its
+ prediction after a patch reveals what the source state carried. Leave blank
+ to view the source prompt alone (no patching).
+
+
+
+
+ {isRunning ? (
+ <>
+
+ Computing...
+ >
+ ) : (
+ <>
+
+ Run Logit Lens
+ >
+ )}
+
+
+
+ );
+}
diff --git a/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/components/CMIntroDisplay.tsx b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/components/CMIntroDisplay.tsx
new file mode 100644
index 00000000..55ffbb07
--- /dev/null
+++ b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/components/CMIntroDisplay.tsx
@@ -0,0 +1,240 @@
+"use client";
+
+import { useCallback, useMemo } from "react";
+import { useParams } from "next/navigation";
+import { useQuery } from "@tanstack/react-query";
+import { getModels } from "@/lib/api/modelsApi";
+import { getChartById } from "@/lib/queries/chartQueries";
+import { queryKeys } from "@/lib/queryKeys";
+import { useWorkspace } from "@/stores/useWorkspace";
+import { CausalMediationExplorer } from "edulogitlens";
+import type { LogitLensData, LogitCell, Intervention } from "edulogitlens";
+import { CMIntroLensResult, useCMIntroIntervention } from "@/lib/api/cmIntroApi";
+import type { LogitLensIntroData } from "@/types/logitLensIntro";
+import type { CMIntroChartData } from "@/types/cmIntro";
+
+function CMSkeleton({ message, showTarget }: { message: string; showTarget: boolean }) {
+ const SkeletonGrid = () => (
+
+
+
+ {Array.from({ length: 40 }).map((_, i) => (
+
+ ))}
+
+
+ );
+
+ return (
+
+
+ {message}
+
+
+
+ {showTarget && }
+
+
+ );
+}
+
+interface CMIntroDisplayProps {
+ sourcePrompt: string;
+ targetPrompt: string;
+ lensResult?: CMIntroLensResult | null;
+ // Snapshots of the prompts the ephemeral lensResult was actually computed
+ // for. Used to detect "user edited the prompts after the last run" so we
+ // can show a placeholder instead of a stale heatmap.
+ lastRunSrcPrompt?: string | null;
+ lastRunTgtPrompt?: string | null;
+}
+
+/**
+ * nnsightful LogitLensData → edulogitlens LogitLensData. Mirrors the transform
+ * used in LogitLensIntroDisplay so the CM explorer sees the same cell shape.
+ */
+function transformToEduFormat(data: LogitLensIntroData): LogitLensData | undefined {
+ if (!data) return undefined;
+
+ const raw = data as unknown as Record;
+ const input = raw.input as string[] | undefined;
+ const layers = raw.layers as number[] | undefined;
+ const tracked = raw.tracked as Record[] | undefined;
+ const topk = raw.topk as string[][][] | undefined;
+
+ if (!input || !layers || !tracked || !topk) return undefined;
+
+ // NOTE: do NOT strip the BOS token here. CM interventions send the clicked
+ // cell's token position to the backend, which indexes the BOS-inclusive
+ // tokenization absolutely (causal_mediation.py). Dropping position 0 would
+ // patch the wrong token. BOS-hiding for CM must happen in the widget while
+ // preserving absolute positions.
+ const cellData: LogitCell[][] = input.map((_, posIdx) => {
+ const posTracked = tracked[posIdx] ?? {};
+ return layers.map((_, layerIdx) => {
+ const topTokenStrs = topk[layerIdx]?.[posIdx] ?? [];
+ const topTokens = topTokenStrs.map((t) => ({
+ token: t,
+ prob: posTracked[t]?.[layerIdx] ?? 0,
+ }));
+ topTokens.sort((a, b) => b.prob - a.prob);
+
+ const best = topTokens[0];
+ return {
+ token: best?.token ?? "",
+ probability: best?.prob ?? 0,
+ topTokens,
+ };
+ });
+ });
+
+ return { tokens: input, layers, data: cellData };
+}
+
+export function CMIntroDisplay({
+ sourcePrompt,
+ targetPrompt,
+ lensResult,
+ lastRunSrcPrompt,
+ lastRunTgtPrompt,
+}: CMIntroDisplayProps) {
+ const { chartId } = useParams<{ chartId: string }>();
+ const { selectedModelIdx } = useWorkspace();
+
+ const { data: models } = useQuery({
+ queryKey: ["models"],
+ queryFn: getModels,
+ refetchInterval: 120000,
+ });
+
+ const selectedModel = useMemo(() => {
+ if (!models || models.length === 0) return undefined;
+ return models[selectedModelIdx]?.name || models[0].name;
+ }, [models, selectedModelIdx]);
+
+ // Hydrate the persisted cm-intro chart row so revisiting the page restores the intervention result.
+ const { data: chart } = useQuery({
+ queryKey: queryKeys.charts.chart(chartId as string),
+ queryFn: () => getChartById(chartId as string),
+ enabled: !!chartId,
+ });
+
+ // Source is always required; target may be absent in single-prompt mode.
+ const persistedData = useMemo(() => {
+ const raw = chart?.data as unknown;
+ if (!raw || typeof raw !== "object") return null;
+ const maybe = raw as Partial;
+ if (!maybe.source) return null;
+ return maybe as CMIntroChartData;
+ }, [chart]);
+
+ const trimmedTarget = targetPrompt.trim();
+ const targetExpected = trimmedTarget.length > 0;
+
+ // "Live" data with its provenance (what prompts produced it). The ephemeral
+ // lensResult wins over persisted; we use the *RunPrompt snapshots to decide
+ // whether the current textareas still match what was run.
+ const liveSourceRaw = lensResult?.source ?? persistedData?.source;
+ const liveTargetRaw = lensResult?.target ?? persistedData?.target;
+ const liveSrcRun =
+ lensResult?.source != null
+ ? (lastRunSrcPrompt ?? null)
+ : (persistedData?.lastRunSourcePrompt ?? null);
+ const liveTgtRun =
+ lensResult?.source != null
+ ? (lastRunTgtPrompt ?? null)
+ : (persistedData?.lastRunTargetPrompt ?? null);
+
+ const hasAnyData = !!liveSourceRaw;
+ const isStale =
+ hasAnyData &&
+ ((liveSrcRun !== null && liveSrcRun !== sourcePrompt) ||
+ (liveTgtRun !== null && liveTgtRun !== targetPrompt));
+ // No source data yet, OR target was expected (two-prompt mode) but is missing.
+ const isMissingExpectedData = !liveSourceRaw || (targetExpected && !liveTargetRaw);
+
+ const showPlaceholder = isStale || isMissingExpectedData;
+
+ const sourceData = useMemo(
+ () => (showPlaceholder || !liveSourceRaw ? undefined : transformToEduFormat(liveSourceRaw)),
+ [showPlaceholder, liveSourceRaw],
+ );
+ const targetData = useMemo(
+ () => (showPlaceholder || !liveTargetRaw ? undefined : transformToEduFormat(liveTargetRaw)),
+ [showPlaceholder, liveTargetRaw],
+ );
+
+ // Undefined (not null) when absent, so CausalMediationExplorer treats the
+ // result as uncontrolled and falls back to internal state populated by the
+ // handleIntervention promise. When a persisted result IS present, we pass
+ // it as a controlled override so revisits restore the UI.
+ const persistedResultData = useMemo(() => {
+ if (!persistedData?.result) return undefined;
+ return transformToEduFormat(persistedData.result);
+ }, [persistedData]);
+
+ const { mutateAsync: runIntervention, isPending: isInterventionPending } =
+ useCMIntroIntervention();
+
+ const handleIntervention = useCallback(
+ async (i: Intervention): Promise => {
+ if (!chartId || !selectedModel) return null;
+ try {
+ const result = await runIntervention({
+ model: selectedModel,
+ srcPrompt: sourcePrompt,
+ tgtPrompt: targetPrompt,
+ chartId,
+ intervention: {
+ srcTokenPos: i.sourceTokenPosition,
+ srcLayer: i.sourceLayer,
+ tgtTokenPos: i.targetTokenPosition,
+ tgtLayer: i.targetLayer,
+ },
+ });
+ return transformToEduFormat(result) ?? null;
+ } catch {
+ return null;
+ }
+ },
+ [chartId, selectedModel, sourcePrompt, targetPrompt, runIntervention],
+ );
+
+ if (showPlaceholder) {
+ const message = isStale
+ ? "Prompts changed since the last run. Click Run to recompute the lens."
+ : "No analysis yet. Enter a prompt and click Run to compute the lens.";
+ return (
+
+
+
+ );
+ }
+
+ return (
+
+
+
+ );
+}
diff --git a/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/layout.tsx b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/layout.tsx
new file mode 100644
index 00000000..13389d10
--- /dev/null
+++ b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/layout.tsx
@@ -0,0 +1,10 @@
+import { CaptureProvider } from "@/components/providers/CaptureProvider";
+import { TooltipProvider } from "@/components/ui/tooltip";
+
+export default function CMIntroChartLayout({ children }: { children: React.ReactNode }) {
+ return (
+
+ {children}
+
+ );
+}
diff --git a/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/page.tsx b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/page.tsx
new file mode 100644
index 00000000..e250f430
--- /dev/null
+++ b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/page.tsx
@@ -0,0 +1,160 @@
+"use client";
+
+import { useCallback, useEffect, useRef, useState } from "react";
+import { useParams } from "next/navigation";
+import { useQuery, useQueryClient } from "@tanstack/react-query";
+import { ResizableHandle, ResizablePanel, ResizablePanelGroup } from "@/components/ui/resizable";
+import ChartCardsSidebar from "../../components/ChartCardsSidebar";
+import CMIntroArea from "./components/CMIntroArea";
+import { CMIntroDisplay } from "./components/CMIntroDisplay";
+import { useIsMobile } from "@/hooks/useIsMobile";
+import { MobileSidebarDrawer } from "../../components/MobileSidebarDrawer";
+import { MobileCollapsibleControls } from "../../components/MobileCollapsibleControls";
+import { GitBranch } from "lucide-react";
+import { CMIntroLensResult } from "@/lib/api/cmIntroApi";
+import { getChartById, setChartData } from "@/lib/queries/chartQueries";
+import { queryKeys } from "@/lib/queryKeys";
+import type { CMIntroChartData } from "@/types/cmIntro";
+
+const PROMPT_AUTOSAVE_DEBOUNCE_MS = 600;
+
+export default function CMIntroChartPage() {
+ const isMobile = useIsMobile();
+ const { chartId } = useParams<{ chartId: string }>();
+ const queryClient = useQueryClient();
+ const [sourcePrompt, setSourcePrompt] = useState("");
+ const [targetPrompt, setTargetPrompt] = useState("");
+ const [lensResult, setLensResult] = useState(null);
+ // Snapshot of the prompts the current lensResult was computed for. Used by
+ // CMIntroArea to gate the predicted-next-token hint so it disappears once
+ // the user starts editing.
+ const [lastRunSrcPrompt, setLastRunSrcPrompt] = useState(null);
+ const [lastRunTgtPrompt, setLastRunTgtPrompt] = useState(null);
+
+ const { data: chart } = useQuery({
+ queryKey: queryKeys.charts.chart(chartId as string),
+ queryFn: () => getChartById(chartId as string),
+ enabled: !!chartId,
+ });
+
+ // hydratedRef gates autosave: we must absorb any persisted prompts before
+ // the autosave effect is allowed to write, otherwise the first render
+ // would clobber a stored prompt with the default placeholder.
+ const hydratedRef = useRef(false);
+ useEffect(() => {
+ if (hydratedRef.current) return;
+ if (chart === undefined) return;
+ const data = chart?.data as Partial | undefined;
+ if (typeof data?.sourcePrompt === "string") setSourcePrompt(data.sourcePrompt);
+ if (typeof data?.targetPrompt === "string") setTargetPrompt(data.targetPrompt);
+ if (data?.source && data?.target) {
+ setLensResult({ source: data.source, target: data.target });
+ }
+ if (typeof data?.lastRunSourcePrompt === "string") {
+ setLastRunSrcPrompt(data.lastRunSourcePrompt);
+ }
+ if (typeof data?.lastRunTargetPrompt === "string") {
+ setLastRunTgtPrompt(data.lastRunTargetPrompt);
+ }
+ hydratedRef.current = true;
+ }, [chart]);
+
+ const handleLensResult = useCallback(
+ (result: CMIntroLensResult, runSrc: string, runTgt: string) => {
+ setLensResult(result);
+ setLastRunSrcPrompt(runSrc);
+ setLastRunTgtPrompt(runTgt);
+ },
+ [],
+ );
+
+ // Autosave the prompts into the chart row so they survive navigation even
+ // if the user never runs the lens. Debounced to avoid a write per keystroke.
+ useEffect(() => {
+ if (!hydratedRef.current || !chartId) return;
+ const handle = setTimeout(async () => {
+ const existing = await getChartById(chartId);
+ const existingData = (existing?.data ?? {}) as Partial;
+ if (
+ existingData.sourcePrompt === sourcePrompt &&
+ existingData.targetPrompt === targetPrompt
+ ) {
+ return;
+ }
+ const merged: CMIntroChartData = {
+ ...existingData,
+ sourcePrompt,
+ targetPrompt,
+ };
+ await setChartData(chartId, merged, "cm-intro");
+ queryClient.invalidateQueries({
+ queryKey: queryKeys.charts.chart(chartId),
+ });
+ }, PROMPT_AUTOSAVE_DEBOUNCE_MS);
+ return () => clearTimeout(handle);
+ }, [sourcePrompt, targetPrompt, chartId, queryClient]);
+
+ if (isMobile === undefined) return null;
+
+ if (isMobile) {
+ return (
+
+ );
+ }
+
+ return (
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ );
+}
diff --git a/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCard.tsx b/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCard.tsx
index 5e47e38d..30b4c7c0 100644
--- a/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCard.tsx
+++ b/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCard.tsx
@@ -2,7 +2,15 @@
import React from "react";
import { useParams, useRouter } from "next/navigation";
-import { Grid3X3, ChartLine, Trash2, Copy, MoreVertical, GitBranch } from "lucide-react";
+import {
+ Grid3X3,
+ ChartLine,
+ Trash2,
+ Copy,
+ MoreVertical,
+ GitBranch,
+ GraduationCap,
+} from "lucide-react";
import Image from "next/image";
import { ChartMetadata } from "@/types/charts";
import { cn } from "@/lib/utils";
@@ -40,6 +48,13 @@ export default function ChartCard({ metadata, handleDelete, canDelete }: ChartCa
chart.chartType === "activation-patching"
) {
router.push(`/workbench/${workspaceId}/activation-patching/${chart.id}`);
+ } else if (
+ chart.toolType === "logit-lens-intro" ||
+ chart.chartType === "logit-lens-intro"
+ ) {
+ router.push(`/workbench/${workspaceId}/logit-lens-intro/${chart.id}`);
+ } else if (chart.toolType === "cm-intro" || chart.chartType === "cm-intro") {
+ router.push(`/workbench/${workspaceId}/cm-intro/${chart.id}`);
} else {
router.push(`/workbench/${workspaceId}/${chart.id}`);
}
@@ -94,6 +109,20 @@ export default function ChartCard({ metadata, handleDelete, canDelete }: ChartCa
Act. Patching
);
+ if (chartType === "logit-lens-intro")
+ return (
+
+
+ LL Intro
+
+ );
+ if (chartType === "cm-intro")
+ return (
+
+
+ CM Intro
+
+ );
return (
diff --git a/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCardsSidebar.tsx b/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCardsSidebar.tsx
index ca505caa..b95803b3 100644
--- a/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCardsSidebar.tsx
+++ b/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCardsSidebar.tsx
@@ -5,6 +5,8 @@ import { getChartsMetadata } from "@/lib/queries/chartQueries";
import { useParams, useRouter } from "next/navigation";
import {
useCreateLens2ChartPair,
+ useCreateLogitLensIntroChartPair,
+ useCreateCMIntroChartPair,
useCreatePatchChartPair,
useCreateActivationPatchingChartPair,
useDeleteChart,
@@ -29,6 +31,7 @@ import {
FileText,
Layers,
GitBranch,
+ GraduationCap,
} from "lucide-react";
import { Button } from "@/components/ui/button";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
@@ -64,6 +67,9 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b
);
const { mutate: createLens2Pair, isPending: isCreatingLens2 } = useCreateLens2ChartPair();
+ const { mutate: createLogitLensIntroPair, isPending: isCreatingLogitLensIntro } =
+ useCreateLogitLensIntroChartPair();
+ const { mutate: createCMIntroPair, isPending: isCreatingCMIntro } = useCreateCMIntroChartPair();
const { mutate: createPatchPair, isPending: isCreatingPatch } = useCreatePatchChartPair();
const { mutate: createActivationPatchingPair, isPending: isCreatingActivationPatching } =
useCreateActivationPatchingChartPair();
@@ -161,6 +167,10 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b
router.push(`/workbench/${workspaceId}/lens2/${chartId}`);
} else if (toolType === "activation-patching") {
router.push(`/workbench/${workspaceId}/activation-patching/${chartId}`);
+ } else if (toolType === "logit-lens-intro") {
+ router.push(`/workbench/${workspaceId}/logit-lens-intro/${chartId}`);
+ } else if (toolType === "cm-intro") {
+ router.push(`/workbench/${workspaceId}/cm-intro/${chartId}`);
} else {
router.push(`/workbench/${workspaceId}/${chartId}`);
}
@@ -170,7 +180,9 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b
router.push(`/workbench/${workspaceId}/overview/${documentId}`);
};
- const handleCreate = (toolType: "lens2" | "patch" | "activation-patching") => {
+ const handleCreate = (
+ toolType: "lens2" | "patch" | "activation-patching" | "logit-lens-intro" | "cm-intro",
+ ) => {
if (toolType === "lens2") {
createLens2Pair(
{ workspaceId: workspaceId as string },
@@ -180,6 +192,24 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b
);
return;
}
+ if (toolType === "logit-lens-intro") {
+ createLogitLensIntroPair(
+ { workspaceId: workspaceId as string },
+ {
+ onSuccess: ({ chart }) => navigateToChart(chart.id, "logit-lens-intro"),
+ },
+ );
+ return;
+ }
+ if (toolType === "cm-intro") {
+ createCMIntroPair(
+ { workspaceId: workspaceId as string },
+ {
+ onSuccess: ({ chart }) => navigateToChart(chart.id, "cm-intro"),
+ },
+ );
+ return;
+ }
if (toolType === "activation-patching") {
createActivationPatchingPair(
{ workspaceId: workspaceId as string },
@@ -245,7 +275,12 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b
};
const isCreatingAny =
- isCreatingLens2 || isCreatingPatch || isCreatingActivationPatching || isCreatingDocument;
+ isCreatingLens2 ||
+ isCreatingLogitLensIntro ||
+ isCreatingCMIntro ||
+ isCreatingPatch ||
+ isCreatingActivationPatching ||
+ isCreatingDocument;
const actionButtons = (
@@ -263,6 +298,34 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b
)}
Logit Lens
+
handleCreate("logit-lens-intro")}
+ disabled={isCreatingAny}
+ className="w-full"
+ title="Logit Lens Intro"
+ >
+ {isCreatingLogitLensIntro ? (
+
+ ) : (
+
+ )}
+ LL Intro
+
+
handleCreate("cm-intro")}
+ disabled={isCreatingAny}
+ className="w-full"
+ title="Causal Mediation Intro"
+ >
+ {isCreatingCMIntro ? (
+
+ ) : (
+
+ )}
+ CM Intro
+
handleCreate("activation-patching")}
@@ -324,6 +387,34 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b
)}
+
handleCreate("logit-lens-intro")}
+ disabled={isCreatingAny}
+ className="h-7 w-7 hover:bg-muted opacity-60 hover:opacity-100 transition-opacity"
+ title="New Logit Lens Intro"
+ >
+ {isCreatingLogitLensIntro ? (
+
+ ) : (
+
+ )}
+
+
handleCreate("cm-intro")}
+ disabled={isCreatingAny}
+ className="h-7 w-7 hover:bg-muted opacity-60 hover:opacity-100 transition-opacity"
+ title="New CM Intro"
+ >
+ {isCreatingCMIntro ? (
+
+ ) : (
+
+ )}
+
();
+
+ const { data: config } = useQuery({
+ queryKey: queryKeys.charts.configByChart(chartId),
+ queryFn: () => getConfigForChart(chartId),
+ enabled: !!chartId,
+ });
+
+ const { data: chart } = useQuery({
+ queryKey: queryKeys.charts.chart(chartId),
+ queryFn: () => getChartById(chartId as string),
+ enabled: !!chartId,
+ });
+
+ const { selectedModelIdx, setSelectedModelIdx } = useWorkspace();
+ const [configModelUnavailable, setConfigModelUnavailable] = useState(null);
+
+ const { data: models } = useQuery({
+ queryKey: ["models"],
+ queryFn: getModels,
+ refetchInterval: 120000,
+ });
+
+ useEffect(() => {
+ if (config && models && models.length > 0) {
+ const introConfig = config as LogitLensIntroConfig;
+ const configModel = introConfig.data?.model;
+
+ if (configModel && configModel.length > 0) {
+ const modelIndex = models.findIndex((m) => m.name === configModel);
+ if (modelIndex !== -1) {
+ setSelectedModelIdx(modelIndex);
+ setConfigModelUnavailable(null);
+ } else {
+ setConfigModelUnavailable(configModel);
+ }
+ } else {
+ setConfigModelUnavailable(null);
+ }
+ }
+ }, [config?.id, models, setSelectedModelIdx]);
+
+ const selectedModel = useMemo(() => {
+ if (!models || models.length === 0) return undefined;
+ return models[selectedModelIdx]?.name || models[0].name;
+ }, [models, selectedModelIdx]);
+
+ if (!config || !selectedModel) {
+ return (
+
+
+
Logit Lens Intro
+
+ {configModelUnavailable && (
+
+
+
+
+
+
+ Model "{configModelUnavailable}" is not currently
+ available. Please select a different model.
+
+
+
+ )}
+
+
+
+
+
+ );
+ }
+
+ return (
+
+
+
Logit Lens Intro
+
+ {configModelUnavailable && (
+
+
+
+
+
+
+ Model "{configModelUnavailable}" is not currently
+ available. Please select a different model.
+
+
+
+ )}
+
+
+
+
+
+
+
+
+ );
+}
diff --git a/workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/components/LogitLensIntroControls.tsx b/workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/components/LogitLensIntroControls.tsx
new file mode 100644
index 00000000..f1987397
--- /dev/null
+++ b/workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/components/LogitLensIntroControls.tsx
@@ -0,0 +1,371 @@
+"use client";
+
+import { useState, useEffect, useCallback, useRef } from "react";
+import { useParams } from "next/navigation";
+import { Button } from "@/components/ui/button";
+import { Textarea } from "@/components/ui/textarea";
+import { Label } from "@/components/ui/label";
+import { Loader2, Play, TriangleAlert } from "lucide-react";
+import { useLogitLensIntro } from "@/lib/api/logitLensIntroApi";
+import { useUpdateChartConfig } from "@/lib/api/configApi";
+import { LogitLensIntroConfigData } from "@/types/logitLensIntro";
+import { Slider } from "@/components/ui/slider";
+import { Checkbox } from "@/components/ui/checkbox";
+import { encodeText } from "@/actions/tok";
+import { TokenizerLoadError } from "@/actions/errors";
+import { Token } from "@/types/models";
+import { cn } from "@/lib/utils";
+import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip";
+import { toast } from "sonner";
+
+interface LogitLensIntroConfig {
+ id: string;
+ data: LogitLensIntroConfigData;
+ type: string;
+}
+
+interface LogitLensIntroControlsProps {
+ initialConfig: LogitLensIntroConfig;
+ selectedModel: string;
+}
+
+const TOKEN_STYLES = {
+ base: "!text-sm !leading-5 whitespace-pre-wrap break-words select-none !box-border relative",
+ hover: "hover:bg-primary/20 hover:ring-1 hover:ring-primary/30 hover:ring-inset",
+} as const;
+
+const fixTokenText = (text: string) => {
+ const numNewlines = (text.match(/\n/g) || []).length;
+ const result = text
+ .replace(/\r\n/g, "\\r\\n")
+ .replace(/\n/g, "\\n")
+ .replace(/\r/g, "\\r")
+ .replace(/\t/g, "\\t");
+ return { result, numNewlines };
+};
+
+function TokenDisplay({ tokens, loading }: { tokens: Token[]; loading: boolean }) {
+ return (
+
+ {tokens.map((token, idx) => {
+ const { result, numNewlines } = fixTokenText(token.text);
+ return (
+
+
+ {result}
+
+ {numNewlines > 0 && "\n".repeat(numNewlines)}
+
+ );
+ })}
+
+ );
+}
+
+export function LogitLensIntroControls({
+ initialConfig,
+ selectedModel,
+}: LogitLensIntroControlsProps) {
+ const { workspaceId, chartId } = useParams<{ workspaceId: string; chartId: string }>();
+
+ const [prompt, setPrompt] = useState(initialConfig.data?.prompt || "");
+ const [topk, setTopk] = useState(initialConfig.data?.topk ?? 5);
+ const [includeEntropy, setIncludeEntropy] = useState(
+ initialConfig.data?.includeEntropy ?? true,
+ );
+
+ const [tokenData, setTokenData] = useState([]);
+ const [editingText, setEditingText] = useState(true);
+ const [tokenizedModel, setTokenizedModel] = useState(null);
+
+ const lastSyncedPromptRef = useRef(initialConfig.data?.prompt || "");
+ const textareaRef = useRef(null);
+ const tokenContainerRef = useRef(null);
+
+ const { mutateAsync: computeLens, isPending: isComputing } = useLogitLensIntro();
+ const { mutateAsync: updateConfig, isPending: isUpdatingConfig } = useUpdateChartConfig();
+
+ const isExecuting = isComputing || isUpdatingConfig;
+
+ useEffect(() => {
+ const configPrompt = initialConfig.data?.prompt || "";
+ if (configPrompt && configPrompt !== lastSyncedPromptRef.current) {
+ setPrompt(configPrompt);
+ lastSyncedPromptRef.current = configPrompt;
+ }
+ }, [initialConfig.data?.prompt]);
+
+ useEffect(() => {
+ const fetchTokens = async () => {
+ if (initialConfig.data?.prompt && selectedModel) {
+ const tokens = await encodeText(initialConfig.data.prompt, selectedModel);
+ if (tokens.length > 0) {
+ setTokenData(tokens);
+ setTokenizedModel(selectedModel);
+ setEditingText(false);
+ }
+ }
+ };
+ fetchTokens();
+ }, [initialConfig.id, selectedModel]);
+
+ const autoResizeTextarea = useCallback(() => {
+ if (textareaRef.current) {
+ textareaRef.current.style.height = "auto";
+ textareaRef.current.style.height = `${textareaRef.current.scrollHeight}px`;
+ }
+ }, []);
+
+ useEffect(() => {
+ if (editingText) autoResizeTextarea();
+ }, [prompt, editingText, autoResizeTextarea]);
+
+ const escapeTokenArea = useCallback(() => {
+ setEditingText(true);
+ setTimeout(() => {
+ if (textareaRef.current) {
+ textareaRef.current.focus();
+ const length = textareaRef.current.value.length;
+ textareaRef.current.setSelectionRange(length, length);
+ }
+ }, 0);
+ }, []);
+
+ const handleTokenize = useCallback(async () => {
+ if (!prompt.trim()) {
+ toast.error("Please enter a prompt.");
+ return;
+ }
+
+ let tokens: Token[];
+ try {
+ tokens = await encodeText(prompt, selectedModel);
+ } catch (error) {
+ if (error instanceof TokenizerLoadError) {
+ toast.error(
+ `Could not load tokenizer for ${selectedModel}. The model may be gated and require authentication.`,
+ );
+ } else {
+ toast.error("Failed to tokenize prompt.");
+ }
+ return;
+ }
+ if (tokens.length <= 1) {
+ toast.error("Please enter a longer prompt.");
+ return;
+ }
+
+ setTokenData(tokens);
+ setTokenizedModel(selectedModel);
+ setEditingText(false);
+ }, [prompt, selectedModel]);
+
+ const handleSubmit = useCallback(async () => {
+ const trimmedPrompt = prompt.trim();
+ if (!trimmedPrompt) return;
+
+ let tokens: Token[];
+ try {
+ tokens = await encodeText(trimmedPrompt, selectedModel);
+ } catch (error) {
+ if (error instanceof TokenizerLoadError) {
+ toast.error(
+ `Could not load tokenizer for ${selectedModel}. The model may be gated and require authentication.`,
+ );
+ } else {
+ toast.error("Failed to tokenize prompt.");
+ }
+ return;
+ }
+ if (tokens.length <= 1) {
+ toast.error("Please enter a longer prompt.");
+ return;
+ }
+ setTokenData(tokens);
+ setTokenizedModel(selectedModel);
+
+ const config: LogitLensIntroConfigData = {
+ model: selectedModel,
+ prompt: trimmedPrompt,
+ topk,
+ includeEntropy,
+ };
+
+ await computeLens({
+ lensRequest: {
+ completion: config,
+ chartId,
+ },
+ configId: initialConfig.id,
+ });
+
+ await updateConfig({
+ configId: initialConfig.id,
+ config: {
+ data: config,
+ workspaceId,
+ type: "logit-lens-intro",
+ },
+ });
+
+ lastSyncedPromptRef.current = trimmedPrompt;
+ setEditingText(false);
+ }, [
+ prompt,
+ topk,
+ includeEntropy,
+ selectedModel,
+ chartId,
+ initialConfig.id,
+ workspaceId,
+ computeLens,
+ updateConfig,
+ ]);
+
+ const handlePromptChange = useCallback((e: React.ChangeEvent) => {
+ setPrompt(e.target.value);
+ }, []);
+
+ const handleKeyDown = useCallback(
+ (e: React.KeyboardEvent) => {
+ if ((e.metaKey || e.ctrlKey) && e.key === "Enter") {
+ e.preventDefault();
+ handleSubmit();
+ }
+ },
+ [handleSubmit],
+ );
+
+ const handleTextareaBlur = useCallback(() => {
+ setTimeout(() => {
+ const activeElement = document.activeElement;
+ const withinTextarea = activeElement && textareaRef.current?.contains(activeElement);
+ const withinToken = activeElement && tokenContainerRef.current?.contains(activeElement);
+ const popoverOpen = document.querySelector("[data-radix-popper-content-wrapper]");
+
+ if (withinTextarea || withinToken || popoverOpen) return;
+
+ if (prompt.trim()) {
+ handleTokenize();
+ }
+ }, 100);
+ }, [prompt, handleTokenize]);
+
+ const modelMismatch =
+ tokenizedModel && tokenizedModel !== selectedModel && tokenData.length > 0;
+
+ return (
+
+
+
+
+
+
+ Top-K Predictions
+
+ {topk}
+
+
setTopk(value)}
+ disabled={isExecuting}
+ className="w-full"
+ />
+
+
+
+ setIncludeEntropy(checked === true)}
+ disabled={isExecuting}
+ />
+
+ Include Entropy
+
+
+
+
+ {isExecuting ? (
+ <>
+
+ Computing...
+ >
+ ) : (
+ <>
+
+ Run Logit Lens
+ >
+ )}
+
+
+
+ ⌘ +{" "}
+ Enter to run
+
+
+ );
+}
diff --git a/workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/components/LogitLensIntroDisplay.tsx b/workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/components/LogitLensIntroDisplay.tsx
new file mode 100644
index 00000000..f2691df4
--- /dev/null
+++ b/workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/components/LogitLensIntroDisplay.tsx
@@ -0,0 +1,201 @@
+"use client";
+
+import { useMemo } from "react";
+import { useParams } from "next/navigation";
+import { useQuery, useIsMutating } from "@tanstack/react-query";
+import { getChartById } from "@/lib/queries/chartQueries";
+import { queryKeys } from "@/lib/queryKeys";
+import { Loader2 } from "lucide-react";
+import { LogitLensGrid } from "edulogitlens";
+import type { LogitLensData, LogitCell } from "edulogitlens";
+import type { LogitLensIntroData } from "@/types/logitLensIntro";
+
+interface LogitLensIntroChart {
+ id: string;
+ data: LogitLensIntroData | null;
+ type: string;
+}
+
+/**
+ * Transform the nnsightful LogitLensData format into the edulogitlens format.
+ *
+ * nnsightful returns:
+ * input: string[] — input token strings
+ * layers: number[] — layer indices
+ * tracked: Record[] — per-position dict of token → prob-per-layer
+ * topk: string[][][] — topk[layer][position] = top-k token strings
+ *
+ * edulogitlens expects:
+ * tokens: string[]
+ * layers: number[]
+ * data: LogitCell[][] — data[position][layer]
+ */
+function transformToEduFormat(data: LogitLensIntroData): LogitLensData | null {
+ if (!data) return null;
+
+ const raw = data as Record;
+ const input = raw.input as string[] | undefined;
+ const layers = raw.layers as number[] | undefined;
+ const tracked = raw.tracked as Record[] | undefined;
+ const topk = raw.topk as string[][][] | undefined;
+
+ if (!input || !layers || !tracked || !topk) return null;
+
+ const cellData: LogitCell[][] = input.map((_, posIdx) => {
+ const posTracked = tracked[posIdx] ?? {};
+ return layers.map((_, layerIdx) => {
+ const topTokenStrs = topk[layerIdx]?.[posIdx] ?? [];
+ const topTokens = topTokenStrs.map((t) => ({
+ token: t,
+ prob: posTracked[t]?.[layerIdx] ?? 0,
+ }));
+ topTokens.sort((a, b) => b.prob - a.prob);
+
+ const best = topTokens[0];
+ return {
+ token: best?.token ?? "",
+ probability: best?.prob ?? 0,
+ topTokens,
+ };
+ });
+ });
+
+ return { tokens: input, layers, data: cellData };
+}
+
+function generateMockData(): LogitLensData {
+ const tokens = [
+ "The",
+ "E",
+ "iff",
+ "el",
+ "Tower",
+ "is",
+ "in",
+ "the",
+ "city",
+ "of",
+ "Paris",
+ ",",
+ "France",
+ ".",
+ ];
+ const layers = Array.from({ length: 12 }, (_, i) => i);
+ const vocab = [
+ "t",
+ "bow",
+ "illi",
+ "Tower",
+ "el",
+ "France",
+ "Paris",
+ "tower",
+ "city",
+ "of",
+ "the",
+ "in",
+ "is",
+ "a",
+ "and",
+ "Eiff",
+ "to",
+ "built",
+ "was",
+ "meters",
+ "at",
+ "by",
+ "from",
+ "with",
+ "on",
+ "for",
+ "an",
+ "stands",
+ "tall",
+ ];
+
+ const data: LogitCell[][] = tokens.map((token) => {
+ return layers.map((_, layerIdx) => {
+ const convergence = layerIdx / layers.length;
+ let primaryToken = token;
+ let prob: number;
+
+ if (convergence < 0.3) {
+ primaryToken = vocab[Math.floor(Math.random() * vocab.length)];
+ prob = 0.05 + Math.random() * 0.15;
+ } else if (convergence < 0.6) {
+ primaryToken =
+ Math.random() > 0.5 ? token : vocab[Math.floor(Math.random() * vocab.length)];
+ prob = 0.2 + Math.random() * 0.3;
+ } else {
+ primaryToken = token;
+ prob = 0.5 + convergence * 0.4 + Math.random() * 0.1;
+ }
+ prob = Math.min(prob, 0.95);
+
+ const topTokens: { token: string; prob: number }[] = [{ token: primaryToken, prob }];
+ let remaining = (1 - prob) * 0.4;
+ for (let i = 0; i < 14; i++) {
+ topTokens.push({
+ token: vocab[Math.floor(Math.random() * vocab.length)],
+ prob: remaining,
+ });
+ remaining *= 0.7;
+ }
+
+ return { token: primaryToken, probability: prob, topTokens };
+ });
+ });
+
+ return { tokens, layers, data };
+}
+
+export function LogitLensIntroDisplay() {
+ const { chartId } = useParams<{ chartId: string }>();
+
+ const isRunning = useIsMutating({ mutationKey: ["logitLensIntro"] }) > 0;
+
+ const { data: chart, isLoading } = useQuery({
+ queryKey: queryKeys.charts.chart(chartId),
+ queryFn: () => getChartById(chartId as string),
+ enabled: !!chartId,
+ });
+
+ const introChart = chart as LogitLensIntroChart | undefined;
+ const hasData = introChart?.data && "input" in introChart.data && "topk" in introChart.data;
+
+ const mockData = useMemo(() => generateMockData(), []);
+
+ if (isLoading) {
+ return (
+
+
+
+ );
+ }
+
+ if (isRunning) {
+ return (
+
+
+
Computing logit lens visualization...
+
+ );
+ }
+
+ // Use real data if available, otherwise show mock data
+ const eduData = hasData ? transformToEduFormat(introChart.data!) : mockData;
+
+ if (!eduData) {
+ return (
+
+
+
+ );
+ }
+
+ return (
+
+
+
+ );
+}
diff --git a/workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/layout.tsx b/workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/layout.tsx
new file mode 100644
index 00000000..999a6697
--- /dev/null
+++ b/workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/layout.tsx
@@ -0,0 +1,10 @@
+import { CaptureProvider } from "@/components/providers/CaptureProvider";
+import { TooltipProvider } from "@/components/ui/tooltip";
+
+export default function LogitLensIntroChartLayout({ children }: { children: React.ReactNode }) {
+ return (
+
+ {children}
+
+ );
+}
diff --git a/workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/page.tsx b/workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/page.tsx
new file mode 100644
index 00000000..a2f7fc87
--- /dev/null
+++ b/workbench/_web/src/app/workbench/[workspaceId]/logit-lens-intro/[chartId]/page.tsx
@@ -0,0 +1,56 @@
+"use client";
+
+import { ResizableHandle, ResizablePanel, ResizablePanelGroup } from "@/components/ui/resizable";
+import ChartCardsSidebar from "../../components/ChartCardsSidebar";
+import LogitLensIntroArea from "./components/LogitLensIntroArea";
+import { LogitLensIntroDisplay } from "./components/LogitLensIntroDisplay";
+import { useIsMobile } from "@/hooks/useIsMobile";
+import { MobileSidebarDrawer } from "../../components/MobileSidebarDrawer";
+import { MobileCollapsibleControls } from "../../components/MobileCollapsibleControls";
+import { GraduationCap } from "lucide-react";
+import { useIsMutating } from "@tanstack/react-query";
+
+export default function LogitLensIntroChartPage() {
+ const isMobile = useIsMobile();
+ const isRunning = useIsMutating({ mutationKey: ["logitLensIntro"] }) > 0;
+
+ if (isMobile === undefined) return null;
+
+ if (isMobile) {
+ return (
+
+ );
+ }
+
+ return (
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ );
+}
diff --git a/workbench/_web/src/app/workbench/[workspaceId]/page.tsx b/workbench/_web/src/app/workbench/[workspaceId]/page.tsx
index ee8182a2..0cbd05ec 100644
--- a/workbench/_web/src/app/workbench/[workspaceId]/page.tsx
+++ b/workbench/_web/src/app/workbench/[workspaceId]/page.tsx
@@ -43,6 +43,10 @@ export default async function Page({
redirect(`/workbench/${workspaceId}/lens2/${chart.id}`);
} else if (chartType === "activation-patching") {
redirect(`/workbench/${workspaceId}/activation-patching/${chart.id}`);
+ } else if (chartType === "logit-lens-intro") {
+ redirect(`/workbench/${workspaceId}/logit-lens-intro/${chart.id}`);
+ } else if (chartType === "cm-intro") {
+ redirect(`/workbench/${workspaceId}/cm-intro/${chart.id}`);
} else {
redirect(`/workbench/${workspaceId}/${chart.id}`);
}
diff --git a/workbench/_web/src/app/workbench/page.tsx b/workbench/_web/src/app/workbench/page.tsx
index a863a9e4..75b5d755 100644
--- a/workbench/_web/src/app/workbench/page.tsx
+++ b/workbench/_web/src/app/workbench/page.tsx
@@ -55,7 +55,7 @@ export default async function WorkbenchPage({
const tgtFreeze = params?.tgtFreeze;
// If no workspaces exist OR createNew flag is set, create a new workspace
- let shouldCreateWorkspace = !workspaces || workspaces.length === 0 || createNew;
+ const shouldCreateWorkspace = !workspaces || workspaces.length === 0 || createNew;
// If workspaceId is provided, use the existing workspace instead of creating new
const useExistingWorkspace = workspaceId && !shouldCreateWorkspace;
diff --git a/workbench/_web/src/components/charts/ChartDisplay.tsx b/workbench/_web/src/components/charts/ChartDisplay.tsx
index 63b13de9..a9e3e3bf 100644
--- a/workbench/_web/src/components/charts/ChartDisplay.tsx
+++ b/workbench/_web/src/components/charts/ChartDisplay.tsx
@@ -46,10 +46,12 @@ export function ChartDisplay() {
const isHeatmapData =
Array.isArray(chart?.data) &&
chart.data.some(
- (row: any) =>
+ (row: { data?: unknown }) =>
row.data &&
Array.isArray(row.data) &&
- row.data.some((cell: any) => "label" in cell),
+ row.data.some(
+ (cell: unknown) => typeof cell === "object" && cell !== null && "label" in cell,
+ ),
);
return (
diff --git a/workbench/_web/src/components/charts/heatmap/Heatmap.tsx b/workbench/_web/src/components/charts/heatmap/Heatmap.tsx
index 85111bcb..ff118ecc 100644
--- a/workbench/_web/src/components/charts/heatmap/Heatmap.tsx
+++ b/workbench/_web/src/components/charts/heatmap/Heatmap.tsx
@@ -17,7 +17,7 @@ interface HeatmapProps {
useTooltip?: boolean;
onMouseMove?: (e: React.MouseEvent) => void;
onMouseLeave?: () => void;
- onMouseDown?: (e: React.MouseEvent) => void;
+ onMouseDown?: (e: React.MouseEvent) => void;
statisticType?: Metrics;
}
diff --git a/workbench/_web/src/components/providers/TourProvider.tsx b/workbench/_web/src/components/providers/TourProvider.tsx
index 70e7e095..a24b017d 100644
--- a/workbench/_web/src/components/providers/TourProvider.tsx
+++ b/workbench/_web/src/components/providers/TourProvider.tsx
@@ -7,26 +7,68 @@ interface TourProviderProps {
children: ReactNode;
}
-function ContentComponent({
- currentStep,
- steps,
- setIsOpen,
- setCurrentStep,
- ...props
-}: PopoverContentProps) {
- const content = steps[currentStep].content;
+function ContentComponent({ currentStep, steps, setIsOpen, setCurrentStep }: PopoverContentProps) {
+ const step = steps[currentStep];
+ const content = step.content;
if (typeof content === "function") {
return Unsupported content type
;
}
- if (steps[currentStep].selector === "sidebar") {
+ if (step.selector === "sidebar") {
return <>>;
}
+ const isFirst = currentStep === 0;
+ const isLast = currentStep === steps.length - 1;
+
return (
-
- {renderTextWithBackticks(content as string)}
+
+
+ {renderTextWithBackticks(content as string)}
+
+
+
+ {currentStep + 1} / {steps.length}
+
+
+ {!isFirst && (
+ setCurrentStep(currentStep - 1)}
+ className="text-xs px-2 py-1 rounded border hover:bg-muted transition-colors"
+ >
+ Prev
+
+ )}
+ {!isLast ? (
+ setCurrentStep(currentStep + 1)}
+ className="text-xs px-3 py-1 rounded bg-primary text-primary-foreground hover:bg-primary/90 transition-colors"
+ >
+ Next
+
+ ) : (
+ setIsOpen(false)}
+ className="text-xs px-3 py-1 rounded bg-primary text-primary-foreground hover:bg-primary/90 transition-colors"
+ >
+ Done
+
+ )}
+ setIsOpen(false)}
+ aria-label="Close tour"
+ title="Close tour"
+ className="text-xs px-2 py-1 rounded border hover:bg-muted transition-colors"
+ >
+ ×
+
+
+
);
}
diff --git a/workbench/_web/src/components/transformer/EmbedComponent.tsx b/workbench/_web/src/components/transformer/EmbedComponent.tsx
index 4c6ad7b0..a86ac41d 100644
--- a/workbench/_web/src/components/transformer/EmbedComponent.tsx
+++ b/workbench/_web/src/components/transformer/EmbedComponent.tsx
@@ -70,7 +70,7 @@ export default function EmbedComponent({
// Function to add event handlers to components
const addComponentHandlers = (
- element: d3.Selection
,
+ element: d3.Selection,
tokenIndex: number,
) => {
element
diff --git a/workbench/_web/src/components/transformer/InteractiveTransformer.tsx b/workbench/_web/src/components/transformer/InteractiveTransformer.tsx
index cf50193e..9ffcfa07 100644
--- a/workbench/_web/src/components/transformer/InteractiveTransformer.tsx
+++ b/workbench/_web/src/components/transformer/InteractiveTransformer.tsx
@@ -226,7 +226,7 @@ export default function LensTransformer({
// Function to add event handlers to components
const addComponentHandlers = (
- element: d3.Selection,
+ element: d3.Selection,
tokenIndex: number,
layerIndex: number,
componentType: "resid" | "attn" | "mlp" | "embed" | "unembed",
diff --git a/workbench/_web/src/components/transformer/UnembedComponent.tsx b/workbench/_web/src/components/transformer/UnembedComponent.tsx
index d807f630..7236ed93 100644
--- a/workbench/_web/src/components/transformer/UnembedComponent.tsx
+++ b/workbench/_web/src/components/transformer/UnembedComponent.tsx
@@ -72,7 +72,7 @@ export default function UnembedComponent({
// Function to add event handlers to components
const addComponentHandlers = (
- element: d3.Selection,
+ element: d3.Selection,
tokenIndex: number,
) => {
element
diff --git a/workbench/_web/src/db/__tests__/local-db.test.ts b/workbench/_web/src/db/__tests__/local-db.test.ts
index f49a5e97..167af3cc 100644
--- a/workbench/_web/src/db/__tests__/local-db.test.ts
+++ b/workbench/_web/src/db/__tests__/local-db.test.ts
@@ -191,7 +191,7 @@ describe("Chart Queries", () => {
patches: [{ layer: 1, position: 0, value: 0.5 }],
};
- const { chart, config } = await createPatchChartPair(workspaceId, patchConfig as any);
+ const { chart, config } = await createPatchChartPair(workspaceId, patchConfig as never);
expect(chart).toBeDefined();
expect(config.type).toBe("patch");
@@ -224,7 +224,7 @@ describe("Chart Queries", () => {
labels: ["a", "b", "c"],
};
- await setChartData(chart.id, chartData as any, "line");
+ await setChartData(chart.id, chartData as never, "line");
const updated = await getChartById(chart.id);
expect(updated!.data).toEqual(chartData);
@@ -278,7 +278,7 @@ describe("Chart Queries", () => {
createTestLensConfig("original"),
);
await updateChartName(original.id, "Original Chart");
- await setChartData(original.id, { test: "data" } as any, "heatmap");
+ await setChartData(original.id, { test: "data" } as never, "heatmap");
const copy = await copyChart(original.id);
@@ -394,7 +394,7 @@ describe("View Queries", () => {
it("should update view data", async () => {
const view = await createView({ chartId, data: { original: true } });
- const updated = await updateView(view.id, { updated: true, zoom: 2 } as any);
+ const updated = await updateView(view.id, { updated: true, zoom: 2 } as never);
expect(updated.data).toEqual({ updated: true, zoom: 2 });
});
@@ -424,7 +424,9 @@ describe("Document Queries", () => {
expect(doc.workspaceId).toBe(workspaceId);
expect(doc.content).toBeDefined();
// Default content has a heading "Overview"
- expect((doc.content as any).root.children[0].type).toBe("heading");
+ expect(
+ (doc.content as { root: { children: { type: string }[] } }).root.children[0].type,
+ ).toBe("heading");
});
it("should get document by ID", async () => {
@@ -458,7 +460,7 @@ describe("Document Queries", () => {
},
};
- const updated = await updateDocument(doc.id, newContent as any);
+ const updated = await updateDocument(doc.id, newContent as never);
expect(updated.content).toEqual(newContent);
});
@@ -509,7 +511,7 @@ describe("JSON Storage in SQLite", () => {
booleanFalse: false,
};
- await setChartData(chart.id, complexData as any, "heatmap");
+ await setChartData(chart.id, complexData as never, "heatmap");
const retrieved = await getChartById(chart.id);
expect(retrieved!.data).toEqual(complexData);
@@ -523,7 +525,7 @@ describe("JSON Storage in SQLite", () => {
unicode: "Unicode: 日本語, émojis 🎉",
};
- await setChartData(chart.id, specialData as any, "line");
+ await setChartData(chart.id, specialData as never, "line");
const retrieved = await getChartById(chart.id);
expect(retrieved!.data).toEqual(specialData);
@@ -558,7 +560,7 @@ describe("Cross-Table Relationships", () => {
);
const { chart: chart2 } = await createPatchChartPair(workspace.id, {
patches: [],
- } as any);
+ } as never);
// Verify charts belong to workspace
const metadata = await getChartsMetadata(workspace.id);
diff --git a/workbench/_web/src/lib/api/chartApi.ts b/workbench/_web/src/lib/api/chartApi.ts
index 427a6507..73b86106 100644
--- a/workbench/_web/src/lib/api/chartApi.ts
+++ b/workbench/_web/src/lib/api/chartApi.ts
@@ -4,6 +4,8 @@ import {
setChartData,
deleteChart,
createLens2ChartPair,
+ createLogitLensIntroChartPair,
+ createCMIntroChartPair,
createPatchChartPair,
createActivationPatchingChartPair,
updateChartName,
@@ -12,6 +14,7 @@ import {
} from "@/lib/queries/chartQueries";
import { LensConfigData } from "@/types/lens";
import { Lens2ConfigData } from "@/types/lens2";
+import { LogitLensIntroConfigData } from "@/types/logitLensIntro";
import { PatchingConfig } from "@/types/patching";
import { ActivationPatchingConfigData } from "@/types/activationPatching";
import { useCapture } from "@/components/providers/CaptureProvider";
@@ -57,7 +60,7 @@ export const useLensLine = () => {
const chartKey = queryKeys.charts.chart(lensRequest.chartId);
await queryClient.cancelQueries({ queryKey: chartKey });
const previousChart = queryClient.getQueryData(chartKey);
- queryClient.setQueryData(chartKey, (old: any) => {
+ queryClient.setQueryData(chartKey, (old: Record | undefined) => {
if (!old) return old;
return { ...old, type: "line" };
});
@@ -79,7 +82,7 @@ export const useLensLine = () => {
},
onError: (error, variables, context) => {
if (context?.previousChart) {
- queryClient.setQueryData(context.chartKey, context.previousChart as any);
+ queryClient.setQueryData(context.chartKey, context.previousChart);
}
toast.error("Failed to compute lens line (timeout or error)");
},
@@ -100,7 +103,9 @@ export const useLensLine = () => {
// GenerateButton), and useUpdateChartConfig owns that invalidation.
// Triggering it here races the in-flight config write and can cache
// the pre-write (stale) row.
- const chart = queryClient.getQueryData(chartKey) as any;
+ const chart = queryClient.getQueryData(chartKey) as
+ | { workspaceId?: string }
+ | undefined;
if (chart?.workspaceId) {
queryClient.invalidateQueries({
queryKey: queryKeys.charts.sidebar(chart.workspaceId),
@@ -144,7 +149,7 @@ export const useLensGrid = () => {
const chartKey = queryKeys.charts.chart(lensRequest.chartId);
await queryClient.cancelQueries({ queryKey: chartKey });
const previousChart = queryClient.getQueryData(chartKey);
- queryClient.setQueryData(chartKey, (old: any) => {
+ queryClient.setQueryData(chartKey, (old: Record | undefined) => {
if (!old) return old;
return { ...old, type: "heatmap" };
});
@@ -166,7 +171,7 @@ export const useLensGrid = () => {
},
onError: (error, variables, context) => {
if (context?.previousChart) {
- queryClient.setQueryData(context.chartKey, context.previousChart as any);
+ queryClient.setQueryData(context.chartKey, context.previousChart);
}
toast.error("Failed to compute grid lens (timeout or error)");
},
@@ -187,7 +192,9 @@ export const useLensGrid = () => {
// GenerateButton), and useUpdateChartConfig owns that invalidation.
// Triggering it here races the in-flight config write and can cache
// the pre-write (stale) row.
- const chart = queryClient.getQueryData(chartKey) as any;
+ const chart = queryClient.getQueryData(chartKey) as
+ | { workspaceId?: string }
+ | undefined;
if (chart?.workspaceId) {
queryClient.invalidateQueries({
queryKey: queryKeys.charts.sidebar(chart.workspaceId),
@@ -271,6 +278,45 @@ export const useCreateLens2ChartPair = () => {
});
};
+export const useCreateLogitLensIntroChartPair = () => {
+ const queryClient = useQueryClient();
+
+ const defaultConfig: LogitLensIntroConfigData = {
+ prompt: "",
+ model: "",
+ topk: 5,
+ includeEntropy: true,
+ };
+
+ return useMutation({
+ mutationFn: async ({
+ workspaceId,
+ config = defaultConfig,
+ }: {
+ workspaceId: string;
+ config?: LogitLensIntroConfigData;
+ }) => {
+ return await createLogitLensIntroChartPair(workspaceId, config);
+ },
+ onSuccess: () => {
+ queryClient.invalidateQueries({ queryKey: queryKeys.charts.sidebar() });
+ },
+ });
+};
+
+export const useCreateCMIntroChartPair = () => {
+ const queryClient = useQueryClient();
+
+ return useMutation({
+ mutationFn: async ({ workspaceId }: { workspaceId: string }) => {
+ return await createCMIntroChartPair(workspaceId);
+ },
+ onSuccess: () => {
+ queryClient.invalidateQueries({ queryKey: queryKeys.charts.sidebar() });
+ },
+ });
+};
+
// TODO(cadentj): FIX THIS
export const useCreatePatchChartPair = () => {
const queryClient = useQueryClient();
diff --git a/workbench/_web/src/lib/api/cmIntroApi.ts b/workbench/_web/src/lib/api/cmIntroApi.ts
new file mode 100644
index 00000000..30e26953
--- /dev/null
+++ b/workbench/_web/src/lib/api/cmIntroApi.ts
@@ -0,0 +1,167 @@
+/**
+ * CM Intro API — reuses the /logit_lens backend to compute logit lens
+ * results for a source and target prompt pair, plus /causal_mediation
+ * for the cell-drop intervention.
+ */
+
+import config from "@/lib/config";
+import { useMutation, useQueryClient } from "@tanstack/react-query";
+import { startAndPoll } from "../startAndPoll";
+import { createUserHeadersAction } from "@/actions/auth";
+import { setChartData, getChartById } from "@/lib/queries/chartQueries";
+import { queryKeys } from "../queryKeys";
+import { toast } from "sonner";
+import type { LogitLensIntroData } from "@/types/logitLensIntro";
+import type { CMIntroChartData, CMIntroInterventionSpec } from "@/types/cmIntro";
+
+export interface CMIntroLensRequest {
+ sourcePrompt: string;
+ targetPrompt: string;
+ model: string;
+ chartId: string;
+ topk?: number;
+ includeEntropy?: boolean;
+}
+
+export interface CMIntroLensResult {
+ source: LogitLensIntroData;
+ // Optional: when the user runs CM Intro in single-prompt mode (target blank)
+ // we only compute the source lens.
+ target: LogitLensIntroData | null;
+}
+
+export interface CMIntroInterventionRequest {
+ model: string;
+ srcPrompt: string;
+ tgtPrompt: string;
+ intervention: CMIntroInterventionSpec;
+ chartId: string;
+ topk?: number;
+ includeEntropy?: boolean;
+}
+
+const runLogitLens = async (
+ prompt: string,
+ model: string,
+ topk: number,
+ includeEntropy: boolean,
+ headers: Record,
+): Promise => {
+ return await startAndPoll(
+ config.endpoints.startLens2,
+ {
+ model,
+ prompt,
+ topk,
+ include_entropy: includeEntropy,
+ },
+ config.endpoints.resultsLens2,
+ headers,
+ );
+};
+
+export const useCMIntroLogitLens = () => {
+ const queryClient = useQueryClient();
+
+ return useMutation({
+ mutationKey: ["cmIntroLogitLens"],
+ mutationFn: async (request: CMIntroLensRequest): Promise => {
+ const headers = await createUserHeadersAction();
+ const topk = request.topk ?? 5;
+ const includeEntropy = request.includeEntropy ?? true;
+ const hasTarget = !!request.targetPrompt && request.targetPrompt.trim().length > 0;
+
+ const [source, target] = await Promise.all([
+ runLogitLens(request.sourcePrompt, request.model, topk, includeEntropy, headers),
+ hasTarget
+ ? runLogitLens(
+ request.targetPrompt,
+ request.model,
+ topk,
+ includeEntropy,
+ headers,
+ )
+ : Promise.resolve(null as unknown as LogitLensIntroData | null),
+ ]);
+
+ // Running the base lens invalidates any prior intervention/result,
+ // so we persist only the fresh { source, target } pair (alongside
+ // the prompts they were computed from, so revisits restore the UI).
+ // lastRunSourcePrompt/lastRunTargetPrompt snapshot the prompts that
+ // were actually run, so the predicted-next-token hint can hide on
+ // edits even though autosave keeps sourcePrompt/targetPrompt fresh.
+ const persisted: CMIntroChartData = {
+ sourcePrompt: request.sourcePrompt,
+ targetPrompt: request.targetPrompt,
+ lastRunSourcePrompt: request.sourcePrompt,
+ lastRunTargetPrompt: request.targetPrompt,
+ source,
+ ...(target ? { target } : {}),
+ };
+ await setChartData(request.chartId, persisted, "cm-intro");
+
+ return { source, target };
+ },
+ onError: () => {
+ toast.error("Failed to run logit lens.");
+ },
+ onSuccess: async (_data, variables) => {
+ const chartKey = queryKeys.charts.chart(variables.chartId);
+ await queryClient.invalidateQueries({ queryKey: chartKey });
+ },
+ });
+};
+
+export const useCMIntroIntervention = () => {
+ const queryClient = useQueryClient();
+
+ return useMutation({
+ mutationKey: ["cmIntroIntervention"],
+ mutationFn: async (request: CMIntroInterventionRequest): Promise => {
+ const headers = await createUserHeadersAction();
+ const topk = request.topk ?? 5;
+ const includeEntropy = request.includeEntropy ?? true;
+
+ const body = {
+ model: request.model,
+ src_prompt: request.srcPrompt,
+ tgt_prompt: request.tgtPrompt,
+ src_token_pos: request.intervention.srcTokenPos,
+ src_layer: request.intervention.srcLayer,
+ tgt_token_pos: request.intervention.tgtTokenPos,
+ tgt_layer: request.intervention.tgtLayer,
+ topk,
+ include_entropy: includeEntropy,
+ };
+
+ const result = await startAndPoll(
+ config.endpoints.startCausalMediation,
+ body,
+ config.endpoints.resultsCausalMediation,
+ headers,
+ );
+
+ // Merge onto existing chart data so we preserve source/target and
+ // the prompts they were computed from.
+ const existingChart = await getChartById(request.chartId);
+ const existingData = (existingChart?.data ?? {}) as Partial;
+ const merged: CMIntroChartData = {
+ ...existingData,
+ sourcePrompt: existingData.sourcePrompt ?? request.srcPrompt,
+ targetPrompt: existingData.targetPrompt ?? request.tgtPrompt,
+ intervention: request.intervention,
+ result,
+ };
+ await setChartData(request.chartId, merged, "cm-intro");
+
+ return result;
+ },
+ onError: () => {
+ toast.error("Failed to run causal mediation intervention.");
+ },
+ onSuccess: async (_data, variables) => {
+ const chartKey = queryKeys.charts.chart(variables.chartId);
+ await queryClient.invalidateQueries({ queryKey: chartKey });
+ },
+ });
+};
diff --git a/workbench/_web/src/lib/api/logitLensIntroApi.ts b/workbench/_web/src/lib/api/logitLensIntroApi.ts
new file mode 100644
index 00000000..f02c8663
--- /dev/null
+++ b/workbench/_web/src/lib/api/logitLensIntroApi.ts
@@ -0,0 +1,95 @@
+/**
+ * Logit Lens Intro API — reuses the same /logit_lens backend endpoints as lens2
+ */
+
+import config from "@/lib/config";
+import { useMutation, useQueryClient } from "@tanstack/react-query";
+import { setChartData } from "@/lib/queries/chartQueries";
+import { LogitLensIntroConfigData, LogitLensIntroData } from "@/types/logitLensIntro";
+import { queryKeys } from "../queryKeys";
+import { toast } from "sonner";
+import { startAndPoll } from "../startAndPoll";
+import { createUserHeadersAction } from "@/actions/auth";
+
+interface LogitLensIntroRequest {
+ completion: LogitLensIntroConfigData;
+ chartId: string;
+}
+
+const getLogitLensIntro = async (
+ lensRequest: LogitLensIntroRequest,
+): Promise => {
+ const headers = await createUserHeadersAction();
+
+ const request = {
+ model: lensRequest.completion.model,
+ prompt: lensRequest.completion.prompt,
+ topk: lensRequest.completion.topk ?? 5,
+ include_entropy: lensRequest.completion.includeEntropy ?? true,
+ };
+
+ return await startAndPoll(
+ config.endpoints.startLens2,
+ request,
+ config.endpoints.resultsLens2,
+ headers,
+ );
+};
+
+export const useLogitLensIntro = () => {
+ const queryClient = useQueryClient();
+
+ return useMutation({
+ mutationKey: ["logitLensIntro"],
+ onMutate: async ({
+ lensRequest,
+ }: {
+ lensRequest: LogitLensIntroRequest;
+ configId: string;
+ }) => {
+ const chartKey = queryKeys.charts.chart(lensRequest.chartId);
+ await queryClient.cancelQueries({ queryKey: chartKey });
+ const previousChart = queryClient.getQueryData(chartKey);
+ queryClient.setQueryData(chartKey, (old: unknown) => {
+ if (!old) return old;
+ return { ...(old as object), type: "logit-lens-intro" };
+ });
+ return { previousChart, chartKey } as {
+ previousChart: unknown;
+ chartKey: ReturnType;
+ };
+ },
+ mutationFn: async ({
+ lensRequest,
+ }: {
+ lensRequest: LogitLensIntroRequest;
+ configId: string;
+ }) => {
+ const response = await getLogitLensIntro(lensRequest);
+ await setChartData(lensRequest.chartId, response, "logit-lens-intro");
+ return response;
+ },
+ onError: (error, variables, context) => {
+ if (context?.previousChart) {
+ queryClient.setQueryData(context.chartKey, context.previousChart);
+ }
+ toast.error("Failed to compute logit lens visualization");
+ },
+ onSuccess: async (data, variables) => {
+ const chartKey = queryKeys.charts.chart(variables.lensRequest.chartId);
+ await queryClient.invalidateQueries({ queryKey: chartKey });
+
+ const chart = queryClient.getQueryData(chartKey) as
+ | { workspaceId?: string }
+ | undefined;
+ if (chart?.workspaceId) {
+ queryClient.invalidateQueries({
+ queryKey: ["chartsForSidebar", chart.workspaceId],
+ });
+ queryClient.invalidateQueries({
+ queryKey: queryKeys.charts.configByChart(variables.lensRequest.chartId),
+ });
+ }
+ },
+ });
+};
diff --git a/workbench/_web/src/lib/config.ts b/workbench/_web/src/lib/config.ts
index fdf94077..34230b6c 100644
--- a/workbench/_web/src/lib/config.ts
+++ b/workbench/_web/src/lib/config.ts
@@ -16,6 +16,9 @@ const config = {
startLens2: "/logit_lens/start",
resultsLens2: (jobId: string) => `/logit_lens/results/${jobId}`,
+ startCausalMediation: "/causal_mediation/start",
+ resultsCausalMediation: (jobId: string) => `/causal_mediation/results/${jobId}`,
+
startActivationPatching: "/activation_patching/start",
resultsActivationPatching: (jobId: string) => `/activation_patching/results/${jobId}`,
diff --git a/workbench/_web/src/lib/queries/chartQueries.ts b/workbench/_web/src/lib/queries/chartQueries.ts
index ad0320f9..3460c795 100644
--- a/workbench/_web/src/lib/queries/chartQueries.ts
+++ b/workbench/_web/src/lib/queries/chartQueries.ts
@@ -7,6 +7,7 @@ import { LensConfigData } from "@/types/lens";
import { Lens2ConfigData } from "@/types/lens2";
import { PatchingConfig } from "@/types/patching";
import { ActivationPatchingConfigData } from "@/types/activationPatching";
+import { LogitLensIntroConfigData } from "@/types/logitLensIntro";
import { eq, asc, desc } from "drizzle-orm";
import { touchWorkspace, getNextWorkspaceItemPosition } from "@/lib/queries/workspaceQueries";
@@ -56,7 +57,9 @@ type ConfigPayload =
| { type: "lens"; data: LensConfigData }
| { type: "lens2"; data: Lens2ConfigData }
| { type: "patch"; data: PatchingConfig }
- | { type: "activation-patching"; data: ActivationPatchingConfigData };
+ | { type: "activation-patching"; data: ActivationPatchingConfigData }
+ | { type: "logit-lens-intro"; data: LogitLensIntroConfigData }
+ | { type: "cm-intro"; data: Record };
// Creates a chart, its config, and the link between them, with the chart
// positioned at the bottom of the unified sidebar list.
@@ -86,6 +89,14 @@ export const createLensChartPair = async (
export const createLens2ChartPair = async (workspaceId: string, defaultConfig: Lens2ConfigData) =>
createChartConfigPair(workspaceId, { type: "lens2", data: defaultConfig });
+export const createLogitLensIntroChartPair = async (
+ workspaceId: string,
+ defaultConfig: LogitLensIntroConfigData,
+) => createChartConfigPair(workspaceId, { type: "logit-lens-intro", data: defaultConfig });
+
+export const createCMIntroChartPair = async (workspaceId: string) =>
+ createChartConfigPair(workspaceId, { type: "cm-intro", data: {} });
+
export const createPatchChartPair = async (workspaceId: string, defaultConfig: PatchingConfig) =>
createChartConfigPair(workspaceId, { type: "patch", data: defaultConfig });
diff --git a/workbench/_web/src/lib/queries/documentQueries.ts b/workbench/_web/src/lib/queries/documentQueries.ts
index ad0447a9..1bd36504 100644
--- a/workbench/_web/src/lib/queries/documentQueries.ts
+++ b/workbench/_web/src/lib/queries/documentQueries.ts
@@ -38,7 +38,16 @@ export const updateDocument = async (
// Extract plain text from a Lexical SerializedEditorState
function extractPlainTextFromLexical(content: SerializedEditorState): string {
try {
- const visit = (node: any, lines: string[], currentLine: string[]) => {
+ type LexicalNode = {
+ type?: string;
+ text?: unknown;
+ children?: LexicalNode[];
+ };
+ const visit = (
+ node: LexicalNode | null | undefined,
+ lines: string[],
+ currentLine: string[],
+ ) => {
if (!node) return;
const type = node.type;
if (type === "text" && typeof node.text === "string") {
@@ -54,7 +63,10 @@ function extractPlainTextFromLexical(content: SerializedEditorState): string {
currentLine.length = 0;
}
// For block-level nodes, terminate the line
- if (["paragraph", "heading", "quote", "list", "listitem", "code"].includes(type)) {
+ if (
+ typeof type === "string" &&
+ ["paragraph", "heading", "quote", "list", "listitem", "code"].includes(type)
+ ) {
if (currentLine.length > 0) {
lines.push(currentLine.join(""));
currentLine.length = 0;
@@ -65,7 +77,7 @@ function extractPlainTextFromLexical(content: SerializedEditorState): string {
}
};
const lines: string[] = [];
- visit((content as any).root, lines, []);
+ visit((content as { root?: LexicalNode }).root, lines, []);
return lines.join("\n");
} catch {
return "";
diff --git a/workbench/_web/src/lib/supabase/server.ts b/workbench/_web/src/lib/supabase/server.ts
index 39b89ec1..7b74da22 100644
--- a/workbench/_web/src/lib/supabase/server.ts
+++ b/workbench/_web/src/lib/supabase/server.ts
@@ -28,7 +28,7 @@ export async function createClient() {
storage: {
from: () => ({}),
},
- } as any;
+ } as unknown as ReturnType;
}
return createServerClient(
diff --git a/workbench/_web/src/lib/utils.ts b/workbench/_web/src/lib/utils.ts
index 4c0f37ec..445ba7e0 100644
--- a/workbench/_web/src/lib/utils.ts
+++ b/workbench/_web/src/lib/utils.ts
@@ -19,27 +19,27 @@ export function hslFromCssVar(name: string, fallback = "#000000"): string {
* Recursively processes a theme object, converting CSS variable strings to concrete colors
* for Canvas compatibility. Only processes strings that match the pattern "hsl(var(--...))"
*/
-export function resolveThemeCssVars(obj: any): any {
+export function resolveThemeCssVars(obj: T): T {
if (typeof obj === "string") {
// Match CSS variable pattern: hsl(var(--variable-name))
const cssVarMatch = obj.match(/^hsl\(var\((--.+?)\)\)$/);
if (cssVarMatch) {
const varName = cssVarMatch[1];
- return hslFromCssVar(varName, obj); // fallback to original if resolution fails
+ return hslFromCssVar(varName, obj) as T; // fallback to original if resolution fails
}
return obj;
}
if (Array.isArray(obj)) {
- return obj.map(resolveThemeCssVars);
+ return obj.map(resolveThemeCssVars) as T;
}
if (obj !== null && typeof obj === "object") {
- const result: any = {};
+ const result: Record = {};
for (const [key, value] of Object.entries(obj)) {
result[key] = resolveThemeCssVars(value);
}
- return result;
+ return result as T;
}
return obj;
diff --git a/workbench/_web/src/stores/useCMIntroTutorial.ts b/workbench/_web/src/stores/useCMIntroTutorial.ts
new file mode 100644
index 00000000..cdc7e21c
--- /dev/null
+++ b/workbench/_web/src/stores/useCMIntroTutorial.ts
@@ -0,0 +1,38 @@
+import { create } from "zustand";
+
+const STORAGE_KEY = "workbench:cm-intro-tutorial-completed:v1";
+
+interface CMIntroTutorialState {
+ completed: boolean;
+ markCompleted: () => void;
+ reset: () => void;
+}
+
+function readCompleted(): boolean {
+ if (typeof window === "undefined") return false;
+ return localStorage.getItem(STORAGE_KEY) === "true";
+}
+
+export const useCMIntroTutorial = create()((set) => ({
+ completed: false,
+ markCompleted: () => {
+ if (typeof window !== "undefined") {
+ localStorage.setItem(STORAGE_KEY, "true");
+ }
+ set({ completed: true });
+ },
+ reset: () => {
+ if (typeof window !== "undefined") {
+ localStorage.removeItem(STORAGE_KEY);
+ }
+ set({ completed: false });
+ },
+}));
+
+/**
+ * Hydrate from localStorage on the client. Call once at mount in a client
+ * component.
+ */
+export function hydrateCMIntroTutorial() {
+ useCMIntroTutorial.setState({ completed: readCompleted() });
+}
diff --git a/workbench/_web/src/tutorials/cmIntro.ts b/workbench/_web/src/tutorials/cmIntro.ts
new file mode 100644
index 00000000..12d9605d
--- /dev/null
+++ b/workbench/_web/src/tutorials/cmIntro.ts
@@ -0,0 +1,70 @@
+import type { ExtendedStepType, TutorialChapterProgress, TutorialProgress } from "@/types/tutorial";
+
+/**
+ * Tutorial for the CM (causal mediation / activation patching) intro feature.
+ *
+ * Selector ids must be present in CMIntroArea.tsx and CMIntroDisplay.tsx.
+ */
+const CMIntroSteps: ExtendedStepType[] = [
+ {
+ selector: "#cm-intro-welcome",
+ content:
+ "Welcome to activation patching.\n\nThis tool lets you take a piece of internal state from one prompt's forward pass and inject it into another prompt's forward pass at the same layer and position. By watching what changes downstream, you can locate the parts of the model that 'carry' a particular piece of information.",
+ styles: {
+ maskArea: (base) => ({ ...base, display: "none" }),
+ },
+ },
+ {
+ selector: "#cm-intro-source-prompt",
+ content:
+ "Start with a Source prompt — the prompt whose internal state you want to STEAL from. Pick something that produces a clear, specific prediction (e.g. 'The Eiffel Tower is in the city of').",
+ },
+ {
+ selector: "#cm-intro-target-prompt",
+ content:
+ "Add a Target prompt — the prompt whose forward pass you'll PATCH INTO. Typically you pick something with similar grammar but a different answer (e.g. 'The Big Ben is in the city of'), so any change in the target's prediction tells you that the patched activation carried the answer.\n\nYou can leave Target blank to use CM Intro as a single-prompt lens viewer.",
+ },
+ {
+ selector: "#cm-intro-run",
+ content: "Click Run to compute the logit lens for both prompts.",
+ trigger: {
+ type: "click",
+ target: "#cm-intro-run",
+ },
+ },
+ {
+ selector: "#cm-intro-display",
+ content:
+ "You now see two heatmaps, one per prompt. Same axes as the lens: rows are token positions, columns are layers. Click any cell to see a crosshair + the 'About this view' legend in the side panel.",
+ },
+ {
+ selector: "#cm-intro-display",
+ content:
+ "To run an intervention, DRAG a cell from the Source heatmap and DROP it on a cell in the Target heatmap. Workbench will:\n\n 1. Run the Source prompt and capture the residual stream at the cell you dragged.\n 2. Run the Target prompt, but overwrite the residual stream at the drop position with the captured Source value.\n 3. Show you the Target's NEW per-layer predictions in a third heatmap below.",
+ },
+ {
+ selector: "#cm-intro-display",
+ content:
+ "Reading the result heatmap: compare it to the original Target heatmap. Any cell that CHANGED means the patched activation influenced that downstream computation. Cells that stayed the same were unaffected — that part of the network ignored the swap.\n\nThe pattern of changes tells you where in the model the patched information is being read by later layers.",
+ },
+ {
+ selector: "#cm-intro-display",
+ content:
+ "A classic experimental design: pick a Source and Target that differ in ONE controlled way (e.g. 'Paris' vs 'London' as the answer). Then patch one layer at a time. The earliest layer where patching flips the Target's prediction to the Source's answer tells you where that fact is 'stored' in the residual stream.\n\nThis is the basic recipe for causal mediation analysis.",
+ },
+];
+
+const CMIntroChapters: TutorialChapterProgress[] = [
+ {
+ title: "Getting started",
+ steps: CMIntroSteps,
+ completed: false,
+ },
+];
+
+export const CMIntroTutorial: TutorialProgress = {
+ chapters: CMIntroChapters,
+ currentChapter: 0,
+ description:
+ "A walkthrough of CM Intro (causal mediation / activation patching). Learn how to set up source and target prompts, run an intervention by dragging a cell from one heatmap to another, and interpret the resulting changes.",
+};
diff --git a/workbench/_web/src/types/charts.ts b/workbench/_web/src/types/charts.ts
index e0c676c0..0b2bfcee 100644
--- a/workbench/_web/src/types/charts.ts
+++ b/workbench/_web/src/types/charts.ts
@@ -1,5 +1,7 @@
import { LensConfigData } from "./lens";
import { Lens2ConfigData, Lens2Data } from "./lens2";
+import { LogitLensIntroConfigData, LogitLensIntroData } from "./logitLensIntro";
+import { CMIntroChartData } from "./cmIntro";
import { PatchingConfig } from "./patching";
import { ActivationPatchingConfigData, ActivationPatchingData } from "./activationPatching";
@@ -63,16 +65,35 @@ export interface LineViewData {
// Combined Types
-export type ChartData = Line[] | HeatmapRow[] | Lens2Data | ActivationPatchingData;
+export type ChartData =
+ | Line[]
+ | HeatmapRow[]
+ | Lens2Data
+ | ActivationPatchingData
+ | LogitLensIntroData
+ | CMIntroChartData;
export type ChartView = HeatmapViewData | LineViewData;
export type ConfigData =
| LensConfigData
| Lens2ConfigData
| PatchingConfig
- | ActivationPatchingConfigData;
-
-export type ChartType = "line" | "heatmap" | "lens2" | "activation-patching";
-export type ToolType = "lens" | "lens2" | "patch" | "activation-patching";
+ | ActivationPatchingConfigData
+ | LogitLensIntroConfigData;
+
+export type ChartType =
+ | "line"
+ | "heatmap"
+ | "lens2"
+ | "activation-patching"
+ | "logit-lens-intro"
+ | "cm-intro";
+export type ToolType =
+ | "lens"
+ | "lens2"
+ | "patch"
+ | "activation-patching"
+ | "logit-lens-intro"
+ | "cm-intro";
export type ChartMetadata = {
id: string;
diff --git a/workbench/_web/src/types/cmIntro.ts b/workbench/_web/src/types/cmIntro.ts
new file mode 100644
index 00000000..6317b10c
--- /dev/null
+++ b/workbench/_web/src/types/cmIntro.ts
@@ -0,0 +1,31 @@
+/**
+ * CM Intro chart data types.
+ *
+ * Persisted into the chart row's `data` field when `type = "cm-intro"`.
+ * `source`/`target` are the per-prompt logit-lens results; `intervention`/`result`
+ * are populated after a causal-mediation cell-drop on the explorer.
+ */
+
+import type { LogitLensIntroData } from "./logitLensIntro";
+
+export interface CMIntroInterventionSpec {
+ srcTokenPos: number;
+ srcLayer: number;
+ tgtTokenPos: number;
+ tgtLayer: number;
+}
+
+export interface CMIntroChartData {
+ sourcePrompt: string;
+ targetPrompt: string;
+ source?: LogitLensIntroData;
+ target?: LogitLensIntroData;
+ // Snapshot of the prompts the persisted lens was actually computed from.
+ // Distinct from sourcePrompt/targetPrompt (which autosave on every edit) —
+ // these only update on a successful lens run, so the UI can hide the
+ // predicted-next-token hint when the user starts editing.
+ lastRunSourcePrompt?: string;
+ lastRunTargetPrompt?: string;
+ intervention?: CMIntroInterventionSpec;
+ result?: LogitLensIntroData;
+}
diff --git a/workbench/_web/src/types/logitLensIntro.ts b/workbench/_web/src/types/logitLensIntro.ts
new file mode 100644
index 00000000..357330f3
--- /dev/null
+++ b/workbench/_web/src/types/logitLensIntro.ts
@@ -0,0 +1,16 @@
+/**
+ * Logit Lens Intro Types
+ *
+ * Uses the same backend as lens2 but renders with the edulogitlens widget.
+ */
+
+import type { LogitLensData } from "nnsightful";
+
+export type LogitLensIntroData = LogitLensData;
+
+export interface LogitLensIntroConfigData {
+ model: string;
+ prompt: string;
+ topk?: number;
+ includeEntropy?: boolean;
+}
diff --git a/workbench/_web/tailwind.config.ts b/workbench/_web/tailwind.config.ts
index 0b140a46..f73b39b8 100644
--- a/workbench/_web/tailwind.config.ts
+++ b/workbench/_web/tailwind.config.ts
@@ -6,6 +6,11 @@ export default {
"./src/pages/**/*.{js,ts,jsx,tsx,mdx}",
"./src/components/**/*.{js,ts,jsx,tsx,mdx}",
"./src/app/**/*.{js,ts,jsx,tsx,mdx}",
+ // edulogitlens is a file: workspace dep; its components use Tailwind
+ // classes (e.g. bg-white on the cell-prediction popup) that need to be
+ // scanned here or they don't make it into the compiled CSS bundle.
+ "./node_modules/edulogitlens/src/**/*.{js,ts,jsx,tsx}",
+ "./node_modules/edulogitlens/index.ts",
],
theme: {
extend: {
diff --git a/workbench/_web/tests/k6/utils.ts b/workbench/_web/tests/k6/utils.ts
index 0a16ef01..21afc12c 100644
--- a/workbench/_web/tests/k6/utils.ts
+++ b/workbench/_web/tests/k6/utils.ts
@@ -19,7 +19,7 @@ const config = {
interface JobStartResponse {
job_id: string | null;
data?: T;
- [key: string]: any;
+ [key: string]: unknown;
}
function awaitNDIFJob(jobId: string): void {
@@ -52,7 +52,7 @@ function awaitNDIFJob(jobId: string): void {
}
}
-function startJob(url: string, body: any): JobStartResponse {
+function startJob(url: string, body: unknown): JobStartResponse {
const response = http.post(url, JSON.stringify(body), {
headers: { "Content-Type": "application/json" },
});
@@ -64,7 +64,7 @@ function startJob(url: string, body: any): JobStartResponse {
return JSON.parse(response.body as string) as JobStartResponse;
}
-function fetchResults(url: string, body: any): T {
+function fetchResults(url: string, body: unknown): T {
const resp = http.post(url, JSON.stringify(body), {
headers: { "Content-Type": "application/json" },
});
@@ -83,7 +83,7 @@ export interface PollResult {
export function startAndPoll(
startEndpoint: string,
- body: any,
+ body: unknown,
resultsEndpoint: (jobId: string) => string,
): PollResult {
const startTime = Date.now();
@@ -95,7 +95,7 @@ export function startAndPoll(
awaitNDIFJob(jobId);
const resultsUrl = config.getApiUrl(resultsEndpoint(jobId));
- const results = fetchResults(resultsUrl, body);
+ const results = fetchResults(resultsUrl, body);
const pollDuration = Date.now() - startTime;
diff --git a/workbench/package-lock.json b/workbench/package-lock.json
new file mode 100644
index 00000000..7bc22c1b
--- /dev/null
+++ b/workbench/package-lock.json
@@ -0,0 +1,6 @@
+{
+ "name": "workbench",
+ "lockfileVersion": 3,
+ "requires": true,
+ "packages": {}
+}