Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"dependencies": {}}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Run Prettier on the root manifest.

This one-line JSON is likely the source of the current formatting warning in CI.

Proposed fix
-{"dependencies": {}}
+{
+  "dependencies": {}
+}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
{"dependencies": {}}
{
"dependencies": {}
}
🧰 Tools
🪛 GitHub Actions: Run Checks / 0_Run Prettier.txt

[warning] 1-1: Prettier formatting check warning: code style issue found.

🪛 GitHub Actions: Run Checks / Run Prettier

[warning] Prettier reported a formatting warning.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@package.json` at line 1, package.json is not prettified; run the project's
Prettier formatting on the root manifest (package.json) to restore proper JSON
formatting according to the repository Prettier config so the CI formatting
check passes—apply Prettier to package.json and re-commit the updated file.

2 changes: 1 addition & 1 deletion scripts/api.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash

cd workbench
uvicorn _api.main:app --host 0.0.0.0 --port 8000 --reload
python -m uvicorn _api.main:app --host 0.0.0.0 --port 8000 --reload
20 changes: 17 additions & 3 deletions workbench/_api/main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
import logging
import os
import traceback
import anyio

from .routes import lens, patch, models, logit_lens, activation_patching
from .routes import lens, patch, models, logit_lens, activation_patching, causal_mediation
from .state import AppState

from dotenv import load_dotenv; load_dotenv()
from pathlib import Path
from dotenv import load_dotenv
load_dotenv(Path(__file__).resolve().parent / ".env")

# Configure logging
logging.basicConfig(
Expand Down Expand Up @@ -59,9 +63,19 @@ def fastapi_app():
app.include_router(lens, prefix="/lens")
app.include_router(logit_lens, prefix="/logit_lens")
app.include_router(activation_patching, prefix="/activation_patching")
app.include_router(causal_mediation, prefix="/causal_mediation", tags=["causal_mediation"])
app.include_router(patch, prefix="/patch")
app.include_router(models, prefix="/models")

@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
tb = traceback.format_exception(type(exc), exc, exc.__traceback__)
logging.error(f"Unhandled exception: {''.join(tb)}")
return JSONResponse(
status_code=500,
content={"detail": str(exc), "traceback": ''.join(tb[-3:])},
)
Comment on lines +70 to +77
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Do not return traceback details in production error responses.

Line 76 exposes internal exception content and stack trace to clients. That leaks implementation details and can expose sensitive data paths. Return a generic error message and keep full traceback only in server logs.

Suggested fix
     `@app.exception_handler`(Exception)
     async def global_exception_handler(request: Request, exc: Exception):
         tb = traceback.format_exception(type(exc), exc, exc.__traceback__)
         logging.error(f"Unhandled exception: {''.join(tb)}")
         return JSONResponse(
             status_code=500,
-            content={"detail": str(exc), "traceback": ''.join(tb[-3:])},
+            content={"detail": "Internal server error"},
         )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@workbench/_api/main.py` around lines 70 - 77, The global_exception_handler
currently logs the full traceback and returns it to clients; change it so
logging.error still records the full traceback (use traceback.format_exception
as now) but the JSONResponse from global_exception_handler returns only a
generic error message (e.g. {"detail":"Internal server error"}) and no traceback
or internal exception text; update the handler around the logging.error and
return JSONResponse (status_code=500) with a generic message and do not include
exc, tb, or exc.__traceback__ in the response body while keeping the existing
logging call in global_exception_handler.


app.state.m = AppState()

return app
Expand Down
10 changes: 9 additions & 1 deletion workbench/_api/routes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,17 @@
from .models import router as models
from .logit_lens import router as logit_lens
from .activation_patching import router as activation_patching
from .causal_mediation import router as causal_mediation

from nnsight import ndif
import nnsightful
ndif.register(nnsightful)

__all__ = ["lens", "patch", "models", "logit_lens", "activation_patching"]
__all__ = [
"lens",
"patch",
"models",
"logit_lens",
"activation_patching",
"causal_mediation",
]
214 changes: 214 additions & 0 deletions workbench/_api/routes/causal_mediation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from __future__ import annotations

from typing import Any

import torch
from fastapi import APIRouter, Depends
from pydantic import BaseModel

from ..auth import require_user_email
from ..data_models import NDIFResponse
from ..state import AppState, get_state

from nnsightful.types import LogitLensData

router = APIRouter()


class CausalMediationRequest(BaseModel):
model: str
src_prompt: str
tgt_prompt: str
src_token_pos: int
src_layer: int
tgt_token_pos: int
tgt_layer: int
topk: int = 5
include_entropy: bool = True
Comment on lines +18 to +27
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate user-controlled indices and topk before tensor indexing.

src_layer, tgt_layer, token positions, and topk are used directly in indexing and topk ops. Invalid values will throw runtime errors and return 500s instead of clear 4xx validation errors.

Suggested fix
 from fastapi import APIRouter, Depends
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
+from fastapi import HTTPException
@@
 class CausalMediationRequest(BaseModel):
