diff --git a/workbench/_api/main.py b/workbench/_api/main.py index e4c17540..e4d8fa64 100644 --- a/workbench/_api/main.py +++ b/workbench/_api/main.py @@ -23,6 +23,7 @@ if os.environ.get('CONFIG') != "prod": ALLOWED_ORIGINS.append("http://localhost:3000") + ALLOWED_ORIGINS.append("http://127.0.0.1:3000") ALLOWED_ORIGIN_REGEX = ( r"^https://workbench-[a-z0-9\-]*-ndif\.vercel\.app$" # dev/staging previews diff --git a/workbench/_api/routes/activation_patching.py b/workbench/_api/routes/activation_patching.py index 1f8a348f..830a7336 100644 --- a/workbench/_api/routes/activation_patching.py +++ b/workbench/_api/routes/activation_patching.py @@ -1,19 +1,18 @@ -from fastapi import APIRouter, Request, Depends from typing import List, Union -from pydantic import BaseModel -from ..data_models import NDIFResponse +from fastapi import APIRouter, Depends +from fastapi.responses import StreamingResponse +from pydantic import BaseModel +from nnsightful.tools.activation_patching import activation_patching -from ..state import AppState from ..auth import require_user_email -from ..state import get_state - -from nnsightful.types import ActivationPatchingData -from nnsightful.tools.activation_patching import activation_patching +from ..sse import MEDIA_TYPE, stream_backend, stream_value +from ..state import AppState, get_state router = APIRouter() + class ActivationPatchingRequest(BaseModel): model_name: str src_prompt: str @@ -23,49 +22,43 @@ class ActivationPatchingRequest(BaseModel): tgt_freeze: List[int] = [] token_ids: List[int] -class ActivationPatchingResponse(NDIFResponse): - data: ActivationPatchingData | None = None - -@router.post("/start", response_model=ActivationPatchingResponse) -async def start_activation_patching( - request: ActivationPatchingRequest, +@router.post("/run") +async def run_activation_patching( + req: ActivationPatchingRequest, state: AppState = Depends(get_state), user_email: str = Depends(require_user_email), ): - model = state[request.model_name] - backend = state.make_backend(model=model) - - raw = activation_patching._run( + model = state[req.model_name] + + if not state.remote: + data = activation_patching( + model, + req.src_prompt, + req.tgt_prompt, + req.src_pos, + req.tgt_pos, + req.tgt_freeze, + remote=False, + ) + return StreamingResponse(stream_value(data), media_type=MEDIA_TYPE) + + backend = state.make_streaming_backend(model=model) + activation_patching._run( model, - request.src_prompt, - request.tgt_prompt, - request.src_pos, - request.tgt_pos, - request.tgt_freeze, - remote=state.remote, + req.src_prompt, + req.tgt_prompt, + req.src_pos, + req.tgt_pos, + req.tgt_freeze, + remote=True, backend=backend, ) - if "job_id" in raw: - return {"job_id": raw["job_id"]} - - data = activation_patching._format(raw) - return {"data": data} - - -@router.post("/results/{job_id}", response_model=ActivationPatchingResponse) -async def collect_results( - job_id: str, - request: ActivationPatchingRequest, - state: AppState = Depends(get_state), - user_email: str = Depends(require_user_email), -): - backend = state.make_backend(job_id=job_id) - results = backend() - - results["tokenizer"] = state[request.model_name].tokenizer + tokenizer = model.tokenizer - data = activation_patching._format(results) + def process(raw: dict): + raw["tokenizer"] = tokenizer + return activation_patching._format(raw) - return {"data": data} \ No newline at end of file + return StreamingResponse(stream_backend(backend, process), media_type=MEDIA_TYPE) diff --git a/workbench/_api/routes/lens.py b/workbench/_api/routes/lens.py index adf02d69..11f13f3c 100644 --- a/workbench/_api/routes/lens.py +++ b/workbench/_api/routes/lens.py @@ -3,19 +3,27 @@ import torch as t from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse from pydantic import BaseModel from ..auth import require_user_email, user_has_model_access -from ..data_models import NDIFResponse, Token +from ..data_models import Token +from ..sse import MEDIA_TYPE, stream_backend, stream_value from ..state import AppState, get_state -############ LINE ############ class LensStatistic(str, Enum): PROBABILITY = "probability" RANK = "rank" ENTROPY = "entropy" + +router = APIRouter() + + +# -------------------------------- LINE ------------------------------------ + + class LensLineRequest(BaseModel): model: str stat: LensStatistic @@ -33,14 +41,8 @@ class Line(BaseModel): data: list[Point] -class LensLineResponse(NDIFResponse): - data: list[Line] | None = None - - -router = APIRouter() - - -def line(req: LensLineRequest, state: AppState) -> list[t.Tensor]: +def _trace_line(req: LensLineRequest, state: AppState, backend): + """Run the lens-line trace. Saves a list of per-layer tensors under key 'results'.""" model = state[req.model] idx = req.token.idx target_ids = req.token.target_ids @@ -49,12 +51,16 @@ def _compute_top_probs(logits): return t.nn.functional.softmax(logits, dim=-1) def _compute_rank(logits): - sorted_probs, sorted_indices = t.nn.functional.softmax(logits, dim=-1).sort(descending=True, dim=-1) + sorted_probs, sorted_indices = t.nn.functional.softmax(logits, dim=-1).sort( + descending=True, dim=-1 + ) rank_map = t.empty_like(sorted_indices) rank_map.scatter_( -1, sorted_indices, - t.arange(1, logits.size(-1)+1).expand_as(sorted_indices).to(logits.device) + t.arange(1, logits.size(-1) + 1) + .expand_as(sorted_indices) + .to(logits.device), ) return rank_map @@ -62,12 +68,13 @@ def _compute_rank(logits): _compute_func = _compute_top_probs elif req.stat == LensStatistic.RANK: _compute_func = _compute_rank + else: + raise HTTPException( + status_code=400, + detail=f"Unsupported statistic for lens-line: {req.stat}", + ) - with model.trace( - req.prompt, - remote=state.remote, - backend=state.make_backend(model=model), - ) as tracer: + with model.trace(req.prompt, remote=state.remote, backend=backend): results = [] for layer in model.model.layers: hidden_BLD = layer.output @@ -81,29 +88,17 @@ def _compute_rank(logits): target_probs_X = t.gather(metrics, 0, target_ids_tensor) results.append(target_probs_X) - results.save() - - if state.remote: - return tracer.backend.job_id + results = results.save() return results -def get_remote_line(user_email: str, job_id: str, state: AppState): - backend = state.make_backend(job_id=job_id) - results = backend() - return results["results"] - - -def process_line_results( - results: list[t.Tensor], - req: LensLineRequest, - state: AppState, -): +def _format_line(raw: dict, req: LensLineRequest, state: AppState) -> list[Line]: tok = state[req.model].tokenizer + results = raw["results"] target_token_strs = tok.batch_decode(req.token.target_ids) - lines = [] + lines: list[Line] = [] for layer_idx, probs in enumerate(results): for line_idx, prob in enumerate(probs.tolist()): @@ -120,51 +115,42 @@ def process_line_results( return lines -@router.post("/start-line", response_model=LensLineResponse) -async def start_line( +@router.post("/run-line") +async def run_line( req: LensLineRequest, state: AppState = Depends(get_state), - user_email: str = Depends(require_user_email) + user_email: str = Depends(require_user_email), ): + if state.remote and not user_has_model_access(user_email, req.model, state): + raise HTTPException( + status_code=403, + detail=f"User does not have access to {req.model}", + ) - if state.remote: - if not user_has_model_access(user_email, req.model, state): - message = f"User does not have access to {req.model}" - raise HTTPException(status_code=403, detail=message) + if not state.remote: + results = _trace_line(req, state, backend=None) + lines = _format_line({"results": results}, req, state) + return StreamingResponse(stream_value(lines), media_type=MEDIA_TYPE) - try: - result = line(req, state) - except Exception as e: - raise e + model = state[req.model] + backend = state.make_streaming_backend(model=model) + _trace_line(req, state, backend=backend) - if state.remote: - return {"job_id": result} + def process(raw: dict) -> list[Line]: + return _format_line(raw, req, state) - return {"data": process_line_results(result, req, state)} + return StreamingResponse(stream_backend(backend, process), media_type=MEDIA_TYPE) -@router.post("/results-line/{job_id}", response_model=LensLineResponse) -async def collect_line( - job_id: str, - req: LensLineRequest, - state: AppState = Depends(get_state), - user_email: str = Depends(require_user_email) -): +# -------------------------------- GRID ------------------------------------ - try: - results = get_remote_line(user_email, job_id, state) - except Exception as e: - raise e - - return {"data": process_line_results(results, req, state)} - -############ GRID ############ class GridLensRequest(BaseModel): model: str stat: LensStatistic prompt: str + class GridCell(Point): label: str @@ -175,13 +161,8 @@ class GridRow(BaseModel): right_axis_label: str | None = None -class GridLensResponse(NDIFResponse): - data: list[GridRow] | None = None - - -def heatmap( - req: GridLensRequest, state: AppState -) -> tuple[list[t.Tensor], list[t.Tensor]]: +def _trace_grid(req: GridLensRequest, state: AppState, backend): + """Run the grid-lens trace. Saves 'stats' and 'pred_ids' lists.""" model = state[req.model] def _compute_top_probs(hs_decoded, logits): @@ -204,17 +185,21 @@ def _compute_rank(hs_decoded, logits): top_tokens = logits.argmax(dim=-1) for hs in hs_decoded: - sorted_probs, sorted_indices = t.nn.functional.softmax(hs, dim=-1).sort(descending=True, dim=-1) + sorted_probs, sorted_indices = t.nn.functional.softmax(hs, dim=-1).sort( + descending=True, dim=-1 + ) rank_map = t.empty_like(sorted_indices) rank_map.scatter_( 2, sorted_indices, - t.arange(1, logits.size(-1)+1).expand_as(sorted_indices).to(hs.device) + t.arange(1, logits.size(-1) + 1) + .expand_as(sorted_indices) + .to(hs.device), ) ranks_L = rank_map.gather(2, top_tokens.unsqueeze(-1)).squeeze(-1) ranks.append(ranks_L[0].to("cpu").tolist()) - return ranks, top_tokens[0].to('cpu').tolist() + return ranks, top_tokens[0].to("cpu").tolist() def _compute_entropy(hs_decoded, logits): entropies = [] @@ -234,12 +219,13 @@ def _compute_entropy(hs_decoded, logits): _compute_func = _compute_rank elif req.stat == LensStatistic.ENTROPY: _compute_func = _compute_entropy + else: + raise HTTPException( + status_code=400, + detail=f"Unsupported statistic for lens-grid: {req.stat}", + ) - with model.trace( - req.prompt, - remote=state.remote, - backend=state.make_backend(model=model), - ) as tracer: + with model.trace(req.prompt, remote=state.remote, backend=backend): hs_decoded = [] for layer in model.model.layers[:-1]: @@ -252,36 +238,22 @@ def _compute_entropy(hs_decoded, logits): hs_decoded.append(logits) stats, pred_ids = _compute_func(hs_decoded, logits) - stats.save() - pred_ids.save() - - if state.remote: - return tracer.backend.job_id + stats = stats.save() + pred_ids = pred_ids.save() return stats, pred_ids -def get_remote_heatmap( - user_email: str, - job_id: str, - state: AppState -) -> tuple[list[t.Tensor], list[t.Tensor]]: - backend = state.make_backend(job_id=job_id) - results = backend() - return results["stats"], results["pred_ids"] - - -def process_grid_results( - stats: list[t.Tensor], - pred_ids: list[t.Tensor], - lens_request: GridLensRequest, - state: AppState, -): - tok = state[lens_request.model].tokenizer - input_strs = tok.batch_decode(tok.encode(lens_request.prompt)) - rows = [] +def _format_grid(raw: dict, req: GridLensRequest, state: AppState) -> list[GridRow]: + tok = state[req.model].tokenizer + stats = raw["stats"] + pred_ids = raw["pred_ids"] + + input_strs = tok.batch_decode(tok.encode(req.prompt)) + + rows: list[GridRow] = [] for seq_idx, input_str in enumerate(input_strs): - if lens_request.stat == LensStatistic.PROBABILITY: + if req.stat == LensStatistic.PROBABILITY: points = [ GridCell( x=layer_idx, @@ -291,7 +263,7 @@ def process_grid_results( for layer_idx, (stat, pred_id) in enumerate(zip(stats, pred_ids)) ] rows.append(GridRow(id=f"{input_str}-{seq_idx}", data=points)) - elif lens_request.stat == LensStatistic.RANK: + elif req.stat == LensStatistic.RANK: points = [ GridCell( x=layer_idx, @@ -300,8 +272,14 @@ def process_grid_results( ) for layer_idx, stat in enumerate(stats) ] - rows.append(GridRow(id=f"{input_str}-{seq_idx}", data=points, right_axis_label=tok.decode(pred_ids[seq_idx]))) - elif lens_request.stat == LensStatistic.ENTROPY: + rows.append( + GridRow( + id=f"{input_str}-{seq_idx}", + data=points, + right_axis_label=tok.decode(pred_ids[seq_idx]), + ) + ) + elif req.stat == LensStatistic.ENTROPY: points = [ GridCell( x=layer_idx, @@ -310,44 +288,39 @@ def process_grid_results( ) for layer_idx, stat in enumerate(stats) ] - rows.append(GridRow(id=f"{input_str}-{seq_idx}", data=points, right_axis_label=tok.decode(pred_ids[seq_idx]))) + rows.append( + GridRow( + id=f"{input_str}-{seq_idx}", + data=points, + right_axis_label=tok.decode(pred_ids[seq_idx]), + ) + ) return rows -@router.post("/start-grid", response_model=GridLensResponse) -async def get_grid( +@router.post("/run-grid") +async def run_grid( req: GridLensRequest, state: AppState = Depends(get_state), - user_email: str = Depends(require_user_email) + user_email: str = Depends(require_user_email), ): - if state.remote: - if not user_has_model_access(user_email, req.model, state): - message = f"User does not have access to {req.model}" - raise HTTPException(status_code=403, detail=message) - - try: - result = heatmap(req, state) - except Exception as e: - raise e - - if state.remote: - return {"job_id": result} + if state.remote and not user_has_model_access(user_email, req.model, state): + raise HTTPException( + status_code=403, + detail=f"User does not have access to {req.model}", + ) - probs, pred_ids = result - return {"data": process_grid_results(probs, pred_ids, req, state)} + if not state.remote: + stats, pred_ids = _trace_grid(req, state, backend=None) + rows = _format_grid({"stats": stats, "pred_ids": pred_ids}, req, state) + return StreamingResponse(stream_value(rows), media_type=MEDIA_TYPE) + model = state[req.model] + backend = state.make_streaming_backend(model=model) + _trace_grid(req, state, backend=backend) -@router.post("/results-grid/{job_id}", response_model=GridLensResponse) -async def collect_grid( - job_id: str, - lens_request: GridLensRequest, - state: AppState = Depends(get_state), - user_email: str = Depends(require_user_email) -): - try: - probs, pred_ids = get_remote_heatmap(user_email, job_id, state) - except Exception as e: - raise e + def process(raw: dict) -> list[GridRow]: + return _format_grid(raw, req, state) - return {"data": process_grid_results(probs, pred_ids, lens_request, state)} + return StreamingResponse(stream_backend(backend, process), media_type=MEDIA_TYPE) diff --git a/workbench/_api/routes/logit_lens.py b/workbench/_api/routes/logit_lens.py index 337e32db..a50a2221 100644 --- a/workbench/_api/routes/logit_lens.py +++ b/workbench/_api/routes/logit_lens.py @@ -1,72 +1,55 @@ from fastapi import APIRouter, Depends +from fastapi.responses import StreamingResponse from pydantic import BaseModel -from ..state import AppState, get_state -from ..auth import require_user_email - -from ..data_models import NDIFResponse -from nnsightful.types import LogitLensData from nnsightful.tools.logit_lens import logit_lens +from ..auth import require_user_email +from ..sse import MEDIA_TYPE, stream_backend, stream_value +from ..state import AppState, get_state + router = APIRouter() + class LogitLensRequest(BaseModel): model: str prompt: str - topk: int = 5 # Number of top-k predictions per cell - include_entropy: bool = True # Whether to include entropy data - - -class LogitLensResponse(NDIFResponse): - data: LogitLensData | None = None + topk: int = 5 + include_entropy: bool = True -@router.post("/start", response_model=LogitLensResponse) -async def start_logit_lens( +@router.post("/run") +async def run_logit_lens( req: LogitLensRequest, state: AppState = Depends(get_state), user_email: str = Depends(require_user_email), ): model = state[req.model] - backend = state.make_backend(model=model) - - raw = logit_lens._run(model, req.prompt, remote=state.remote, backend=backend) - - if "job_id" in raw: - return {"job_id": raw["job_id"]} - - data = logit_lens._format( - raw, - top_k=req.topk, - include_entropy=req.include_entropy, - ) - return {"data": data} - - -@router.post("/results/{job_id}", response_model=LogitLensResponse) -async def collect_logit_lens( - job_id: str, - req: LogitLensRequest, - state: AppState = Depends(get_state), - user_email: str = Depends(require_user_email), -): - backend = state.make_backend(job_id=job_id) - results = backend() - - print("logit_lens collect keys:", list(results.keys()) if isinstance(results, dict) else type(results)) - - tokenizer = state[req.model].tokenizer - results["tokenizer"] = tokenizer - results["model_name"] = req.model - results["input_tokens"] = [ - str(tokenizer.decode(token)) - for token in tokenizer.encode(req.prompt) - ] - - data = logit_lens._format( - results, - top_k=req.topk, - include_entropy=req.include_entropy, - ) - return {"data": data} \ No newline at end of file + if not state.remote: + data = logit_lens( + model, + req.prompt, + remote=False, + top_k=req.topk, + include_entropy=req.include_entropy, + ) + return StreamingResponse(stream_value(data), media_type=MEDIA_TYPE) + + backend = state.make_streaming_backend(model=model) + logit_lens._run(model, req.prompt, remote=True, backend=backend) + + tokenizer = model.tokenizer + input_tokens = [str(tokenizer.decode(t)) for t in tokenizer.encode(req.prompt)] + + def process(raw: dict): + raw["tokenizer"] = tokenizer + raw["model_name"] = req.model + raw["input_tokens"] = input_tokens + return logit_lens._format( + raw, + top_k=req.topk, + include_entropy=req.include_entropy, + ) + + return StreamingResponse(stream_backend(backend, process), media_type=MEDIA_TYPE) diff --git a/workbench/_api/routes/models.py b/workbench/_api/routes/models.py index a47fd02d..3786197c 100644 --- a/workbench/_api/routes/models.py +++ b/workbench/_api/routes/models.py @@ -4,31 +4,28 @@ import requests import torch as t from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import StreamingResponse from pydantic import BaseModel from ..auth import get_user_email, require_user_email, user_has_model_access -from ..data_models import NDIFResponse, Token -from ..telemetry import TelemetryClient, RequestStatus +from ..data_models import Token +from ..sse import MEDIA_TYPE, stream_backend, stream_value from ..state import AppState, get_state -from ..data_models import Token, NDIFResponse -from ..auth import require_user_email, get_user_email - -import logging +from ..telemetry import RequestStatus, TelemetryClient logger = logging.getLogger(__name__) router = APIRouter() -MODELS = list() +MODELS: list = [] MODELS_LAST_UPDATED = 0 MODEL_INTERVAL = 60 -def get_remote_models(state: AppState, is_user_signed_in: bool): +def get_remote_models(state: AppState, is_user_signed_in: bool): global MODELS, MODELS_LAST_UPDATED if MODELS_LAST_UPDATED == 0 or time.time() - MODELS_LAST_UPDATED > 60: - ping_resp = requests.get(f"{state.ndif_backend_url}/ping", timeout=30) logger.info(f"Call NDIF_BACKEND/ping: {ping_resp.status_code}") @@ -45,39 +42,43 @@ def get_remote_models(state: AppState, is_user_signed_in: bool): data = stats_resp.json() for deployment_state in data["deployments"].values(): - if deployment_state == {'application_state': 'UNHEALTHY'}: + if deployment_state == {"application_state": "UNHEALTHY"}: continue - if deployment_state['deployment_level'] == "HOT" and deployment_state['application_state'] == "RUNNING": - state.add_model(deployment_state['repo_id']) + if ( + deployment_state["deployment_level"] == "HOT" + and deployment_state["application_state"] == "RUNNING" + ): + state.add_model(deployment_state["repo_id"]) else: - state.remove_model(deployment_state['repo_id']) + state.remove_model(deployment_state["repo_id"]) MODELS = state.get_active_model_list() MODELS_LAST_UPDATED = time.time() models = [model.copy() for model in MODELS] for model in models: - if not is_user_signed_in and model['gated']: - model['allowed'] = False + if not is_user_signed_in and model["gated"]: + model["allowed"] = False else: - model['allowed'] = True + model["allowed"] = True return models + @router.get("/") async def get_models( state: AppState = Depends(get_state), - user_email: str = Depends(get_user_email) + user_email: str = Depends(get_user_email), ): if state.remote: is_user_signed_in: bool = user_email is not None and user_email != "guest@localhost" - models = get_remote_models(state, is_user_signed_in) + return get_remote_models(state, is_user_signed_in) + + return state.get_all_model_list() - return models - else: - return state.get_all_model_list() +# ------------------------------ Prediction --------------------------------- class LensCompletion(BaseModel): @@ -86,62 +87,34 @@ class LensCompletion(BaseModel): token: Token -def prediction( - req: LensCompletion, state: AppState -) -> tuple[t.Tensor, t.Tensor] | str: +class Prediction(BaseModel): + idx: int + ids: list[int] + probs: list[float] + texts: list[str] + + +def _trace_prediction(req: LensCompletion, state: AppState, backend): + """Run the prediction trace. Saves values_LV + indices_LV on the tracer.""" model = state[req.model] idx = req.token.idx - with model.trace( - req.prompt, - remote=state.remote, - backend=state.make_backend(model=model), - ) as tracer: + with model.trace(req.prompt, remote=state.remote, backend=backend): logits_BLV = model.logits - - # Get logits for the correct index logits_LV = logits_BLV[0, [idx], :].softmax(dim=-1) - - # Sort logits by descending probability values_LV_indices_LV = t.sort(logits_LV, dim=-1, descending=True) - values_LV = values_LV_indices_LV[0].save() indices_LV = values_LV_indices_LV[1].save() - if state.remote: - return tracer.backend.job_id - return values_LV, indices_LV -def get_remote_prediction( - job_id: str, state: AppState -) -> tuple[t.Tensor, t.Tensor]: - backend = state.make_backend(job_id=job_id) - results = backend() - return results["values_LV"], results["indices_LV"] - - -class Prediction(BaseModel): - idx: int - ids: list[int] - probs: list[float] - texts: list[str] - -class PredictionResponse(NDIFResponse): - data: Prediction | None = None - - -def process_prediction( - values_LV: t.Tensor, - indices_LV: t.Tensor, - req: LensCompletion, - state: AppState, -): +def _format_prediction(raw: dict, req: LensCompletion, state: AppState) -> Prediction: tok = state[req.model].tokenizer - idxs = [req.token.idx] - # Round values to 2 decimal places + values_LV = raw["values_LV"] + indices_LV = raw["indices_LV"] + idx_values = t.round(values_LV[0] * 100) / 100 nonzero = idx_values > 0 @@ -149,99 +122,60 @@ def process_prediction( nonzero_indices = indices_LV[0][nonzero].tolist() nonzero_texts = tok.batch_decode(nonzero_indices) - prediction = Prediction( - idx=idxs[0], + return Prediction( + idx=req.token.idx, ids=nonzero_indices, probs=nonzero_values, texts=nonzero_texts, ) - return prediction - -@router.post("/start-prediction", response_model=PredictionResponse) -async def start_prediction( - prediction_request: LensCompletion, +@router.post("/run-prediction") +async def run_prediction( + req: LensCompletion, state: AppState = Depends(get_state), - user_email: str = Depends(require_user_email) + user_email: str = Depends(require_user_email), ): - if state.remote: - if not user_has_model_access(user_email, prediction_request.model, state): - message = f"User does not have access to {prediction_request.model}" - TelemetryClient.log_request( - RequestStatus.ERROR, - user_email, - method="PREDICTION", - type="NEXT_TOKEN", - msg=message, - ) - raise HTTPException(status_code=403, detail=message) + if state.remote and not user_has_model_access(user_email, req.model, state): + message = f"User does not have access to {req.model}" + TelemetryClient.log_request( + RequestStatus.ERROR, + user_email, + method="PREDICTION", + type="NEXT_TOKEN", + msg=message, + ) + raise HTTPException(status_code=403, detail=message) TelemetryClient.log_request( - RequestStatus.STARTED, + RequestStatus.STARTED, user_email, method="PREDICTION", type="NEXT_TOKEN", ) - try: - result = prediction(prediction_request, state) - except Exception as e: - TelemetryClient.log_request( - RequestStatus.ERROR, - user_email, - method="PREDICTION", - type="NEXT_TOKEN", - msg=str(e), + if not state.remote: + values_LV, indices_LV = _trace_prediction(req, state, backend=None) + data = _format_prediction( + {"values_LV": values_LV, "indices_LV": indices_LV}, req, state ) - raise e - - if state.remote: - TelemetryClient.log_request( - RequestStatus.READY, - user_email, - method="PREDICTION", - type="NEXT_TOKEN", - job_id=result - ) - return {"job_id": result} + return StreamingResponse(stream_value(data), media_type=MEDIA_TYPE) - values_LV, indices_LV = result - data = process_prediction(values_LV, indices_LV, prediction_request, state) - return {"data": data} + model = state[req.model] + backend = state.make_streaming_backend(model=model) + _trace_prediction(req, state, backend=backend) + # job_id isn't assigned until iteration actually submits the request, so + # the READY/COMPLETE milestones previously logged here would carry None. + # Skip them for now; STARTED + downstream errors are still captured. -@router.post("/results-prediction/{job_id}", response_model=PredictionResponse) -async def results_prediction( - job_id: str, - prediction_request: LensCompletion, - state: AppState = Depends(get_state), - user_email: str = Depends(require_user_email) -): + def process(raw: dict) -> Prediction: + return _format_prediction(raw, req, state) - try: - values_LV, indices_LV = get_remote_prediction(job_id, state) - data = process_prediction(values_LV, indices_LV, prediction_request, state) - except Exception as e: - TelemetryClient.log_request( - RequestStatus.ERROR, - user_email, - job_id=job_id, - method="PREDICTION", - type="NEXT_TOKEN", - msg=str(e), - ) - raise e + return StreamingResponse(stream_backend(backend, process), media_type=MEDIA_TYPE) - TelemetryClient.log_request( - RequestStatus.COMPLETE, - user_email, - job_id=job_id, - method="PREDICTION", - type="NEXT_TOKEN", - ) - return {"data": data} +# ------------------------------ Generation --------------------------------- class Completion(BaseModel): @@ -255,20 +189,17 @@ class Generation(BaseModel): last_token_prediction: Prediction -class GenerationResponse(NDIFResponse): - data: Generation | None = None - - -def generate(req: Completion, state: AppState): +def _trace_generate(req: Completion, state: AppState, backend): + """Run the generation trace. Saves values_V, indices_V, new_token_ids.""" model = state[req.model] last_iter = req.max_new_tokens - 1 + with model.generate( req.prompt, max_new_tokens=req.max_new_tokens, remote=state.remote, - backend=state.make_backend(model=model), + backend=backend, ) as tracer: - with tracer.iter[last_iter]: logits = model.logits @@ -279,28 +210,16 @@ def generate(req: Completion, state: AppState): new_token_ids = model.generator.output[0].save() - if state.remote: - return tracer.backend.job_id - return values_V, indices_V, new_token_ids -def get_remote_generate( - job_id: str, state: AppState -) -> tuple[t.Tensor, t.Tensor, t.Tensor]: - backend = state.make_backend(job_id=job_id) - results = backend() - return results["values_V"], results["indices_V"], results["new_token_ids"] +def _format_generation(raw: dict, req: Completion, state: AppState) -> Generation: + tok = state[req.model].tokenizer + values_V = raw["values_V"] + indices_V = raw["indices_V"] + new_token_ids = raw["new_token_ids"] -def process_generation_results( - values_V: t.Tensor, - indices_V: t.Tensor, - new_token_ids: t.Tensor, - req: Completion, - state: AppState, -): - tok = state[req.model].tokenizer new_token_text = tok.batch_decode(new_token_ids) tokens = [ @@ -308,7 +227,6 @@ def process_generation_results( for i, text in enumerate(new_token_text) ] - # Round values to 2 decimal places idx_values = t.round(values_V * 100) / 100 nonzero = idx_values > 0 @@ -317,106 +235,60 @@ def process_generation_results( nonzero_texts = tok.batch_decode(nonzero_indices) last_token_prediction = Prediction( - idx=new_token_ids[-1], + idx=new_token_ids[-1].item(), ids=nonzero_indices, probs=nonzero_values, texts=nonzero_texts, - ).model_dump() - - return { - "completion": tokens, - "last_token_prediction": last_token_prediction, - } - - -@router.post("/start-generate", response_model=GenerationResponse) -async def start_generate( - req: Completion, - state: AppState = Depends(get_state), - user_email: str = Depends(require_user_email) -): - - if state.remote: - if not user_has_model_access(user_email, req.model, state): - message = f"User does not have access to {req.model}" - TelemetryClient.log_request( - RequestStatus.ERROR, - user_email, - method="GENERATE", - type="NEXT_TOKEN", - msg=message, - ) - raise HTTPException(status_code=403, detail=message) - - TelemetryClient.log_request( - RequestStatus.STARTED, - user_email, - method="GENERATE", - type="NEXT_TOKEN", ) - try: - result = generate(req, state) - except Exception as e: - TelemetryClient.log_request( - RequestStatus.ERROR, - user_email, - method="GENERATE", - type="NEXT_TOKEN", - msg=str(e), - ) - raise e - - if state.remote: - TelemetryClient.log_request( - RequestStatus.READY, - user_email, - method="GENERATE", - type="NEXT_TOKEN", - job_id=result - ) - print("Hollla") - return {"job_id": result} - - else: - values_V, indices_V, new_token_ids = result - - data = process_generation_results( - values_V, indices_V, new_token_ids, req, state - ) - return {"data": data} + return Generation( + completion=tokens, + last_token_prediction=last_token_prediction, + ) -@router.post("/results-generate/{job_id}", response_model=GenerationResponse) -async def results_generate( - job_id: str, +@router.post("/run-generate") +async def run_generate( req: Completion, state: AppState = Depends(get_state), - user_email: str = Depends(require_user_email) + user_email: str = Depends(require_user_email), ): - - try: - values_V, indices_V, new_token_ids = get_remote_generate(job_id, state) - data = process_generation_results( - values_V, indices_V, new_token_ids, req, state - ) - except Exception as e: + if state.remote and not user_has_model_access(user_email, req.model, state): + message = f"User does not have access to {req.model}" TelemetryClient.log_request( - RequestStatus.ERROR, + RequestStatus.ERROR, user_email, - job_id=job_id, method="GENERATE", type="NEXT_TOKEN", - msg=str(e), + msg=message, ) - raise e + raise HTTPException(status_code=403, detail=message) TelemetryClient.log_request( - RequestStatus.COMPLETE, + RequestStatus.STARTED, user_email, - job_id=job_id, method="GENERATE", type="NEXT_TOKEN", ) - return {"data": data} + if not state.remote: + values_V, indices_V, new_token_ids = _trace_generate(req, state, backend=None) + data = _format_generation( + { + "values_V": values_V, + "indices_V": indices_V, + "new_token_ids": new_token_ids, + }, + req, + state, + ) + return StreamingResponse(stream_value(data), media_type=MEDIA_TYPE) + + model = state[req.model] + backend = state.make_streaming_backend(model=model) + _trace_generate(req, state, backend=backend) + + def process(raw: dict) -> Generation: + return _format_generation(raw, req, state) + + return StreamingResponse(stream_backend(backend, process), media_type=MEDIA_TYPE) diff --git a/workbench/_api/sse.py b/workbench/_api/sse.py new file mode 100644 index 00000000..a17cf049 --- /dev/null +++ b/workbench/_api/sse.py @@ -0,0 +1,72 @@ +"""Server-Sent Events helpers shared by workbench SSE routes. + +Each SSE route emits a sequence of `status` events during execution followed +by a single terminal event: + - `data` — the formatted payload, JSON-encoded + - `error` — a JSON object with an `error` string + +The helpers here keep the three routes that use :class:`StreamingRemoteBackend` +consistent without duplicating the generator boilerplate. +""" + +from __future__ import annotations + +import json +from typing import Any, AsyncIterator, Awaitable, Callable, Union + +from nnsight.schema.response import ResponseModel +from pydantic import BaseModel + +from .streaming_backend import StreamingRemoteBackend + +MEDIA_TYPE = "text/event-stream" + +ProcessFn = Callable[[dict], Union[BaseModel, dict, list, Awaitable[Any]]] + + +def sse_event(event: str, data: str) -> str: + """Format a single Server-Sent Events frame.""" + return f"event: {event}\ndata: {data}\n\n" + + +def _jsonify(payload: Any) -> str: + """JSON-encode a Pydantic model, list-of-models, or plain JSON-compatible value.""" + if isinstance(payload, BaseModel): + return payload.model_dump_json() + if isinstance(payload, list) and payload and isinstance(payload[0], BaseModel): + return json.dumps([p.model_dump() for p in payload]) + return json.dumps(payload) + + +async def stream_backend( + backend: StreamingRemoteBackend, + process: ProcessFn, +) -> AsyncIterator[str]: + """Iterate a streaming backend and yield SSE frames. + + Forwards every non-terminal :class:`ResponseModel` as a `status` frame, + calls ``process`` on the downloaded dict when ``COMPLETED`` arrives, and + emits a terminal `data` frame with the JSON-encoded result. Any exception + is caught and emitted as an `error` frame so the stream closes cleanly. + """ + try: + async for response in backend: + if response.status == ResponseModel.JobStatus.COMPLETED: + result = process(response.data) + if hasattr(result, "__await__"): + result = await result # type: ignore[misc] + yield sse_event("data", _jsonify(result)) + else: + yield sse_event("status", response.model_dump_json(exclude={"data"})) + except Exception as e: + yield sse_event("error", json.dumps({"error": str(e)})) + + +async def stream_value(value: Any) -> AsyncIterator[str]: + """Single-event stream for local (non-remote) endpoints.""" + yield sse_event("data", _jsonify(value)) + + +async def stream_error(message: str) -> AsyncIterator[str]: + """Single-event stream emitting a single `error` frame.""" + yield sse_event("error", json.dumps({"error": message})) diff --git a/workbench/_api/state.py b/workbench/_api/state.py index 0e26bfb3..56da8241 100644 --- a/workbench/_api/state.py +++ b/workbench/_api/state.py @@ -12,6 +12,7 @@ from nnsight.intervention.backends.remote import RemoteBackend from pydantic import BaseModel +from .streaming_backend import StreamingRemoteBackend from .telemetry import TelemetryClient # Set up logger for this module @@ -136,6 +137,12 @@ def make_backend(self, model: StandardizedTransformer | None = None, job_id: str else: return None + def make_streaming_backend(self, model: StandardizedTransformer) -> StreamingRemoteBackend | None: + """Backend for SSE routes: async-iterable, defers submission until iteration.""" + if self.remote: + return StreamingRemoteBackend(model_key=model.to_model_key()) + return None + def __getitem__(self, model_name: str): return self.get_model(model_name) diff --git a/workbench/_api/streaming_backend.py b/workbench/_api/streaming_backend.py new file mode 100644 index 00000000..e52ff569 --- /dev/null +++ b/workbench/_api/streaming_backend.py @@ -0,0 +1,102 @@ +"""Streaming remote backend used by workbench SSE endpoints. + +Unlike the parent ``RemoteBackend`` — which runs the WebSocket receive loop +inline and blocks ``__call__`` until the job finishes — this subclass defers +both submission and status-waiting so the calling FastAPI route can drive +the lifecycle asynchronously and forward each update to the browser as a +Server-Sent Event. + +Lifecycle: + + 1. nnsight's trace / session ``__exit__`` invokes ``__call__(tracer)`` + synchronously. We capture the tracer and serialize the request + payload, but perform no I/O. + 2. The route opens the SSE stream and does ``async for response in backend``. + On the first step, the backend opens an async WebSocket, stamps the + socket session id into the request headers, POSTs the submit, and + begins yielding :class:`ResponseModel` updates as they arrive. + 3. When ``COMPLETED`` arrives, the result is downloaded via the parent's + ``async_get_result``; ``response.data`` is replaced with the downloaded + dict of save-keyed tensors, the response is yielded one last time, + and iteration ends. + 4. On ``ERROR``, the response is yielded and :class:`RemoteException` + is raised. +""" + +from __future__ import annotations + +from typing import Any, AsyncIterator, Dict, Optional + +import socketio + +from nnsight.intervention.backends.remote import RemoteBackend, RemoteException +from nnsight.schema.response import ResponseModel + + +class StreamingRemoteBackend(RemoteBackend): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._request_data: Optional[bytes] = None + self._request_headers: Optional[Dict[str, str]] = None + self._tracer = None + + def __call__(self, tracer=None): + """Capture tracer and serialize the request. Fires on trace/session __exit__.""" + if tracer is None: + return None + self._tracer = tracer + self._request_data, self._request_headers = self.request(tracer) + + async def __aiter__(self) -> AsyncIterator[ResponseModel]: + if self._request_data is None: + raise RuntimeError( + "StreamingRemoteBackend is not primed; it must be passed as the " + "`backend` argument to a model.trace(...) or model.session(...) " + "context before iteration." + ) + + async with socketio.AsyncSimpleClient(reconnection_attempts=10) as sio: + await sio.connect( + self.ws_address, + socketio_path="/ws/socket.io", + transports=["websocket"], + wait_timeout=10, + ) + + headers = dict(self._request_headers) + headers["ndif-session_id"] = sio.sid + + initial = await self.async_submit_request(self._request_data, headers) + + if initial.status == ResponseModel.JobStatus.COMPLETED: + await self._async_finalize(initial) + yield initial + return + if initial.status == ResponseModel.JobStatus.ERROR: + yield initial + raise RemoteException(initial.description) + + yield initial + + while True: + msg = await sio.receive(timeout=None) + response = ResponseModel.unpickle(msg[1]) + + if response.status == ResponseModel.JobStatus.COMPLETED: + await self._async_finalize(response) + yield response + return + if response.status == ResponseModel.JobStatus.ERROR: + yield response + raise RemoteException(response.description) + + yield response + + async def _async_finalize(self, response: ResponseModel) -> None: + """Download the final payload (if delivered as a URL) and inline it.""" + result: Any = response.data + if isinstance(result, str): + result = await self.async_get_result(result) + elif isinstance(result, (tuple, list)): + result = await self.async_get_result(*result) + response.data = result diff --git a/workbench/_web/src/lib/api/activationPatchingApi.ts b/workbench/_web/src/lib/api/activationPatchingApi.ts index f39d3a7e..531649ae 100644 --- a/workbench/_web/src/lib/api/activationPatchingApi.ts +++ b/workbench/_web/src/lib/api/activationPatchingApi.ts @@ -12,7 +12,7 @@ import { } from "@/types/activationPatching"; import { queryKeys } from "../queryKeys"; import { toast } from "sonner"; -import { startAndPoll } from "../startAndPoll"; +import { runAndStream } from "../runAndStream"; import { createUserHeadersAction } from "@/actions/auth"; /** @@ -42,10 +42,9 @@ const getActivationPatching = async ( token_ids: [], // Backend will use src_pred and clean_pred from results }; - return await startAndPoll( - config.endpoints.startActivationPatching, + return await runAndStream( + config.endpoints.runActivationPatching, apiRequest, - config.endpoints.resultsActivationPatching, headers, ); }; diff --git a/workbench/_web/src/lib/api/chartApi.ts b/workbench/_web/src/lib/api/chartApi.ts index aa2c70e8..7b1c9975 100644 --- a/workbench/_web/src/lib/api/chartApi.ts +++ b/workbench/_web/src/lib/api/chartApi.ts @@ -19,7 +19,7 @@ import { useCapture } from "@/components/providers/CaptureProvider"; import { Line, HeatmapRow, ChartView } from "@/types/charts"; import { queryKeys } from "../queryKeys"; import { toast } from "sonner"; -import { startAndPoll } from "../startAndPoll"; +import { runAndStream } from "../runAndStream"; import { useHeatmapView, useLineView } from "@/components/charts/ViewProvider"; import { createUserHeadersAction } from "@/actions/auth"; @@ -34,10 +34,9 @@ const getLensLine = async (lensRequest: { completion: LensConfigData; chartId: s token: lensRequest.completion.token, }; - return await startAndPoll( - config.endpoints.startLensLine, + return await runAndStream( + config.endpoints.runLensLine, lineRequest, - config.endpoints.resultsLensLine, headers, ); }; @@ -121,10 +120,9 @@ const getLensGrid = async (lensRequest: { completion: LensConfigData; chartId: s prompt: lensRequest.completion.prompt, }; - return await startAndPoll( - config.endpoints.startLensGrid, + return await runAndStream( + config.endpoints.runLensGrid, gridRequest, - config.endpoints.resultsLensGrid, headers, ); }; diff --git a/workbench/_web/src/lib/api/lensApi.ts b/workbench/_web/src/lib/api/lensApi.ts index 10948419..7613668a 100644 --- a/workbench/_web/src/lib/api/lensApi.ts +++ b/workbench/_web/src/lib/api/lensApi.ts @@ -8,7 +8,7 @@ import { setChartData } from "@/lib/queries/chartQueries"; import { Lens2ConfigData, Lens2Data } from "@/types/lens2"; import { queryKeys } from "../queryKeys"; import { toast } from "sonner"; -import { startAndPoll } from "../startAndPoll"; +import { runAndStream } from "../runAndStream"; import { createUserHeadersAction } from "@/actions/auth"; /** @@ -33,10 +33,9 @@ const getLens2 = async (lensRequest: Lens2Request): Promise => { include_entropy: lensRequest.completion.includeEntropy ?? true, }; - return await startAndPoll( - config.endpoints.startLens2, + return await runAndStream( + config.endpoints.runLens2, request, - config.endpoints.resultsLens2, headers, ); }; diff --git a/workbench/_web/src/lib/api/modelsApi.ts b/workbench/_web/src/lib/api/modelsApi.ts index 3601dbd5..64394fa1 100644 --- a/workbench/_web/src/lib/api/modelsApi.ts +++ b/workbench/_web/src/lib/api/modelsApi.ts @@ -1,10 +1,9 @@ import config from "@/lib/config"; import type { LensConfigData } from "@/types/lens"; import type { Model, Token } from "@/types/models"; -import { startAndPoll } from "../startAndPoll"; +import { runAndStream } from "../runAndStream"; import { useMutation } from "@tanstack/react-query"; import { toast } from "sonner"; -import { useWorkspace } from "@/stores/useWorkspace"; import { createUserHeadersAction } from "@/actions/auth"; interface Prediction { @@ -16,10 +15,9 @@ interface Prediction { const getPrediction = async (request: LensConfigData): Promise => { const headers = await createUserHeadersAction(); - return await startAndPoll( - config.endpoints.startPrediction, + return await runAndStream( + config.endpoints.runPrediction, request, - config.endpoints.resultsPrediction, headers, ); }; @@ -46,10 +44,9 @@ export interface GenerationResponse { const generate = async (request: Completion): Promise => { const headers = await createUserHeadersAction(); - return await startAndPoll( - config.endpoints.startGenerate, + return await runAndStream( + config.endpoints.runGenerate, request, - config.endpoints.resultsGenerate, headers, ); }; diff --git a/workbench/_web/src/lib/config.ts b/workbench/_web/src/lib/config.ts index fdf94077..5c435199 100644 --- a/workbench/_web/src/lib/config.ts +++ b/workbench/_web/src/lib/config.ts @@ -2,33 +2,17 @@ const config = { backendUrl: process.env.NEXT_PUBLIC_BACKEND_URL || "http://localhost:8000", - ndifUrl: - process.env.NEXT_PUBLIC_LOCAL_NDIF === "true" - ? "http://localhost:5001" - : "https://api.ndif.us", endpoints: { - startLensLine: "/lens/start-line", - resultsLensLine: (jobId: string) => `/lens/results-line/${jobId}`, - - startLensGrid: "/lens/start-grid", - resultsLensGrid: (jobId: string) => `/lens/results-grid/${jobId}`, - - startLens2: "/logit_lens/start", - resultsLens2: (jobId: string) => `/logit_lens/results/${jobId}`, - - startActivationPatching: "/activation_patching/start", - resultsActivationPatching: (jobId: string) => `/activation_patching/results/${jobId}`, - - startPrediction: "/models/start-prediction", - resultsPrediction: (jobId: string) => `/models/results-prediction/${jobId}`, - - startGenerate: "/models/start-generate", - resultsGenerate: (jobId: string) => `/models/results-generate/${jobId}`, + runLensLine: "/lens/run-line", + runLensGrid: "/lens/run-grid", + runLens2: "/logit_lens/run", + runActivationPatching: "/activation_patching/run", + runPrediction: "/models/run-prediction", + runGenerate: "/models/run-generate", models: "/models/", }, getApiUrl: (endpoint: string) => `${config.backendUrl}${endpoint}`, - ndifStatusUrl: (jobId: string) => `${config.ndifUrl}/response/${jobId}`, } as const; export default config; diff --git a/workbench/_web/src/lib/runAndStream.ts b/workbench/_web/src/lib/runAndStream.ts new file mode 100644 index 00000000..2dd1f7e3 --- /dev/null +++ b/workbench/_web/src/lib/runAndStream.ts @@ -0,0 +1,101 @@ +import config from "./config"; +import { useWorkspace } from "@/stores/useWorkspace"; + +type SSEEvent = { event: string; data: string }; + +// Parse a stream of SSE events from a fetch ReadableStream body. +async function* parseSSE(body: ReadableStream): AsyncGenerator { + const reader = body.getReader(); + const decoder = new TextDecoder(); + let buf = ""; + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + buf += decoder.decode(value, { stream: true }); + + let sep: number; + while ((sep = buf.indexOf("\n\n")) !== -1) { + const raw = buf.slice(0, sep); + buf = buf.slice(sep + 2); + + let eventName = "message"; + const dataLines: string[] = []; + for (const line of raw.split("\n")) { + if (line.startsWith("event:")) { + eventName = line.slice(6).trim(); + } else if (line.startsWith("data:")) { + dataLines.push(line.slice(5).replace(/^ /, "")); + } + } + if (dataLines.length === 0) continue; + yield { event: eventName, data: dataLines.join("\n") }; + } + } + } finally { + reader.releaseLock(); + } +} + +/** + * POST to an SSE endpoint that emits `status` events during execution, a single + * `data` event with the final payload, and `error` events on failure. + * Resolves with the parsed `data` payload or throws on `error` / protocol issues. + */ +export async function runAndStream( + endpoint: string, + body: unknown, + headers?: Record, +): Promise { + const { setJobStatus } = useWorkspace.getState(); + const url = config.getApiUrl(endpoint); + + const response = await fetch(url, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + ...headers, + }, + body: JSON.stringify(body), + }); + + if (!response.ok || !response.body) { + setJobStatus("Error"); + throw new Error(`Request failed: ${response.status} ${response.statusText}`); + } + + let finalData: T | null = null; + let sseError: string | null = null; + + for await (const evt of parseSSE(response.body)) { + if (evt.event === "status") { + try { + const parsed = JSON.parse(evt.data); + if (parsed?.status) setJobStatus(parsed.status); + } catch { + /* ignore malformed status frame */ + } + } else if (evt.event === "data") { + finalData = JSON.parse(evt.data) as T; + setJobStatus("Idle"); + } else if (evt.event === "error") { + try { + sseError = (JSON.parse(evt.data) as { error?: string }).error ?? evt.data; + } catch { + sseError = evt.data; + } + } + } + + if (sseError !== null) { + setJobStatus("Error"); + throw new Error(sseError); + } + if (finalData === null) { + setJobStatus("Error"); + throw new Error("Stream ended without a data event"); + } + return finalData; +} diff --git a/workbench/_web/src/lib/startAndPoll.ts b/workbench/_web/src/lib/startAndPoll.ts deleted file mode 100644 index e3c4d104..00000000 --- a/workbench/_web/src/lib/startAndPoll.ts +++ /dev/null @@ -1,99 +0,0 @@ -import config from "./config"; -import { useWorkspace } from "@/stores/useWorkspace"; - -const POLL_TIMEOUT_MS = 60000; -const POLL_INTERVAL_MS = 1000; - -async function awaitNDIFJob(jobId: string): Promise { - const startedAt = Date.now(); - const { setJobStatus } = useWorkspace.getState(); - while (true) { - if (Date.now() - startedAt > POLL_TIMEOUT_MS) { - setJobStatus("timeout"); - throw new Error("Timed out waiting for job to complete"); - } - - const pollResp = await fetch(config.ndifStatusUrl(jobId)); - if (!pollResp.ok) throw new Error("Polling failed"); - const data = await pollResp.json(); - const status = data?.status as string | undefined; - - if (status === "COMPLETED") { - setJobStatus("Idle"); - return; - } - if (status === "ERROR" || status === "NNSIGHT_ERROR") { - setJobStatus("Error"); - console.error(data); - throw new Error("Job failed"); - } - - if (status) { - setJobStatus(status); - } - - // For non-terminal statuses, wait and try again - await new Promise((r) => setTimeout(r, POLL_INTERVAL_MS)); - } -} - -type JobStartResponse = { job_id: string | null } & { data?: T } & Record; - -async function startJob( - url: string, - body: unknown, - headers?: Record, -): Promise> { - const response = await fetch(url, { - method: "POST", - headers: { - "Content-Type": "application/json", - ...headers, - }, - body: JSON.stringify(body), - }); - if (!response.ok) throw new Error("Failed to start job"); - return await response.json(); -} - -async function fetchResults( - url: string, - body: unknown, - headers?: Record, -): Promise { - const resp = await fetch(url, { - method: "POST", - headers: { - "Content-Type": "application/json", - ...headers, - }, - body: JSON.stringify(body), - }); - if (!resp.ok) throw new Error("Failed to fetch results"); - return resp.json() as Promise; -} - -export async function startAndPoll( - startEndpoint: string, - body: unknown, - resultsEndpoint: (jobId: string) => string, - headers?: Record, -): Promise { - const startUrl = config.getApiUrl(startEndpoint); - const response = await startJob(startUrl, body, headers); - const jobId = response?.job_id ?? null; - if (jobId) { - await awaitNDIFJob(jobId); - - const resultsUrl = config.getApiUrl(resultsEndpoint(jobId)); - const results = await fetchResults(resultsUrl, body, headers); - if (results && typeof results === "object" && "data" in results) { - return (results as { data: T }).data; - } - return results as T; - } - if ("data" in response) { - return (response as { data: T | null }).data as T; - } - return response as unknown as T; -}