diff --git a/package.json b/package.json new file mode 100644 index 00000000..668e1889 --- /dev/null +++ b/package.json @@ -0,0 +1 @@ +{"dependencies": {}} \ No newline at end of file diff --git a/scripts/api.sh b/scripts/api.sh index b69f0e34..9b0b0cf4 100644 --- a/scripts/api.sh +++ b/scripts/api.sh @@ -1,4 +1,4 @@ #!/bin/bash cd workbench -uvicorn _api.main:app --host 0.0.0.0 --port 8000 --reload \ No newline at end of file +python -m uvicorn _api.main:app --host 0.0.0.0 --port 8000 --reload \ No newline at end of file diff --git a/workbench/_api/main.py b/workbench/_api/main.py index 08755b02..8020ad93 100644 --- a/workbench/_api/main.py +++ b/workbench/_api/main.py @@ -1,13 +1,17 @@ -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse import logging import os +import traceback import anyio -from .routes import lens, patch, models, logit_lens, activation_patching +from .routes import lens, patch, models, logit_lens, activation_patching, causal_mediation from .state import AppState -from dotenv import load_dotenv; load_dotenv() +from pathlib import Path +from dotenv import load_dotenv +load_dotenv(Path(__file__).resolve().parent / ".env") # Configure logging logging.basicConfig( @@ -59,9 +63,19 @@ def fastapi_app(): app.include_router(lens, prefix="/lens") app.include_router(logit_lens, prefix="/logit_lens") app.include_router(activation_patching, prefix="/activation_patching") + app.include_router(causal_mediation, prefix="/causal_mediation", tags=["causal_mediation"]) app.include_router(patch, prefix="/patch") app.include_router(models, prefix="/models") + @app.exception_handler(Exception) + async def global_exception_handler(request: Request, exc: Exception): + tb = traceback.format_exception(type(exc), exc, exc.__traceback__) + logging.error(f"Unhandled exception: {''.join(tb)}") + return JSONResponse( + status_code=500, + content={"detail": str(exc), "traceback": ''.join(tb[-3:])}, + ) + app.state.m = AppState() return app diff --git a/workbench/_api/routes/__init__.py b/workbench/_api/routes/__init__.py index a1b6e8a6..ca438fa6 100644 --- a/workbench/_api/routes/__init__.py +++ b/workbench/_api/routes/__init__.py @@ -3,9 +3,17 @@ from .models import router as models from .logit_lens import router as logit_lens from .activation_patching import router as activation_patching +from .causal_mediation import router as causal_mediation from nnsight import ndif import nnsightful ndif.register(nnsightful) -__all__ = ["lens", "patch", "models", "logit_lens", "activation_patching"] \ No newline at end of file +__all__ = [ + "lens", + "patch", + "models", + "logit_lens", + "activation_patching", + "causal_mediation", +] \ No newline at end of file diff --git a/workbench/_api/routes/causal_mediation.py b/workbench/_api/routes/causal_mediation.py new file mode 100644 index 00000000..1e3f69a5 --- /dev/null +++ b/workbench/_api/routes/causal_mediation.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from typing import Any + +import torch +from fastapi import APIRouter, Depends +from pydantic import BaseModel + +from ..auth import require_user_email +from ..data_models import NDIFResponse +from ..state import AppState, get_state + +from nnsightful.types import LogitLensData + +router = APIRouter() + + +class CausalMediationRequest(BaseModel): + model: str + src_prompt: str + tgt_prompt: str + src_token_pos: int + src_layer: int + tgt_token_pos: int + tgt_layer: int + topk: int = 5 + include_entropy: bool = True + + +class CausalMediationResponse(NDIFResponse): + """Identical shape to LogitLensResponse so the frontend can reuse the + existing logit-lens transform/renderer.""" + data: LogitLensData | None = None + + +def _format_lens( + logits: torch.Tensor, + tokenizer, + model_name: str, + input_tokens: list[str], + n_layers: int, + *, + top_k: int = 5, + include_entropy: bool = True, +) -> dict[str, Any]: + """Inlined mirror of the local `format()` inside + `nnsightful.tools.logit_lens._run`. Turns a [L, T, V] logits tensor into + the dict shape consumed by `LogitLensData(**...)`. + + Why: `LogitLensTool` doesn't override `_format`, so calling + `logit_lens_tool._format(...)` falls through to the abstract `Tool._format` + in `nnsightful/tools/_base.py` whose body is `...` — i.e. returns `None`. + """ + layers = list(range(n_layers)) + positions = list(range(len(input_tokens))) + + if include_entropy: + log_p = torch.nn.functional.log_softmax(logits, dim=-1) + p = log_p.exp() + entropy = torch.round(-(p * log_p).sum(dim=-1), decimals=3).tolist() + else: + entropy = None + + probs = torch.nn.functional.softmax(logits, dim=-1) + logits.to("cpu") # free memory + + _, top_indices = torch.topk(probs, k=top_k, dim=-1) + + topks = [ + [tokenizer.batch_decode(torch.tensor(pos).unsqueeze(dim=1)) for pos in layer] + for layer in top_indices.tolist() + ] + + unique_indices = [ + torch.unique(top_indices[:, pi, :].flatten(), sorted=False).tolist() + for pi in range(top_indices.shape[1]) + ] + probs = probs.permute(1, 2, 0) + trajectories = [ + { + tokenizer.decode(token): torch.round(probs[pos_idx][token], decimals=3).tolist() + for token in pos + } + for pos_idx, pos in enumerate(unique_indices) + ] + + return { + "meta": {"version": 2, "timestamp": "3h", "model": model_name}, + "layers": layers, + "input": input_tokens, + "tracked": trajectories, + "topk": topks, + "entropy": entropy, + "positions": positions, + } + + +def _run_causal_mediation( + model, + src_prompt: str, + tgt_prompt: str, + src_token_pos: int, + src_layer: int, + tgt_token_pos: int, + tgt_layer: int, + *, + remote: bool = False, + backend=None, +) -> dict[str, Any]: + """Capture the source residual at (src_layer, src_token_pos), patch it + into the target prompt at (tgt_layer, tgt_token_pos), then run a logit + lens over the *patched* forward pass. + + Patching pattern mirrors `nnsightful.tools.activation_patching._run`: + a slice-assign on `model.layers_output[i]` is sufficient to register the + intervention and have it propagate through subsequent layers — no + explicit `model.layers[i].output = (...)` reassignment is needed. + """ + n_layers = model.num_layers + + with torch.no_grad(): + with model.session(remote=remote, backend=backend): + # 1) Source pass — capture the residual at (src_layer, src_token_pos). + with model.trace(src_prompt): + src_hidden = model.layers_output[src_layer][0, src_token_pos].save() + + # 2) Target pass — at tgt_layer, slice-assign the source residual + # into the target's tgt_token_pos. project_on_vocab at every + # layer gives us a logit-lens grid over the patched pass. + with model.trace(tgt_prompt): + per_layer_logits = [] + for i in range(n_layers): + hs = model.layers_output[i] + if i == tgt_layer: + hs[0, tgt_token_pos][:] = src_hidden + per_layer_logits.append(model.project_on_vocab(hs)) + # Stack to a single [L, T, V] tensor so backend() returns a + # known shape on the remote path (one saved key: "logits"). + logits = torch.cat(per_layer_logits, dim=0).save() + + if remote and backend is not None: + return {"job_id": backend.job_id} + + return {"logits": logits} + + +@router.post("/start", response_model=CausalMediationResponse) +async def start_causal_mediation( + req: CausalMediationRequest, + state: AppState = Depends(get_state), + user_email: str = Depends(require_user_email), +): + model = state[req.model] + backend = state.make_backend(model=model) + + raw = _run_causal_mediation( + model, + req.src_prompt, + req.tgt_prompt, + req.src_token_pos, + req.src_layer, + req.tgt_token_pos, + req.tgt_layer, + remote=state.remote, + backend=backend, + ) + + if "job_id" in raw: + return {"job_id": raw["job_id"]} + + input_tokens = [ + str(model.tokenizer.decode(token)) + for token in model.tokenizer.encode(req.tgt_prompt) + ] + data = _format_lens( + raw["logits"], + tokenizer=model.tokenizer, + model_name=req.model, + input_tokens=input_tokens, + n_layers=model.num_layers, + top_k=req.topk, + include_entropy=req.include_entropy, + ) + return {"data": data} + + +@router.post("/results/{job_id}", response_model=CausalMediationResponse) +async def collect_causal_mediation( + job_id: str, + req: CausalMediationRequest, + state: AppState = Depends(get_state), + user_email: str = Depends(require_user_email), +): + backend = state.make_backend(job_id=job_id) + results = backend() + + model = state[req.model] + tokenizer = model.tokenizer + input_tokens = [ + str(tokenizer.decode(token)) + for token in tokenizer.encode(req.tgt_prompt) + ] + + data = _format_lens( + results["logits"], + tokenizer=tokenizer, + model_name=req.model, + input_tokens=input_tokens, + n_layers=model.num_layers, + top_k=req.topk, + include_entropy=req.include_entropy, + ) + + return {"data": data} diff --git a/workbench/_web/bun.lock b/workbench/_web/bun.lock index 2acc8147..24cfc38f 100644 --- a/workbench/_web/bun.lock +++ b/workbench/_web/bun.lock @@ -1,5 +1,6 @@ { "lockfileVersion": 1, + "configVersion": 0, "workspaces": { "": { "name": "nextjs-shadcn", @@ -34,6 +35,7 @@ "d3-delaunay": "^6.0.4", "dotenv": "^17.2.1", "drizzle-orm": "^0.44.4", + "edulogitlens": "github:jon-bell/edulogitlens#cm-only", "framer-motion": "^12.23.22", "html-to-image": "^1.11.13", "lexical": "^0.34.0", @@ -662,6 +664,12 @@ "@radix-ui/rect": ["@radix-ui/rect@1.1.1", "", {}, "sha512-HPwpGIzkl28mWyZqG52jiqDJ12waP11Pa1lGoiyUkIEuMLBP0oeK/C89esbXrxsky5we7dfd8U58nm0SgAWpVw=="], + "@react-dnd/asap": ["@react-dnd/asap@5.0.2", "", {}, "sha512-WLyfoHvxhs0V9U+GTsGilGgf2QsPl6ZZ44fnv0/b8T3nQyvzxidxsg/ZltbWssbsRDlYW8UKSQMTGotuTotZ6A=="], + + "@react-dnd/invariant": ["@react-dnd/invariant@4.0.2", "", {}, "sha512-xKCTqAK/FFauOM9Ta2pswIyT3D8AQlfrYdOi/toTPEhqCuAs1v5tcJ3Y08Izh1cJ5Jchwy9SeAXmMg6zrKs2iw=="], + + "@react-dnd/shallowequal": ["@react-dnd/shallowequal@4.0.2", "", {}, "sha512-/RVXdLvJxLg4QKvMoM5WlwNR9ViO9z8B/qPcc+C0Sa/teJY7QG7kJ441DwzOjMYEY7GmU4dj5EcGHIkKZiQZCA=="], + "@react-pdf/fns": ["@react-pdf/fns@3.1.3", "", {}, "sha512-0I7pApDr1/RLAKbizuLy/IHTEa93LSPy/bEwYniboC3Xqnp6Od8xFJKbKEzGw2wh/5zKFFwl00g4t9RwgIMc3w=="], "@react-pdf/font": ["@react-pdf/font@4.0.6", "", { "dependencies": { "@react-pdf/pdfkit": "^5.0.0", "@react-pdf/types": "^2.10.0", "fontkit": "^2.0.2", "is-url": "^1.2.4" } }, "sha512-1RxR/hTyZcbgjESUjrMms574xuS9PLB4ovqQx6jvgdrIHXUyeUtSH6i3Szw1qVfUnA9MfaEm1FBuydQeJD39BQ=="], @@ -1074,6 +1082,8 @@ "dlv": ["dlv@1.1.3", "", {}, "sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA=="], + "dnd-core": ["dnd-core@16.0.1", "", { "dependencies": { "@react-dnd/asap": "^5.0.1", "@react-dnd/invariant": "^4.0.1", "redux": "^4.2.0" } }, "sha512-HK294sl7tbw6F6IeuK16YSBUoorvHpY8RHO+9yFfaJyCDVb6n7PRcezrOEOa2SBCqiYpemh5Jx20ZcjKdFAVng=="], + "doctrine": ["doctrine@2.1.0", "", { "dependencies": { "esutils": "^2.0.2" } }, "sha512-35mSku4ZXK0vfCuHEDAwt55dg2jNajHZ1odvF+8SSr82EsZY4QmXfuWso8oEd8zRhVObSN18aM0CjSdoBX7zIw=="], "dom-helpers": ["dom-helpers@5.2.1", "", { "dependencies": { "@babel/runtime": "^7.8.7", "csstype": "^3.0.2" } }, "sha512-nRCa7CK3VTrM2NmGkIy4cbK7IZlgBE/PYMn55rrXefr5xXDP0LdtfPnblFDoVdcAfslJ7or6iqAUnx0CCGIWQA=="], @@ -1090,6 +1100,8 @@ "easy-table": ["easy-table@1.1.0", "", { "optionalDependencies": { "wcwidth": ">=1.0.1" } }, "sha512-oq33hWOSSnl2Hoh00tZWaIPi1ievrD9aFG82/IgjlycAnW9hHx5PkJiXpxPsgEE+H7BsbVQXFVFST8TEXS6/pA=="], + "edulogitlens": ["edulogitlens@github:jon-bell/edulogitlens#416f56b", { "dependencies": { "lucide-react": "^0.487.0", "motion": "^12.23.24", "react-dnd": "^16.0.1", "react-dnd-html5-backend": "^16.0.1" }, "peerDependencies": { "react": "^18.0.0 || ^19.0.0", "react-dom": "^18.0.0 || ^19.0.0" } }, "jon-bell-edulogitlens-416f56b", "sha512-ISmUpe0w3MzO32GVLKWKLolch7/0awOjZKTsQH/mXG5v+ai0HTpGFnNHi9Lb3MDuPW3UDWqwpiAD60ii3oECog=="], + "electron-to-chromium": ["electron-to-chromium@1.5.200", "", {}, "sha512-rFCxROw7aOe4uPTfIAx+rXv9cEcGx+buAF4npnhtTqCJk5KDFRnh3+KYj7rdVh6lsFt5/aPs+Irj9rZ33WMA7w=="], "emoji-regex": ["emoji-regex@9.2.2", "", {}, "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg=="], @@ -1630,6 +1642,10 @@ "react": ["react@18.3.1", "", { "dependencies": { "loose-envify": "^1.1.0" } }, "sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ=="], + "react-dnd": ["react-dnd@16.0.1", "", { "dependencies": { "@react-dnd/invariant": "^4.0.1", "@react-dnd/shallowequal": "^4.0.1", "dnd-core": "^16.0.1", "fast-deep-equal": "^3.1.3", "hoist-non-react-statics": "^3.3.2" }, "peerDependencies": { "@types/hoist-non-react-statics": ">= 3.3.1", "@types/node": ">= 12", "@types/react": ">= 16", "react": ">= 16.14" }, "optionalPeers": ["@types/hoist-non-react-statics", "@types/node", "@types/react"] }, "sha512-QeoM/i73HHu2XF9aKksIUuamHPDvRglEwdHL4jsp784BgUuWcg6mzfxT0QDdQz8Wj0qyRKx2eMg8iZtWvU4E2Q=="], + + "react-dnd-html5-backend": ["react-dnd-html5-backend@16.0.1", "", { "dependencies": { "dnd-core": "^16.0.1" } }, "sha512-Wu3dw5aDJmOGw8WjH1I1/yTH+vlXEL4vmjk5p+MHxP8HuHJS1lAGeIdG/hze1AvNeXWo/JgULV87LyQOr+r5jw=="], + "react-dom": ["react-dom@18.3.1", "", { "dependencies": { "loose-envify": "^1.1.0", "scheduler": "^0.23.2" }, "peerDependencies": { "react": "^18.3.1" } }, "sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw=="], "react-error-boundary": ["react-error-boundary@3.1.4", "", { "dependencies": { "@babel/runtime": "^7.12.5" }, "peerDependencies": { "react": ">=16.13.1" } }, "sha512-uM9uPzZJTF6wRQORmSrvOIgt4lJ9MC1sNgEOj2XGsDTRE4kmpWxg7ENK9EWNKJRMAOY9z0MuF4yIfl6gp4sotA=="], @@ -1660,6 +1676,8 @@ "readdirp": ["readdirp@3.6.0", "", { "dependencies": { "picomatch": "^2.2.1" } }, "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA=="], + "redux": ["redux@4.2.1", "", { "dependencies": { "@babel/runtime": "^7.9.2" } }, "sha512-LAUYz4lc+Do8/g7aeRa8JkyDErK6ekstQaqWQrNRW//MY1TvCEpMtpTWvlQ+FPbWCx+Xixu/6SHt5N0HR+SB4w=="], + "reflect.getprototypeof": ["reflect.getprototypeof@1.0.10", "", { "dependencies": { "call-bind": "^1.0.8", "define-properties": "^1.2.1", "es-abstract": "^1.23.9", "es-errors": "^1.3.0", "es-object-atoms": "^1.0.0", "get-intrinsic": "^1.2.7", "get-proto": "^1.0.1", "which-builtin-type": "^1.2.1" } }, "sha512-00o4I+DVrefhv+nX0ulyi3biSHCPDe+yLv5o/p6d/UVlirijB8E16FtfwSAi4g3tcqrQ4lRAqQSoFEZJehYEcw=="], "regexp.prototype.flags": ["regexp.prototype.flags@1.5.4", "", { "dependencies": { "call-bind": "^1.0.8", "define-properties": "^1.2.1", "es-errors": "^1.3.0", "get-proto": "^1.0.1", "gopd": "^1.2.0", "set-function-name": "^2.0.2" } }, "sha512-dYqgNSZbDwkaJ2ceRd9ojCGjBq+mOm9LmtXnAnEGyHhN/5R7iDW2TRw3h+o/jCFxus3P2LfWIIiwowAjANm7IA=="], @@ -2004,6 +2022,10 @@ "defaults/clone": ["clone@1.0.4", "", {}, "sha512-JQHZ2QMW6l3aH/j6xCqQThY/9OH4D/9ls34cgkUBiEeocRTU04tHfKPBsUK1PqZCUQM7GiA0IIXJSuXHI64Kbg=="], + "edulogitlens/lucide-react": ["lucide-react@0.487.0", "", { "peerDependencies": { "react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "sha512-aKqhOQ+YmFnwq8dWgGjOuLc8V1R9/c/yOd+zDY4+ohsR2Jo05lSGc3WsstYPIzcTpeosN7LoCkLReUUITvaIvw=="], + + "edulogitlens/motion": ["motion@12.38.0", "", { "dependencies": { "framer-motion": "^12.38.0", "tslib": "^2.4.0" }, "peerDependencies": { "@emotion/is-prop-valid": "*", "react": "^18.0.0 || ^19.0.0", "react-dom": "^18.0.0 || ^19.0.0" }, "optionalPeers": ["@emotion/is-prop-valid", "react", "react-dom"] }, "sha512-uYfXzeHlgThchzwz5Te47dlv5JOUC7OB4rjJ/7XTUgtBZD8CchMN8qEJ4ZVsUmTyYA44zjV0fBwsiktRuFnn+w=="], + "error-ex/is-arrayish": ["is-arrayish@0.2.1", "", {}, "sha512-zz06S8t0ozoDXMG+ube26zeCTNXcKIPJZJi8hBrF4idCLms4CG9QtK7qBl1boi5ODzFpjswb5JPmHCbMpjaYzg=="], "esbuild-register/debug": ["debug@4.4.1", "", { "dependencies": { "ms": "^2.1.3" } }, "sha512-KcKCqiftBJcZr++7ykoDIEwSa3XWowTfNPo92BYxjXiyYEVrUQh2aLyhxBCwww+heortUFxEJYcRzosstTEBYQ=="], @@ -2182,6 +2204,8 @@ "@typescript-eslint/typescript-estree/minimatch/brace-expansion": ["brace-expansion@2.0.2", "", { "dependencies": { "balanced-match": "^1.0.0" } }, "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ=="], + "edulogitlens/motion/framer-motion": ["framer-motion@12.38.0", "", { "dependencies": { "motion-dom": "^12.38.0", "motion-utils": "^12.36.0", "tslib": "^2.4.0" }, "peerDependencies": { "@emotion/is-prop-valid": "*", "react": "^18.0.0 || ^19.0.0", "react-dom": "^18.0.0 || ^19.0.0" }, "optionalPeers": ["@emotion/is-prop-valid", "react", "react-dom"] }, "sha512-rFYkY/pigbcswl1XQSb7q424kSTQ8q6eAC+YUsSKooHQYuLdzdHjrt6uxUC+PRAO++q5IS7+TamgIw1AphxR+g=="], + "form-data/mime-types/mime-db": ["mime-db@1.52.0", "", {}, "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg=="], "glob/minimatch/brace-expansion": ["brace-expansion@2.0.2", "", { "dependencies": { "balanced-match": "^1.0.0" } }, "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ=="], @@ -2204,6 +2228,10 @@ "@argos-ci/core/sharp/@img/sharp-wasm32/@emnapi/runtime": ["@emnapi/runtime@1.10.0", "", { "dependencies": { "tslib": "^2.4.0" } }, "sha512-ewvYlk86xUoGI0zQRNq/mC+16R1QeDlKQy21Ki3oSYXNgLb45GV1P6A0M+/s6nyCuNDqe5VpaY84BzXGwVbwFA=="], + "edulogitlens/motion/framer-motion/motion-dom": ["motion-dom@12.38.0", "", { "dependencies": { "motion-utils": "^12.36.0" } }, "sha512-pdkHLD8QYRp8VfiNLb8xIBJis1byQ9gPT3Jnh2jqfFtAsWUA3dEepDlsWe/xMpO8McV+VdpKVcp+E+TGJEtOoA=="], + + "edulogitlens/motion/framer-motion/motion-utils": ["motion-utils@12.36.0", "", {}, "sha512-eHWisygbiwVvf6PZ1vhaHCLamvkSbPIeAYxWUuL3a2PD/TROgE7FvfHWTIH4vMl798QLfMw15nRqIaRDXTlYRg=="], + "rimraf/glob/path-scurry/lru-cache": ["lru-cache@11.2.2", "", {}, "sha512-F9ODfyqML2coTIsQpSkRHnLSZMtkU8Q+mSfcaIyKwy58u+8k5nvAYeiNhsyMARvzNcXJ9QfWVrcPsC9e9rAxtg=="], } } diff --git a/workbench/_web/instrumentation.ts b/workbench/_web/instrumentation.ts index ba38ab9f..258d4907 100644 --- a/workbench/_web/instrumentation.ts +++ b/workbench/_web/instrumentation.ts @@ -6,7 +6,7 @@ export function register() { export const onRequestError = async (err: Error, request: NextRequest, context: NextResponse) => { if (process.env.NEXT_RUNTIME === "nodejs") { - const { getPostHogServer } = require("./src/lib/posthog-server"); + const { getPostHogServer } = await import("./src/lib/posthog-server"); const posthog = await getPostHogServer(); // Only capture exception if PostHog is enabled diff --git a/workbench/_web/next.config.js b/workbench/_web/next.config.js index ed00fba2..5f7d1d7a 100644 --- a/workbench/_web/next.config.js +++ b/workbench/_web/next.config.js @@ -8,7 +8,7 @@ const __dirname = path.dirname(fileURLToPath(import.meta.url)); const nextConfig = { reactStrictMode: true, - transpilePackages: ["nnsightful"], + transpilePackages: ["nnsightful", "edulogitlens"], turbopack: { // Expand root so Turbopack can resolve the symlinked nnsightful package // which lives at ../../nnsightful (outside the default _web/ root) diff --git a/workbench/_web/package.json b/workbench/_web/package.json index 5d2b1a74..6fffae2e 100644 --- a/workbench/_web/package.json +++ b/workbench/_web/package.json @@ -46,6 +46,7 @@ "d3-delaunay": "^6.0.4", "dotenv": "^17.2.1", "drizzle-orm": "^0.44.4", + "edulogitlens": "github:jon-bell/edulogitlens#cm-only", "framer-motion": "^12.23.22", "html-to-image": "^1.11.13", "lexical": "^0.34.0", diff --git a/workbench/_web/src/app/dev/logit-lens-intro/page.tsx b/workbench/_web/src/app/dev/logit-lens-intro/page.tsx new file mode 100644 index 00000000..fbaa60c7 --- /dev/null +++ b/workbench/_web/src/app/dev/logit-lens-intro/page.tsx @@ -0,0 +1,104 @@ +"use client"; + +import { LogitLensGrid } from "edulogitlens"; +import type { LogitLensData, LogitCell } from "edulogitlens"; + +function generateMockData(): LogitLensData { + const tokens = [ + "The", + "E", + "iff", + "el", + "Tower", + "is", + "in", + "the", + "city", + "of", + "Paris", + ",", + "France", + ".", + ]; + + const layers = Array.from({ length: 12 }, (_, i) => i); + + const vocab = [ + "t", + "bow", + "illi", + "Tower", + "el", + "France", + "Paris", + "tower", + "city", + "of", + "the", + "in", + "is", + "a", + "and", + "Eiff", + "to", + "built", + "was", + "meters", + "at", + "by", + "from", + "with", + "on", + "for", + "an", + "stands", + "tall", + ]; + + const data: LogitCell[][] = tokens.map((token) => { + return layers.map((_, layerIdx) => { + const convergence = layerIdx / layers.length; + + let primaryToken = token; + let prob: number; + + if (convergence < 0.3) { + primaryToken = vocab[Math.floor(Math.random() * vocab.length)]; + prob = 0.05 + Math.random() * 0.15; + } else if (convergence < 0.6) { + primaryToken = + Math.random() > 0.5 ? token : vocab[Math.floor(Math.random() * vocab.length)]; + prob = 0.2 + Math.random() * 0.3; + } else { + primaryToken = token; + prob = 0.5 + convergence * 0.4 + Math.random() * 0.1; + } + prob = Math.min(prob, 0.95); + + const topTokens: { token: string; prob: number }[] = [{ token: primaryToken, prob }]; + let remaining = (1 - prob) * 0.4; + for (let i = 0; i < 14; i++) { + const candidate = vocab[Math.floor(Math.random() * vocab.length)]; + topTokens.push({ token: candidate, prob: remaining }); + remaining *= 0.7; + } + + return { token: primaryToken, probability: prob, topTokens }; + }); + }); + + return { tokens, layers, data }; +} + +const MOCK_DATA = generateMockData(); + +export default function DevLogitLensIntroPage() { + return ( +
+