@@
-    src_token_pos: int
-    src_layer: int
-    tgt_token_pos: int
-    tgt_layer: int
-    topk: int = 5
+    src_token_pos: int = Field(ge=0)
+    src_layer: int = Field(ge=0)
+    tgt_token_pos: int = Field(ge=0)
+    tgt_layer: int = Field(ge=0)
+    topk: int = Field(default=5, ge=1)
@@
 async def start_causal_mediation(
@@
 ):
     model = state[req.model]
+    if req.src_layer >= model.num_layers or req.tgt_layer >= model.num_layers:
+        raise HTTPException(status_code=422, detail="Layer index out of range")

Also applies to: 125-136, 181-183, 210-212

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@workbench/_api/routes/causal_mediation.py` around lines 18 - 27, Validate
user-controlled indices and topk before any tensor indexing: add Pydantic
validators on CausalMediationRequest for src_layer, tgt_layer, src_token_pos,
tgt_token_pos, and topk to ensure they are integers, non-negative, and within
sensible bounds (e.g., topk >=1 and <= a safe max), and/or perform explicit
checks at the start of the request handler that compute tensors (the places
using these fields around the code regions referenced) and raise a 400 HTTP
error with a clear message for invalid values; ensure checks compare layer
indices against the model's available layer count and token positions against
sequence length (or defer validation until you have the tensor shapes), and
clamp or reject out-of-range topk against the axis size before any torch/numpy
indexing or topk ops to prevent runtime exceptions.



class CausalMediationResponse(NDIFResponse):
"""Identical shape to LogitLensResponse so the frontend can reuse the
existing logit-lens transform/renderer."""
data: LogitLensData | None = None


def _format_lens(
logits: torch.Tensor,
tokenizer,
model_name: str,
input_tokens: list[str],
n_layers: int,
*,
top_k: int = 5,
include_entropy: bool = True,
) -> dict[str, Any]:
"""Inlined mirror of the local `format()` inside
`nnsightful.tools.logit_lens._run`. Turns a [L, T, V] logits tensor into
the dict shape consumed by `LogitLensData(**...)`.

Why: `LogitLensTool` doesn't override `_format`, so calling
`logit_lens_tool._format(...)` falls through to the abstract `Tool._format`
in `nnsightful/tools/_base.py` whose body is `...` — i.e. returns `None`.
"""
layers = list(range(n_layers))
positions = list(range(len(input_tokens)))

if include_entropy:
log_p = torch.nn.functional.log_softmax(logits, dim=-1)
p = log_p.exp()
entropy = torch.round(-(p * log_p).sum(dim=-1), decimals=3).tolist()
else:
entropy = None

probs = torch.nn.functional.softmax(logits, dim=-1)
logits.to("cpu") # free memory

_, top_indices = torch.topk(probs, k=top_k, dim=-1)
Comment on lines +64 to +67
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🌐 Web query:

In PyTorch, does tensor.to("cpu") mutate the tensor in place, or return a new tensor that must be assigned?

💡 Result:

In PyTorch, the tensor.to("cpu") method does not mutate the tensor in place [1]. Instead, it returns a new tensor that resides on the specified device [1][2]. Because the original tensor remains unchanged, you must assign the result of the operation to a variable if you wish to use the moved tensor [3][2]. For example: tensor = tensor.to("cpu") If the tensor is already on the target device and no copy is required, the method may return the original tensor itself, but it is still standard practice to assign the result to ensure your code handles all cases correctly [1].

Citations:


🏁 Script executed:

# Find and examine the file
fd "causal_mediation.py" --type f

Repository: ndif-team/workbench

Length of output: 105


🏁 Script executed:

# Read the file around lines 64-67
cat -n workbench/_api/routes/causal_mediation.py | sed -n '55,80p'

Repository: ndif-team/workbench

Length of output: 1033


🏁 Script executed:

# Check if logits is referenced after line 65
cat -n workbench/_api/routes/causal_mediation.py | sed -n '64,100p'

Repository: ndif-team/workbench

Length of output: 1357


🏁 Script executed:

# Check the broader function context to see if logits is used elsewhere or has special handling
cat -n workbench/_api/routes/causal_mediation.py | sed -n '40,70p'

Repository: ndif-team/workbench

Length of output: 1382


logits.to("cpu") is a no-op here.

Line 65 does not mutate logits; it returns a new tensor that must be assigned. Since logits is never used after this line, the intent to free memory is not achieved. If memory pressure is a concern, move probs instead, which is actually used in subsequent operations:

Suggested fix
-    probs = torch.nn.functional.softmax(logits, dim=-1)
-    logits.to("cpu")  # free memory
+    probs = torch.nn.functional.softmax(logits, dim=-1).cpu()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
probs = torch.nn.functional.softmax(logits, dim=-1)
logits.to("cpu") # free memory
_, top_indices = torch.topk(probs, k=top_k, dim=-1)
probs = torch.nn.functional.softmax(logits, dim=-1).cpu()
_, top_indices = torch.topk(probs, k=top_k, dim=-1)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@workbench/_api/routes/causal_mediation.py` around lines 64 - 67, The call
logits.to("cpu") is a no-op because .to() returns a new tensor; assign the
result or, as suggested, move the actual tensor used later: replace the
standalone logits.to("cpu") with either logits = logits.to("cpu") or,
preferably, move probs instead by assigning probs = probs.to("cpu") after
computing probs = torch.nn.functional.softmax(logits, dim=-1), and then delete
or let logits go out of scope (e.g., del logits) to free GPU memory before
calling torch.topk on probs to compute top_indices.


