-
Notifications
You must be signed in to change notification settings - Fork 10
Edu logit lens #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Edu logit lens #114
Changes from 6 commits
e65ed88
80ab233
b93d528
966f257
be42273
5bd087f
d72b9df
c41dac6
ec5c459
553944a
353ed72
1e2fceb
f5d9408
b577959
7523abf
be6348c
c80eced
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| {"dependencies": {}} | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| #!/bin/bash | ||
|
|
||
| cd workbench | ||
| uvicorn _api.main:app --host 0.0.0.0 --port 8000 --reload | ||
| python -m uvicorn _api.main:app --host 0.0.0.0 --port 8000 --reload |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,13 +1,17 @@ | ||
| from fastapi import FastAPI | ||
| from fastapi import FastAPI, Request | ||
| from fastapi.middleware.cors import CORSMiddleware | ||
| from fastapi.responses import JSONResponse | ||
| import logging | ||
| import os | ||
| import traceback | ||
| import anyio | ||
|
|
||
| from .routes import lens, patch, models, logit_lens, activation_patching | ||
| from .routes import lens, patch, models, logit_lens, activation_patching, causal_mediation | ||
| from .state import AppState | ||
|
|
||
| from dotenv import load_dotenv; load_dotenv() | ||
| from pathlib import Path | ||
| from dotenv import load_dotenv | ||
| load_dotenv(Path(__file__).resolve().parent / ".env") | ||
|
|
||
| # Configure logging | ||
| logging.basicConfig( | ||
|
|
@@ -59,9 +63,19 @@ def fastapi_app(): | |
| app.include_router(lens, prefix="/lens") | ||
| app.include_router(logit_lens, prefix="/logit_lens") | ||
| app.include_router(activation_patching, prefix="/activation_patching") | ||
| app.include_router(causal_mediation, prefix="/causal_mediation", tags=["causal_mediation"]) | ||
| app.include_router(patch, prefix="/patch") | ||
| app.include_router(models, prefix="/models") | ||
|
|
||
| @app.exception_handler(Exception) | ||
| async def global_exception_handler(request: Request, exc: Exception): | ||
| tb = traceback.format_exception(type(exc), exc, exc.__traceback__) | ||
| logging.error(f"Unhandled exception: {''.join(tb)}") | ||
| return JSONResponse( | ||
| status_code=500, | ||
| content={"detail": str(exc), "traceback": ''.join(tb[-3:])}, | ||
| ) | ||
|
Comment on lines
+70
to
+77
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not return traceback details in production error responses. Line 76 exposes internal exception content and stack trace to clients. That leaks implementation details and can expose sensitive data paths. Return a generic error message and keep full traceback only in server logs. Suggested fix `@app.exception_handler`(Exception)
async def global_exception_handler(request: Request, exc: Exception):
tb = traceback.format_exception(type(exc), exc, exc.__traceback__)
logging.error(f"Unhandled exception: {''.join(tb)}")
return JSONResponse(
status_code=500,
- content={"detail": str(exc), "traceback": ''.join(tb[-3:])},
+ content={"detail": "Internal server error"},
)🤖 Prompt for AI Agents |
||
|
|
||
| app.state.m = AppState() | ||
|
|
||
| return app | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,214 @@ | ||||||||||||||||
| from __future__ import annotations | ||||||||||||||||
|
|
||||||||||||||||
| from typing import Any | ||||||||||||||||
|
|
||||||||||||||||
| import torch | ||||||||||||||||
| from fastapi import APIRouter, Depends | ||||||||||||||||
| from pydantic import BaseModel | ||||||||||||||||
|
|
||||||||||||||||
| from ..auth import require_user_email | ||||||||||||||||
| from ..data_models import NDIFResponse | ||||||||||||||||
| from ..state import AppState, get_state | ||||||||||||||||
|
|
||||||||||||||||
| from nnsightful.types import LogitLensData | ||||||||||||||||
|
|
||||||||||||||||
| router = APIRouter() | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| class CausalMediationRequest(BaseModel): | ||||||||||||||||
| model: str | ||||||||||||||||
| src_prompt: str | ||||||||||||||||
| tgt_prompt: str | ||||||||||||||||
| src_token_pos: int | ||||||||||||||||
| src_layer: int | ||||||||||||||||
| tgt_token_pos: int | ||||||||||||||||
| tgt_layer: int | ||||||||||||||||
| topk: int = 5 | ||||||||||||||||
| include_entropy: bool = True | ||||||||||||||||
|
Comment on lines
+18
to
+27
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate user-controlled indices and
Suggested fix from fastapi import APIRouter, Depends
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
+from fastapi import HTTPException
@@
class CausalMediationRequest(BaseModel):
@@
- src_token_pos: int
- src_layer: int
- tgt_token_pos: int
- tgt_layer: int
- topk: int = 5
+ src_token_pos: int = Field(ge=0)
+ src_layer: int = Field(ge=0)
+ tgt_token_pos: int = Field(ge=0)
+ tgt_layer: int = Field(ge=0)
+ topk: int = Field(default=5, ge=1)
@@
async def start_causal_mediation(
@@
):
model = state[req.model]
+ if req.src_layer >= model.num_layers or req.tgt_layer >= model.num_layers:
+ raise HTTPException(status_code=422, detail="Layer index out of range")Also applies to: 125-136, 181-183, 210-212 🤖 Prompt for AI Agents |
||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| class CausalMediationResponse(NDIFResponse): | ||||||||||||||||
| """Identical shape to LogitLensResponse so the frontend can reuse the | ||||||||||||||||
| existing logit-lens transform/renderer.""" | ||||||||||||||||
| data: LogitLensData | None = None | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| def _format_lens( | ||||||||||||||||
| logits: torch.Tensor, | ||||||||||||||||
| tokenizer, | ||||||||||||||||
| model_name: str, | ||||||||||||||||
| input_tokens: list[str], | ||||||||||||||||
| n_layers: int, | ||||||||||||||||
| *, | ||||||||||||||||
| top_k: int = 5, | ||||||||||||||||
| include_entropy: bool = True, | ||||||||||||||||
| ) -> dict[str, Any]: | ||||||||||||||||
| """Inlined mirror of the local `format()` inside | ||||||||||||||||
| `nnsightful.tools.logit_lens._run`. Turns a [L, T, V] logits tensor into | ||||||||||||||||
| the dict shape consumed by `LogitLensData(**...)`. | ||||||||||||||||
|
|
||||||||||||||||
| Why: `LogitLensTool` doesn't override `_format`, so calling | ||||||||||||||||
| `logit_lens_tool._format(...)` falls through to the abstract `Tool._format` | ||||||||||||||||
| in `nnsightful/tools/_base.py` whose body is `...` — i.e. returns `None`. | ||||||||||||||||
| """ | ||||||||||||||||
| layers = list(range(n_layers)) | ||||||||||||||||
| positions = list(range(len(input_tokens))) | ||||||||||||||||
|
|
||||||||||||||||
| if include_entropy: | ||||||||||||||||
| log_p = torch.nn.functional.log_softmax(logits, dim=-1) | ||||||||||||||||
| p = log_p.exp() | ||||||||||||||||
| entropy = torch.round(-(p * log_p).sum(dim=-1), decimals=3).tolist() | ||||||||||||||||
| else: | ||||||||||||||||
| entropy = None | ||||||||||||||||
|
|
||||||||||||||||
| probs = torch.nn.functional.softmax(logits, dim=-1) | ||||||||||||||||
| logits.to("cpu") # free memory | ||||||||||||||||
|
|
||||||||||||||||
| _, top_indices = torch.topk(probs, k=top_k, dim=-1) | ||||||||||||||||
|
Comment on lines
+64
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🌐 Web query:
💡 Result: In PyTorch, the Citations:
🏁 Script executed: # Find and examine the file
fd "causal_mediation.py" --type fRepository: ndif-team/workbench Length of output: 105 🏁 Script executed: # Read the file around lines 64-67
cat -n workbench/_api/routes/causal_mediation.py | sed -n '55,80p'Repository: ndif-team/workbench Length of output: 1033 🏁 Script executed: # Check if logits is referenced after line 65
cat -n workbench/_api/routes/causal_mediation.py | sed -n '64,100p'Repository: ndif-team/workbench Length of output: 1357 🏁 Script executed: # Check the broader function context to see if logits is used elsewhere or has special handling
cat -n workbench/_api/routes/causal_mediation.py | sed -n '40,70p'Repository: ndif-team/workbench Length of output: 1382
Line 65 does not mutate Suggested fix- probs = torch.nn.functional.softmax(logits, dim=-1)
- logits.to("cpu") # free memory
+ probs = torch.nn.functional.softmax(logits, dim=-1).cpu()📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||
|
|
||||||||||||||||
| topks = [ | ||||||||||||||||
| [tokenizer.batch_decode(torch.tensor(pos).unsqueeze(dim=1)) for pos in layer] | ||||||||||||||||
| for layer in top_indices.tolist() | ||||||||||||||||
| ] | ||||||||||||||||
|
|
||||||||||||||||
| unique_indices = [ | ||||||||||||||||
| torch.unique(top_indices[:, pi, :].flatten(), sorted=False).tolist() | ||||||||||||||||
| for pi in range(top_indices.shape[1]) | ||||||||||||||||
| ] | ||||||||||||||||
| probs = probs.permute(1, 2, 0) | ||||||||||||||||
| trajectories = [ | ||||||||||||||||
| { | ||||||||||||||||
| tokenizer.decode(token): torch.round(probs[pos_idx][token], decimals=3).tolist() | ||||||||||||||||
| for token in pos | ||||||||||||||||
| } | ||||||||||||||||
| for pos_idx, pos in enumerate(unique_indices) | ||||||||||||||||
| ] | ||||||||||||||||
|
|
||||||||||||||||
| return { | ||||||||||||||||
| "meta": {"version": 2, "timestamp": "3h", "model": model_name}, | ||||||||||||||||
| "layers": layers, | ||||||||||||||||
| "input": input_tokens, | ||||||||||||||||
| "tracked": trajectories, | ||||||||||||||||
| "topk": topks, | ||||||||||||||||
| "entropy": entropy, | ||||||||||||||||
| "positions": positions, | ||||||||||||||||
| } | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| def _run_causal_mediation( | ||||||||||||||||
| model, | ||||||||||||||||
| src_prompt: str, | ||||||||||||||||
| tgt_prompt: str, | ||||||||||||||||
| src_token_pos: int, | ||||||||||||||||
| src_layer: int, | ||||||||||||||||
| tgt_token_pos: int, | ||||||||||||||||
| tgt_layer: int, | ||||||||||||||||
| *, | ||||||||||||||||
| remote: bool = False, | ||||||||||||||||
| backend=None, | ||||||||||||||||
| ) -> dict[str, Any]: | ||||||||||||||||
| """Capture the source residual at (src_layer, src_token_pos), patch it | ||||||||||||||||
| into the target prompt at (tgt_layer, tgt_token_pos), then run a logit | ||||||||||||||||
| lens over the *patched* forward pass. | ||||||||||||||||
|
|
||||||||||||||||
| Patching pattern mirrors `nnsightful.tools.activation_patching._run`: | ||||||||||||||||
| a slice-assign on `model.layers_output[i]` is sufficient to register the | ||||||||||||||||
| intervention and have it propagate through subsequent layers — no | ||||||||||||||||
| explicit `model.layers[i].output = (...)` reassignment is needed. | ||||||||||||||||
| """ | ||||||||||||||||
| n_layers = model.num_layers | ||||||||||||||||
|
|
||||||||||||||||
| with torch.no_grad(): | ||||||||||||||||
| with model.session(remote=remote, backend=backend): | ||||||||||||||||
| # 1) Source pass — capture the residual at (src_layer, src_token_pos). | ||||||||||||||||
| with model.trace(src_prompt): | ||||||||||||||||
| src_hidden = model.layers_output[src_layer][0, src_token_pos].save() | ||||||||||||||||
|
|
||||||||||||||||
| # 2) Target pass — at tgt_layer, slice-assign the source residual | ||||||||||||||||
| # into the target's tgt_token_pos. project_on_vocab at every | ||||||||||||||||
| # layer gives us a logit-lens grid over the patched pass. | ||||||||||||||||
| with model.trace(tgt_prompt): | ||||||||||||||||
| per_layer_logits = [] | ||||||||||||||||
| for i in range(n_layers): | ||||||||||||||||
| hs = model.layers_output[i] | ||||||||||||||||
| if i == tgt_layer: | ||||||||||||||||
| hs[0, tgt_token_pos][:] = src_hidden | ||||||||||||||||
| per_layer_logits.append(model.project_on_vocab(hs)) | ||||||||||||||||
| # Stack to a single [L, T, V] tensor so backend() returns a | ||||||||||||||||
| # known shape on the remote path (one saved key: "logits"). | ||||||||||||||||
| logits = torch.cat(per_layer_logits, dim=0).save() | ||||||||||||||||
|
|
||||||||||||||||
| if remote and backend is not None: | ||||||||||||||||
| return {"job_id": backend.job_id} | ||||||||||||||||
|
|
||||||||||||||||
| return {"logits": logits} | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| @router.post("/start", response_model=CausalMediationResponse) | ||||||||||||||||
| async def start_causal_mediation( | ||||||||||||||||
| req: CausalMediationRequest, | ||||||||||||||||
| state: AppState = Depends(get_state), | ||||||||||||||||
| user_email: str = Depends(require_user_email), | ||||||||||||||||
| ): | ||||||||||||||||
| model = state[req.model] | ||||||||||||||||
| backend = state.make_backend(model=model) | ||||||||||||||||
|
|
||||||||||||||||
| raw = _run_causal_mediation( | ||||||||||||||||
| model, | ||||||||||||||||
| req.src_prompt, | ||||||||||||||||
| req.tgt_prompt, | ||||||||||||||||
| req.src_token_pos, | ||||||||||||||||
| req.src_layer, | ||||||||||||||||
| req.tgt_token_pos, | ||||||||||||||||
| req.tgt_layer, | ||||||||||||||||
| remote=state.remote, | ||||||||||||||||
| backend=backend, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| if "job_id" in raw: | ||||||||||||||||
| return {"job_id": raw["job_id"]} | ||||||||||||||||
|
|
||||||||||||||||
| input_tokens = [ | ||||||||||||||||
| str(model.tokenizer.decode(token)) | ||||||||||||||||
| for token in model.tokenizer.encode(req.tgt_prompt) | ||||||||||||||||
| ] | ||||||||||||||||
| data = _format_lens( | ||||||||||||||||
| raw["logits"], | ||||||||||||||||
| tokenizer=model.tokenizer, | ||||||||||||||||
| model_name=req.model, | ||||||||||||||||
| input_tokens=input_tokens, | ||||||||||||||||
| n_layers=model.num_layers, | ||||||||||||||||
| top_k=req.topk, | ||||||||||||||||
| include_entropy=req.include_entropy, | ||||||||||||||||
| ) | ||||||||||||||||
| return {"data": data} | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| @router.post("/results/{job_id}", response_model=CausalMediationResponse) | ||||||||||||||||
| async def collect_causal_mediation( | ||||||||||||||||
| job_id: str, | ||||||||||||||||
| req: CausalMediationRequest, | ||||||||||||||||
| state: AppState = Depends(get_state), | ||||||||||||||||
| user_email: str = Depends(require_user_email), | ||||||||||||||||
| ): | ||||||||||||||||
| backend = state.make_backend(job_id=job_id) | ||||||||||||||||
| results = backend() | ||||||||||||||||
|
|
||||||||||||||||
| model = state[req.model] | ||||||||||||||||
| tokenizer = model.tokenizer | ||||||||||||||||
| input_tokens = [ | ||||||||||||||||
| str(tokenizer.decode(token)) | ||||||||||||||||
| for token in tokenizer.encode(req.tgt_prompt) | ||||||||||||||||
| ] | ||||||||||||||||
|
|
||||||||||||||||
| data = _format_lens( | ||||||||||||||||
| results["logits"], | ||||||||||||||||
| tokenizer=tokenizer, | ||||||||||||||||
| model_name=req.model, | ||||||||||||||||
| input_tokens=input_tokens, | ||||||||||||||||
| n_layers=model.num_layers, | ||||||||||||||||
| top_k=req.topk, | ||||||||||||||||
| include_entropy=req.include_entropy, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| return {"data": data} | ||||||||||||||||
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Run Prettier on the root manifest.
This one-line JSON is likely the source of the current formatting warning in CI.
Proposed fix
📝 Committable suggestion
🧰 Tools
🪛 GitHub Actions: Run Checks / 0_Run Prettier.txt
[warning] 1-1: Prettier formatting check warning: code style issue found.
🪛 GitHub Actions: Run Checks / Run Prettier
[warning] Prettier reported a formatting warning.
🤖 Prompt for AI Agents