Logit Lens Intro — Dev Preview

+
+ +
+
+ ); +} diff --git a/workbench/_web/src/app/workbench/[workspaceId]/activation-patching/[chartId]/components/TokenSelector.tsx b/workbench/_web/src/app/workbench/[workspaceId]/activation-patching/[chartId]/components/TokenSelector.tsx index 7d430824..8a463489 100644 --- a/workbench/_web/src/app/workbench/[workspaceId]/activation-patching/[chartId]/components/TokenSelector.tsx +++ b/workbench/_web/src/app/workbench/[workspaceId]/activation-patching/[chartId]/components/TokenSelector.tsx @@ -3,7 +3,14 @@ import { useMemo } from "react"; import { X, RotateCcw } from "lucide-react"; import { cn } from "@/lib/utils"; -import Select, { MultiValue, StylesConfig, GroupBase, components } from "react-select"; +import Select, { + MultiValue, + StylesConfig, + GroupBase, + components, + type MultiValueProps, + type OptionProps, +} from "react-select"; // Option type for react-select interface TokenOption { @@ -152,7 +159,7 @@ const selectStyles: StylesConfig> = { }; // Custom MultiValue component with colored indicator -const CustomMultiValue = (props: any) => { +const CustomMultiValue = (props: MultiValueProps>) => { const color = LINE_COLORS[props.data.value % LINE_COLORS.length]; return ( @@ -169,7 +176,9 @@ const CustomMultiValue = (props: any) => { onClick={(e) => { e.preventDefault(); e.stopPropagation(); - props.removeProps.onClick(e); + props.removeProps.onClick?.( + e as unknown as React.MouseEvent, + ); }} onMouseDown={(e) => { e.preventDefault(); @@ -184,7 +193,7 @@ const CustomMultiValue = (props: any) => { }; // Custom Option component with badges for source/target predictions -const CustomOption = (props: any) => { +const CustomOption = (props: OptionProps>) => { const tokenIndex = props.data.value; const badge = tokenIndex === 0 ? "source pred" : tokenIndex === 1 ? "target pred" : null; diff --git a/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/components/CMIntroArea.tsx b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/components/CMIntroArea.tsx new file mode 100644 index 00000000..32406977 --- /dev/null +++ b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/components/CMIntroArea.tsx @@ -0,0 +1,460 @@ +"use client"; + +import { useState, useEffect, useCallback, useMemo, useRef } from "react"; +import { useParams } from "next/navigation"; +import { useQuery } from "@tanstack/react-query"; +import { ModelSelector } from "@/components/ModelSelector"; +import { AlertCircle, Loader2, Play, X } from "lucide-react"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import { Button } from "@/components/ui/button"; +import { useTour } from "@reactour/tour"; +import { CMIntroTutorial } from "@/tutorials/cmIntro"; +import { useCMIntroTutorial, hydrateCMIntroTutorial } from "@/stores/useCMIntroTutorial"; +import { getModels } from "@/lib/api/modelsApi"; +import { useWorkspace } from "@/stores/useWorkspace"; +import { encodeText } from "@/actions/tok"; +import { TokenizerLoadError } from "@/actions/errors"; +import { Token } from "@/types/models"; +import { PatchPromptSection } from "@/components/activation-patching/toolkit"; +import { toast } from "sonner"; +import { useCMIntroLogitLens, CMIntroLensResult } from "@/lib/api/cmIntroApi"; +import type { LogitLensIntroData } from "@/types/logitLensIntro"; + +// Top-1 next-token at the final layer / final input position. Returns null +// if the lens data is empty/malformed. Mirrors the prediction the +// activation-patching tool surfaces in italics next to each prompt. +function getNextTokenPrediction(data: LogitLensIntroData | undefined): string | null { + if (!data) return null; + const { layers, input, tracked, topk } = data; + if (!layers?.length || !input?.length || !tracked?.length || !topk?.length) return null; + const finalLayerIdx = layers.length - 1; + const lastPosIdx = input.length - 1; + const candidates = topk[finalLayerIdx]?.[lastPosIdx]; + if (!candidates?.length) return null; + const posTracked = tracked[lastPosIdx]; + if (!posTracked) return candidates[0]; + let bestToken = candidates[0]; + let bestProb = posTracked[bestToken]?.[finalLayerIdx] ?? 0; + for (const token of candidates) { + const prob = posTracked[token]?.[finalLayerIdx] ?? 0; + if (prob > bestProb) { + bestProb = prob; + bestToken = token; + } + } + return bestToken; +} + +// Default model for the CM intro. 32 layers reads as a manageable heatmap; +// index 0 in the model list is the 70B (80 layers), far too many for a primer. +const DEFAULT_INTRO_MODEL = "meta-llama/Llama-3.1-8B"; + +interface CMIntroAreaProps { + sourcePrompt: string; + targetPrompt: string; + onSourcePromptChange: (value: string) => void; + onTargetPromptChange: (value: string) => void; + onLensResult?: (result: CMIntroLensResult, runSrc: string, runTgt: string) => void; + lensResult?: CMIntroLensResult | null; + lastRunSrcPrompt?: string | null; + lastRunTgtPrompt?: string | null; +} + +function useTutorialAutoStart() { + const { setSteps, setIsOpen, isOpen } = useTour(); + const { completed, markCompleted } = useCMIntroTutorial(); + // Auto-start fires at most once per mount; dismissing then resaving the + // localStorage flag prevents a popup loop (same pattern as lens-intro). + const autoStartedRef = useRef(false); + + useEffect(() => { + hydrateCMIntroTutorial(); + }, []); + + useEffect(() => { + if (autoStartedRef.current || completed || isOpen) return; + if (!setSteps || !setIsOpen) return; + autoStartedRef.current = true; + const steps = CMIntroTutorial.chapters[0]?.steps ?? []; + setSteps(steps); + const id = setTimeout(() => { + setIsOpen(true); + markCompleted(); + }, 600); + return () => clearTimeout(id); + }, [completed, isOpen, setSteps, setIsOpen, markCompleted]); + + const startTutorial = () => { + if (!setSteps || !setIsOpen) return; + const steps = CMIntroTutorial.chapters[0]?.steps ?? []; + setSteps(steps); + setIsOpen(true); + }; + + return { startTutorial }; +} + +export default function CMIntroArea({ + sourcePrompt, + targetPrompt, + onSourcePromptChange, + onTargetPromptChange, + onLensResult, + lensResult, + lastRunSrcPrompt, + lastRunTgtPrompt, +}: CMIntroAreaProps) { + const { chartId } = useParams<{ chartId: string }>(); + const { selectedModelIdx, setSelectedModelIdx } = useWorkspace(); + + const { data: models } = useQuery({ + queryKey: ["models"], + queryFn: getModels, + refetchInterval: 120000, + }); + + // Default to Llama-3.1-8B once when models load, rather than leaving the + // workspace default at index 0 (the 70B, 80 layers). Guarded so a later + // manual model choice is not overridden. + const didDefaultModel = useRef(false); + useEffect(() => { + if (didDefaultModel.current || !models || models.length === 0) return; + didDefaultModel.current = true; + const idx = models.findIndex((m) => m.name === DEFAULT_INTRO_MODEL); + if (idx !== -1 && idx !== selectedModelIdx) { + setSelectedModelIdx(idx); + } + }, [models, selectedModelIdx, setSelectedModelIdx]); + + const selectedModel = useMemo(() => { + if (!models || models.length === 0) return undefined; + return models[selectedModelIdx]?.name || models[0].name; + }, [models, selectedModelIdx]); + + // Source prompt state + const [srcTokens, setSrcTokens] = useState([]); + const [srcEditing, setSrcEditing] = useState(true); + const [srcTokenizedModel, setSrcTokenizedModel] = useState(null); + const srcTextareaRef = useRef(null); + const srcTokenContainerRef = useRef(null); + + // Target prompt state + const [tgtTokens, setTgtTokens] = useState([]); + const [tgtEditing, setTgtEditing] = useState(true); + const [tgtTokenizedModel, setTgtTokenizedModel] = useState(null); + const tgtTextareaRef = useRef(null); + const tgtTokenContainerRef = useRef(null); + + // Auto-focus the source prompt textarea on first mount so new users + // can start typing immediately. Only if it's empty (don't steal focus + // when revisiting a chart that already has prompts). + const didFocusRef = useRef(false); + useEffect(() => { + if (didFocusRef.current) return; + if (!sourcePrompt) { + srcTextareaRef.current?.focus(); + didFocusRef.current = true; + } + // eslint-disable-next-line react-hooks/exhaustive-deps + }, []); + + const tokenize = useCallback(async (text: string, model: string): Promise => { + try { + return await encodeText(text, model); + } catch (error) { + if (error instanceof TokenizerLoadError) { + toast.error( + `Could not load tokenizer for ${model}. The model may be gated and require authentication.`, + ); + } else { + toast.error("Failed to tokenize prompt."); + } + return null; + } + }, []); + + // Initial tokenize on mount when a model becomes available + useEffect(() => { + if (!selectedModel) return; + const run = async () => { + if (sourcePrompt) { + const tokens = await tokenize(sourcePrompt, selectedModel); + if (tokens && tokens.length > 0) { + setSrcTokens(tokens); + setSrcTokenizedModel(selectedModel); + setSrcEditing(false); + } + } + if (targetPrompt) { + const tokens = await tokenize(targetPrompt, selectedModel); + if (tokens && tokens.length > 0) { + setTgtTokens(tokens); + setTgtTokenizedModel(selectedModel); + setTgtEditing(false); + } + } + }; + run(); + // Only auto-tokenize when the model first becomes available + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [selectedModel]); + + const handleSrcBlur = useCallback(() => { + setTimeout(async () => { + const activeElement = document.activeElement; + const withinTextarea = activeElement && srcTextareaRef.current?.contains(activeElement); + const withinToken = + activeElement && srcTokenContainerRef.current?.contains(activeElement); + if (withinTextarea || withinToken) return; + + if (!sourcePrompt || !selectedModel) { + // Clear any stale tokens so the empty state actually reads as + // empty (otherwise the token display would show the previous + // tokenization after the user deletes the text). + setSrcTokens([]); + setSrcTokenizedModel(null); + setSrcEditing(true); + return; + } + const tokens = await tokenize(sourcePrompt, selectedModel); + if (tokens && tokens.length > 0) { + setSrcTokens(tokens); + setSrcTokenizedModel(selectedModel); + setSrcEditing(false); + } + }, 100); + }, [sourcePrompt, selectedModel, tokenize]); + + const handleTgtBlur = useCallback(() => { + setTimeout(async () => { + const activeElement = document.activeElement; + const withinTextarea = activeElement && tgtTextareaRef.current?.contains(activeElement); + const withinToken = + activeElement && tgtTokenContainerRef.current?.contains(activeElement); + if (withinTextarea || withinToken) return; + + if (!targetPrompt || !selectedModel) { + setTgtTokens([]); + setTgtTokenizedModel(null); + setTgtEditing(true); + return; + } + const tokens = await tokenize(targetPrompt, selectedModel); + if (tokens && tokens.length > 0) { + setTgtTokens(tokens); + setTgtTokenizedModel(selectedModel); + setTgtEditing(false); + } + }, 100); + }, [targetPrompt, selectedModel, tokenize]); + + // Explicit clear handlers so the user has a one-click way to reset a + // prompt back to empty (deleting characters in the textarea works too, + // but a button is more discoverable). + const handleClearSrc = useCallback(() => { + onSourcePromptChange(""); + setSrcTokens([]); + setSrcTokenizedModel(null); + setSrcEditing(true); + setTimeout(() => srcTextareaRef.current?.focus(), 0); + }, [onSourcePromptChange]); + + const handleClearTgt = useCallback(() => { + onTargetPromptChange(""); + setTgtTokens([]); + setTgtTokenizedModel(null); + setTgtEditing(true); + setTimeout(() => tgtTextareaRef.current?.focus(), 0); + }, [onTargetPromptChange]); + + const configModelUnavailable = + srcTokenizedModel && selectedModel && srcTokenizedModel !== selectedModel + ? srcTokenizedModel + : null; + + // Predicted next-token from the last lens run. Hidden when the prompt + // currently in the textarea no longer matches what the lens was run on. + const srcPrediction = useMemo(() => { + if (!lensResult?.source) return null; + if (lastRunSrcPrompt == null || sourcePrompt !== lastRunSrcPrompt) return null; + return getNextTokenPrediction(lensResult.source); + }, [lensResult, sourcePrompt, lastRunSrcPrompt]); + + const tgtPrediction = useMemo(() => { + if (!lensResult?.target) return null; + if (lastRunTgtPrompt == null || targetPrompt !== lastRunTgtPrompt) return null; + return getNextTokenPrediction(lensResult.target); + }, [lensResult, targetPrompt, lastRunTgtPrompt]); + + const { mutateAsync: runLogitLens, isPending: isRunning } = useCMIntroLogitLens(); + + // Target is optional: when blank, CM Intro runs in single-prompt mode and + // only computes the source lens. The widget hides the target heatmap and + // disables drag-and-drop patching in that mode. + const canRun = !!selectedModel && !!sourcePrompt.trim() && !isRunning; + + const handleRun = useCallback(async () => { + if (!selectedModel) { + toast.error("Please select a model."); + return; + } + const src = sourcePrompt.trim(); + const tgt = targetPrompt.trim(); + if (!src) { + toast.error("Please enter a source prompt."); + return; + } + + if (!chartId) { + toast.error("Missing chart id."); + return; + } + + try { + const result = await runLogitLens({ + sourcePrompt: src, + targetPrompt: tgt, // empty string is fine — mutation skips the call + model: selectedModel, + chartId, + }); + onLensResult?.(result, src, tgt); + toast.success(tgt ? "Logit lens computed for both prompts." : "Logit lens computed."); + } catch (error) { + // Error toast handled by the mutation's onError. + } + }, [selectedModel, sourcePrompt, targetPrompt, runLogitLens, onLensResult, chartId]); + + const { startTutorial } = useTutorialAutoStart(); + + return ( +
+
+