topks = [
[tokenizer.batch_decode(torch.tensor(pos).unsqueeze(dim=1)) for pos in layer]
for layer in top_indices.tolist()
]

unique_indices = [
torch.unique(top_indices[:, pi, :].flatten(), sorted=False).tolist()
for pi in range(top_indices.shape[1])
]
probs = probs.permute(1, 2, 0)
trajectories = [
{
tokenizer.decode(token): torch.round(probs[pos_idx][token], decimals=3).tolist()
for token in pos
}
for pos_idx, pos in enumerate(unique_indices)
]

return {
"meta": {"version": 2, "timestamp": "3h", "model": model_name},
"layers": layers,
"input": input_tokens,
"tracked": trajectories,
"topk": topks,
"entropy": entropy,
"positions": positions,
}


def _run_causal_mediation(
model,
src_prompt: str,
tgt_prompt: str,
src_token_pos: int,
src_layer: int,
tgt_token_pos: int,
tgt_layer: int,
*,
remote: bool = False,
backend=None,
) -> dict[str, Any]:
"""Capture the source residual at (src_layer, src_token_pos), patch it
into the target prompt at (tgt_layer, tgt_token_pos), then run a logit
lens over the *patched* forward pass.

Patching pattern mirrors `nnsightful.tools.activation_patching._run`:
a slice-assign on `model.layers_output[i]` is sufficient to register the
intervention and have it propagate through subsequent layers — no
explicit `model.layers[i].output = (...)` reassignment is needed.
"""
n_layers = model.num_layers

with torch.no_grad():
with model.session(remote=remote, backend=backend):
# 1) Source pass — capture the residual at (src_layer, src_token_pos).
with model.trace(src_prompt):
src_hidden = model.layers_output[src_layer][0, src_token_pos].save()

# 2) Target pass — at tgt_layer, slice-assign the source residual
# into the target's tgt_token_pos. project_on_vocab at every
# layer gives us a logit-lens grid over the patched pass.
with model.trace(tgt_prompt):
per_layer_logits = []
for i in range(n_layers):
hs = model.layers_output[i]
if i == tgt_layer:
hs[0, tgt_token_pos][:] = src_hidden
per_layer_logits.append(model.project_on_vocab(hs))
# Stack to a single [L, T, V] tensor so backend() returns a
# known shape on the remote path (one saved key: "logits").
logits = torch.cat(per_layer_logits, dim=0).save()

if remote and backend is not None:
return {"job_id": backend.job_id}

return {"logits": logits}


@router.post("/start", response_model=CausalMediationResponse)
async def start_causal_mediation(
req: CausalMediationRequest,
state: AppState = Depends(get_state),
user_email: str = Depends(require_user_email),
):
model = state[req.model]
backend = state.make_backend(model=model)

raw = _run_causal_mediation(
model,
req.src_prompt,
req.tgt_prompt,
req.src_token_pos,
req.src_layer,
req.tgt_token_pos,
req.tgt_layer,
remote=state.remote,
backend=backend,
)

if "job_id" in raw:
return {"job_id": raw["job_id"]}

input_tokens = [
str(model.tokenizer.decode(token))
for token in model.tokenizer.encode(req.tgt_prompt)
]
data = _format_lens(
raw["logits"],
tokenizer=model.tokenizer,
model_name=req.model,
input_tokens=input_tokens,
n_layers=model.num_layers,
top_k=req.topk,
include_entropy=req.include_entropy,
)
return {"data": data}


@router.post("/results/{job_id}", response_model=CausalMediationResponse)
async def collect_causal_mediation(
job_id: str,
req: CausalMediationRequest,
state: AppState = Depends(get_state),
user_email: str = Depends(require_user_email),
):
backend = state.make_backend(job_id=job_id)
results = backend()

model = state[req.model]
tokenizer = model.tokenizer
input_tokens = [
str(tokenizer.decode(token))
for token in tokenizer.encode(req.tgt_prompt)
]

data = _format_lens(
results["logits"],
tokenizer=tokenizer,
model_name=req.model,
input_tokens=input_tokens,
n_layers=model.num_layers,
top_k=req.topk,
include_entropy=req.include_entropy,
)

return {"data": data}
27 changes: 27 additions & 0 deletions workbench/_web/bun.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion workbench/_web/next.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const __dirname = path.dirname(fileURLToPath(import.meta.url));

const nextConfig = {
reactStrictMode: true,
transpilePackages: ["nnsightful"],
transpilePackages: ["nnsightful", "edulogitlens"],
turbopack: {
// Expand root so Turbopack can resolve the symlinked nnsightful package
// which lives at ../../nnsightful (outside the default _web/ root)
Expand Down
Loading
Loading