Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions workbench/_api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

if os.environ.get('CONFIG') != "prod":
ALLOWED_ORIGINS.append("http://localhost:3000")
ALLOWED_ORIGINS.append("http://127.0.0.1:3000")

ALLOWED_ORIGIN_REGEX = (
r"^https://workbench-[a-z0-9\-]*-ndif\.vercel\.app$" # dev/staging previews
Expand Down
81 changes: 37 additions & 44 deletions workbench/_api/routes/activation_patching.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from fastapi import APIRouter, Request, Depends
from typing import List, Union
from pydantic import BaseModel

from ..data_models import NDIFResponse
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel

from nnsightful.tools.activation_patching import activation_patching

from ..state import AppState
from ..auth import require_user_email
from ..state import get_state

from nnsightful.types import ActivationPatchingData
from nnsightful.tools.activation_patching import activation_patching
from ..sse import MEDIA_TYPE, stream_backend, stream_value
from ..state import AppState, get_state

router = APIRouter()


class ActivationPatchingRequest(BaseModel):
model_name: str
src_prompt: str
Expand All @@ -23,49 +22,43 @@ class ActivationPatchingRequest(BaseModel):
tgt_freeze: List[int] = []
token_ids: List[int]

class ActivationPatchingResponse(NDIFResponse):
data: ActivationPatchingData | None = None


@router.post("/start", response_model=ActivationPatchingResponse)
async def start_activation_patching(
request: ActivationPatchingRequest,
@router.post("/run")
async def run_activation_patching(
req: ActivationPatchingRequest,
state: AppState = Depends(get_state),
user_email: str = Depends(require_user_email),
):
model = state[request.model_name]
backend = state.make_backend(model=model)

raw = activation_patching._run(
model = state[req.model_name]

if not state.remote:
data = activation_patching(
model,
req.src_prompt,
req.tgt_prompt,
req.src_pos,
req.tgt_pos,
req.tgt_freeze,
remote=False,
)
return StreamingResponse(stream_value(data), media_type=MEDIA_TYPE)

backend = state.make_streaming_backend(model=model)
activation_patching._run(
model,
request.src_prompt,
request.tgt_prompt,
request.src_pos,
request.tgt_pos,
request.tgt_freeze,
remote=state.remote,
req.src_prompt,
req.tgt_prompt,
req.src_pos,
req.tgt_pos,
req.tgt_freeze,
remote=True,
backend=backend,
)

if "job_id" in raw:
return {"job_id": raw["job_id"]}

data = activation_patching._format(raw)
return {"data": data}


@router.post("/results/{job_id}", response_model=ActivationPatchingResponse)
async def collect_results(
job_id: str,
request: ActivationPatchingRequest,
state: AppState = Depends(get_state),
user_email: str = Depends(require_user_email),
):
backend = state.make_backend(job_id=job_id)
results = backend()

results["tokenizer"] = state[request.model_name].tokenizer
tokenizer = model.tokenizer

data = activation_patching._format(results)
def process(raw: dict):
raw["tokenizer"] = tokenizer
return activation_patching._format(raw)

return {"data": data}
return StreamingResponse(stream_backend(backend, process), media_type=MEDIA_TYPE)
Loading
Loading