CM Intro

+
+ {configModelUnavailable && ( + + + + + +

+ Tokens last computed with "{configModelUnavailable}". + Click a prompt and blur to retokenize. +

+
+
+ )} + + +
+
+ +
+
+ {sourcePrompt && ( + + )} + {}} + predictionToken={srcPrediction} + /> +

+ The prompt you want to steal state from + . Pick something with a clear, specific prediction — its internal + activations will be the source of the patch. +

+
+ +
+ {targetPrompt && ( + + )} + {}} + predictionToken={tgtPrediction} + /> +

+ The prompt you want to patch into. + Usually similar grammar but a different answer — any change in its + prediction after a patch reveals what the source state carried. Leave blank + to view the source prompt alone (no patching). +

+
+ + +
+
+ ); +} diff --git a/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/components/CMIntroDisplay.tsx b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/components/CMIntroDisplay.tsx new file mode 100644 index 00000000..55ffbb07 --- /dev/null +++ b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/components/CMIntroDisplay.tsx @@ -0,0 +1,240 @@ +"use client"; + +import { useCallback, useMemo } from "react"; +import { useParams } from "next/navigation"; +import { useQuery } from "@tanstack/react-query"; +import { getModels } from "@/lib/api/modelsApi"; +import { getChartById } from "@/lib/queries/chartQueries"; +import { queryKeys } from "@/lib/queryKeys"; +import { useWorkspace } from "@/stores/useWorkspace"; +import { CausalMediationExplorer } from "edulogitlens"; +import type { LogitLensData, LogitCell, Intervention } from "edulogitlens"; +import { CMIntroLensResult, useCMIntroIntervention } from "@/lib/api/cmIntroApi"; +import type { LogitLensIntroData } from "@/types/logitLensIntro"; +import type { CMIntroChartData } from "@/types/cmIntro"; + +function CMSkeleton({ message, showTarget }: { message: string; showTarget: boolean }) { + const SkeletonGrid = () => ( +
+
+
+ {Array.from({ length: 40 }).map((_, i) => ( +
+ ))} +
+
+ ); + + return ( +
+
+ {message} +
+
+ + {showTarget && } +
+
+ ); +} + +interface CMIntroDisplayProps { + sourcePrompt: string; + targetPrompt: string; + lensResult?: CMIntroLensResult | null; + // Snapshots of the prompts the ephemeral lensResult was actually computed + // for. Used to detect "user edited the prompts after the last run" so we + // can show a placeholder instead of a stale heatmap. + lastRunSrcPrompt?: string | null; + lastRunTgtPrompt?: string | null; +} + +/** + * nnsightful LogitLensData → edulogitlens LogitLensData. Mirrors the transform + * used in LogitLensIntroDisplay so the CM explorer sees the same cell shape. + */ +function transformToEduFormat(data: LogitLensIntroData): LogitLensData | undefined { + if (!data) return undefined; + + const raw = data as unknown as Record; + const input = raw.input as string[] | undefined; + const layers = raw.layers as number[] | undefined; + const tracked = raw.tracked as Record[] | undefined; + const topk = raw.topk as string[][][] | undefined; + + if (!input || !layers || !tracked || !topk) return undefined; + + // NOTE: do NOT strip the BOS token here. CM interventions send the clicked + // cell's token position to the backend, which indexes the BOS-inclusive + // tokenization absolutely (causal_mediation.py). Dropping position 0 would + // patch the wrong token. BOS-hiding for CM must happen in the widget while + // preserving absolute positions. + const cellData: LogitCell[][] = input.map((_, posIdx) => { + const posTracked = tracked[posIdx] ?? {}; + return layers.map((_, layerIdx) => { + const topTokenStrs = topk[layerIdx]?.[posIdx] ?? []; + const topTokens = topTokenStrs.map((t) => ({ + token: t, + prob: posTracked[t]?.[layerIdx] ?? 0, + })); + topTokens.sort((a, b) => b.prob - a.prob); + + const best = topTokens[0]; + return { + token: best?.token ?? "", + probability: best?.prob ?? 0, + topTokens, + }; + }); + }); + + return { tokens: input, layers, data: cellData }; +} + +export function CMIntroDisplay({ + sourcePrompt, + targetPrompt, + lensResult, + lastRunSrcPrompt, + lastRunTgtPrompt, +}: CMIntroDisplayProps) { + const { chartId } = useParams<{ chartId: string }>(); + const { selectedModelIdx } = useWorkspace(); + + const { data: models } = useQuery({ + queryKey: ["models"], + queryFn: getModels, + refetchInterval: 120000, + }); + + const selectedModel = useMemo(() => { + if (!models || models.length === 0) return undefined; + return models[selectedModelIdx]?.name || models[0].name; + }, [models, selectedModelIdx]); + + // Hydrate the persisted cm-intro chart row so revisiting the page restores the intervention result. + const { data: chart } = useQuery({ + queryKey: queryKeys.charts.chart(chartId as string), + queryFn: () => getChartById(chartId as string), + enabled: !!chartId, + }); + + // Source is always required; target may be absent in single-prompt mode. + const persistedData = useMemo(() => { + const raw = chart?.data as unknown; + if (!raw || typeof raw !== "object") return null; + const maybe = raw as Partial; + if (!maybe.source) return null; + return maybe as CMIntroChartData; + }, [chart]); + + const trimmedTarget = targetPrompt.trim(); + const targetExpected = trimmedTarget.length > 0; + + // "Live" data with its provenance (what prompts produced it). The ephemeral + // lensResult wins over persisted; we use the *RunPrompt snapshots to decide + // whether the current textareas still match what was run. + const liveSourceRaw = lensResult?.source ?? persistedData?.source; + const liveTargetRaw = lensResult?.target ?? persistedData?.target; + const liveSrcRun = + lensResult?.source != null + ? (lastRunSrcPrompt ?? null) + : (persistedData?.lastRunSourcePrompt ?? null); + const liveTgtRun = + lensResult?.source != null + ? (lastRunTgtPrompt ?? null) + : (persistedData?.lastRunTargetPrompt ?? null); + + const hasAnyData = !!liveSourceRaw; + const isStale = + hasAnyData && + ((liveSrcRun !== null && liveSrcRun !== sourcePrompt) || + (liveTgtRun !== null && liveTgtRun !== targetPrompt)); + // No source data yet, OR target was expected (two-prompt mode) but is missing. + const isMissingExpectedData = !liveSourceRaw || (targetExpected && !liveTargetRaw); + + const showPlaceholder = isStale || isMissingExpectedData; + + const sourceData = useMemo( + () => (showPlaceholder || !liveSourceRaw ? undefined : transformToEduFormat(liveSourceRaw)), + [showPlaceholder, liveSourceRaw], + ); + const targetData = useMemo( + () => (showPlaceholder || !liveTargetRaw ? undefined : transformToEduFormat(liveTargetRaw)), + [showPlaceholder, liveTargetRaw], + ); + + // Undefined (not null) when absent, so CausalMediationExplorer treats the + // result as uncontrolled and falls back to internal state populated by the + // handleIntervention promise. When a persisted result IS present, we pass + // it as a controlled override so revisits restore the UI. + const persistedResultData = useMemo(() => { + if (!persistedData?.result) return undefined; + return transformToEduFormat(persistedData.result); + }, [persistedData]); + + const { mutateAsync: runIntervention, isPending: isInterventionPending } = + useCMIntroIntervention(); + + const handleIntervention = useCallback( + async (i: Intervention): Promise => { + if (!chartId || !selectedModel) return null; + try { + const result = await runIntervention({ + model: selectedModel, + srcPrompt: sourcePrompt, + tgtPrompt: targetPrompt, + chartId, + intervention: { + srcTokenPos: i.sourceTokenPosition, + srcLayer: i.sourceLayer, + tgtTokenPos: i.targetTokenPosition, + tgtLayer: i.targetLayer, + }, + }); + return transformToEduFormat(result) ?? null; + } catch { + return null; + } + }, + [chartId, selectedModel, sourcePrompt, targetPrompt, runIntervention], + ); + + if (showPlaceholder) { + const message = isStale + ? "Prompts changed since the last run. Click Run to recompute the lens." + : "No analysis yet. Enter a prompt and click Run to compute the lens."; + return ( +
+ +
+ ); + } + + return ( +
+ +
+ ); +} diff --git a/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/layout.tsx b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/layout.tsx new file mode 100644 index 00000000..13389d10 --- /dev/null +++ b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/layout.tsx @@ -0,0 +1,10 @@ +import { CaptureProvider } from "@/components/providers/CaptureProvider"; +import { TooltipProvider } from "@/components/ui/tooltip"; + +export default function CMIntroChartLayout({ children }: { children: React.ReactNode }) { + return ( + + {children} + + ); +} diff --git a/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/page.tsx b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/page.tsx new file mode 100644 index 00000000..e250f430 --- /dev/null +++ b/workbench/_web/src/app/workbench/[workspaceId]/cm-intro/[chartId]/page.tsx @@ -0,0 +1,160 @@ +"use client"; + +import { useCallback, useEffect, useRef, useState } from "react"; +import { useParams } from "next/navigation"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { ResizableHandle, ResizablePanel, ResizablePanelGroup } from "@/components/ui/resizable"; +import ChartCardsSidebar from "../../components/ChartCardsSidebar"; +import CMIntroArea from "./components/CMIntroArea"; +import { CMIntroDisplay } from "./components/CMIntroDisplay"; +import { useIsMobile } from "@/hooks/useIsMobile"; +import { MobileSidebarDrawer } from "../../components/MobileSidebarDrawer"; +import { MobileCollapsibleControls } from "../../components/MobileCollapsibleControls"; +import { GitBranch } from "lucide-react"; +import { CMIntroLensResult } from "@/lib/api/cmIntroApi"; +import { getChartById, setChartData } from "@/lib/queries/chartQueries"; +import { queryKeys } from "@/lib/queryKeys"; +import type { CMIntroChartData } from "@/types/cmIntro"; + +const PROMPT_AUTOSAVE_DEBOUNCE_MS = 600; + +export default function CMIntroChartPage() { + const isMobile = useIsMobile(); + const { chartId } = useParams<{ chartId: string }>(); + const queryClient = useQueryClient(); + const [sourcePrompt, setSourcePrompt] = useState(""); + const [targetPrompt, setTargetPrompt] = useState(""); + const [lensResult, setLensResult] = useState(null); + // Snapshot of the prompts the current lensResult was computed for. Used by + // CMIntroArea to gate the predicted-next-token hint so it disappears once + // the user starts editing. + const [lastRunSrcPrompt, setLastRunSrcPrompt] = useState(null); + const [lastRunTgtPrompt, setLastRunTgtPrompt] = useState(null); + + const { data: chart } = useQuery({ + queryKey: queryKeys.charts.chart(chartId as string), + queryFn: () => getChartById(chartId as string), + enabled: !!chartId, + }); + + // hydratedRef gates autosave: we must absorb any persisted prompts before + // the autosave effect is allowed to write, otherwise the first render + // would clobber a stored prompt with the default placeholder. + const hydratedRef = useRef(false); + useEffect(() => { + if (hydratedRef.current) return; + if (chart === undefined) return; + const data = chart?.data as Partial | undefined; + if (typeof data?.sourcePrompt === "string") setSourcePrompt(data.sourcePrompt); + if (typeof data?.targetPrompt === "string") setTargetPrompt(data.targetPrompt); + if (data?.source && data?.target) { + setLensResult({ source: data.source, target: data.target }); + } + if (typeof data?.lastRunSourcePrompt === "string") { + setLastRunSrcPrompt(data.lastRunSourcePrompt); + } + if (typeof data?.lastRunTargetPrompt === "string") { + setLastRunTgtPrompt(data.lastRunTargetPrompt); + } + hydratedRef.current = true; + }, [chart]); + + const handleLensResult = useCallback( + (result: CMIntroLensResult, runSrc: string, runTgt: string) => { + setLensResult(result); + setLastRunSrcPrompt(runSrc); + setLastRunTgtPrompt(runTgt); + }, + [], + ); + + // Autosave the prompts into the chart row so they survive navigation even + // if the user never runs the lens. Debounced to avoid a write per keystroke. + useEffect(() => { + if (!hydratedRef.current || !chartId) return; + const handle = setTimeout(async () => { + const existing = await getChartById(chartId); + const existingData = (existing?.data ?? {}) as Partial; + if ( + existingData.sourcePrompt === sourcePrompt && + existingData.targetPrompt === targetPrompt + ) { + return; + } + const merged: CMIntroChartData = { + ...existingData, + sourcePrompt, + targetPrompt, + }; + await setChartData(chartId, merged, "cm-intro"); + queryClient.invalidateQueries({ + queryKey: queryKeys.charts.chart(chartId), + }); + }, PROMPT_AUTOSAVE_DEBOUNCE_MS); + return () => clearTimeout(handle); + }, [sourcePrompt, targetPrompt, chartId, queryClient]); + + if (isMobile === undefined) return null; + + if (isMobile) { + return ( +
+ + + +
+ +
+ +
+ ); + } + + return ( +
+ +
+ + + + + + + + + +
+
+ ); +} diff --git a/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCard.tsx b/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCard.tsx index 5e47e38d..30b4c7c0 100644 --- a/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCard.tsx +++ b/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCard.tsx @@ -2,7 +2,15 @@ import React from "react"; import { useParams, useRouter } from "next/navigation"; -import { Grid3X3, ChartLine, Trash2, Copy, MoreVertical, GitBranch } from "lucide-react"; +import { + Grid3X3, + ChartLine, + Trash2, + Copy, + MoreVertical, + GitBranch, + GraduationCap, +} from "lucide-react"; import Image from "next/image"; import { ChartMetadata } from "@/types/charts"; import { cn } from "@/lib/utils"; @@ -40,6 +48,13 @@ export default function ChartCard({ metadata, handleDelete, canDelete }: ChartCa chart.chartType === "activation-patching" ) { router.push(`/workbench/${workspaceId}/activation-patching/${chart.id}`); + } else if ( + chart.toolType === "logit-lens-intro" || + chart.chartType === "logit-lens-intro" + ) { + router.push(`/workbench/${workspaceId}/logit-lens-intro/${chart.id}`); + } else if (chart.toolType === "cm-intro" || chart.chartType === "cm-intro") { + router.push(`/workbench/${workspaceId}/cm-intro/${chart.id}`); } else { router.push(`/workbench/${workspaceId}/${chart.id}`); } @@ -94,6 +109,20 @@ export default function ChartCard({ metadata, handleDelete, canDelete }: ChartCa Act. Patching ); + if (chartType === "logit-lens-intro") + return ( + + + LL Intro + + ); + if (chartType === "cm-intro") + return ( + + + CM Intro + + ); return ( diff --git a/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCardsSidebar.tsx b/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCardsSidebar.tsx index ca505caa..b95803b3 100644 --- a/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCardsSidebar.tsx +++ b/workbench/_web/src/app/workbench/[workspaceId]/components/ChartCardsSidebar.tsx @@ -5,6 +5,8 @@ import { getChartsMetadata } from "@/lib/queries/chartQueries"; import { useParams, useRouter } from "next/navigation"; import { useCreateLens2ChartPair, + useCreateLogitLensIntroChartPair, + useCreateCMIntroChartPair, useCreatePatchChartPair, useCreateActivationPatchingChartPair, useDeleteChart, @@ -29,6 +31,7 @@ import { FileText, Layers, GitBranch, + GraduationCap, } from "lucide-react"; import { Button } from "@/components/ui/button"; import { useCallback, useEffect, useMemo, useRef, useState } from "react"; @@ -64,6 +67,9 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b ); const { mutate: createLens2Pair, isPending: isCreatingLens2 } = useCreateLens2ChartPair(); + const { mutate: createLogitLensIntroPair, isPending: isCreatingLogitLensIntro } = + useCreateLogitLensIntroChartPair(); + const { mutate: createCMIntroPair, isPending: isCreatingCMIntro } = useCreateCMIntroChartPair(); const { mutate: createPatchPair, isPending: isCreatingPatch } = useCreatePatchChartPair(); const { mutate: createActivationPatchingPair, isPending: isCreatingActivationPatching } = useCreateActivationPatchingChartPair(); @@ -161,6 +167,10 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b router.push(`/workbench/${workspaceId}/lens2/${chartId}`); } else if (toolType === "activation-patching") { router.push(`/workbench/${workspaceId}/activation-patching/${chartId}`); + } else if (toolType === "logit-lens-intro") { + router.push(`/workbench/${workspaceId}/logit-lens-intro/${chartId}`); + } else if (toolType === "cm-intro") { + router.push(`/workbench/${workspaceId}/cm-intro/${chartId}`); } else { router.push(`/workbench/${workspaceId}/${chartId}`); } @@ -170,7 +180,9 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b router.push(`/workbench/${workspaceId}/overview/${documentId}`); }; - const handleCreate = (toolType: "lens2" | "patch" | "activation-patching") => { + const handleCreate = ( + toolType: "lens2" | "patch" | "activation-patching" | "logit-lens-intro" | "cm-intro", + ) => { if (toolType === "lens2") { createLens2Pair( { workspaceId: workspaceId as string }, @@ -180,6 +192,24 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b ); return; } + if (toolType === "logit-lens-intro") { + createLogitLensIntroPair( + { workspaceId: workspaceId as string }, + { + onSuccess: ({ chart }) => navigateToChart(chart.id, "logit-lens-intro"), + }, + ); + return; + } + if (toolType === "cm-intro") { + createCMIntroPair( + { workspaceId: workspaceId as string }, + { + onSuccess: ({ chart }) => navigateToChart(chart.id, "cm-intro"), + }, + ); + return; + } if (toolType === "activation-patching") { createActivationPatchingPair( { workspaceId: workspaceId as string }, @@ -245,7 +275,12 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b }; const isCreatingAny = - isCreatingLens2 || isCreatingPatch || isCreatingActivationPatching || isCreatingDocument; + isCreatingLens2 || + isCreatingLogitLensIntro || + isCreatingCMIntro || + isCreatingPatch || + isCreatingActivationPatching || + isCreatingDocument; const actionButtons = (
@@ -263,6 +298,34 @@ export default function ChartCardsSidebar({ fillWidth = false }: { fillWidth?: b )} Logit Lens + + + +