diff --git a/backend/app/alembic/versions/064_add_prefilter_columns_to_assessment_run.py b/backend/app/alembic/versions/064_add_prefilter_columns_to_assessment_run.py new file mode 100644 index 000000000..25ba9a3c0 --- /dev/null +++ b/backend/app/alembic/versions/064_add_prefilter_columns_to_assessment_run.py @@ -0,0 +1,105 @@ +"""Add prefilter columns and pipeline stage-machine columns to assessment_run + +Revision ID: 064 +Revises: 063 +Create Date: 2026-05-27 00:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision = "064" +down_revision = "063" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "assessment_run", + sa.Column( + "prefilter_object_store_url", + sa.String(), + nullable=True, + comment="S3 URL of prefilter results JSON", + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "prefilter_total_rows", + sa.Integer(), + nullable=True, + comment="Total rows fed into the prefilter stages", + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "prefilter_total_passed", + sa.Integer(), + nullable=True, + comment="Rows that passed the go/no-go gates and went to L2", + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "prefilter_total_rejected", + sa.Integer(), + nullable=True, + comment="Rows rejected by a go/no-go gate", + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "stage", + sa.String(), + nullable=True, + comment=( + "Current pipeline stage: PRE_FILTER_TOPIC_RELEVANCE, " + "PRE_FILTER_DUPLICATE_DETECTION, L2_ASSESSMENT, COMPLETED, FAILED" + ), + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "stage_status", + sa.String(), + nullable=True, + comment="Status of stage: PENDING, PROCESSING, COMPLETED, FAILED", + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "pipeline", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="Ordered stage config driving execution: {'stages': [...]}", + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "stage_batches", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + comment="Map of stage name -> batch_job id, for per-stage result lookup", + ), + ) + + +def downgrade() -> None: + op.drop_column("assessment_run", "stage_batches") + op.drop_column("assessment_run", "pipeline") + op.drop_column("assessment_run", "stage_status") + op.drop_column("assessment_run", "stage") + op.drop_column("assessment_run", "prefilter_total_rejected") + op.drop_column("assessment_run", "prefilter_total_passed") + op.drop_column("assessment_run", "prefilter_total_rows") + op.drop_column("assessment_run", "prefilter_object_store_url") diff --git a/backend/app/api/docs/assessment/resume_run.md b/backend/app/api/docs/assessment/resume_run.md new file mode 100644 index 000000000..a7dc2713f --- /dev/null +++ b/backend/app/api/docs/assessment/resume_run.md @@ -0,0 +1,5 @@ +Resume a failed assessment run from its failed stage. + +Re-runs the same child run in place, starting at the stage that failed. +Stages that already completed are reused (their batch results are not +recomputed). Only valid when the run is in a failed state. diff --git a/backend/app/api/docs/assessment/update_post_processing.md b/backend/app/api/docs/assessment/update_post_processing.md new file mode 100644 index 000000000..0d6f3278a --- /dev/null +++ b/backend/app/api/docs/assessment/update_post_processing.md @@ -0,0 +1,15 @@ +Save post-processing config for a single assessment run. + +Stores the config inside the run's `input` JSON blob (key +`post_processing_config`). It is applied at export/preview time and never +re-runs the LLM, so it can be edited after the run completes. + +The config has three optional sections: + +- `computed_columns`: derived columns from formulas, e.g. + `{"name": "Total_Score", "formula": "@Novelty_score + @Usefulness_score"}`. + Formulas reference columns with `@` and support `+ - * /` and parentheses. +- `filter`: row filters combined with AND logic. +- `sort`: sort rules applied in priority order. + +Pass `null` (or an empty body) to clear post-processing for the run. diff --git a/backend/app/api/routes/assessment/runs.py b/backend/app/api/routes/assessment/runs.py index 18a9be60e..3ed8305ef 100644 --- a/backend/app/api/routes/assessment/runs.py +++ b/backend/app/api/routes/assessment/runs.py @@ -3,14 +3,19 @@ import logging from typing import Any, Literal -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Body, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from app.api.deps import AuthContextDep, SessionDep from app.api.permissions import Permission, require_permission from app.crud.assessment import ( get_assessment_by_id, + update_run_post_processing_config, +) +from app.crud.assessment import ( get_assessment_run_by_id as get_run_by_id, +) +from app.crud.assessment import ( list_assessment_runs as list_runs, ) from app.models.assessment import ( @@ -21,6 +26,9 @@ AssessmentRunPublic, ) from app.models.evaluation import EvaluationDataset +from app.services.assessment.service import ( + resume_assessment_run as resume_run, +) from app.services.assessment.service import ( retry_assessment_run as retry_run, ) @@ -33,6 +41,7 @@ load_export_rows_for_run, sort_export_rows, ) +from app.services.assessment.utils.post_processing import apply_post_processing from app.utils import APIResponse, load_description logger = logging.getLogger(__name__) @@ -65,6 +74,13 @@ def _build_run_public( total_items=run.total_items, error_message=run.error_message, input=run.input, + prefilter_total_rows=run.prefilter_total_rows, + prefilter_total_passed=run.prefilter_total_passed, + prefilter_total_rejected=run.prefilter_total_rejected, + stage=run.stage, + stage_status=run.stage_status, + pipeline=run.pipeline, + post_processing_config=(run.input or {}).get("post_processing_config"), inserted_at=run.inserted_at, updated_at=run.updated_at, ) @@ -127,6 +143,34 @@ def retry_assessment_run( return APIResponse.success_response(data=result) +@router.post( + "/runs/{run_id}/resume", + description=load_description("assessment/resume_run.md"), + response_model=APIResponse[AssessmentResponse], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def resume_assessment_run( + run_id: int, + session: SessionDep, + auth_context: AuthContextDep, +) -> APIResponse[AssessmentResponse]: + """Resume a failed child run from its failed stage, reusing completed stages.""" + run = get_run_by_id( + session=session, + run_id=run_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + + result = resume_run( + session=session, + run=run, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + return APIResponse.success_response(data=result) + + @router.get( "/runs", description=load_description("assessment/list_runs.md"), @@ -212,12 +256,44 @@ def export_assessment_run_results( ) ) + post_processing_config = (run.input or {}).get("post_processing_config") or None base_label = assessment.experiment_name if assessment else f"run_{run.id}" + if export_format != "json": return build_export_response( export_rows=export_rows, export_format=export_format, base_name=f"{base_label}_run_{run.id}_results", + post_processing_config=post_processing_config, ) - return APIResponse.success_response(data=build_json_export_rows(export_rows)) + rows = build_json_export_rows(export_rows) + rows = apply_post_processing(rows, post_processing_config) + return APIResponse.success_response(data=rows) + + +@router.patch( + "/runs/{run_id}/post-processing", + description=load_description("assessment/update_post_processing.md"), + response_model=APIResponse[AssessmentRunPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def update_post_processing( + run_id: int, + session: SessionDep, + auth_context: AuthContextDep, + config: dict[str, Any] | None = Body(default=None), +) -> APIResponse[AssessmentRunPublic]: + """Save post-processing config (computed columns, sort, filter) for a run.""" + run = get_run_by_id( + session=session, + run_id=run_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + if run is None: + raise HTTPException(status_code=404, detail="Run not found") + + run = update_run_post_processing_config(session=session, run=run, config=config) + + return APIResponse.success_response(data=_build_run_public(session, run)) diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index adadf1c9c..34a3f5878 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -232,6 +232,29 @@ def run_tts_batch_submission( ) +@celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_assessment_pipeline") +def run_assessment_pipeline( + self, + run_id: int, + organization_id: int, + project_id: int, + trace_id: str, + **kwargs, +): + from app.services.assessment.tasks import execute_assessment_pipeline + + _set_trace(trace_id) + return _run_with_otel_parent( + self, + lambda: execute_assessment_pipeline( + run_id=run_id, + organization_id=organization_id, + project_id=project_id, + ), + ) + + @celery_app.task(bind=True, queue="low_priority", priority=1) @gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_tts_result_processing") def run_tts_result_processing( diff --git a/backend/app/crud/assessment/__init__.py b/backend/app/crud/assessment/__init__.py index cd71bff91..2f5f6f217 100644 --- a/backend/app/crud/assessment/__init__.py +++ b/backend/app/crud/assessment/__init__.py @@ -13,7 +13,9 @@ list_assessment_runs, list_assessments, recompute_assessment_status, + update_assessment_run_prefilter_stats, update_assessment_run_status, + update_run_post_processing_config, ) from app.crud.assessment.dataset import ( create_assessment_dataset, @@ -42,5 +44,7 @@ "list_assessment_datasets", "list_assessments", "recompute_assessment_status", + "update_assessment_run_prefilter_stats", "update_assessment_run_status", + "update_run_post_processing_config", ] diff --git a/backend/app/crud/assessment/batch.py b/backend/app/crud/assessment/batch.py index b45603853..4918a97d0 100644 --- a/backend/app/crud/assessment/batch.py +++ b/backend/app/crud/assessment/batch.py @@ -13,7 +13,8 @@ from openpyxl.utils.exceptions import InvalidFileException from sqlmodel import Session -from app.core.batch import BATCH_KEY, start_batch_job +from app.core.batch import BATCH_KEY, GeminiBatchProvider, start_batch_job +from app.core.batch.client import GeminiClient from app.core.batch.openai import OpenAIBatchProvider from app.core.cloud import get_cloud_storage from app.models.assessment import ( @@ -30,13 +31,12 @@ normalize_llm_text, ) from app.services.assessment.utils.attachments import ( + attachment_type_for_row, + build_gemini_attachment_parts, resolve_attachment_values, - resolve_image_mime_and_payload, - split_attachment_urls, - split_data_url, - to_direct_attachment_url, ) from app.services.llm.providers.registry import LLMProvider +from app.utils import get_openai_client logger = logging.getLogger(__name__) @@ -161,6 +161,7 @@ def build_openai_jsonl( attachments: list[AssessmentAttachment], prompt_template: str | None, openai_params: dict, + row_indices: list[int] | None = None, ) -> list[dict[str, Any]]: """Build OpenAI batch JSONL data from dataset rows. @@ -174,7 +175,8 @@ def build_openai_jsonl( """ jsonl_data = [] - for idx, row in enumerate(rows): + for i, row in enumerate(rows): + idx = row_indices[i] if row_indices is not None else i # Build input array input_parts: list[dict[str, Any]] = [] @@ -186,7 +188,13 @@ def build_openai_jsonl( # Attachments for att in attachments: cell_value = row.get(att.column, "") - input_parts.extend(resolve_attachment_values(cell_value, att)) + input_parts.extend( + resolve_attachment_values( + cell_value, + att, + type_override=attachment_type_for_row(att, row), + ) + ) if not input_parts: logger.warning("[build_openai_jsonl] Skipping empty row | idx=%s", idx) @@ -219,6 +227,7 @@ def build_google_jsonl( attachments: list[AssessmentAttachment], prompt_template: str | None, google_params: dict, + row_indices: list[int] | None = None, ) -> list[dict[str, Any]]: """Build Google (Gemini) batch JSONL data from dataset rows. @@ -230,7 +239,8 @@ def build_google_jsonl( """ jsonl_data = [] - for idx, row in enumerate(rows): + for i, row in enumerate(rows): + idx = row_indices[i] if row_indices is not None else i parts: list[dict[str, Any]] = [] # Text prompt @@ -240,64 +250,14 @@ def build_google_jsonl( # Attachments (Gemini uses file_data for inline content) for att in attachments: - cell_value = row.get(att.column, "").strip() - if not cell_value: - continue - - cell_values = ( - split_attachment_urls(cell_value) - if att.format == "url" - else [cell_value] - ) - - for item_value in cell_values: - normalized_value = ( - to_direct_attachment_url(item_value, att.type) - if att.format == "url" - else item_value + cell_value = row.get(att.column, "") + parts.extend( + build_gemini_attachment_parts( + cell_value, + att, + type_override=attachment_type_for_row(att, row), ) - if att.type == "image": - mime_type, payload = resolve_image_mime_and_payload( - normalized_value, - att.format, - ) - if att.format == "url": - parts.append( - { - "fileData": { - "mimeType": mime_type, - "fileUri": normalized_value, - } - } - ) - else: - parts.append( - { - "inlineData": { - "mimeType": mime_type, - "data": payload, - } - } - ) - elif att.type == "pdf": - if att.format == "url": - parts.append( - { - "fileData": { - "mimeType": "application/pdf", - "fileUri": normalized_value, - } - } - ) - else: - parts.append( - { - "inlineData": { - "mimeType": "application/pdf", - "data": split_data_url(normalized_value)[1], - } - } - ) + ) if not parts: logger.warning("[build_google_jsonl] Skipping empty row | idx=%s", idx) @@ -349,6 +309,8 @@ def submit_assessment_batch( assessment_input: dict[str, Any], organization_id: int, project_id: int, + preloaded_rows: list[dict[str, str]] | None = None, + row_indices: list[int] | None = None, ) -> BatchJob: """Build JSONL and submit a batch for one assessment run. @@ -371,8 +333,11 @@ def submit_assessment_batch( output_schema = assessment_input.get("output_schema") attachments = [AssessmentAttachment(**a) for a in attachments_raw] - # Load dataset rows - rows = _load_dataset_rows(session, dataset) + # Use preloaded rows (post-prefilter filtered) if provided, else load from dataset. + if preloaded_rows is not None: + rows = preloaded_rows + else: + rows = _load_dataset_rows(session, dataset) if not rows: raise ValueError(f"Dataset {dataset.id} has no rows") @@ -412,11 +377,9 @@ def submit_assessment_batch( attachments=attachments, prompt_template=prompt_template, openai_params=mapped_params, + row_indices=row_indices, ) - # Get OpenAI client and submit - from app.utils import get_openai_client - openai_client = get_openai_client( session=session, org_id=organization_id, @@ -452,12 +415,9 @@ def submit_assessment_batch( attachments=attachments, prompt_template=prompt_template, google_params=mapped_params, + row_indices=row_indices, ) - # Get Gemini client and submit - from app.core.batch import GeminiBatchProvider - from app.core.batch.client import GeminiClient - gemini_client = GeminiClient.from_credentials( session=session, org_id=organization_id, diff --git a/backend/app/crud/assessment/core.py b/backend/app/crud/assessment/core.py index c91626660..eb3b529a9 100644 --- a/backend/app/crud/assessment/core.py +++ b/backend/app/crud/assessment/core.py @@ -5,6 +5,7 @@ from uuid import UUID from fastapi import HTTPException +from sqlalchemy.orm.attributes import flag_modified from sqlmodel import Session, select from app.core.util import now @@ -129,6 +130,30 @@ def create_assessment_run( return run +def update_run_post_processing_config( + session: Session, + run: AssessmentRun, + config: dict[str, Any] | None, +) -> AssessmentRun: + """Set post_processing_config inside the run's input JSON blob and persist.""" + run.input = {**(run.input or {}), "post_processing_config": config} + flag_modified(run, "input") + session.add(run) + try: + session.commit() + session.refresh(run) + except Exception as e: + session.rollback() + logger.error( + f"[update_run_post_processing_config] Failed for run id={run.id}: {e}", + exc_info=True, + ) + raise + + logger.info(f"[update_run_post_processing_config] Updated run id={run.id}") + return run + + def get_assessment_run_by_id( session: Session, run_id: int, @@ -223,16 +248,58 @@ def update_assessment_run_status( return run +def update_assessment_run_prefilter_stats( + session: Session, + run: AssessmentRun, + prefilter_object_store_url: str | None = None, + prefilter_total_rows: int | None = None, + prefilter_total_passed: int | None = None, + prefilter_total_rejected: int | None = None, +) -> AssessmentRun: + """Persist prefilter result stats (rows/passed/rejected + S3 URL) on a run.""" + run.updated_at = now() + + if prefilter_object_store_url is not None: + run.prefilter_object_store_url = prefilter_object_store_url + if prefilter_total_rows is not None: + run.prefilter_total_rows = prefilter_total_rows + if prefilter_total_passed is not None: + run.prefilter_total_passed = prefilter_total_passed + if prefilter_total_rejected is not None: + run.prefilter_total_rejected = prefilter_total_rejected + + session.add(run) + try: + session.commit() + session.refresh(run) + except Exception as e: + session.rollback() + logger.error( + f"[update_assessment_run_prefilter_stats] Failed: {e}", exc_info=True + ) + raise + + return run + + +_ACTIVE_RUN_STATUSES = { + "prefilter_processing", + "l2_processing", + "processing", + "in_progress", +} +_FAILED_RUN_STATUSES = {"failed", "prefilter_failed"} +_COMPLETED_RUN_STATUSES = {"completed", "completed_with_errors"} + + def compute_run_counts(runs: list[AssessmentRun]) -> AssessmentRunCounts: """Aggregate child run statuses into counters.""" return AssessmentRunCounts( total=len(runs), pending=sum(1 for run in runs if run.status == "pending"), - processing=sum( - 1 for run in runs if run.status in {"processing", "in_progress"} - ), - completed=sum(1 for run in runs if run.status == "completed"), - failed=sum(1 for run in runs if run.status == "failed"), + processing=sum(1 for run in runs if run.status in _ACTIVE_RUN_STATUSES), + completed=sum(1 for run in runs if run.status in _COMPLETED_RUN_STATUSES), + failed=sum(1 for run in runs if run.status in _FAILED_RUN_STATUSES), ) @@ -267,6 +334,11 @@ def build_run_stats(runs: list[AssessmentRun]) -> list[AssessmentRunStat]: total_items=run.total_items, error_message=run.error_message, updated_at=run.updated_at, + prefilter_total_rows=run.prefilter_total_rows, + prefilter_total_passed=run.prefilter_total_passed, + prefilter_total_rejected=run.prefilter_total_rejected, + stage=run.stage, + stage_status=run.stage_status, ) for run in runs ] diff --git a/backend/app/crud/assessment/cron.py b/backend/app/crud/assessment/cron.py index c69b3157e..397554607 100644 --- a/backend/app/crud/assessment/cron.py +++ b/backend/app/crud/assessment/cron.py @@ -12,10 +12,10 @@ update_assessment_run_status, ) from app.crud.assessment.processing import ( - check_and_process_assessment, format_assessment_failure_message, + process_run_batches, ) -from app.models.assessment import Assessment, AssessmentRun +from app.models.assessment import Assessment, AssessmentRun, StageStatus logger = logging.getLogger(__name__) @@ -78,7 +78,9 @@ async def poll_all_pending_assessment_evaluations( runs = get_assessment_runs_for_assessment( session=session, assessment_id=assessment.id ) - active_runs = [run for run in runs if run.status == "processing"] + active_runs = [ + run for run in runs if run.stage_status == StageStatus.PROCESSING + ] if not active_runs: refreshed = recompute_assessment_status( @@ -100,7 +102,7 @@ async def poll_all_pending_assessment_evaluations( for run in active_runs: try: - result = await check_and_process_assessment( + result = await process_run_batches( run=run, session=session, ) @@ -114,52 +116,44 @@ async def poll_all_pending_assessment_evaluations( else: still_processing += 1 - except Exception as e: - error_msg = format_assessment_failure_message(e) + except ValueError as e: + session.rollback() + message = format_assessment_failure_message(e) logger.error( - "[poll_all_pending_assessment_evaluations] Failed run %s | " - "experiment=%s | assessment_id=%s | config_id=%s | config_version=%s | error=%s", + "[poll_all_pending_assessment_evaluations] deterministic error on " + "run %s (assessment %s), marking failed: %s", run.id, - assessment.experiment_name, run.assessment_id, - run.config_id, - run.config_version, - error_msg, - exc_info=True, + message, ) try: + run.stage_status = StageStatus.FAILED update_assessment_run_status( session=session, run=run, status="failed", - error_message=error_msg, - ) - recompute_assessment_status( - session=session, assessment_id=assessment.id + error_message=message, ) - failure_result = { - "assessment_id": run.assessment_id, - "run_id": run.id, - "experiment_name": assessment.experiment_name, - "config_id": str(run.config_id) if run.config_id else None, - "config_version": run.config_version, - "action": "failed", - "error": error_msg, - "current_status": "failed", - } - all_results.append(failure_result) failed += 1 - except Exception as cleanup_exc: + except Exception: + session.rollback() logger.error( - "[poll_all_pending_assessment_evaluations] Cleanup failed for run %s | " - "assessment_id=%s | experiment=%s | error=%s", + "[poll_all_pending_assessment_evaluations] could not mark run " + "%s failed", run.id, - run.assessment_id, - assessment.experiment_name, - cleanup_exc, exc_info=True, ) - failed += 1 + still_processing += 1 + except Exception as e: + session.rollback() + logger.warning( + "[poll_all_pending_assessment_evaluations] transient error polling " + "run %s (assessment %s), will retry: %s", + run.id, + run.assessment_id, + format_assessment_failure_message(e), + ) + still_processing += 1 logger.info( "[poll_all_pending_assessment_evaluations] Summary | processed=%s | failed=%s | still_processing=%s", diff --git a/backend/app/crud/assessment/processing.py b/backend/app/crud/assessment/processing.py index f2a27455c..dbf792bae 100644 --- a/backend/app/crud/assessment/processing.py +++ b/backend/app/crud/assessment/processing.py @@ -8,27 +8,28 @@ from typing import Any from fastapi import HTTPException +from sqlalchemy.orm.attributes import flag_modified from sqlmodel import Session -from app.core.batch import ( - BATCH_KEY, - GeminiBatchProvider, - OpenAIBatchProvider, - download_batch_results, - poll_batch_status, - upload_batch_results_to_object_store, -) +from app.celery.tasks.job_execution import run_assessment_pipeline +from app.core.batch import BATCH_KEY, poll_batch_status, process_completed_batch from app.core.batch.base import BatchProvider -from app.core.batch.client import GeminiClient from app.core.batch.gemini import BatchJobState, extract_text_from_response_dict from app.crud.assessment import ( recompute_assessment_status, + update_assessment_run_prefilter_stats, update_assessment_run_status, ) from app.crud.job import get_batch_job -from app.models.assessment import Assessment, AssessmentRun +from app.models.assessment import Assessment, AssessmentRun, StageStatus +from app.services.assessment.stages import ( + GATE_STAGES, + STAGE_PARSERS, + _get_batch_provider, + advance_or_finalize, + load_raw_batch_results, +) from app.services.llm.providers.registry import LLMProvider -from app.utils import get_openai_client logger = logging.getLogger(__name__) @@ -87,32 +88,6 @@ def _sanitize_json_output(raw: str) -> str: return "".join(result) -def _get_batch_provider( - session: Session, - provider_name: str, - organization_id: int, - project_id: int, -) -> BatchProvider: - """Get the appropriate batch provider instance.""" - if provider_name in (LLMProvider.OPENAI, LLMProvider.OPENAI_NATIVE): - openai_client = get_openai_client( - session=session, - org_id=organization_id, - project_id=project_id, - ) - return OpenAIBatchProvider(client=openai_client) - - if provider_name in (LLMProvider.GOOGLE, LLMProvider.GOOGLE_NATIVE): - gemini_client = GeminiClient.from_credentials( - session=session, - org_id=organization_id, - project_id=project_id, - ) - return GeminiBatchProvider(client=gemini_client.client) - - raise ValueError(f"Unsupported provider for assessment polling: {provider_name}") - - def parse_assessment_output( raw_results: list[dict[str, Any]], provider_name: str, @@ -262,234 +237,170 @@ def parse_assessment_output( return results -async def check_and_process_assessment( - run: AssessmentRun, - session: Session, -) -> dict[str, Any]: - """Check assessment batch status and process if completed. +_PROVIDER_SUCCESS = {"completed", BatchJobState.SUCCEEDED.value} +_PROVIDER_FAILED = { + "failed", + "expired", + "cancelled", + BatchJobState.FAILED.value, + BatchJobState.CANCELLED.value, + BatchJobState.EXPIRED.value, +} - Args: - run: AssessmentRun to check - session: Database session - Returns: - Dict with status information - """ - log_prefix = f"[check_and_process_assessment][assessment_run={run.id}]" - previous_status = run.status - parent_pre = session.get(Assessment, run.assessment_id) - experiment_name_pre = parent_pre.experiment_name if parent_pre else None +def _poll_stage_outcome(session: Session, provider: BatchProvider, batch_job) -> str: + """Poll one stage's batch; on success download+persist. Returns the outcome.""" + status_result = poll_batch_status( + session=session, provider=provider, batch_job=batch_job + ) + session.refresh(batch_job) + status = batch_job.provider_status - try: - if not run.batch_job_id: - raise ValueError(f"Assessment run {run.id} has no batch_job_id") + if status in _PROVIDER_SUCCESS: + if batch_job.provider_output_file_id: + process_completed_batch( + session=session, provider=provider, batch_job=batch_job + ) + return "completed" + counts = status_result.get("request_counts") or {} + if counts.get("completed", 0) == 0 and ( + counts.get("failed", 0) > 0 or status_result.get("error_file_id") + ): + return "failed" + return "no_change" # output genuinely not ready yet — retry next cycle + if status in _PROVIDER_FAILED: + return "failed" + return "no_change" - batch_job = get_batch_job(session=session, batch_job_id=run.batch_job_id) - if not batch_job: - raise ValueError(f"BatchJob {run.batch_job_id} not found") - parent = parent_pre - if not parent: - raise ValueError(f"Parent assessment {run.assessment_id} not found") +def _record_gate_stats( + session: Session, run: AssessmentRun, stage: str, batch_job, project_id: int +) -> None: + """For a go/no-go stage, persist passed/rejected counts and accepted row indices. - # Get provider and poll status - provider = _get_batch_provider( + The accepted indices are stored on ``run.pipeline`` so the next stage's batch + build reads them directly instead of re-downloading and re-parsing this batch. + """ + try: + raw = load_raw_batch_results(session, batch_job, project_id) + outputs = parse_assessment_output(raw, batch_job.provider) + parsed = STAGE_PARSERS[stage](outputs) + total = len(parsed) + passed = sum(1 for r in parsed.values() if r.get("verdict")) + update_assessment_run_prefilter_stats( session=session, - provider_name=batch_job.provider, - organization_id=parent.organization_id, - project_id=parent.project_id, + run=run, + prefilter_total_rows=total, + prefilter_total_passed=passed, + prefilter_total_rejected=total - passed, ) - status_result = poll_batch_status( - session=session, - provider=provider, - batch_job=batch_job, + + # Persist the cumulative accepted set (intersect with prior gates). + accepted = {idx for idx, r in parsed.items() if r.get("verdict")} + prev = (run.pipeline or {}).get("accepted_indices") + if prev is not None: + accepted &= set(prev) + pipeline = dict(run.pipeline or {}) + pipeline["accepted_indices"] = sorted(accepted) + run.pipeline = pipeline + flag_modified(run, "pipeline") + except Exception as exc: + logger.warning( + "[_record_gate_stats] run_id=%s stage=%s — %s", run.id, stage, exc ) - session.refresh(batch_job) - provider_status = batch_job.provider_status - if ( - provider_status == "completed" - or provider_status == BatchJobState.SUCCEEDED.value - ): - if not batch_job.provider_output_file_id: - request_counts = status_result.get("request_counts") or {} - error_file_id = status_result.get("error_file_id") - failed_count = request_counts.get("failed", 0) - completed_count = request_counts.get("completed", 0) - total_count = request_counts.get("total", 0) - - if error_file_id and failed_count > 0 and completed_count == 0: - error_msg = ( - f"Batch completed with {failed_count} failed request(s)" - f" and no successful outputs" - ) - if total_count: - error_msg += f" out of {total_count}" - error_msg += f" (error_file_id: {error_file_id})" - - update_assessment_run_status( - session=session, - run=run, - status="failed", - error_message=error_msg, - ) - recompute_assessment_status( - session=session, assessment_id=run.assessment_id - ) +def _fail_run_stage( + session: Session, run: AssessmentRun, message: str +) -> dict[str, Any]: + # Keep run.stage at the failed stage so a resume knows where to restart; + # stage_status == FAILED is the failure marker. + run.stage_status = StageStatus.FAILED + update_assessment_run_status( + session=session, run=run, status="failed", error_message=message + ) + recompute_assessment_status(session=session, assessment_id=run.assessment_id) + return {"run_id": run.id, "current_status": "failed", "action": "failed"} - return { - "run_id": run.id, - "assessment_id": run.assessment_id, - "experiment_name": experiment_name_pre, - "previous_status": previous_status, - "current_status": "failed", - "provider_status": provider_status, - "action": "failed", - "error": error_msg, - } - logger.info( - f"{log_prefix} Batch completed but output file is not ready yet | " - f"batch_job_id={batch_job.id} | provider_status={provider_status}" - ) - return { - "run_id": run.id, - "assessment_id": run.assessment_id, - "experiment_name": experiment_name_pre, - "previous_status": previous_status, - "current_status": run.status, - "provider_status": provider_status, - "action": "no_change", - } +async def process_run_batches(run: AssessmentRun, session: Session) -> dict[str, Any]: + """Poll the run's current-stage batch; on completion advance to the next stage.""" + parent = session.get(Assessment, run.assessment_id) + if not parent: + raise ValueError(f"Parent assessment {run.assessment_id} not found") - # Download and process results - raw_results = download_batch_results(provider=provider, batch_job=batch_job) + stage = run.stage + if not stage or run.stage_status != StageStatus.PROCESSING: + return {"run_id": run.id, "current_status": run.status, "action": "no_change"} - # Upload raw results to object store - object_store_url = None - try: - object_store_url = upload_batch_results_to_object_store( - session=session, batch_job=batch_job, results=raw_results - ) - except Exception as e: - logger.error( - "%s Object store upload failed — results may be unrecoverable " - "if the provider deletes the output file before next poll: %s", - log_prefix, - e, - exc_info=True, - ) + batch_id = (run.stage_batches or {}).get(stage) + batch_job = ( + get_batch_job(session=session, batch_job_id=batch_id) if batch_id else None + ) + if not batch_job: + return _fail_run_stage(session, run, f"Stage {stage} batch not found") - # Parse results - parsed = parse_assessment_output(raw_results, batch_job.provider) - error_count = sum(1 for result in parsed if result.get("error")) - success_count = sum(1 for result in parsed if not result.get("error")) - - # Update run status - error_msg = f"{error_count} item(s) failed" if error_count > 0 else None - run_status = ( - "failed" - if parsed and success_count == 0 and error_count > 0 - else "completed" - ) + # Transient errors here (DNS, network, provider hiccup) must NOT fail the run — + # the batch is still running. Skip this cycle; the cron retries next tick. + try: + provider = _get_batch_provider( + session=session, + provider_name=batch_job.provider, + organization_id=parent.organization_id, + project_id=parent.project_id, + ) + outcome = _poll_stage_outcome(session, provider, batch_job) + except Exception as exc: + logger.warning( + "[process_run_batches] run_id=%s stage=%s poll error, will retry: %s", + run.id, + stage, + exc, + ) + return {"run_id": run.id, "current_status": run.status, "action": "no_change"} - if not parsed: - run_status = "failed" - error_msg = "Batch completed but no valid results were produced" + if outcome == "no_change": + return {"run_id": run.id, "current_status": run.status, "action": "no_change"} + if outcome == "failed": + return _fail_run_stage( + session, run, batch_job.error_message or f"Stage {stage} failed" + ) - update_assessment_run_status( - session=session, - run=run, - status=run_status, - error_message=error_msg, - object_store_url=object_store_url, + run.stage_status = StageStatus.COMPLETED + if stage in GATE_STAGES: + _record_gate_stats(session, run, stage, batch_job, parent.project_id) + + nxt = advance_or_finalize(run) + session.add(run) + session.commit() + recompute_assessment_status(session=session, assessment_id=run.assessment_id) + + if nxt: + try: + run_assessment_pipeline.delay( + run_id=run.id, + organization_id=parent.organization_id, + project_id=parent.project_id, + trace_id="", ) - recompute_assessment_status( - session=session, assessment_id=run.assessment_id + except Exception as exc: + logger.error( + "[process_run_batches] run_id=%s stage=%s enqueue failed — marking failed for resume: %s", + run.id, + run.stage, + exc, + exc_info=True, ) - - return { - "run_id": run.id, - "assessment_id": run.assessment_id, - "experiment_name": experiment_name_pre, - "previous_status": previous_status, - "current_status": run_status, - "provider_status": provider_status, - "action": "processed" if run_status == "completed" else "failed", - "total_results": len(parsed), - "errors": error_count, - } - - elif provider_status in ( - "failed", - "expired", - "cancelled", - BatchJobState.FAILED.value, - BatchJobState.CANCELLED.value, - BatchJobState.EXPIRED.value, - ): - error_msg = batch_job.error_message or f"Batch {provider_status}" - update_assessment_run_status( - session=session, - run=run, - status="failed", - error_message=error_msg, + return _fail_run_stage( + session, + run, + "Failed to enqueue the next pipeline stage. Resume the run to retry.", ) - recompute_assessment_status( - session=session, assessment_id=run.assessment_id - ) - - return { - "run_id": run.id, - "assessment_id": run.assessment_id, - "experiment_name": experiment_name_pre, - "previous_status": previous_status, - "current_status": "failed", - "provider_status": provider_status, - "action": "failed", - "error": error_msg, - } - else: - # Still processing - return { - "run_id": run.id, - "assessment_id": run.assessment_id, - "experiment_name": experiment_name_pre, - "previous_status": previous_status, - "current_status": run.status, - "provider_status": provider_status, - "action": "no_change", - } - - except Exception as e: - error_msg = format_assessment_failure_message(e) - logger.error( - f"{log_prefix} Error checking assessment: {error_msg}", - exc_info=True, - ) - update_assessment_run_status( - session=session, - run=run, - status="failed", - error_message=error_msg, - ) - recompute_assessment_status(session=session, assessment_id=run.assessment_id) - return { - "run_id": run.id, - "assessment_id": run.assessment_id, - "experiment_name": experiment_name_pre, - "previous_status": previous_status, - "current_status": "failed", - "provider_status": "unknown", - "action": "failed", - "error": error_msg, - } - - -async def poll_all_pending_assessments(session: Session) -> dict[str, Any]: - """Backward-compatible wrapper for parent-first assessment polling.""" - from app.crud.assessment.cron import poll_all_pending_assessment_evaluations - - return await poll_all_pending_assessment_evaluations(session=session) + return { + "run_id": run.id, + "assessment_id": run.assessment_id, + "experiment_name": parent.experiment_name, + "current_status": run.status, + "action": "processed", + } diff --git a/backend/app/models/assessment.py b/backend/app/models/assessment.py index 25ac0f00e..fee792d8a 100644 --- a/backend/app/models/assessment.py +++ b/backend/app/models/assessment.py @@ -1,10 +1,11 @@ """Assessment models — DB tables, Pydantic schemas, and LLM param wrappers.""" from datetime import datetime +from enum import StrEnum from typing import TYPE_CHECKING, Any, Literal, Optional from uuid import UUID -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from sqlalchemy import Column, Index, Text from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Field as SQLField @@ -17,6 +18,25 @@ from app.models.batch_job import BatchJob +class Stage(StrEnum): + """Pipeline stages, in execution order. Business step only (status is separate).""" + + PRE_FILTER_TOPIC_RELEVANCE = "PRE_FILTER_TOPIC_RELEVANCE" + PRE_FILTER_DUPLICATE_DETECTION = "PRE_FILTER_DUPLICATE_DETECTION" + L2_ASSESSMENT = "L2_ASSESSMENT" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +class StageStatus(StrEnum): + """Execution status of the current stage.""" + + PENDING = "PENDING" + PROCESSING = "PROCESSING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + class Assessment(SQLModel, table=True): """Parent assessment — one experiment over a dataset, grouping N config runs.""" @@ -108,9 +128,45 @@ class AssessmentRun(SQLModel, table=True): status: str = SQLField( default="pending", sa_column_kwargs={ - "comment": "Run status: pending, processing, completed, failed" + "comment": ( + "Overall run status: pending, processing, completed, " + "completed_with_errors, failed" + ) + }, + ) + stage: str | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={ + "comment": ( + "Current pipeline stage (Stage enum): PRE_FILTER_TOPIC_RELEVANCE, " + "PRE_FILTER_DUPLICATE_DETECTION, L2_ASSESSMENT, COMPLETED, FAILED" + ) + }, + ) + stage_status: str | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={ + "comment": "StageStatus of stage: PENDING, PROCESSING, COMPLETED, FAILED" }, ) + pipeline: dict[str, Any] | None = SQLField( + default=None, + sa_column=Column( + JSONB, + nullable=True, + comment="Ordered stage config driving execution: {'stages': [...]}", + ), + ) + stage_batches: dict[str, int] | None = SQLField( + default=None, + sa_column=Column( + JSONB, + nullable=True, + comment="Map of stage name -> batch_job id, for per-stage result lookup", + ), + ) batch_job_id: int | None = SQLField( default=None, foreign_key="batch_job.id", @@ -136,7 +192,29 @@ class AssessmentRun(SQLModel, table=True): object_store_url: str | None = SQLField( default=None, nullable=True, - sa_column_kwargs={"comment": "S3 URL of processed batch results"}, + sa_column_kwargs={"comment": "S3 URL of processed L2 batch results"}, + ) + prefilter_object_store_url: str | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={"comment": "S3 URL of stored prefilter filter results JSON"}, + ) + prefilter_total_rows: int | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Total rows fed into prefilter pipeline"}, + ) + prefilter_total_passed: int | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Rows that passed topic relevance and went to L2"}, + ) + prefilter_total_rejected: int | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={ + "comment": "Rows rejected by topic relevance, stopped at prefilter" + }, ) error_message: str | None = SQLField( default=None, @@ -185,6 +263,11 @@ class AssessmentRunStat(BaseModel): total_items: int error_message: str | None = None updated_at: datetime | None = None + prefilter_total_rows: int | None = None + prefilter_total_passed: int | None = None + prefilter_total_rejected: int | None = None + stage: str | None = None + stage_status: str | None = None class AssessmentPublic(BaseModel): @@ -224,6 +307,13 @@ class AssessmentRunPublic(BaseModel): "text_columns, attachments, output_schema" ), ) + prefilter_total_rows: int | None = None + prefilter_total_passed: int | None = None + prefilter_total_rejected: int | None = None + stage: str | None = None + stage_status: str | None = None + pipeline: dict[str, Any] | None = None + post_processing_config: dict[str, Any] | None = None inserted_at: datetime updated_at: datetime @@ -245,8 +335,39 @@ class AssessmentAttachment(BaseModel): """Attachment column configuration.""" column: str = Field(..., description="Column name containing the attachment data") - type: Literal["image", "pdf"] = Field(..., description="Attachment type") + type: Literal["image", "pdf", "mixed"] = Field( + ..., + description=( + "Attachment type. 'image'/'pdf' force the type for every row. 'mixed' " + "resolves the per-row type from type_column via type_value_map." + ), + ) format: Literal["url", "base64"] = Field(..., description="Data format") + type_column: str | None = Field( + None, + description=( + "For 'mixed': the dataset column whose value decides each row's type." + ), + ) + type_value_map: dict[str, Literal["image", "pdf"]] | None = Field( + None, + description=( + "For 'mixed': maps a type_column value to 'image' or 'pdf' " + "(e.g. {'Photo': 'image', 'Report': 'pdf'})." + ), + ) + + @model_validator(mode="after") + def _validate_mixed_config(self) -> "AssessmentAttachment": + """A 'mixed' column must carry the per-row routing fields; others must not.""" + if self.type == "mixed": + if self.type_column is None or self.type_value_map is None: + raise ValueError( + "type='mixed' requires both 'type_column' and 'type_value_map'." + ) + if not self.type_value_map: + raise ValueError("type_value_map must not be empty for type='mixed'.") + return self class AssessmentConfigRef(BaseModel): @@ -286,6 +407,20 @@ class AssessmentCreate(BaseModel): configs: list[AssessmentConfigRef] = Field( ..., min_length=1, max_length=4, description="Config versions to run" ) + prefilter_config: dict[str, Any] | None = Field( + None, + description=( + "prefilter pipeline config. Keys: topic_relevance (columns, prompt), " + "duplicate_detection (columns). Omit to skip prefilter." + ), + ) + post_processing_config: dict[str, Any] | None = Field( + None, + description=( + "Post-processing config applied at export. " + "Keys: computed_columns, sort, filter." + ), + ) class AssessmentRunSummary(BaseModel): @@ -324,6 +459,8 @@ class AssessmentExportRow(BaseModel): row_id: str result_status: str input_data: dict[str, str] | None = None + topic_relevance: str | None = None + duplicate_detection: str | None = None output: str | None = None error: str | None = None response_id: str | None = None diff --git a/backend/app/services/assessment/prefilter/__init__.py b/backend/app/services/assessment/prefilter/__init__.py new file mode 100644 index 000000000..2e763bd4f --- /dev/null +++ b/backend/app/services/assessment/prefilter/__init__.py @@ -0,0 +1,3 @@ +from app.services.assessment.prefilter.pipeline import resolve_prefilter_settings + +__all__ = ["resolve_prefilter_settings"] diff --git a/backend/app/services/assessment/prefilter/constants.py b/backend/app/services/assessment/prefilter/constants.py new file mode 100644 index 000000000..1fe54ba85 --- /dev/null +++ b/backend/app/services/assessment/prefilter/constants.py @@ -0,0 +1,14 @@ +"""Static config for the assessment prefilter stages. +""" + +from typing import Literal + +# Provider + model that run the batch prefilter stages (topic relevance, dup check). +ASSESSMENT_PREFILTER_PROVIDER: Literal["openai", "google"] = "openai" +ASSESSMENT_PREFILTER_MODEL: str = "gpt-5-mini" +# ASSESSMENT_PREFILTER_MODEL: str = "gemini-3.1-flash-lite" + + +# File-search/vector store holding the corpus for duplicate detection. +ASSESSMENT_PREFILTER_DUPLICATE_STORE: str = "vs_6a20339fbc148191867fd06d29133278" +# ASSESSMENT_PREFILTER_DUPLICATE_STORE: str = "fileSearchStores/inquilabcorpus-782mxjcwisaz" diff --git a/backend/app/services/assessment/prefilter/duplicate_detection.py b/backend/app/services/assessment/prefilter/duplicate_detection.py new file mode 100644 index 000000000..7512616db --- /dev/null +++ b/backend/app/services/assessment/prefilter/duplicate_detection.py @@ -0,0 +1,117 @@ +"""Duplicate detection stage: build per-record file_search batch requests, parse verdicts.""" + +import json +import logging +from typing import Any + +from app.services.assessment.prefilter import constants +from app.services.assessment.prefilter.request_builder import build_request_line + +logger = logging.getLogger(__name__) + +_DUP_SYS = """ +You are a strict duplicate-detection judge for an innovation competition corpus. + +If the submission is too vague for corpus matching (no problem/target/domain AND no +solution mechanism, or empty/gibberish), use verdict VAGUE. + +Otherwise search the corpus and compare precisely. Focus on the MECHANISM of the +solution, not category or theme: +- DUPLICATE: problem AND solution mechanism substantially match a corpus entry. +- OVERLAP: problem OR solution mechanism matches; the other side clearly differs. +- PARTIAL_MATCH: thematic/conceptual similarity only — same domain, different mechanism. +- UNIQUE: neither problem nor solution substantially matches anything in the corpus. + +Return JSON with keys: verdict, match_title, source_url, matching_sentence, reason. +For UNIQUE or VAGUE, set match_title, source_url and matching_sentence to "" and give a +short reason. Otherwise fill match_title, source_url (the SOURCE_URL verbatim from the +retrieved chunk), matching_sentence (the exact sentence) and a one-sentence reason. +Never invent or construct URLs or filenames. +""" + +_DUP_SCHEMA: dict[str, Any] = { + "type": "object", + "properties": { + "verdict": { + "type": "string", + "enum": ["DUPLICATE", "OVERLAP", "PARTIAL_MATCH", "UNIQUE", "VAGUE"], + }, + "match_title": {"type": "string"}, + "source_url": {"type": "string"}, + "matching_sentence": {"type": "string"}, + "reason": {"type": "string"}, + }, + "required": [ + "verdict", + "match_title", + "source_url", + "matching_sentence", + "reason", + ], +} + + +def _combined_text(row: dict[str, str], columns: list[str]) -> str: + parts = [ + f"{col}:\n{row.get(col, '')}" for col in columns if row.get(col, "").strip() + ] + return "\n\n".join(parts) or "(empty submission)" + + +def build_duplicate_detection_requests( + rows: list[tuple[int, dict[str, str]]], + columns: list[str], +) -> list[dict[str, Any]]: + """Build one batch JSONL line per record, grounded on the provider's corpus store.""" + store = constants.ASSESSMENT_PREFILTER_DUPLICATE_STORE or None + return [ + build_request_line( + key=f"dup_{idx}", + system=_DUP_SYS, + user_text=f"Submitted idea to check:\n\n{_combined_text(row, columns)}", + response_schema=_DUP_SCHEMA, + file_search_store=store, + ) + for idx, row in rows + ] + + +def parse_duplicate_detection_results( + outputs: list[dict[str, Any]], +) -> dict[int, dict[str, Any]]: + """Parse extracted batch outputs into {row_id: {verdict, match_title, ...}}.""" + parsed: dict[int, dict[str, Any]] = {} + for out in outputs: + key = str(out.get("row_id", "")) + if not key.startswith("dup_"): + continue + try: + idx = int(key.split("_", 1)[1]) + except (ValueError, IndexError): + continue + if out.get("error") or not out.get("output"): + parsed[idx] = _error_record(out.get("error") or "Empty response") + continue + try: + data = json.loads(out["output"]) + parsed[idx] = { + "verdict": str(data.get("verdict") or "UNKNOWN"), + "match_title": data.get("match_title") or None, + "source_url": data.get("source_url") or None, + "matching_sentence": data.get("matching_sentence") or None, + "reason": data.get("reason") or None, + } + except Exception as exc: + logger.warning("[parse_duplicate_detection_results] %s — %s", key, exc) + parsed[idx] = _error_record(str(exc)[:200]) + return parsed + + +def _error_record(reason: str) -> dict[str, Any]: + return { + "verdict": "ERROR", + "match_title": None, + "source_url": None, + "matching_sentence": None, + "reason": reason, + } diff --git a/backend/app/services/assessment/prefilter/pipeline.py b/backend/app/services/assessment/prefilter/pipeline.py new file mode 100644 index 000000000..55bfb4f71 --- /dev/null +++ b/backend/app/services/assessment/prefilter/pipeline.py @@ -0,0 +1,22 @@ +"""Prefilter config helpers shared by the batch pipeline stages.""" + +from typing import Any + + +def resolve_prefilter_settings(prefilter_config: dict[str, Any]) -> dict[str, Any]: + """Flatten the prefilter config into the values the stage builders need.""" + tr_config = prefilter_config.get("topic_relevance") or {} + dup_config = prefilter_config.get("duplicate_detection") or {} + + tr_columns = tr_config.get("columns") or [] + tr_prompt = tr_config.get("prompt") or "" + dup_columns = dup_config.get("columns") or [] + + return { + "tr_columns": tr_columns, + "tr_prompt": tr_prompt, + "tr_attachment_columns": tr_config.get("attachment_columns"), + "dup_columns": dup_columns, + "tr_enabled": bool(tr_columns and tr_prompt), + "dup_enabled": bool(dup_columns), + } diff --git a/backend/app/services/assessment/prefilter/request_builder.py b/backend/app/services/assessment/prefilter/request_builder.py new file mode 100644 index 000000000..3e438ffed --- /dev/null +++ b/backend/app/services/assessment/prefilter/request_builder.py @@ -0,0 +1,73 @@ +"""Provider-aware batch request line builder for prefilter stages.""" + +from typing import Any + +from app.services.assessment.mappers import _ensure_openai_strict_schema +from app.services.assessment.prefilter import constants + + +def build_request_line( + key: str, + system: str, + user_text: str, + *, + attachment_parts: list[dict[str, Any]] | None = None, + response_schema: dict[str, Any] | None = None, + file_search_store: str | None = None, +) -> dict[str, Any]: + """Build one batch JSONL line shaped for the configured prefilter provider. + + ``attachment_parts`` are provider-shaped content parts (from the OpenAI/Gemini + attachment resolvers) appended after the text part. + """ + model = constants.ASSESSMENT_PREFILTER_MODEL + + if constants.ASSESSMENT_PREFILTER_PROVIDER == "openai": + content: list[dict[str, Any]] = [{"type": "input_text", "text": user_text}] + content.extend(attachment_parts or []) + body: dict[str, Any] = { + "model": model, + "instructions": system, + "input": [{"role": "user", "content": content}], + } + if response_schema is not None: + body["text"] = { + "format": { + "type": "json_schema", + "name": "result", + "strict": True, + "schema": _ensure_openai_strict_schema(response_schema), + } + } + if file_search_store: + body["tools"] = [ + { + "type": "file_search", + "vector_store_ids": [file_search_store], + "max_num_results": 20, + } + ] + return { + "custom_id": key, + "method": "POST", + "url": "/v1/responses", + "body": body, + } + + parts: list[dict[str, Any]] = [{"text": user_text}] + parts.extend(attachment_parts or []) + request: dict[str, Any] = { + "contents": [{"role": "user", "parts": parts}], + "systemInstruction": {"parts": [{"text": system}]}, + "model": f"models/{model}", + } + if response_schema is not None: + request["generationConfig"] = { + "responseMimeType": "application/json", + "responseSchema": response_schema, + } + if file_search_store: + request["tools"] = [ + {"fileSearch": {"fileSearchStoreNames": [file_search_store]}} + ] + return {"key": key, "request": request} diff --git a/backend/app/services/assessment/prefilter/topic_relevance.py b/backend/app/services/assessment/prefilter/topic_relevance.py new file mode 100644 index 000000000..1fa6acb43 --- /dev/null +++ b/backend/app/services/assessment/prefilter/topic_relevance.py @@ -0,0 +1,126 @@ +"""Topic relevance go/no-go gate: one batch request per row (text + attachments). + +Each request returns a per-column relevance boolean for every text and attachment +column plus a final ACCEPT/REJECT verdict. +""" + +import json +import logging +from typing import Any + +from app.models.assessment import AssessmentAttachment +from app.services.assessment.prefilter import constants +from app.services.assessment.prefilter.request_builder import build_request_line +from app.services.assessment.utils.attachments import ( + attachment_type_for_row, + build_gemini_attachment_parts, + resolve_attachment_values, +) + +logger = logging.getLogger(__name__) + +_INSTRUCTIONS = ( + "\n\nJudge whether this submission is relevant to the topic. For EACH listed " + "column (including any attached document/image columns) set its value to true if " + "that column's content is relevant to the topic, else false. Then give a final " + "decision: ACCEPT if relevant enough to proceed, otherwise REJECT." +) + + +def _build_schema(columns: list[str]) -> dict[str, Any]: + """Output schema: decision + reasoning + a boolean per column.""" + props: dict[str, Any] = { + "decision": {"type": "string", "enum": ["ACCEPT", "REJECT"]}, + "reasoning": {"type": "string"}, + } + for col in columns: + props[col] = {"type": "boolean"} + return { + "type": "object", + "properties": props, + "required": ["decision", "reasoning", *columns], + } + + +def _record_text(row: dict[str, str], columns: list[str]) -> str: + return "\n\n".join(f"{col}:\n{row.get(col, '') or ''}" for col in columns) + + +def build_topic_relevance_requests( + rows: list[tuple[int, dict[str, str]]], + columns: list[str], + user_prompt: str, + attachments: list[AssessmentAttachment] | None = None, +) -> list[dict[str, Any]]: + """Build one batch JSONL line per row, with text columns + attachment parts.""" + attachments = attachments or [] + is_openai = constants.ASSESSMENT_PREFILTER_PROVIDER == "openai" + schema = _build_schema(columns + [a.column for a in attachments]) + system = user_prompt.strip() + _INSTRUCTIONS + + lines: list[dict[str, Any]] = [] + for idx, row in rows: + attachment_parts: list[dict[str, Any]] = [] + for att in attachments: + cell = row.get(att.column, "") + if not cell.strip(): + continue + override = attachment_type_for_row(att, row) + attachment_parts.extend( + resolve_attachment_values(cell, att, type_override=override) + if is_openai + else build_gemini_attachment_parts(cell, att, type_override=override) + ) + lines.append( + build_request_line( + key=f"tr_{idx}", + system=system, + user_text=_record_text(row, columns), + attachment_parts=attachment_parts or None, + response_schema=schema, + ) + ) + return lines + + +def parse_topic_relevance_results( + outputs: list[dict[str, Any]], +) -> dict[int, dict[str, Any]]: + """Parse outputs into {row_id: {verdict, decision, reasoning, column_relevance}}.""" + parsed: dict[int, dict[str, Any]] = {} + for out in outputs: + key = str(out.get("row_id", "")) + if not key.startswith("tr_"): + continue + try: + idx = int(key.split("_", 1)[1]) + except (ValueError, IndexError): + continue + try: + data = json.loads(out.get("output") or "") + decision = str(data.get("decision", "ACCEPT")).upper() + column_relevance = { + k: bool(v) + for k, v in data.items() + if k not in ("decision", "reasoning") + } + parsed[idx] = { + "verdict": decision == "ACCEPT", + "decision": decision, + "reasoning": str(data.get("reasoning", "")), + "column_relevance": column_relevance, + } + except Exception as exc: + logger.warning("[parse_topic_relevance_results] %s — %s", key, exc) + parsed[idx] = _accept_on_error() + return parsed + + +def _accept_on_error() -> dict[str, Any]: + """Fail-open gate record for a row whose topic-relevance output was unparseable.""" + return { + "verdict": True, + "decision": "", + "reasoning": "", + "column_relevance": {}, + } diff --git a/backend/app/services/assessment/service.py b/backend/app/services/assessment/service.py index 45a283ea5..a0bcf7487 100644 --- a/backend/app/services/assessment/service.py +++ b/backend/app/services/assessment/service.py @@ -1,9 +1,9 @@ """Assessment run orchestration service.""" import logging -from typing import Any from uuid import UUID +from asgi_correlation_id import correlation_id from fastapi import HTTPException from sqlmodel import Session @@ -13,9 +13,7 @@ get_assessment_dataset_by_id, get_assessment_runs_for_assessment, recompute_assessment_status, - update_assessment_run_status, ) -from app.crud.assessment.batch import submit_assessment_batch from app.crud.config import ConfigCrud from app.crud.evaluations.core import resolve_evaluation_config from app.models.assessment import ( @@ -26,6 +24,7 @@ AssessmentResponse, AssessmentRun, AssessmentRunSummary, + StageStatus, ) from app.models.config.config import ConfigTag from app.services.llm.providers.registry import LLMProvider @@ -81,6 +80,8 @@ def _build_retry_request( attachments=[AssessmentAttachment.model_validate(item) for item in attachments], output_schema=assessment_input.get("output_schema"), configs=configs, + prefilter_config=assessment_input.get("prefilter_config"), + post_processing_config=assessment_input.get("post_processing_config"), ) @@ -90,11 +91,13 @@ def start_assessment( organization_id: int, project_id: int, ) -> AssessmentResponse: - """Start an assessment run request. + """Validate, create Assessment + AssessmentRun records, dispatch Celery tasks. - Validates the dataset, resolves each config, creates one AssessmentRun per config, - and kicks off batch processing for each. + Each run is created with status='pending' and handed off to a Celery worker + that runs prefilter filtering then submits the L2 batch. """ + from app.celery.tasks.job_execution import run_assessment_pipeline + logger.info( "[start_assessment] Starting | experiment=%s | dataset_id=%s | configs=%s | org_id=%s", request.experiment_name, @@ -110,7 +113,7 @@ def start_assessment( project_id=project_id, ) - assessment_input: dict[str, Any] = { + assessment_input: dict = { "prompt_template": request.prompt_template, "system_instruction": request.system_instruction, "text_columns": request.text_columns, @@ -118,12 +121,15 @@ def start_assessment( } if request.output_schema: assessment_input["output_schema"] = request.output_schema + if request.prefilter_config: + assessment_input["prefilter_config"] = request.prefilter_config + if request.post_processing_config: + assessment_input["post_processing_config"] = request.post_processing_config config_crud = ConfigCrud(session=session, project_id=project_id) resolved_configs = [] for cfg in request.configs: - # Assessment runs must use configs explicitly tagged for assessment use. parent_config = config_crud.read_one(cfg.config_id) if parent_config is not None and parent_config.tag != ConfigTag.ASSESSMENT: tag_value = ( @@ -165,7 +171,7 @@ def start_assessment( f"Supported providers: {sorted(_SUPPORTED_BATCH_PROVIDERS)}" ), ) - resolved_configs.append((cfg, config_blob)) + resolved_configs.append(cfg) assessment = create_assessment( session=session, @@ -176,54 +182,30 @@ def start_assessment( ) runs: list[AssessmentRun] = [] - try: - for cfg, config_blob in resolved_configs: - run = create_assessment_run( - session=session, - assessment_id=assessment.id, - config_id=cfg.config_id, - config_version=cfg.config_version, - assessment_input=assessment_input, - ) + trace_id = correlation_id.get() or "" - try: - batch_job = submit_assessment_batch( - session=session, - run=run, - assessment=assessment, - dataset=dataset, - config_blob=config_blob, - assessment_input=assessment_input, - organization_id=organization_id, - project_id=project_id, - ) - - run = update_assessment_run_status( - session=session, - run=run, - status="processing", - batch_job_id=batch_job.id, - total_items=batch_job.total_items, - ) - - except Exception as e: - logger.error( - "[start_assessment] Failed to submit batch for run %s: %s", - run.id, - e, - exc_info=True, - ) - run = update_assessment_run_status( - session=session, - run=run, - status="failed", - error_message="Batch submission failed. Please try again or contact support.", - ) - - runs.append(run) - except Exception: - recompute_assessment_status(session=session, assessment_id=assessment.id) - raise + for cfg in resolved_configs: + run = create_assessment_run( + session=session, + assessment_id=assessment.id, + config_id=cfg.config_id, + config_version=cfg.config_version, + assessment_input=assessment_input, + ) + runs.append(run) + + run_assessment_pipeline.delay( + run_id=run.id, + organization_id=organization_id, + project_id=project_id, + trace_id=trace_id, + ) + + logger.info( + "[start_assessment] Dispatched Celery task | run_id=%s | config_id=%s", + run.id, + cfg.config_id, + ) recompute_assessment_status(session=session, assessment_id=assessment.id) @@ -242,13 +224,13 @@ def start_assessment( num_configs=len(runs), runs=[ AssessmentRunSummary( - run_id=completed_run.id, - assessment_id=completed_run.assessment_id, - config_id=str(completed_run.config_id), - config_version=completed_run.config_version, - status=completed_run.status, + run_id=run.id, + assessment_id=run.assessment_id, + config_id=str(run.config_id), + config_version=run.config_version, + status=run.status, ) - for completed_run in runs + for run in runs ], ) @@ -302,3 +284,77 @@ def retry_assessment_run( organization_id=organization_id, project_id=project_id, ) + + +def resume_assessment_run( + session: Session, + run: AssessmentRun, + organization_id: int, + project_id: int, +) -> AssessmentResponse: + """Re-run a failed run from its failed stage, reusing completed upstream batches.""" + from app.celery.tasks.job_execution import run_assessment_pipeline + from app.services.assessment.stages import ordered_stages + + if run.stage_status != StageStatus.FAILED: + raise HTTPException( + status_code=400, + detail=f"Run {run.id} is not in a failed state and cannot be resumed", + ) + if run.stage not in ordered_stages(run.pipeline): + raise HTTPException( + status_code=400, + detail=f"Run {run.id} has no resumable failed stage", + ) + + parent = getattr(run, "assessment", None) or session.get( + Assessment, run.assessment_id + ) + if not parent: + raise HTTPException( + status_code=404, + detail=f"Parent assessment {run.assessment_id} not found", + ) + dataset = get_assessment_dataset_by_id( + session=session, + dataset_id=parent.dataset_id, + organization_id=organization_id, + project_id=project_id, + ) + + run.stage_status = StageStatus.PENDING + run.status = "processing" + run.error_message = None + session.add(run) + session.commit() + session.refresh(run) + recompute_assessment_status(session=session, assessment_id=run.assessment_id) + + logger.info( + "[resume_assessment_run] Resuming run_id=%s from stage=%s", + run.id, + run.stage, + ) + run_assessment_pipeline.delay( + run_id=run.id, + organization_id=organization_id, + project_id=project_id, + trace_id=correlation_id.get() or "", + ) + + return AssessmentResponse( + assessment_id=parent.id, + experiment_name=parent.experiment_name, + dataset_id=parent.dataset_id, + dataset_name=dataset.name if dataset else None, + num_configs=1, + runs=[ + AssessmentRunSummary( + run_id=run.id, + assessment_id=run.assessment_id, + config_id=str(run.config_id), + config_version=run.config_version, + status=run.status, + ) + ], + ) diff --git a/backend/app/services/assessment/stages.py b/backend/app/services/assessment/stages.py new file mode 100644 index 000000000..7afc276da --- /dev/null +++ b/backend/app/services/assessment/stages.py @@ -0,0 +1,194 @@ +"""Stage registry, pipeline ordering, and Batch API executor.""" + +import logging +from collections.abc import Callable +from typing import Any + +from sqlmodel import Session + +from app.core.batch import ( + GeminiBatchProvider, + OpenAIBatchProvider, + download_batch_results, + start_batch_job, +) +from app.core.batch.base import BatchProvider +from app.core.batch.client import GeminiClient +from app.core.cloud import get_cloud_storage +from app.models.assessment import AssessmentRun, Stage, StageStatus +from app.models.batch_job import BatchJob, BatchJobType +from app.services.assessment.prefilter import constants, resolve_prefilter_settings +from app.services.assessment.prefilter.duplicate_detection import ( + build_duplicate_detection_requests, + parse_duplicate_detection_results, +) +from app.services.assessment.prefilter.topic_relevance import ( + build_topic_relevance_requests, + parse_topic_relevance_results, +) +from app.services.llm.providers.registry import LLMProvider +from app.utils import get_openai_client + +logger = logging.getLogger(__name__) + +# Stages that gate the pipeline (only ACCEPTed rows continue). Others annotate. +GATE_STAGES = {Stage.PRE_FILTER_TOPIC_RELEVANCE} + +# Result parser per stage: raw batch results -> {row_id: result dict}. +STAGE_PARSERS: dict[str, Callable[[list[dict]], dict[int, dict[str, Any]]]] = { + Stage.PRE_FILTER_TOPIC_RELEVANCE: parse_topic_relevance_results, + Stage.PRE_FILTER_DUPLICATE_DETECTION: parse_duplicate_detection_results, +} + + +def build_pipeline(assessment_input: dict[str, Any]) -> dict[str, Any]: + """Build the ordered stage config; prefilter stages added only when configured.""" + cfg = resolve_prefilter_settings(assessment_input.get("prefilter_config") or {}) + stages: list[dict[str, Any]] = [] + if cfg["tr_enabled"]: + stages.append({"stage": Stage.PRE_FILTER_TOPIC_RELEVANCE, "type": "GO_NO_GO"}) + if cfg["dup_enabled"]: + stages.append( + {"stage": Stage.PRE_FILTER_DUPLICATE_DETECTION, "type": "ANNOTATIVE"} + ) + stages.append({"stage": Stage.L2_ASSESSMENT, "type": "ASSESSMENT"}) + + for order, entry in enumerate(stages, start=1): + entry["order"] = order + return {"stages": stages} + + +def ordered_stages(pipeline: dict[str, Any] | None) -> list[str]: + """The stage names in execution order.""" + return [s["stage"] for s in (pipeline or {}).get("stages", [])] + + +def next_stage( + pipeline: dict[str, Any] | None, current: str | None = None +) -> str | None: + """First stage when ``current`` is None, else the stage after it (None if last).""" + stages = ordered_stages(pipeline) + if current is None: + return stages[0] if stages else None + if current in stages and stages.index(current) + 1 < len(stages): + return stages[stages.index(current) + 1] + return None + + +def submit_prefilter_batch( + session: Session, + organization_id: int, + project_id: int, + jsonl_data: list[dict[str, Any]], + display_name: str, +) -> BatchJob: + """Submit a prefilter batch on the configured provider and return the BatchJob.""" + base = constants.ASSESSMENT_PREFILTER_PROVIDER + provider = _get_batch_provider( + session=session, + provider_name=base, + organization_id=organization_id, + project_id=project_id, + ) + if base == "openai": + config = { + "endpoint": "/v1/responses", + "completion_window": "24h", + "description": display_name, + } + else: + config = { + "display_name": display_name, + "model": f"models/{constants.ASSESSMENT_PREFILTER_MODEL}", + } + return start_batch_job( + session=session, + provider=provider, + provider_name=base, + job_type=BatchJobType.ASSESSMENT, + organization_id=organization_id, + project_id=project_id, + jsonl_data=jsonl_data, + config=config, + ) + + +def build_prefilter_requests( + stage: str, + rows: list[tuple[int, dict[str, str]]], + cfg: dict[str, Any], + attachments: list | None = None, +) -> list[dict[str, Any]]: + """Build the JSONL request lines for a prefilter stage.""" + if stage == Stage.PRE_FILTER_TOPIC_RELEVANCE: + return build_topic_relevance_requests( + rows, cfg["tr_columns"], cfg["tr_prompt"], attachments + ) + if stage == Stage.PRE_FILTER_DUPLICATE_DETECTION: + return build_duplicate_detection_requests(rows, cfg["dup_columns"]) + raise ValueError(f"Unknown prefilter stage: {stage}") + + +def _get_batch_provider( + session: Session, + provider_name: str, + organization_id: int, + project_id: int, +) -> BatchProvider: + """Build the batch provider instance for a given provider name.""" + if provider_name in (LLMProvider.OPENAI, LLMProvider.OPENAI_NATIVE): + return OpenAIBatchProvider( + client=get_openai_client( + session=session, org_id=organization_id, project_id=project_id + ) + ) + if provider_name in (LLMProvider.GOOGLE, LLMProvider.GOOGLE_NATIVE): + gemini_client = GeminiClient.from_credentials( + session=session, org_id=organization_id, project_id=project_id + ) + return GeminiBatchProvider(client=gemini_client.client) + raise ValueError(f"Unsupported batch provider: {provider_name}") + + +def load_raw_batch_results( + session: Session, batch_job: BatchJob, project_id: int +) -> list[dict[str, Any]]: + """Load a completed batch's raw result lines (object store first, else provider).""" + # Lazy import: app.services.assessment.utils.__init__ pulls in export, which + # imports this module's package — a top-level import would be circular. + from app.services.assessment.utils.parsing import parse_stored_results + + if batch_job.raw_output_url: + try: + storage = get_cloud_storage(session, project_id=project_id) + raw = parse_stored_results( + storage.stream(batch_job.raw_output_url).read().decode("utf-8") + ) + if raw: + return raw + except Exception as exc: + logger.warning( + "[load_raw_batch_results] S3 read failed batch %s — %s", + batch_job.id, + exc, + ) + provider = _get_batch_provider( + session=session, + provider_name=batch_job.provider, + organization_id=batch_job.organization_id, + project_id=project_id, + ) + return download_batch_results(provider=provider, batch_job=batch_job) + + +def advance_or_finalize(run: AssessmentRun) -> str | None: + """Advance the run to the next stage (returned) or finalize it (returns None).""" + nxt = next_stage(run.pipeline, run.stage) + if nxt: + run.stage = nxt + run.stage_status = StageStatus.PENDING + return nxt + run.stage = Stage.COMPLETED + run.stage_status = StageStatus.COMPLETED + run.status = "completed" + return None diff --git a/backend/app/services/assessment/tasks.py b/backend/app/services/assessment/tasks.py new file mode 100644 index 000000000..d21d86eeb --- /dev/null +++ b/backend/app/services/assessment/tasks.py @@ -0,0 +1,304 @@ +"""Orchestrator: submit the run's current PENDING stage as a batch, then exit.""" + +import logging + +from asgi_correlation_id import correlation_id +from celery.exceptions import SoftTimeLimitExceeded +from sqlalchemy.orm.attributes import flag_modified +from sqlmodel import Session + +from app.celery.tasks.job_execution import run_assessment_pipeline +from app.core.db import engine +from app.crud.assessment import ( + get_assessment_dataset_by_id, + recompute_assessment_status, + update_assessment_run_status, +) +from app.crud.assessment.batch import _load_dataset_rows, submit_assessment_batch +from app.crud.assessment.processing import parse_assessment_output +from app.crud.evaluations.core import resolve_evaluation_config +from app.crud.job import get_batch_job +from app.models.assessment import ( + Assessment, + AssessmentAttachment, + AssessmentRun, + Stage, + StageStatus, +) +from app.models.config.config import ConfigTag +from app.services.assessment.prefilter import resolve_prefilter_settings +from app.services.assessment.stages import ( + GATE_STAGES, + STAGE_PARSERS, + advance_or_finalize, + build_pipeline, + build_prefilter_requests, + load_raw_batch_results, + next_stage, + ordered_stages, + submit_prefilter_batch, +) + +logger = logging.getLogger(__name__) + +_PREFILTER_STAGES = { + Stage.PRE_FILTER_TOPIC_RELEVANCE, + Stage.PRE_FILTER_DUPLICATE_DETECTION, +} + + +def _mark_run_failed(run_id: int, error_message: str) -> None: + """Fail a run from a fresh session so a killed task leaves no dangling run.""" + try: + with Session(engine) as session: + run = session.get(AssessmentRun, run_id) + if ( + run is None + or run.stage == Stage.COMPLETED + or run.stage_status == StageStatus.FAILED + ): + return + run.stage_status = StageStatus.FAILED + update_assessment_run_status( + session=session, run=run, status="failed", error_message=error_message + ) + recompute_assessment_status( + session=session, assessment_id=run.assessment_id + ) + logger.info("[_mark_run_failed] run_id=%s marked failed", run_id) + except Exception: + logger.error( + "[_mark_run_failed] could not mark run_id=%s failed", run_id, exc_info=True + ) + + +def execute_assessment_pipeline( + run_id: int, organization_id: int, project_id: int +) -> None: + """Guarded entrypoint: submit the run's current stage, never leave it dangling.""" + try: + _orchestrate(run_id, organization_id, project_id) + except SoftTimeLimitExceeded: + logger.error("[execute_assessment_pipeline] soft time limit run_id=%s", run_id) + _mark_run_failed(run_id, "Assessment run exceeded the time limit.") + raise + except Exception: + logger.error( + "[execute_assessment_pipeline] unexpected failure run_id=%s", + run_id, + exc_info=True, + ) + _mark_run_failed(run_id, "Assessment run failed unexpectedly.") + raise + + +def _dispatch(run_id: int, organization_id: int, project_id: int) -> None: + run_assessment_pipeline.delay( + run_id=run_id, + organization_id=organization_id, + project_id=project_id, + trace_id=correlation_id.get() or "", + ) + + +def _resolve_run_context( + session: Session, run: AssessmentRun, organization_id: int, project_id: int +): + """Load the assessment, dataset, and resolved config; ``error`` set on failure.""" + assessment = session.get(Assessment, run.assessment_id) + if assessment is None: + return None, None, None, "Parent assessment not found." + dataset = get_assessment_dataset_by_id( + session=session, + dataset_id=assessment.dataset_id, + organization_id=organization_id, + project_id=project_id, + ) + config_blob, error = resolve_evaluation_config( + session=session, + config_id=run.config_id, + config_version=run.config_version, + project_id=project_id, + tag=ConfigTag.ASSESSMENT, + ) + if error or config_blob is None: + return assessment, dataset, None, f"Config resolution failed: {error}" + return assessment, dataset, config_blob, None + + +def _accepted_indices( + session: Session, run: AssessmentRun, total_rows: int, project_id: int +) -> list[int]: + """Row indices that passed every gate stage before the current one. + + Prefers the accepted set persisted by the gate stage on ``run.pipeline`` + (set in ``_record_gate_stats``), avoiding a re-download + re-parse of the + gate batch at the memory-heavy prefilter -> assessment transition. Falls back + to recomputing from the gate batches only if nothing was persisted. + """ + stored = (run.pipeline or {}).get("accepted_indices") + if stored is not None: + return [i for i in sorted(stored) if 0 <= i < total_rows] + + accepted = set(range(total_rows)) + for stage in ordered_stages(run.pipeline): + if stage == run.stage: + break + if stage not in GATE_STAGES: + continue + batch_id = (run.stage_batches or {}).get(stage) + if batch_id is None: + continue + batch_job = get_batch_job(session=session, batch_job_id=batch_id) + if not batch_job: + continue + raw = load_raw_batch_results(session, batch_job, project_id) + outputs = parse_assessment_output(raw, batch_job.provider) + parsed = STAGE_PARSERS[stage](outputs) + accepted &= {idx for idx, r in parsed.items() if r.get("verdict")} + return sorted(accepted) + + +def _orchestrate(run_id: int, organization_id: int, project_id: int) -> None: + with Session(engine) as session: + run = session.get(AssessmentRun, run_id) + if run is None: + logger.error("[execute_assessment_pipeline] run_id=%s not found", run_id) + return + if run.stage == Stage.COMPLETED or run.stage_status == StageStatus.FAILED: + return + + if not run.pipeline: + run.pipeline = build_pipeline(run.input or {}) + flag_modified(run, "pipeline") + if run.stage is None: + run.stage = next_stage(run.pipeline) + run.stage_status = StageStatus.PENDING + run.status = "processing" + if run.stage_status != StageStatus.PENDING: + session.add(run) + session.commit() + return + session.add(run) + session.commit() + session.refresh(run) + + _submit_stage(session, run, organization_id, project_id) + + +def _submit_stage( + session: Session, run: AssessmentRun, organization_id: int, project_id: int +) -> None: + assessment, dataset, config_blob, error = _resolve_run_context( + session, run, organization_id, project_id + ) + if error: + run.stage_status = StageStatus.FAILED + update_assessment_run_status( + session=session, run=run, status="failed", error_message=error + ) + recompute_assessment_status(session=session, assessment_id=run.assessment_id) + return + + all_rows = _load_dataset_rows(session, dataset) + if not all_rows: + run.stage_status = StageStatus.FAILED + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message="Dataset has no rows.", + ) + recompute_assessment_status(session=session, assessment_id=run.assessment_id) + return + + accepted = _accepted_indices(session, run, len(all_rows), project_id) + rows_with_idx = [(i, all_rows[i]) for i in accepted] + stage = run.stage + + if not rows_with_idx: + # Nothing left for this stage (all rows rejected upstream) — advance. + _persist_advance(session, run, organization_id, project_id) + return + + if stage in _PREFILTER_STAGES: + cfg = resolve_prefilter_settings(run.input.get("prefilter_config") or {}) + attachments = [ + AssessmentAttachment(**a) for a in (run.input.get("attachments") or []) + ] + selected = cfg.get("tr_attachment_columns") + if selected is not None: + attachments = [a for a in attachments if a.column in set(selected)] + jsonl = build_prefilter_requests(stage, rows_with_idx, cfg, attachments) + batch_job = submit_prefilter_batch( + session=session, + organization_id=organization_id, + project_id=project_id, + jsonl_data=jsonl, + display_name=f"assessment-{run.id}-{stage}", + ) + elif stage == Stage.L2_ASSESSMENT: + batch_job = submit_assessment_batch( + session=session, + run=run, + assessment=assessment, + dataset=dataset, + config_blob=config_blob, + assessment_input=run.input or {}, + organization_id=organization_id, + project_id=project_id, + preloaded_rows=[r for _, r in rows_with_idx], + row_indices=[i for i, _ in rows_with_idx], + ) + run.total_items = batch_job.total_items + else: + raise ValueError(f"Unknown stage: {stage}") + + stage_batches = dict(run.stage_batches or {}) + stage_batches[stage] = batch_job.id + run.stage_batches = stage_batches + flag_modified(run, "stage_batches") + run.stage_status = StageStatus.PROCESSING + run.status = "processing" + session.add(run) + session.commit() + recompute_assessment_status(session=session, assessment_id=run.assessment_id) + + logger.info( + "[execute_assessment_pipeline] run_id=%s | stage=%s submitted | batch=%s | rows=%s", + run.id, + stage, + batch_job.id, + len(rows_with_idx), + ) + + +def _persist_advance( + session: Session, run: AssessmentRun, organization_id: int, project_id: int +) -> None: + nxt = advance_or_finalize(run) + session.add(run) + session.commit() + recompute_assessment_status(session=session, assessment_id=run.assessment_id) + if not nxt: + return + # Commit precedes dispatch (the worker only acts on a committed PENDING run). + # If the broker call fails the run would otherwise sit at PENDING forever — the + # cron only re-polls PROCESSING runs — so mark it failed (resumable) instead. + try: + _dispatch(run.id, organization_id, project_id) + except Exception: + logger.error( + "[_persist_advance] run_id=%s stage=%s enqueue failed — marking failed for resume", + run.id, + run.stage, + exc_info=True, + ) + run.stage_status = StageStatus.FAILED + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message="Failed to enqueue the next pipeline stage. Resume the run to retry.", + ) + recompute_assessment_status(session=session, assessment_id=run.assessment_id) diff --git a/backend/app/services/assessment/utils/attachments.py b/backend/app/services/assessment/utils/attachments.py index 5a141a757..26f82ad75 100644 --- a/backend/app/services/assessment/utils/attachments.py +++ b/backend/app/services/assessment/utils/attachments.py @@ -1,17 +1,20 @@ """Attachment resolution utilities for assessment batch builds. -Handles MIME type detection, base64 decoding, Google Drive URL normalization, -data-URL parsing, and conversion of dataset cell values into provider input objects. +URL-only: dataset cells hold attachment URLs. Handles Google Drive URL +normalization and conversion of cell values into provider input objects. +Attachments are passed to providers by reference (URL), never inlined as base64, +to keep the batch build memory-light. """ -import base64 -import binascii +import logging import re from typing import Any from urllib.parse import urlparse from app.models.assessment import AssessmentAttachment +logger = logging.getLogger(__name__) + _IMAGE_MIME_BY_EXT = { ".png": "image/png", ".jpg": "image/jpeg", @@ -60,18 +63,6 @@ def to_direct_attachment_url(url: str, attachment_type: str) -> str: return f"https://drive.google.com/uc?export=download&id={file_id}" -def split_data_url(value: str) -> tuple[str | None, str]: - """Return (mime_type, base64_payload) for a data URL; otherwise (None, value).""" - match = re.match( - r"^data:([^;]+);base64,(.+)$", - value.strip(), - flags=re.IGNORECASE | re.DOTALL, - ) - if not match: - return None, value.strip() - return match.group(1).strip().lower(), match.group(2).strip() - - def _guess_image_mime_from_url(url: str) -> str | None: path = urlparse(url).path or "" for ext, mime in _IMAGE_MIME_BY_EXT.items(): @@ -80,104 +71,112 @@ def _guess_image_mime_from_url(url: str) -> str | None: return None -def _decode_base64_prefix(payload: str, max_chars: int = 256) -> bytes | None: - compact = re.sub(r"\s+", "", payload) - if not compact: - return None - sample = compact[:max_chars] - padding = "=" * (-len(sample) % 4) - try: - return base64.b64decode(sample + padding, validate=False) - except (binascii.Error, ValueError): - return None +def resolve_item_type(declared: str, type_override: str | None = None) -> str | None: + """Resolve an attachment item as 'image' or 'pdf' from the user-declared type. + A per-row ``type_override`` (for 'mixed' columns) wins, else the column's declared + ``type``. Returns None when the type stays unresolved (e.g. a 'mixed' row whose + value didn't map to a concrete type) so callers can skip rather than guess. + """ + item_type = type_override or declared + return item_type if item_type in ("image", "pdf") else None -def _guess_image_mime_from_base64(payload: str) -> str | None: - blob = _decode_base64_prefix(payload) - if not blob: - return None - if blob.startswith(b"\x89PNG\r\n\x1a\n"): - return "image/png" - if blob.startswith(b"\xff\xd8\xff"): - return "image/jpeg" - if blob.startswith((b"GIF87a", b"GIF89a")): - return "image/gif" - if blob.startswith(b"BM"): - return "image/bmp" - if len(blob) >= 12 and blob[:4] == b"RIFF" and blob[8:12] == b"WEBP": - return "image/webp" - if blob.startswith((b"II*\x00", b"MM\x00*")): - return "image/tiff" - return None +def _normalize_type_value(value: str) -> str: + return re.sub(r"\s+", " ", value).strip().casefold() -def resolve_image_mime_and_payload( - value: str, - format_type: str, -) -> tuple[str, str]: - """Resolve image mime type and raw base64 payload (for base64 format).""" - if format_type == "url": - return _guess_image_mime_from_url(value) or "image/png", value - data_url_mime, payload = split_data_url(value) - if data_url_mime and data_url_mime.startswith("image/"): - return data_url_mime, payload +def _split_type_values(value: str) -> list[str]: + return [ + normalized + for part in re.split(r"[\n,]+", value) + if (normalized := _normalize_type_value(part)) + ] + + +def attachment_type_for_row( + att: AssessmentAttachment, row: dict[str, str] +) -> str | None: + """For a 'mixed' column, resolve this row's type from type_column + type_value_map. + + Returns 'image'/'pdf', or None to let normal detection (extension/declared) decide. + """ + type_column = getattr(att, "type_column", None) + type_value_map = getattr(att, "type_value_map", None) + if att.type != "mixed" or not type_column or not type_value_map: + return None + + normalized_map: dict[str, str] = {} + for raw_values, mapped_type in type_value_map.items(): + if mapped_type not in ("image", "pdf"): + continue + for value in _split_type_values(raw_values): + normalized_map[value] = mapped_type + + row_values = _split_type_values(row.get(type_column) or "") + if not row_values: + return None - return _guess_image_mime_from_base64(payload) or "image/png", payload + mapped_values = { + normalized_map[value] for value in row_values if value in normalized_map + } + return mapped_values.pop() if len(mapped_values) == 1 else None def resolve_attachment_values( value: str, att: AssessmentAttachment, + type_override: str | None = None, ) -> list[dict[str, Any]]: - """Convert one dataset cell into one or more OpenAI-style input objects.""" + """Convert one dataset cell into one or more OpenAI-style input objects (by URL).""" value = value.strip() if not value: return [] - if att.format == "url": - values = split_attachment_urls(value) - else: - values = [value] - - resolved: list[dict[str, Any]] = [] - for item_value in values: - normalized_value = ( - to_direct_attachment_url(item_value, att.type) - if att.format == "url" - else item_value + item_type = resolve_item_type(att.type, type_override) + if item_type is None: + logger.warning( + "[resolve_attachment_values] Unresolved type for column=%s — skipping", + att.column, ) + return [] + resolved: list[dict[str, Any]] = [] + for item_value in split_attachment_urls(value): + url = to_direct_attachment_url(item_value, item_type) + if item_type == "image": + resolved.append({"type": "input_image", "image_url": url}) + else: + resolved.append({"type": "input_file", "file_url": url}) + return resolved - if att.type == "image": - if att.format == "url": - resolved.append({"type": "input_image", "image_url": normalized_value}) - else: - mime_type, payload = resolve_image_mime_and_payload( - normalized_value, - "base64", - ) - resolved.append( - { - "type": "input_image", - "image_url": f"data:{mime_type};base64,{payload}", - } - ) - elif att.type == "pdf": - if att.format == "url": - resolved.append( - { - "type": "input_file", - "file_url": normalized_value, - } - ) - else: - _, payload = split_data_url(normalized_value) - resolved.append( - { - "type": "input_file", - "file_data": f"data:application/pdf;base64,{payload}", - "filename": "document.pdf", - } - ) - return resolved +def build_gemini_attachment_parts( + value: str, + att: AssessmentAttachment, + type_override: str | None = None, +) -> list[dict[str, Any]]: + """Convert one dataset cell into one or more Gemini content parts (by URL). + + Mirrors the per-item type routing used for the L2 batch so the same + image/pdf handling applies to prefilter (topic relevance) calls. + """ + value = value.strip() + if not value: + return [] + + item_type = resolve_item_type(att.type, type_override) + if item_type is None: + logger.warning( + "[build_gemini_attachment_parts] Unresolved type for column=%s — skipping", + att.column, + ) + return [] + parts: list[dict[str, Any]] = [] + for item_value in split_attachment_urls(value): + url = to_direct_attachment_url(item_value, item_type) + if item_type == "image": + mime_type = _guess_image_mime_from_url(url) or "image/png" + parts.append({"fileData": {"mimeType": mime_type, "fileUri": url}}) + else: + parts.append({"fileData": {"mimeType": "application/pdf", "fileUri": url}}) + return parts diff --git a/backend/app/services/assessment/utils/export.py b/backend/app/services/assessment/utils/export.py index ca273afc6..4dded4bab 100644 --- a/backend/app/services/assessment/utils/export.py +++ b/backend/app/services/assessment/utils/export.py @@ -12,16 +12,33 @@ from fastapi.responses import StreamingResponse from sqlmodel import Session +from app.core.batch import download_batch_results from app.core.cloud import get_cloud_storage from app.core.storage_utils import generate_timestamped_filename from app.crud.assessment.processing import parse_assessment_output from app.crud.job import get_batch_job -from app.models.assessment import Assessment, AssessmentExportRow, AssessmentRun +from app.models.assessment import ( + Assessment, + AssessmentExportRow, + AssessmentRun, + Stage, +) from app.models.batch_job import BatchJob from app.models.evaluation import EvaluationDataset +from app.services.assessment.prefilter.duplicate_detection import ( + parse_duplicate_detection_results, +) +from app.services.assessment.prefilter.topic_relevance import ( + parse_topic_relevance_results, +) +from app.services.assessment.stages import _get_batch_provider, load_raw_batch_results from app.services.assessment.utils.parsing import parse_stored_results, usage_totals +from app.services.assessment.utils.post_processing import apply_post_processing from app.utils import APIResponse +_PREFILTER_JSON_COLUMNS = ["topic_relevance", "duplicate_detection"] +_XLSX_ILLEGAL_RE = re.compile("[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f\ud800-\udfff﷐-﷯￾￿]") + logger = logging.getLogger(__name__) @@ -29,11 +46,61 @@ def _load_dataset_rows( session: Session, dataset: EvaluationDataset, ) -> list[dict[str, str]]: + # Imported lazily: app.crud.assessment.batch pulls this module via + # app.services.assessment.utils, so a top-level import would be circular. from app.crud.assessment.batch import _load_dataset_rows as load_dataset_rows return load_dataset_rows(session, dataset) +def _stage_batch_job( + session: Session, run: AssessmentRun, stage: str +) -> BatchJob | None: + """The batch job a run produced for a given stage, via stage_batches.""" + batch_id = (run.stage_batches or {}).get(stage) + return get_batch_job(session=session, batch_job_id=batch_id) if batch_id else None + + +def _load_prefilter_results( + session: Session, + run: AssessmentRun, + assessment: Assessment, +) -> dict[str, dict[str, Any]]: + """Build per-row prefilter annotations from the TR + dup stage batches.""" + out: dict[str, dict[str, Any]] = {} + + tr_job = _stage_batch_job(session, run, Stage.PRE_FILTER_TOPIC_RELEVANCE) + if tr_job: + try: + raw = load_raw_batch_results(session, tr_job, assessment.project_id) + outputs = parse_assessment_output(raw, tr_job.provider) + for idx, r in parse_topic_relevance_results(outputs).items(): + out.setdefault(f"row_{idx}", {})["prefilter_passed"] = r["verdict"] + out[f"row_{idx}"]["topic_relevance"] = { + "decision": r["decision"], + "reasoning": r["reasoning"], + "column_relevance": r.get("column_relevance") or {}, + } + except Exception as exc: + logger.warning( + "[_load_prefilter_results] TR load failed run=%s: %s", run.id, exc + ) + + dup_job = _stage_batch_job(session, run, Stage.PRE_FILTER_DUPLICATE_DETECTION) + if dup_job: + try: + raw = load_raw_batch_results(session, dup_job, assessment.project_id) + outputs = parse_assessment_output(raw, dup_job.provider) + for idx, r in parse_duplicate_detection_results(outputs).items(): + out.setdefault(f"row_{idx}", {})["duplicate_detection"] = r + except Exception as exc: + logger.warning( + "[_load_prefilter_results] dup load failed run=%s: %s", run.id, exc + ) + + return out + + def _safe_filename_part(value: str) -> str: """Build a filesystem-safe filename component.""" sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("._") @@ -113,111 +180,144 @@ def _drop_empty_columns( return pruned, non_empty_fields +def _parse_json_col(raw: Any) -> dict[str, Any] | None: + if raw is None: + return None + if isinstance(raw, dict): + return raw + if isinstance(raw, str): + try: + parsed = json.loads(raw) + return parsed if isinstance(parsed, dict) else None + except (json.JSONDecodeError, TypeError): + return None + return None + + def _expand_output_columns( row_payload: list[dict[str, Any]], -) -> tuple[list[dict[str, Any]], list[str]]: - """Expand the ``output`` field into separate columns when it contains valid JSON. +) -> tuple[list[dict[str, Any]], list[str], list[str], list[str], list[str]]: + """Expand ``output``, ``topic_relevance``, and ``duplicate_detection`` JSON columns + into separate flat columns when they contain valid JSON objects. Returns: (expanded_rows, ordered_fieldnames) """ - # First expand input columns row_payload, input_col_names = _expand_input_columns(row_payload) + json_expand_cols = {"output", "input_data"} | set(_PREFILTER_JSON_COLUMNS) base_fields = [ field for field in AssessmentExportRow.model_fields.keys() - if field not in ("output", "input_data") + if field not in json_expand_cols ] - parsed_outputs: list[dict[str, Any] | None] = [] - output_keys: list[str] = [] - seen_keys: dict[str, None] = {} # ordered set + # prefilter columns are prefixed with their parent name to avoid key collisions + parsed_cols: dict[str, list[dict[str, Any] | None]] = { + col: [] for col in ["output"] + _PREFILTER_JSON_COLUMNS + } + col_keys: dict[str, list[str]] = { + col: [] for col in ["output"] + _PREFILTER_JSON_COLUMNS + } + col_seen: dict[str, dict[str, None]] = { + col: {} for col in ["output"] + _PREFILTER_JSON_COLUMNS + } has_unparsed_output = False for row in row_payload: - raw = row.get("output") - if raw is None: - parsed_outputs.append(None) - continue - - if isinstance(raw, str): - try: - parsed = json.loads(raw) - except (json.JSONDecodeError, TypeError): - parsed = None - elif isinstance(raw, dict): - parsed = raw - else: - parsed = None - - if not isinstance(parsed, dict): - has_unparsed_output = True - parsed_outputs.append(None) - continue - - parsed_outputs.append(parsed) - for output_key in parsed: - if output_key not in seen_keys: - seen_keys[output_key] = None - output_keys.append(output_key) - - if not output_keys: - # Keep original layout with output as a single column - fieldnames = input_col_names + list(AssessmentExportRow.model_fields.keys()) - fieldnames = [field for field in fieldnames if field != "input_data"] - return row_payload, fieldnames + for col in ["output"] + _PREFILTER_JSON_COLUMNS: + parsed = _parse_json_col(row.get(col)) + if parsed is None and col == "output" and row.get(col) is not None: + has_unparsed_output = True + parsed_cols[col].append(parsed) + if parsed: + for k in parsed: + prefixed = f"{col}_{k}" if col in _PREFILTER_JSON_COLUMNS else k + if prefixed not in col_seen[col]: + col_seen[col][prefixed] = None + col_keys[col].append(prefixed) + + def _get_prefixed(parsed: dict[str, Any] | None, col: str) -> dict[str, Any]: + if not parsed: + return {} + if col in _PREFILTER_JSON_COLUMNS: + return {f"{col}_{k}": v for k, v in parsed.items()} + return parsed # Build expanded rows expanded: list[dict[str, Any]] = [] - for row, parsed in zip(row_payload, parsed_outputs, strict=True): - new_row = {col: val for col, val in row.items() if col != "output"} - if parsed: - for output_key in output_keys: - new_row[output_key] = parsed.get(output_key) - else: - for output_key in output_keys: - new_row[output_key] = None - if row.get("output") is not None: - new_row["output_raw"] = row.get("output") + for i, row in enumerate(row_payload): + new_row = {k: v for k, v in row.items() if k not in json_expand_cols} + for col in ["output"] + _PREFILTER_JSON_COLUMNS: + parsed = parsed_cols[col][i] + keys = col_keys[col] + prefixed_vals = _get_prefixed(parsed, col) + if prefixed_vals: + for k in keys: + new_row[k] = prefixed_vals.get(k) + else: + for k in keys: + new_row[k] = None + if col == "output" and row.get("output") is not None: + new_row["output_raw"] = row.get("output") expanded.append(new_row) - # Build fieldnames: input columns + base fields + output columns - output_idx = base_fields.index("result_status") + 1 # after result_status - fieldnames = ( - input_col_names - + base_fields[:output_idx] - + output_keys - + base_fields[output_idx:] - ) + prefilter_keys = col_keys["topic_relevance"] + col_keys["duplicate_detection"] + output_keys = col_keys["output"] + + all_output_keys = prefilter_keys + output_keys + if not all_output_keys: + fieldnames = input_col_names + list(AssessmentExportRow.model_fields.keys()) + fieldnames = [f for f in fieldnames if f != "input_data"] + return row_payload, fieldnames, input_col_names, [], [] + + fieldnames = input_col_names + prefilter_keys + output_keys + base_fields if has_unparsed_output: fieldnames.insert( - len(input_col_names) + output_idx + len(output_keys), "output_raw" + len(input_col_names) + len(prefilter_keys) + len(output_keys), "output_raw" ) - return expanded, fieldnames + return expanded, fieldnames, input_col_names, prefilter_keys, output_keys def serialize_export_rows( export_rows: list[AssessmentExportRow], export_format: Literal["json", "csv", "xlsx"], + post_processing_config: dict[str, Any] | None = None, ) -> tuple[bytes, str]: """Serialize export rows into the requested file format.""" row_payload = [row.model_dump(mode="json") for row in export_rows] if export_format == "json": - expanded, _ = _expand_output_columns(row_payload) + expanded, *_ = _expand_output_columns(row_payload) + expanded = apply_post_processing(expanded, post_processing_config) return ( json.dumps(expanded, ensure_ascii=False, indent=2).encode("utf-8"), "application/json", ) - # For CSV/XLSX, expand output keys into separate columns - expanded, fieldnames = _expand_output_columns(row_payload) + ( + expanded, + fieldnames, + input_col_names, + prefilter_keys, + output_keys, + ) = _expand_output_columns(row_payload) + expanded = apply_post_processing(expanded, post_processing_config) + + # Add any new computed columns to fieldnames so they appear in output + existing = set(fieldnames) + computed_names = [ + c["name"] + for c in (post_processing_config or {}).get("computed_columns") or [] + if c.get("name") and c["name"] not in existing + ] + if computed_names: + fieldnames = fieldnames + computed_names if export_format == "csv": output = io.StringIO() - writer = csv.DictWriter(output, fieldnames=fieldnames) + writer = csv.DictWriter(output, fieldnames=fieldnames, extrasaction="ignore") writer.writeheader() writer.writerows(expanded) return output.getvalue().encode("utf-8"), "text/csv" @@ -230,19 +330,18 @@ def serialize_export_rows( detail="XLSX export requires pandas/openpyxl support in the backend runtime", ) from exc - # XLSX shows input columns + output columns only (no metadata fields). - metadata_fields = { - field - for field in AssessmentExportRow.model_fields.keys() - if field not in ("output", "input_data") - } - excel_fields = [field for field in fieldnames if field not in metadata_fields] + # Explicit ordering: inputs → prefilter → L2 → computed columns + excel_fields = input_col_names + prefilter_keys + output_keys + computed_names if not excel_fields: - excel_fields = ["output"] + excel_fields = output_keys or ["output"] - # Drop columns where every row is null/empty expanded, excel_fields = _drop_empty_columns(expanded, excel_fields) + def _clean(value: Any) -> Any: + return _XLSX_ILLEGAL_RE.sub("", value) if isinstance(value, str) else value + + expanded = [{k: _clean(v) for k, v in row.items()} for row in expanded] + buf = io.BytesIO() data_frame = pd.DataFrame(expanded, columns=excel_fields) with pd.ExcelWriter(buf) as writer: @@ -258,17 +357,20 @@ def build_json_export_rows( ) -> list[dict[str, Any]]: """Return JSON rows with structured output expanded into top-level keys.""" row_payload = [row.model_dump(mode="json") for row in export_rows] - expanded, _ = _expand_output_columns(row_payload) - return expanded + expanded, fieldnames, *_ = _expand_output_columns(row_payload) + return [{k: row.get(k) for k in fieldnames if k in row} for row in expanded] def build_export_response( export_rows: list[AssessmentExportRow], export_format: Literal["json", "csv", "xlsx"], base_name: str, + post_processing_config: dict[str, Any] | None = None, ) -> StreamingResponse: """Return a file download response for assessment exports.""" - payload, media_type = serialize_export_rows(export_rows, export_format) + payload, media_type = serialize_export_rows( + export_rows, export_format, post_processing_config + ) filename = generate_timestamped_filename( _safe_filename_part(base_name), extension=export_format, @@ -320,9 +422,6 @@ def _load_parsed_results_for_run( # 2. Fallback: download directly from batch provider if batch_job.provider_output_file_id: try: - from app.core.batch import download_batch_results - from app.crud.assessment.processing import _get_batch_provider - provider = _get_batch_provider( session=session, provider_name=batch_job.provider, @@ -376,26 +475,114 @@ def _load_dataset_rows_for_run( return [] +def _extract_prefilter_json_columns( + prefilter_item: dict[str, Any] | None, +) -> dict[str, Any]: + """Return topic_relevance and duplicate_detection as JSON strings for export expansion.""" + if not prefilter_item: + return {"topic_relevance": None, "duplicate_detection": None} + + tr = prefilter_item.get("topic_relevance") + dup = prefilter_item.get("duplicate_detection") + + tr_flat: dict[str, Any] | None = None + if tr: + tr_flat = {} + for col, val in (tr.get("column_relevance") or {}).items(): + tr_flat[col] = val + tr_flat["decision"] = tr.get("decision") + tr_flat["reasoning"] = tr.get("reasoning") + + dup_flat: dict[str, Any] | None = None + if dup: + dup_flat = {k: v for k, v in dup.items() if k != "row_id"} + + return { + "topic_relevance": json.dumps(tr_flat, ensure_ascii=False) if tr_flat else None, + "duplicate_detection": json.dumps(dup_flat, ensure_ascii=False) + if dup_flat + else None, + } + + +def _load_parsed_results_for_batch_job( + session: Session, + batch_job: BatchJob, + assessment: Assessment, +) -> list[dict[str, Any]] | None: + """Parse one chunk batch's stored results (object store first, provider fallback).""" + if batch_job.raw_output_url: + try: + storage = get_cloud_storage(session, project_id=assessment.project_id) + raw = parse_stored_results( + storage.stream(batch_job.raw_output_url).read().decode("utf-8") + ) + if raw: + return parse_assessment_output(raw, batch_job.provider) + except Exception as exc: + logger.warning( + "[_load_parsed_results_for_batch_job] S3 read failed for batch %s: %s", + batch_job.id, + exc, + ) + + if batch_job.provider_output_file_id: + try: + provider = _get_batch_provider( + session=session, + provider_name=batch_job.provider, + organization_id=assessment.organization_id, + project_id=assessment.project_id, + ) + raw = download_batch_results(provider=provider, batch_job=batch_job) + return parse_assessment_output(raw, batch_job.provider) + except Exception as exc: + logger.error( + "[_load_parsed_results_for_batch_job] Provider download failed for " + "batch %s: %s", + batch_job.id, + exc, + exc_info=True, + ) + return None + + +def _load_l2_results_for_run( + session: Session, + run: AssessmentRun, + assessment: Assessment, +) -> dict[str, dict[str, Any]]: + """L2 results keyed by row_id, from the run's L2 stage batch ({} if not done).""" + merged: dict[str, dict[str, Any]] = {} + batch_job = _stage_batch_job(session, run, Stage.L2_ASSESSMENT) + if batch_job: + for item in ( + _load_parsed_results_for_batch_job(session, batch_job, assessment) or [] + ): + if "row_id" in item: + merged[str(item["row_id"])] = item + return merged + + +def _row_result_status( + prefilter_passed: bool, + l2_item: dict[str, Any] | None, + run_status: str, +) -> str: + """Per-row status: rejected, failed, passed, or processing (batch not done).""" + if not prefilter_passed: + return "prefilter_rejected" + if l2_item is None: + return "failed" if run_status == "failed" else "processing" + return "failed" if l2_item.get("error") else "passed" + + def load_export_rows_for_run( session: Session, run: AssessmentRun, assessment: Assessment | None = None, ) -> list[AssessmentExportRow]: - """Load flattened export rows for a single child assessment run.""" - if not run.batch_job_id: - logger.warning( - "[load_export_rows_for_run] No batch_job_id for run id=%s", run.id - ) - return [] - - batch_job = get_batch_job(session=session, batch_job_id=run.batch_job_id) - if not batch_job: - logger.warning( - "[load_export_rows_for_run] Missing batch job for run id=%s", - run.id, - ) - return [] - + """Flatten one run's rows, merging prefilter annotations + L2 results by row_id.""" if assessment is None: assessment = session.get(Assessment, run.assessment_id) if assessment is None: @@ -405,64 +592,91 @@ def load_export_rows_for_run( ) return [] - parsed_results = _load_parsed_results_for_run( - session=session, - run=run, - batch_job=batch_job, - ) - if parsed_results is None: - return [] - - if not parsed_results: - logger.warning( - "[load_export_rows_for_run] Parsed results empty for run id=%s", run.id - ) - return [] - - dataset_rows = _load_dataset_rows_for_run(session, run, assessment) dataset = session.get(EvaluationDataset, assessment.dataset_id) dataset_name = dataset.name if dataset else None + dataset_rows = _load_dataset_rows_for_run(session, run, assessment) - export_rows: list[AssessmentExportRow] = [] - for item in parsed_results: - input_tokens, output_tokens, total_tokens = usage_totals(item.get("usage")) - - # Correlate with original input row via row_id (format: "row_{idx}") - input_data: dict[str, str] | None = None - row_id_str = str(item.get("row_id", "")) - if dataset_rows and row_id_str.startswith("row_"): - try: - row_idx = int(row_id_str.split("_", 1)[1]) - if 0 <= row_idx < len(dataset_rows): - input_data = dataset_rows[row_idx] - except (ValueError, IndexError): - pass - - export_rows.append( - AssessmentExportRow( - assessment_id=run.assessment_id, - experiment_name=assessment.experiment_name, - dataset_id=assessment.dataset_id, + prefilter_by_row_id = _load_prefilter_results(session, run, assessment) + l2_by_row_id = _load_l2_results_for_run(session, run, assessment) + has_prefilter = bool(prefilter_by_row_id) + + if dataset_rows: + rows = [ + _build_export_row( + run=run, + assessment=assessment, dataset_name=dataset_name, - run_id=run.id, - run_name=assessment.experiment_name, - run_status=run.status, - config_id=run.config_id, - config_version=run.config_version, - row_id=row_id_str, - result_status="failed" if item.get("error") else "passed", + row_id=f"row_{row_idx}", input_data=input_data, - output=item.get("output"), - error=item.get("error"), - response_id=item.get("response_id"), - input_tokens=input_tokens, - output_tokens=output_tokens, - total_tokens=total_tokens, - updated_at=run.updated_at, + prefilter_item=prefilter_by_row_id.get(f"row_{row_idx}"), + l2_item=l2_by_row_id.get(f"row_{row_idx}"), + has_prefilter=has_prefilter, ) + for row_idx, input_data in enumerate(dataset_rows) + ] + return rows + + # Dataset unavailable — emit whatever results we have, indexed by row_id. + all_row_ids = sorted( + {str(rid) for rid in l2_by_row_id} | {str(rid) for rid in prefilter_by_row_id} + ) + return [ + _build_export_row( + run=run, + assessment=assessment, + dataset_name=dataset_name, + row_id=row_id, + input_data=None, + prefilter_item=prefilter_by_row_id.get(row_id), + l2_item=l2_by_row_id.get(row_id), + has_prefilter=has_prefilter, ) + for row_id in all_row_ids + ] - return export_rows + +def _build_export_row( + run: AssessmentRun, + assessment: Assessment, + dataset_name: str | None, + row_id: str, + input_data: dict[str, str] | None, + prefilter_item: dict[str, Any] | None, + l2_item: dict[str, Any] | None, + has_prefilter: bool, +) -> AssessmentExportRow: + prefilter_cols = ( + _extract_prefilter_json_columns(prefilter_item) + if has_prefilter + else {"topic_relevance": None, "duplicate_detection": None} + ) + prefilter_passed = (prefilter_item or {}).get("prefilter_passed", True) + input_tokens, output_tokens, total_tokens = usage_totals( + l2_item.get("usage") if l2_item else None + ) + return AssessmentExportRow( + assessment_id=run.assessment_id, + experiment_name=assessment.experiment_name, + dataset_id=assessment.dataset_id, + dataset_name=dataset_name, + run_id=run.id, + run_name=assessment.experiment_name, + run_status=run.status, + config_id=run.config_id, + config_version=run.config_version, + row_id=row_id, + result_status=_row_result_status(prefilter_passed, l2_item, run.status), + input_data=input_data, + topic_relevance=prefilter_cols.get("topic_relevance"), + duplicate_detection=prefilter_cols.get("duplicate_detection"), + output=l2_item.get("output") if l2_item else None, + error=l2_item.get("error") if l2_item else None, + response_id=l2_item.get("response_id") if l2_item else None, + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + updated_at=run.updated_at, + ) def sort_export_rows( diff --git a/backend/app/services/assessment/utils/post_processing.py b/backend/app/services/assessment/utils/post_processing.py new file mode 100644 index 000000000..9b0d36c45 --- /dev/null +++ b/backend/app/services/assessment/utils/post_processing.py @@ -0,0 +1,184 @@ +"""Post-processing engine for assessment exports. +""" + +import ast +import logging +import operator +import re +from typing import Any + +logger = logging.getLogger(__name__) + +# Safe formula evaluator +_SAFE_OPS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.USub: operator.neg, +} + + +def _eval_node(node: ast.AST) -> float: + if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): + return float(node.value) + if isinstance(node, ast.BinOp) and type(node.op) in _SAFE_OPS: + return _SAFE_OPS[type(node.op)](_eval_node(node.left), _eval_node(node.right)) + if isinstance(node, ast.UnaryOp) and type(node.op) in _SAFE_OPS: + return _SAFE_OPS[type(node.op)](_eval_node(node.operand)) + raise ValueError(f"Unsupported operation in formula: {ast.dump(node)}") + + +def evaluate_formula(formula: str, row: dict[str, Any]) -> float | None: + """Evaluate a formula like '@Novelty_score + @Feasibility_score * 0.5'. + + Returns None if the formula fails or references missing columns. + """ + + def resolve(match: re.Match) -> str: + col = match.group(1) + val = row.get(col) + if val is None: + return "0" + try: + return str(float(val)) + except (TypeError, ValueError): + return "0" + + expr = re.sub(r"@([\w]+)", resolve, formula) + + try: + tree = ast.parse(expr, mode="eval") + return _eval_node(tree.body) + except Exception as exc: + logger.warning("[evaluate_formula] Failed to evaluate %r: %s", formula, exc) + return None + + +# Filter + +_FILTER_OPS = { + "eq": lambda a, b: str(a).strip().lower() == str(b).strip().lower(), + "ne": lambda a, b: str(a).strip().lower() != str(b).strip().lower(), + "contains": lambda a, b: str(b).lower() in str(a).lower(), + "not_contains": lambda a, b: str(b).lower() not in str(a).lower(), + "in": lambda a, b: str(a).strip().lower() in {str(v).lower() for v in b}, + "not_in": lambda a, b: str(a).strip().lower() not in {str(v).lower() for v in b}, + "is_empty": lambda a, _: a is None or str(a).strip() == "", + "is_not_empty": lambda a, _: a is not None and str(a).strip() != "", +} + + +def _numeric_filter(op: str, a: Any, b: Any) -> bool: + try: + fa, fb = float(a), float(b) + if op == "gt": + return fa > fb + if op == "lt": + return fa < fb + if op == "gte": + return fa >= fb + if op == "lte": + return fa <= fb + except (TypeError, ValueError): + pass + return False + + +def _row_matches_filter(row: dict[str, Any], rule: dict[str, Any]) -> bool: + col = rule["column"] + op = rule["op"] + value = rule.get("value") + cell = row.get(col) + + if op in ("gt", "lt", "gte", "lte"): + return _numeric_filter(op, cell, value) + if op in _FILTER_OPS: + return _FILTER_OPS[op](cell, value) + return True + + +def apply_computed_columns( + rows: list[dict[str, Any]], + computed_columns: list[dict[str, Any]], +) -> None: + """Add computed columns to each row in-place.""" + for row in rows: + for col_def in computed_columns: + name = col_def.get("name", "").strip() + formula = col_def.get("formula", "").strip() + if not name or not formula: + continue + row[name] = evaluate_formula(formula, row) + + +def apply_filter( + rows: list[dict[str, Any]], + filter_rules: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Return only rows that match ALL filter rules (AND logic).""" + if not filter_rules: + return rows + return [ + row + for row in rows + if all(_row_matches_filter(row, rule) for rule in filter_rules) + ] + + +def apply_sort( + rows: list[dict[str, Any]], + sort_rules: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Sort rows by priority-ordered rules. First rule has highest priority.""" + if not sort_rules: + return rows + + # Build sort key: iterate rules in reverse (lowest priority first) + # so that highest priority rule is the final (dominant) tiebreaker. + result = rows + for rule in reversed(sort_rules): + col = rule.get("column", "") + desc = str(rule.get("direction", "asc")).lower() == "desc" + + def sort_key(row: dict[str, Any], _col: str = col) -> tuple: + val = row.get(_col) + if val is None: + return (1, 0, "") + try: + return (0, -float(val) if desc else float(val), "") + except (TypeError, ValueError): + s = str(val).lower() + return ( + (0, 0, s) + if not desc + else (0, 0, "".join(chr(0x10FFFF - ord(c)) for c in s)) + ) + + result = sorted(result, key=sort_key) + + return result + + +def apply_post_processing( + rows: list[dict[str, Any]], + config: dict[str, Any] | None, +) -> list[dict[str, Any]]: + """Apply full post-processing pipeline: computed columns → filter → sort. + + Safe to call with config=None (no-op). + """ + if not config: + return rows + + computed_columns = config.get("computed_columns") or [] + filter_rules = config.get("filter") or [] + sort_rules = config.get("sort") or [] + + if computed_columns: + apply_computed_columns(rows, computed_columns) + + rows = apply_filter(rows, filter_rules) + rows = apply_sort(rows, sort_rules) + + return rows diff --git a/backend/app/tests/assessment/test_batch.py b/backend/app/tests/assessment/test_batch.py index 6d524e81f..c025ba98c 100644 --- a/backend/app/tests/assessment/test_batch.py +++ b/backend/app/tests/assessment/test_batch.py @@ -18,13 +18,12 @@ ) from app.models.assessment import AssessmentAttachment from app.services.assessment.utils.attachments import ( - _decode_base64_prefix, - _guess_image_mime_from_base64, _guess_image_mime_from_url, + attachment_type_for_row, + build_gemini_attachment_parts, resolve_attachment_values, - resolve_image_mime_and_payload, + resolve_item_type, split_attachment_urls, - split_data_url, to_direct_attachment_url, ) @@ -77,7 +76,7 @@ def test_openai_native_routes_to_openai_batch(self) -> None: return_value=[{"custom_id": "row_0"}], ), patch( - "app.utils.get_openai_client", + "app.crud.assessment.batch.get_openai_client", return_value=MagicMock(), ), patch( @@ -139,7 +138,7 @@ def test_config_instruction_is_not_used_without_request_instruction(self) -> Non return_value=[{"custom_id": "row_0"}], ), patch( - "app.utils.get_openai_client", + "app.crud.assessment.batch.get_openai_client", return_value=MagicMock(), ), patch( @@ -194,9 +193,9 @@ def test_google_native_routes_to_google_batch(self) -> None: "app.crud.assessment.batch.build_google_jsonl", return_value=[{"key": "row_0"}], ), - patch("app.core.batch.client.GeminiClient") as gemini_cls, + patch("app.crud.assessment.batch.GeminiClient") as gemini_cls, patch( - "app.core.batch.GeminiBatchProvider", + "app.crud.assessment.batch.GeminiBatchProvider", return_value=MagicMock(), ), patch( @@ -346,56 +345,24 @@ def test_split_and_direct_urls(self) -> None: ) assert "drive.google.com/uc" in pdf_url - def test_data_url_and_mime_guessers(self) -> None: - mime, payload = split_data_url("data:image/png;base64,AAAA") - assert mime == "image/png" - assert payload == "AAAA" - none_mime, raw = split_data_url("rawbase64") - assert none_mime is None - assert raw == "rawbase64" + def test_url_mime_guessers(self) -> None: assert _guess_image_mime_from_url("https://x/y/file.jpeg") == "image/jpeg" assert _guess_image_mime_from_url("https://x/y/file.unknown") is None - def test_base64_guess_and_decode(self) -> None: - png_head = "iVBORw0KGgoAAAANSUhEUg==" - assert _guess_image_mime_from_base64(png_head) == "image/png" - assert _decode_base64_prefix("###") == b"" - - def testresolve_image_mime_and_payload(self) -> None: - mime, payload = resolve_image_mime_and_payload("https://x/y/file.webp", "url") - assert mime == "image/webp" - assert payload.endswith("file.webp") - mime2, payload2 = resolve_image_mime_and_payload( - "data:image/jpeg;base64,AAAA", "base64" - ) - assert mime2 == "image/jpeg" - assert payload2 == "AAAA" - def testresolve_attachment_values(self) -> None: image_url_att = AssessmentAttachment(column="img", type="image", format="url") - image_b64_att = AssessmentAttachment( - column="img", type="image", format="base64" - ) pdf_url_att = AssessmentAttachment(column="pdf", type="pdf", format="url") - pdf_b64_att = AssessmentAttachment(column="pdf", type="pdf", format="base64") values = resolve_attachment_values( "https://example.com/a.png,https://example.com/b.png", image_url_att ) assert len(values) == 2 assert values[0]["type"] == "input_image" - - values = resolve_attachment_values("data:image/png;base64,AAAA", image_b64_att) - assert values[0]["image_url"].startswith("data:image/png;base64,") + assert values[0]["image_url"] == "https://example.com/a.png" values = resolve_attachment_values("https://example.com/a.pdf", pdf_url_att) assert values[0]["type"] == "input_file" - assert "file_url" in values[0] - - values = resolve_attachment_values( - "data:application/pdf;base64,AAAA", pdf_b64_att - ) - assert values[0]["file_data"].startswith("data:application/pdf;base64,") + assert values[0]["file_url"] == "https://example.com/a.pdf" def test_build_openai_and_google_jsonl(self) -> None: rows = [{"q": "What is 2+2?", "img": "https://example.com/a.png"}] @@ -423,3 +390,178 @@ def test_build_openai_and_google_jsonl(self) -> None: assert google_jsonl[0]["request"]["systemInstruction"] == { "parts": [{"text": "system"}] } + + +class TestResolveItemType: + """Image/pdf routing now trusts the user-declared type (no detection).""" + + def test_declared_image(self) -> None: + assert resolve_item_type("image") == "image" + + def test_declared_pdf(self) -> None: + assert resolve_item_type("pdf") == "pdf" + + def test_override_wins(self) -> None: + assert resolve_item_type("image", "pdf") == "pdf" + assert resolve_item_type("pdf", "image") == "image" + + def test_mixed_without_override_is_unresolved(self) -> None: + assert resolve_item_type("mixed") is None + + def test_unknown_declared_is_unresolved(self) -> None: + assert resolve_item_type("whatever") is None + + def test_column_uses_single_declared_type(self) -> None: + """One column, many URLs -> all routed by the declared type.""" + att = AssessmentAttachment(column="docs", type="pdf", format="url") + value = "https://x.com/a/photo.jpg, https://x.com/b/report" + resolved = resolve_attachment_values(value, att) + types = [obj["type"] for obj in resolved] + assert types == ["input_file", "input_file"] + + +class TestAttachmentMime: + def test_guess_image_mime_from_url_variants(self) -> None: + assert _guess_image_mime_from_url("http://x/a.PNG") == "image/png" + assert _guess_image_mime_from_url("http://x/a.jpeg") == "image/jpeg" + assert _guess_image_mime_from_url("http://x/a.webp") == "image/webp" + assert _guess_image_mime_from_url("http://x/a.txt") is None + + +class TestAttachmentTypeForRow: + def test_mixed_resolves_from_type_column(self) -> None: + from app.services.assessment.utils.attachments import attachment_type_for_row + + att = AssessmentAttachment( + column="Docs", + type="mixed", + format="url", + type_column="DOC type", + type_value_map={"Photo": "image", "Report": "pdf"}, + ) + assert attachment_type_for_row(att, {"DOC type": "Photo"}) == "image" + assert attachment_type_for_row(att, {"DOC type": "Report"}) == "pdf" + assert attachment_type_for_row(att, {"DOC type": "Unknown"}) is None + + def test_mixed_resolves_comma_separated_value_lists(self) -> None: + from app.services.assessment.utils.attachments import attachment_type_for_row + + att = AssessmentAttachment( + column="Docs", + type="mixed", + format="url", + type_column="DOC type", + type_value_map={"Img-Prototype, Img-Handtext": "image", "Pdf": "pdf"}, + ) + + assert attachment_type_for_row(att, {"DOC type": "Img-Prototype"}) == "image" + assert attachment_type_for_row(att, {"DOC type": "Img-Handtext"}) == "image" + assert attachment_type_for_row(att, {"DOC type": "pdf"}) == "pdf" + + def test_mixed_resolves_row_value_lists_when_same_type(self) -> None: + from app.services.assessment.utils.attachments import attachment_type_for_row + + att = AssessmentAttachment( + column="Docs", + type="mixed", + format="url", + type_column="DOC type", + type_value_map={"Img-Prototype, Img-Handtext": "image", "Pdf": "pdf"}, + ) + + assert ( + attachment_type_for_row( + att, + {"DOC type": "Img-Prototype, Img-Handtext"}, + ) + == "image" + ) + assert attachment_type_for_row(att, {"DOC type": "Img-Prototype, Pdf"}) is None + + def test_mixed_missing_type_mapping_fields_returns_none(self) -> None: + from app.services.assessment.utils.attachments import attachment_type_for_row + + att = SimpleNamespace(column="Docs", type="mixed", format="url") + + assert attachment_type_for_row(att, {"Docs": "x"}) is None + + def test_non_mixed_returns_none(self) -> None: + from app.services.assessment.utils.attachments import attachment_type_for_row + + att = AssessmentAttachment(column="Docs", type="image", format="url") + assert attachment_type_for_row(att, {"Docs": "x"}) is None + + def test_mixed_config_missing_routing_fields_is_rejected(self) -> None: + import pytest + from pydantic import ValidationError + + with pytest.raises(ValidationError): + AssessmentAttachment(column="Docs", type="mixed", format="url") + + def test_mixed_config_invalid_map_value_is_rejected(self) -> None: + import pytest + from pydantic import ValidationError + + with pytest.raises(ValidationError): + AssessmentAttachment( + column="Docs", + type="mixed", + format="url", + type_column="DOC type", + type_value_map={"Report": "spreadsheet"}, + ) + + def test_override_forces_part_type(self) -> None: + from app.services.assessment.utils.attachments import resolve_attachment_values + + att = AssessmentAttachment( + column="Docs", + type="mixed", + format="url", + type_column="DOC type", + type_value_map={"Report": "pdf"}, + ) + url = "https://drive.google.com/file/d/ID/view" + parts = resolve_attachment_values(url, att, type_override="pdf") + assert parts[0]["type"] == "input_file" + + +class TestAttachmentResolutionBranches: + _IMG = AssessmentAttachment(column="Docs", type="image", format="url") + _PDF = AssessmentAttachment(column="Docs", type="pdf", format="url") + _MIXED = AssessmentAttachment( + column="Docs", + type="mixed", + format="url", + type_column="DOC type", + type_value_map={"Report": "pdf"}, + ) + + def test_blank_value_returns_empty(self) -> None: + assert resolve_attachment_values(" ", self._IMG) == [] + assert build_gemini_attachment_parts(" ", self._IMG) == [] + + def test_unresolved_mixed_is_skipped(self) -> None: + url = "https://x.com/a.jpg" + # No override and declared 'mixed' -> unresolved -> skip rather than guess. + assert resolve_attachment_values(url, self._MIXED) == [] + assert build_gemini_attachment_parts(url, self._MIXED) == [] + + def test_gemini_image_and_pdf_parts(self) -> None: + img = build_gemini_attachment_parts("https://x.com/a.png", self._IMG)[0] + pdf = build_gemini_attachment_parts("https://x.com/a.pdf", self._PDF)[0] + assert img["fileData"]["mimeType"] == "image/png" + assert pdf["fileData"]["mimeType"] == "application/pdf" + + def test_type_for_row_blank_value_returns_none(self) -> None: + assert attachment_type_for_row(self._MIXED, {"DOC type": " "}) is None + + def test_type_for_row_ignores_invalid_map_value(self) -> None: + # SimpleNamespace bypasses the model validator to exercise the guard that + # skips map entries whose target type isn't 'image'/'pdf'. + att = SimpleNamespace( + type="mixed", + type_column="DOC type", + type_value_map={"Report": "spreadsheet"}, + ) + assert attachment_type_for_row(att, {"DOC type": "Report"}) is None diff --git a/backend/app/tests/assessment/test_cron.py b/backend/app/tests/assessment/test_cron.py index c9407bd5c..77797bb88 100644 --- a/backend/app/tests/assessment/test_cron.py +++ b/backend/app/tests/assessment/test_cron.py @@ -103,7 +103,7 @@ async def test_no_active_runs_recompute(self) -> None: "app.crud.assessment.cron.recompute_assessment_status", return_value=refreshed, ), patch( - "app.crud.assessment.cron.check_and_process_assessment", new=AsyncMock() + "app.crud.assessment.cron.process_run_batches", new=AsyncMock() ): result = await poll_all_pending_assessment_evaluations(session=session) @@ -115,14 +115,14 @@ async def test_active_run_processed(self) -> None: session = MagicMock() assessment = _make_assessment(id=1, status="processing") run = _make_run(id=11) - run.status = "processing" + run.stage_status = "PROCESSING" session.exec.return_value.all.return_value = [assessment] with patch( "app.crud.assessment.cron.get_assessment_runs_for_assessment", return_value=[run], ), patch( - "app.crud.assessment.cron.check_and_process_assessment", + "app.crud.assessment.cron.process_run_batches", new=AsyncMock( return_value={ "action": "processed", @@ -136,55 +136,46 @@ async def test_active_run_processed(self) -> None: assert result["processed"] == 1 @pytest.mark.asyncio - async def test_active_run_failure_and_cleanup_failure(self) -> None: + async def test_transient_poll_exception_does_not_fail_run(self) -> None: + """A transient error while polling leaves the run active for retry.""" session = MagicMock() assessment = _make_assessment(id=1, status="processing") run = _make_run(id=11) - run.status = "processing" + run.stage_status = "PROCESSING" session.exec.return_value.all.return_value = [assessment] with patch( "app.crud.assessment.cron.get_assessment_runs_for_assessment", return_value=[run], ), patch( - "app.crud.assessment.cron.check_and_process_assessment", - new=AsyncMock(side_effect=RuntimeError("boom")), - ), patch( - "app.crud.assessment.cron.update_assessment_run_status", - side_effect=RuntimeError("cleanup-failed"), - ), patch( - "app.crud.assessment.cron.recompute_assessment_status", + "app.crud.assessment.cron.process_run_batches", + new=AsyncMock(side_effect=RuntimeError("nodename nor servname provided")), ): result = await poll_all_pending_assessment_evaluations(session=session) - assert result["failed"] == 1 + assert result["failed"] == 0 + assert result["still_processing"] == 1 @pytest.mark.asyncio - async def test_active_run_failure_updates_db_with_same_error_message(self) -> None: + async def test_deterministic_error_marks_run_failed(self) -> None: + """A deterministic ValueError fails the run instead of retrying forever.""" session = MagicMock() assessment = _make_assessment(id=1, status="processing") run = _make_run(id=11) - run.status = "processing" + run.stage_status = "PROCESSING" session.exec.return_value.all.return_value = [assessment] with patch( "app.crud.assessment.cron.get_assessment_runs_for_assessment", return_value=[run], ), patch( - "app.crud.assessment.cron.check_and_process_assessment", - new=AsyncMock(side_effect=RuntimeError("gemini quota exceeded")), + "app.crud.assessment.cron.process_run_batches", + new=AsyncMock(side_effect=ValueError("Parent assessment 1 not found")), ), patch( - "app.crud.assessment.cron.update_assessment_run_status", - ) as update_run, patch( - "app.crud.assessment.cron.recompute_assessment_status", - ): + "app.crud.assessment.cron.update_assessment_run_status" + ) as mark_failed: result = await poll_all_pending_assessment_evaluations(session=session) assert result["failed"] == 1 - assert result["details"][0]["error"] == "gemini quota exceeded" - update_run.assert_called_once_with( - session=session, - run=run, - status="failed", - error_message="gemini quota exceeded", - ) + assert result["still_processing"] == 0 + assert mark_failed.call_args.kwargs["status"] == "failed" diff --git a/backend/app/tests/assessment/test_crud.py b/backend/app/tests/assessment/test_crud.py index 2bc076342..e68feb813 100644 --- a/backend/app/tests/assessment/test_crud.py +++ b/backend/app/tests/assessment/test_crud.py @@ -2,7 +2,7 @@ from datetime import datetime from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from uuid import UUID import pytest @@ -25,7 +25,9 @@ list_assessments, recompute_assessment_status, update_assessment_run_status, + update_run_post_processing_config, ) +from app.crud.assessment.core import update_assessment_run_prefilter_stats from app.models.stt_evaluation import EvaluationType @@ -232,12 +234,18 @@ def test_build_run_stats(self) -> None: total_items=2, error_message=None, updated_at=datetime(2024, 1, 1), + prefilter_total_rows=None, + prefilter_total_passed=None, + prefilter_total_rejected=None, + stage="COMPLETED", + stage_status="COMPLETED", ), ] stats = build_run_stats(runs) assert len(stats) == 1 assert stats[0].run_id == 1 assert stats[0].status == "completed" + assert stats[0].stage == "COMPLETED" def test_derive_aggregate_error(self) -> None: assert derive_aggregate_error(_counts(total=2, completed=2)) is None @@ -297,3 +305,79 @@ def test_recompute_commit_failure_rolls_back(self) -> None: with pytest.raises(RuntimeError): recompute_assessment_status(session=session, assessment_id=1) session.rollback.assert_called_once() + + +class TestUpdateRunPostProcessingConfig: + def test_sets_config_in_input_blob(self) -> None: + session = MagicMock() + run = SimpleNamespace(id=5, input={"text_columns": ["q"]}) + cfg = {"computed_columns": [{"name": "T", "formula": "@a"}]} + with patch("app.crud.assessment.core.flag_modified") as flag: + out = update_run_post_processing_config( + session=session, run=run, config=cfg + ) + assert out.input["post_processing_config"] == cfg + assert out.input["text_columns"] == ["q"] + flag.assert_called_once_with(run, "input") + session.commit.assert_called_once() + + def test_none_input_handled(self) -> None: + session = MagicMock() + run = SimpleNamespace(id=6, input=None) + with patch("app.crud.assessment.core.flag_modified"): + out = update_run_post_processing_config( + session=session, run=run, config=None + ) + assert out.input == {"post_processing_config": None} + + def test_commit_failure_rolls_back(self) -> None: + session = MagicMock() + session.commit.side_effect = RuntimeError("db error") + run = SimpleNamespace(id=7, input={}) + with patch("app.crud.assessment.core.flag_modified"): + with pytest.raises(RuntimeError): + update_run_post_processing_config(session=session, run=run, config={}) + session.rollback.assert_called_once() + + +class TestUpdateAssessmentRunL1Stats: + def test_sets_stats_fields(self) -> None: + session = MagicMock() + run = SimpleNamespace( + id=8, + updated_at=None, + prefilter_object_store_url=None, + prefilter_total_rows=None, + prefilter_total_passed=None, + prefilter_total_rejected=None, + ) + out = update_assessment_run_prefilter_stats( + session=session, + run=run, + prefilter_object_store_url="s3://x", + prefilter_total_rows=10, + prefilter_total_passed=7, + prefilter_total_rejected=3, + ) + assert out.prefilter_object_store_url == "s3://x" + assert out.prefilter_total_rows == 10 + assert out.prefilter_total_passed == 7 + assert out.prefilter_total_rejected == 3 + session.commit.assert_called_once() + + def test_commit_failure_rolls_back(self) -> None: + session = MagicMock() + session.commit.side_effect = RuntimeError("db error") + run = SimpleNamespace( + id=9, + updated_at=None, + prefilter_object_store_url=None, + prefilter_total_rows=None, + prefilter_total_passed=None, + prefilter_total_rejected=None, + ) + with pytest.raises(RuntimeError): + update_assessment_run_prefilter_stats( + session=session, run=run, prefilter_total_rows=1 + ) + session.rollback.assert_called_once() diff --git a/backend/app/tests/assessment/test_duplicate_detection.py b/backend/app/tests/assessment/test_duplicate_detection.py new file mode 100644 index 000000000..b9b0c033a --- /dev/null +++ b/backend/app/tests/assessment/test_duplicate_detection.py @@ -0,0 +1,85 @@ +"""Tests for the duplicate-detection batch request builder and result parser.""" + +from unittest.mock import patch + +from app.services.assessment.prefilter import constants +from app.services.assessment.prefilter.duplicate_detection import ( + build_duplicate_detection_requests, + parse_duplicate_detection_results, +) + + +class TestBuildRequests: + def test_one_request_per_record(self) -> None: + rows = [(0, {"Problem": "p0", "Solution": "s0"}), (1, {"Problem": "p1"})] + lines = build_duplicate_detection_requests(rows, ["Problem", "Solution"]) + # key (gemini) or custom_id (openai) depending on configured provider. + keys = [ln.get("key") or ln.get("custom_id") for ln in lines] + assert keys == ["dup_0", "dup_1"] + + def test_openai_request_grounds_on_file_search_store(self) -> None: + with patch.object( + constants, "ASSESSMENT_PREFILTER_PROVIDER", "openai" + ), patch.object(constants, "ASSESSMENT_PREFILTER_DUPLICATE_STORE", "vs_corpus"): + lines = build_duplicate_detection_requests( + [(0, {"Problem": "p"})], ["Problem"] + ) + tool = lines[0]["body"]["tools"][0] + assert tool["type"] == "file_search" + assert tool["vector_store_ids"] == ["vs_corpus"] + + +class TestParseResults: + def test_parses_structured_verdict_per_row(self) -> None: + import json + + outputs = [ + { + "row_id": "dup_0", + "output": json.dumps( + { + "verdict": "UNIQUE", + "match_title": "", + "source_url": "", + "matching_sentence": "", + "reason": "novel", + } + ), + "error": None, + }, + { + "row_id": "dup_1", + "output": json.dumps( + { + "verdict": "DUPLICATE", + "match_title": "T", + "source_url": "http://x", + "matching_sentence": "s", + "reason": "same mechanism", + } + ), + "error": None, + }, + ] + parsed = parse_duplicate_detection_results(outputs) + assert parsed[0]["verdict"] == "UNIQUE" + assert parsed[0]["source_url"] is None # "" -> None + assert parsed[1]["verdict"] == "DUPLICATE" + assert parsed[1]["source_url"] == "http://x" + + def test_empty_response_records_error(self) -> None: + parsed = parse_duplicate_detection_results( + [{"row_id": "dup_3", "output": None, "error": None}] + ) + assert parsed[3]["verdict"] == "ERROR" + + def test_bad_json_records_error_and_foreign_keys_skipped(self) -> None: + parsed = parse_duplicate_detection_results( + [ + {"row_id": "tr_0", "output": "{}", "error": None}, # not a dup key + {"row_id": "dup_x", "output": "{}", "error": None}, # bad index + {"row_id": "dup_4", "output": "{not json", "error": None}, + ] + ) + assert set(parsed) == {4} + assert parsed[4]["verdict"] == "ERROR" diff --git a/backend/app/tests/assessment/test_export.py b/backend/app/tests/assessment/test_export.py index 3ace89dbd..a13a1e03a 100644 --- a/backend/app/tests/assessment/test_export.py +++ b/backend/app/tests/assessment/test_export.py @@ -2,21 +2,127 @@ import json from datetime import datetime +from types import SimpleNamespace from unittest.mock import MagicMock, patch from app.models.assessment import AssessmentExportRow +from app.services.assessment.utils import export as export_mod from app.services.assessment.utils.export import ( + _build_export_row, _drop_empty_columns, _expand_input_columns, _expand_output_columns, _load_dataset_rows_for_run, + _load_l2_results_for_run, + _load_parsed_results_for_batch_job, _load_parsed_results_for_run, + _load_prefilter_results, _safe_filename_part, + _stage_batch_job, build_json_export_rows, load_export_rows_for_run, serialize_export_rows, sort_export_rows, ) +from app.models.assessment import Stage + + +def _run_ns(status: str = "processing") -> SimpleNamespace: + return SimpleNamespace( + id=5, + assessment_id=9, + status=status, + config_id="00000000-0000-0000-0000-000000000001", + config_version=1, + updated_at=datetime(2026, 1, 1), + ) + + +def _assessment_ns() -> SimpleNamespace: + return SimpleNamespace(experiment_name="exp", dataset_id=3) + + +class TestBuildExportRow: + def test_prefilter_rejected_with_annotations(self) -> None: + prefilter_item = { + "prefilter_passed": False, + "topic_relevance": { + "decision": "REJECT", + "reasoning": "off-topic", + "column_relevance": {"Problem": False}, + }, + "duplicate_detection": {"row_id": "dup_0", "verdict": "UNIQUE"}, + } + row = _build_export_row( + run=_run_ns(), + assessment=_assessment_ns(), + dataset_name="ds", + row_id="row_0", + input_data={"Problem": "p"}, + prefilter_item=prefilter_item, + l2_item=None, + has_prefilter=True, + ) + assert row.result_status == "prefilter_rejected" + assert json.loads(row.topic_relevance)["decision"] == "REJECT" + assert json.loads(row.duplicate_detection)["verdict"] == "UNIQUE" + + def test_passed_with_l2_output(self) -> None: + row = _build_export_row( + run=_run_ns(), + assessment=_assessment_ns(), + dataset_name=None, + row_id="row_1", + input_data=None, + prefilter_item={"prefilter_passed": True}, + l2_item={"output": "{}", "error": None}, + has_prefilter=True, + ) + assert row.result_status == "passed" + + def test_l2_error_is_failed_and_no_prefilter_cols(self) -> None: + row = _build_export_row( + run=_run_ns(), + assessment=_assessment_ns(), + dataset_name=None, + row_id="row_2", + input_data=None, + prefilter_item=None, + l2_item={"output": None, "error": "boom"}, + has_prefilter=False, + ) + assert row.result_status == "failed" + assert row.topic_relevance is None + + def test_no_l2_processing_vs_failed(self) -> None: + processing = _build_export_row( + run=_run_ns(status="processing"), + assessment=_assessment_ns(), + dataset_name=None, + row_id="row_3", + input_data=None, + prefilter_item=None, + l2_item=None, + has_prefilter=False, + ) + failed = _build_export_row( + run=_run_ns(status="failed"), + assessment=_assessment_ns(), + dataset_name=None, + row_id="row_4", + input_data=None, + prefilter_item=None, + l2_item=None, + has_prefilter=False, + ) + assert processing.result_status == "processing" + assert failed.result_status == "failed" + + +def _named_dataset() -> MagicMock: + ds = MagicMock() + ds.name = "ds" + return ds def _make_row( @@ -144,14 +250,14 @@ def test_all_empty_drops_all(self) -> None: class TestExpandOutputColumns: def test_plain_string_output_not_expanded(self) -> None: rows = [{"output": "plain text", "input_data": None}] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) assert "output" in fieldnames def test_json_dict_output_expanded(self) -> None: rows = [ {"output": json.dumps({"score": 5, "reason": "good"}), "input_data": None} ] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) assert "score" in fieldnames assert "reason" in fieldnames assert expanded[0]["score"] == 5 @@ -161,14 +267,14 @@ def test_mixed_parsed_and_unparsed_adds_output_raw(self) -> None: {"output": json.dumps({"score": 3}), "input_data": None}, {"output": "not json", "input_data": None}, ] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) assert "output_raw" in fieldnames # Second row that didn't parse should get output_raw assert expanded[1].get("output_raw") == "not json" def test_none_output_handled(self) -> None: rows = [{"output": None, "input_data": None}] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) assert expanded[0].get("output") is None @@ -253,13 +359,13 @@ class TestExpandOutputColumnsDictOutput: def test_dict_output_expanded_directly(self) -> None: # raw output is already a dict (not a JSON string) rows = [{"output": {"score": 9, "label": "good"}, "input_data": None}] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) assert "score" in fieldnames assert expanded[0]["score"] == 9 def test_non_dict_non_string_output_treated_as_unparsed(self) -> None: rows = [{"output": 42, "input_data": None}] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) # 42 is not a dict/string, treated as unparsed → output stays as-is assert "output" in fieldnames @@ -415,10 +521,11 @@ def test_s3_failure_falls_back_to_provider_download(self) -> None: "app.services.assessment.utils.export.get_cloud_storage", side_effect=Exception("S3 down"), ), patch( - "app.crud.assessment.processing._get_batch_provider", + "app.services.assessment.utils.export._get_batch_provider", return_value=MagicMock(), ), patch( - "app.core.batch.download_batch_results", return_value=raw + "app.services.assessment.utils.export.download_batch_results", + return_value=raw, ): result = _load_parsed_results_for_run( session=session, run=run, batch_job=batch_job @@ -523,64 +630,48 @@ def _make_assessment(self) -> MagicMock: assessment.dataset_id = 2 return assessment - def test_no_batch_job_id_returns_empty(self) -> None: - session = MagicMock() - run = self._make_run() - run.batch_job_id = None - result = load_export_rows_for_run(session=session, run=run) - assert result == [] - - def test_batch_job_not_found_returns_empty(self) -> None: - session = MagicMock() - run = self._make_run() - with patch( - "app.services.assessment.utils.export.get_batch_job", return_value=None - ): - result = load_export_rows_for_run( - session=session, run=run, assessment=self._make_assessment() - ) - assert result == [] + def _patches(self, *, l2, prefilter=None, dataset_rows=None): + return [ + patch( + "app.services.assessment.utils.export._load_l2_results_for_run", + return_value=l2, + ), + patch( + "app.services.assessment.utils.export._load_prefilter_results", + return_value=prefilter or {}, + ), + patch( + "app.services.assessment.utils.export._load_dataset_rows_for_run", + return_value=dataset_rows if dataset_rows is not None else [], + ), + ] - def test_no_parsed_results_returns_empty(self) -> None: + def test_no_results_no_dataset_returns_empty(self) -> None: session = MagicMock() + session.get.return_value = _named_dataset() run = self._make_run() - with patch( - "app.services.assessment.utils.export.get_batch_job", - return_value=MagicMock(), - ), patch( - "app.services.assessment.utils.export._load_parsed_results_for_run", - return_value=None, - ): + p1, p2, p3 = self._patches(l2={}) + with p1, p2, p3: result = load_export_rows_for_run( session=session, run=run, assessment=self._make_assessment() ) assert result == [] - def test_parsed_results_build_export_rows(self) -> None: + def test_merged_results_build_export_rows(self) -> None: session = MagicMock() - dataset = MagicMock() - dataset.name = "ds" - session.get.return_value = dataset + session.get.return_value = _named_dataset() run = self._make_run() - parsed = [ - { + l2 = { + "row_0": { "row_id": "row_0", "output": '{"score": 5}', "error": None, "usage": None, "response_id": "r1", } - ] - with patch( - "app.services.assessment.utils.export.get_batch_job", - return_value=MagicMock(), - ), patch( - "app.services.assessment.utils.export._load_parsed_results_for_run", - return_value=parsed, - ), patch( - "app.services.assessment.utils.export._load_dataset_rows_for_run", - return_value=[], - ): + } + p1, p2, p3 = self._patches(l2=l2) + with p1, p2, p3: result = load_export_rows_for_run( session=session, run=run, assessment=self._make_assessment() ) @@ -590,61 +681,175 @@ def test_parsed_results_build_export_rows(self) -> None: def test_error_result_sets_failed_status(self) -> None: session = MagicMock() - dataset = MagicMock() - dataset.name = "ds" - session.get.return_value = dataset + session.get.return_value = _named_dataset() run = self._make_run() - parsed = [ - { + l2 = { + "row_0": { "row_id": "row_0", "output": None, "error": "timeout", "usage": None, "response_id": None, } - ] - with patch( - "app.services.assessment.utils.export.get_batch_job", - return_value=MagicMock(), - ), patch( - "app.services.assessment.utils.export._load_parsed_results_for_run", - return_value=parsed, - ), patch( - "app.services.assessment.utils.export._load_dataset_rows_for_run", - return_value=[], - ): + } + p1, p2, p3 = self._patches(l2=l2) + with p1, p2, p3: result = load_export_rows_for_run( session=session, run=run, assessment=self._make_assessment() ) assert result[0].result_status == "failed" - def test_input_data_correlated_via_row_id(self) -> None: + def test_dataset_rows_include_pending_and_correlate_input(self) -> None: session = MagicMock() - dataset = MagicMock() - dataset.name = "ds" - session.get.return_value = dataset + session.get.return_value = _named_dataset() run = self._make_run() - parsed = [ - { + run.status = "l2_processing" + l2 = { + "row_1": { "row_id": "row_1", "output": "x", "error": None, "usage": None, "response_id": None, } - ] + } dataset_rows = [{"q": "first"}, {"q": "second"}] - with patch( - "app.services.assessment.utils.export.get_batch_job", - return_value=MagicMock(), - ), patch( - "app.services.assessment.utils.export._load_parsed_results_for_run", - return_value=parsed, - ), patch( - "app.services.assessment.utils.export._load_dataset_rows_for_run", - return_value=dataset_rows, - ): + p1, p2, p3 = self._patches(l2=l2, dataset_rows=dataset_rows) + with p1, p2, p3: result = load_export_rows_for_run( session=session, run=run, assessment=self._make_assessment() ) - assert result[0].input_data == {"q": "second"} + assert len(result) == 2 + assert result[0].result_status == "processing" # row_0 not done yet + assert result[1].input_data == {"q": "second"} + assert result[1].result_status == "passed" + + +class TestStageBatchJob: + def test_returns_job_for_stage(self) -> None: + run = SimpleNamespace(stage_batches={Stage.L2_ASSESSMENT: 7}) + with patch.object(export_mod, "get_batch_job", return_value="JOB") as g: + assert _stage_batch_job(MagicMock(), run, Stage.L2_ASSESSMENT) == "JOB" + assert g.call_args.kwargs["batch_job_id"] == 7 + + def test_none_when_no_batch(self) -> None: + run = SimpleNamespace(stage_batches=None) + assert _stage_batch_job(MagicMock(), run, Stage.L2_ASSESSMENT) is None + + +class TestLoadPrefilterResults: + def test_merges_tr_and_dup_annotations(self) -> None: + run = SimpleNamespace(id=5) + assessment = SimpleNamespace(project_id=1) + with patch.object( + export_mod, + "_stage_batch_job", + return_value=SimpleNamespace(provider="openai"), + ), patch.object( + export_mod, "load_raw_batch_results", return_value=[] + ), patch.object( + export_mod, "parse_assessment_output", return_value=[] + ), patch.object( + export_mod, + "parse_topic_relevance_results", + return_value={ + 0: { + "verdict": True, + "decision": "ACCEPT", + "reasoning": "ok", + "column_relevance": {"a": True}, + } + }, + ), patch.object( + export_mod, + "parse_duplicate_detection_results", + return_value={0: {"verdict": "UNIQUE"}}, + ): + out = _load_prefilter_results(MagicMock(), run, assessment) + assert out["row_0"]["prefilter_passed"] is True + assert out["row_0"]["topic_relevance"]["decision"] == "ACCEPT" + assert out["row_0"]["duplicate_detection"]["verdict"] == "UNIQUE" + + def test_tr_load_failure_is_swallowed(self) -> None: + run = SimpleNamespace(id=5) + assessment = SimpleNamespace(project_id=1) + with patch.object( + export_mod, + "_stage_batch_job", + return_value=SimpleNamespace(provider="openai"), + ), patch.object( + export_mod, "load_raw_batch_results", side_effect=RuntimeError("s3 down") + ): + out = _load_prefilter_results(MagicMock(), run, assessment) + assert out == {} + + +class TestLoadParsedResultsForBatchJob: + def test_object_store_path(self) -> None: + job = SimpleNamespace( + id=1, + provider="openai", + raw_output_url="s3://x", + provider_output_file_id=None, + ) + assessment = SimpleNamespace(project_id=1, organization_id=1) + storage = MagicMock() + storage.stream.return_value.read.return_value.decode.return_value = "raw" + with patch.object( + export_mod, "get_cloud_storage", return_value=storage + ), patch.object( + export_mod, "parse_stored_results", return_value=[{"k": 1}] + ), patch.object( + export_mod, "parse_assessment_output", return_value=[{"row_id": "row_0"}] + ) as parse: + result = _load_parsed_results_for_batch_job(MagicMock(), job, assessment) + assert result == [{"row_id": "row_0"}] + parse.assert_called_once() + + def test_provider_fallback_path(self) -> None: + job = SimpleNamespace( + id=1, + provider="openai", + raw_output_url=None, + provider_output_file_id="f1", + organization_id=1, + ) + assessment = SimpleNamespace(project_id=1, organization_id=1) + with patch.object( + export_mod, "_get_batch_provider", return_value=MagicMock() + ), patch.object( + export_mod, "download_batch_results", return_value=[{"k": 1}] + ), patch.object( + export_mod, "parse_assessment_output", return_value=[{"row_id": "row_1"}] + ): + result = _load_parsed_results_for_batch_job(MagicMock(), job, assessment) + assert result == [{"row_id": "row_1"}] + + def test_returns_none_without_outputs(self) -> None: + job = SimpleNamespace( + id=1, provider="openai", raw_output_url=None, provider_output_file_id=None + ) + assessment = SimpleNamespace(project_id=1, organization_id=1) + assert _load_parsed_results_for_batch_job(MagicMock(), job, assessment) is None + + +class TestLoadL2ResultsForRun: + def test_keys_by_row_id(self) -> None: + run = SimpleNamespace() + assessment = SimpleNamespace() + with patch.object( + export_mod, "_stage_batch_job", return_value=SimpleNamespace() + ), patch.object( + export_mod, + "_load_parsed_results_for_batch_job", + return_value=[{"row_id": "row_0", "output": "x"}, {"no_row": 1}], + ): + merged = _load_l2_results_for_run(MagicMock(), run, assessment) + assert set(merged) == {"row_0"} + + def test_empty_when_no_batch(self) -> None: + with patch.object(export_mod, "_stage_batch_job", return_value=None): + merged = _load_l2_results_for_run( + MagicMock(), SimpleNamespace(), SimpleNamespace() + ) + assert merged == {} diff --git a/backend/app/tests/assessment/test_pipeline.py b/backend/app/tests/assessment/test_pipeline.py new file mode 100644 index 000000000..c807d415d --- /dev/null +++ b/backend/app/tests/assessment/test_pipeline.py @@ -0,0 +1,106 @@ +"""Tests for prefilter settings + pipeline stage ordering.""" + +from types import SimpleNamespace + +import pytest + +from app.models.assessment import Stage, StageStatus +from app.services.assessment.prefilter import resolve_prefilter_settings +from app.services.assessment.stages import ( + advance_or_finalize, + build_pipeline, + build_prefilter_requests, + next_stage, + ordered_stages, +) + +_FULL_INPUT = { + "prefilter_config": { + "topic_relevance": {"columns": ["Problem"], "prompt": "rubric"}, + "duplicate_detection": {"columns": ["Problem"]}, + } +} + + +class TestResolvePrefilterSettings: + def test_both_enabled(self) -> None: + cfg = resolve_prefilter_settings(_FULL_INPUT["prefilter_config"]) + assert cfg["tr_enabled"] is True + assert cfg["dup_enabled"] is True + + def test_disabled_when_empty(self) -> None: + cfg = resolve_prefilter_settings({}) + assert cfg["tr_enabled"] is False + assert cfg["dup_enabled"] is False + + +class TestPipeline: + def test_full_pipeline_order(self) -> None: + pipeline = build_pipeline(_FULL_INPUT) + assert ordered_stages(pipeline) == [ + Stage.PRE_FILTER_TOPIC_RELEVANCE, + Stage.PRE_FILTER_DUPLICATE_DETECTION, + Stage.L2_ASSESSMENT, + ] + + def test_no_prefilter_is_l2_only(self) -> None: + pipeline = build_pipeline({}) + assert ordered_stages(pipeline) == [Stage.L2_ASSESSMENT] + assert next_stage(pipeline) == Stage.L2_ASSESSMENT + + def test_next_stage(self) -> None: + pipeline = build_pipeline(_FULL_INPUT) + assert next_stage(pipeline, Stage.PRE_FILTER_TOPIC_RELEVANCE) == ( + Stage.PRE_FILTER_DUPLICATE_DETECTION + ) + assert next_stage(pipeline, Stage.L2_ASSESSMENT) is None + + +class TestAdvanceOrFinalize: + def test_advances_to_next_pending_stage(self) -> None: + run = SimpleNamespace( + pipeline=build_pipeline(_FULL_INPUT), + stage=Stage.PRE_FILTER_TOPIC_RELEVANCE, + stage_status=StageStatus.COMPLETED, + status="processing", + ) + nxt = advance_or_finalize(run) + assert nxt == Stage.PRE_FILTER_DUPLICATE_DETECTION + assert run.stage == Stage.PRE_FILTER_DUPLICATE_DETECTION + assert run.stage_status == StageStatus.PENDING + + def test_finalizes_after_last_stage(self) -> None: + run = SimpleNamespace( + pipeline=build_pipeline({}), + stage=Stage.L2_ASSESSMENT, + stage_status=StageStatus.COMPLETED, + status="processing", + ) + assert advance_or_finalize(run) is None + assert run.stage == Stage.COMPLETED + assert run.stage_status == StageStatus.COMPLETED + assert run.status == "completed" + + +class TestBuildPrefilterRequests: + _CFG = { + "tr_columns": ["Problem"], + "tr_prompt": "rubric", + "dup_columns": ["Problem"], + } + + def test_topic_relevance_stage(self) -> None: + lines = build_prefilter_requests( + Stage.PRE_FILTER_TOPIC_RELEVANCE, [(0, {"Problem": "p"})], self._CFG + ) + assert len(lines) == 1 + + def test_duplicate_detection_stage(self) -> None: + lines = build_prefilter_requests( + Stage.PRE_FILTER_DUPLICATE_DETECTION, [(0, {"Problem": "p"})], self._CFG + ) + assert len(lines) == 1 + + def test_unknown_stage_raises(self) -> None: + with pytest.raises(ValueError): + build_prefilter_requests("BOGUS", [(0, {"Problem": "p"})], self._CFG) diff --git a/backend/app/tests/assessment/test_post_processing.py b/backend/app/tests/assessment/test_post_processing.py new file mode 100644 index 000000000..0ee7b81cc --- /dev/null +++ b/backend/app/tests/assessment/test_post_processing.py @@ -0,0 +1,212 @@ +"""Tests for the assessment export post-processing engine.""" + +from app.services.assessment.utils.post_processing import ( + apply_computed_columns, + apply_filter, + apply_post_processing, + apply_sort, + evaluate_formula, +) + + +class TestEvaluateFormula: + def test_addition(self) -> None: + assert evaluate_formula("@a + @b", {"a": 2, "b": 3}) == 5.0 + + def test_all_operators(self) -> None: + row = {"a": 10, "b": 4} + assert evaluate_formula("@a - @b", row) == 6.0 + assert evaluate_formula("@a * @b", row) == 40.0 + assert evaluate_formula("@a / @b", row) == 2.5 + assert evaluate_formula("-@a", row) == -10.0 + + def test_precedence_and_constants(self) -> None: + assert evaluate_formula("@a + @b * 0.5", {"a": 1, "b": 4}) == 3.0 + + def test_string_numeric_values_coerced(self) -> None: + assert evaluate_formula("@a + @b", {"a": "2", "b": "3"}) == 5.0 + + def test_missing_column_is_zero(self) -> None: + assert evaluate_formula("@a + @b", {"a": 5}) == 5.0 + + def test_non_numeric_value_is_zero(self) -> None: + assert evaluate_formula("@a + @b", {"a": 5, "b": "abc"}) == 5.0 + + def test_unsupported_operation_returns_none(self) -> None: + # Power operator is not in the safe-ops allowlist. + assert evaluate_formula("@a ** @b", {"a": 2, "b": 3}) is None + + def test_syntax_error_returns_none(self) -> None: + assert evaluate_formula("@a +", {"a": 1}) is None + + +class TestApplyComputedColumns: + def test_adds_column_in_place(self) -> None: + rows = [{"a": 1, "b": 2}, {"a": 3, "b": 4}] + apply_computed_columns(rows, [{"name": "total", "formula": "@a + @b"}]) + assert rows[0]["total"] == 3.0 + assert rows[1]["total"] == 7.0 + + def test_skips_empty_name_or_formula(self) -> None: + rows = [{"a": 1}] + apply_computed_columns( + rows, + [ + {"name": "", "formula": "@a"}, + {"name": "x", "formula": ""}, + ], + ) + assert rows[0] == {"a": 1} + + +class TestApplyFilter: + def test_no_rules_returns_all(self) -> None: + rows = [{"a": 1}, {"a": 2}] + assert apply_filter(rows, []) == rows + + def test_eq_ne(self) -> None: + rows = [{"x": "Yes"}, {"x": "no"}] + assert apply_filter(rows, [{"column": "x", "op": "eq", "value": "yes"}]) == [ + {"x": "Yes"} + ] + assert apply_filter(rows, [{"column": "x", "op": "ne", "value": "yes"}]) == [ + {"x": "no"} + ] + + def test_contains_not_contains(self) -> None: + rows = [{"x": "hello world"}, {"x": "bye"}] + assert apply_filter( + rows, [{"column": "x", "op": "contains", "value": "world"}] + ) == [{"x": "hello world"}] + assert apply_filter( + rows, [{"column": "x", "op": "not_contains", "value": "world"}] + ) == [{"x": "bye"}] + + def test_in_not_in(self) -> None: + rows = [{"x": "a"}, {"x": "b"}] + assert apply_filter( + rows, [{"column": "x", "op": "in", "value": ["a", "c"]}] + ) == [{"x": "a"}] + assert apply_filter( + rows, [{"column": "x", "op": "not_in", "value": ["a", "c"]}] + ) == [{"x": "b"}] + + def test_is_empty_is_not_empty(self) -> None: + rows = [{"x": ""}, {"x": "v"}, {"x": None}] + assert apply_filter(rows, [{"column": "x", "op": "is_empty"}]) == [ + {"x": ""}, + {"x": None}, + ] + assert apply_filter(rows, [{"column": "x", "op": "is_not_empty"}]) == [ + {"x": "v"} + ] + + def test_numeric_comparisons(self) -> None: + rows = [{"n": 1}, {"n": 5}, {"n": 10}] + assert apply_filter(rows, [{"column": "n", "op": "gt", "value": 4}]) == [ + {"n": 5}, + {"n": 10}, + ] + assert apply_filter(rows, [{"column": "n", "op": "lt", "value": 5}]) == [ + {"n": 1} + ] + assert apply_filter(rows, [{"column": "n", "op": "gte", "value": 5}]) == [ + {"n": 5}, + {"n": 10}, + ] + assert apply_filter(rows, [{"column": "n", "op": "lte", "value": 5}]) == [ + {"n": 1}, + {"n": 5}, + ] + + def test_numeric_filter_non_numeric_excluded(self) -> None: + rows = [{"n": "abc"}, {"n": 5}] + assert apply_filter(rows, [{"column": "n", "op": "gt", "value": 1}]) == [ + {"n": 5} + ] + + def test_unknown_op_keeps_row(self) -> None: + rows = [{"x": "a"}] + assert apply_filter(rows, [{"column": "x", "op": "weird", "value": 1}]) == rows + + def test_and_logic_across_rules(self) -> None: + rows = [{"n": 5, "x": "yes"}, {"n": 5, "x": "no"}, {"n": 1, "x": "yes"}] + out = apply_filter( + rows, + [ + {"column": "n", "op": "gte", "value": 5}, + {"column": "x", "op": "eq", "value": "yes"}, + ], + ) + assert out == [{"n": 5, "x": "yes"}] + + +class TestApplySort: + def test_no_rules_returns_input(self) -> None: + rows = [{"n": 2}, {"n": 1}] + assert apply_sort(rows, []) == rows + + def test_numeric_asc_desc(self) -> None: + rows = [{"n": 3}, {"n": 1}, {"n": 2}] + assert [ + r["n"] for r in apply_sort(rows, [{"column": "n", "direction": "asc"}]) + ] == [1, 2, 3] + assert [ + r["n"] for r in apply_sort(rows, [{"column": "n", "direction": "desc"}]) + ] == [3, 2, 1] + + def test_none_values_sort_last(self) -> None: + rows = [{"n": None}, {"n": 2}, {"n": 1}] + assert [ + r["n"] for r in apply_sort(rows, [{"column": "n", "direction": "asc"}]) + ] == [1, 2, None] + + def test_string_asc_desc(self) -> None: + rows = [{"s": "banana"}, {"s": "apple"}, {"s": "cherry"}] + assert [ + r["s"] for r in apply_sort(rows, [{"column": "s", "direction": "asc"}]) + ] == ["apple", "banana", "cherry"] + assert [ + r["s"] for r in apply_sort(rows, [{"column": "s", "direction": "desc"}]) + ] == ["cherry", "banana", "apple"] + + def test_multi_rule_priority(self) -> None: + rows = [ + {"grp": "a", "n": 2}, + {"grp": "b", "n": 1}, + {"grp": "a", "n": 1}, + ] + out = apply_sort( + rows, + [ + {"column": "grp", "direction": "asc"}, + {"column": "n", "direction": "desc"}, + ], + ) + assert out == [ + {"grp": "a", "n": 2}, + {"grp": "a", "n": 1}, + {"grp": "b", "n": 1}, + ] + + +class TestApplyPostProcessing: + def test_none_config_is_noop(self) -> None: + rows = [{"a": 1}] + assert apply_post_processing(rows, None) is rows + + def test_full_pipeline(self) -> None: + rows = [ + {"Novelty": 3, "Feasibility": 4}, + {"Novelty": 9, "Feasibility": 8}, + {"Novelty": 1, "Feasibility": 1}, + ] + config = { + "computed_columns": [ + {"name": "Total", "formula": "@Novelty + @Feasibility"} + ], + "filter": [{"column": "Total", "op": "gt", "value": 5}], + "sort": [{"column": "Total", "direction": "desc"}], + } + out = apply_post_processing(rows, config) + assert [r["Total"] for r in out] == [17.0, 7.0] diff --git a/backend/app/tests/assessment/test_prefilter_batching.py b/backend/app/tests/assessment/test_prefilter_batching.py new file mode 100644 index 000000000..4ecdf52e9 --- /dev/null +++ b/backend/app/tests/assessment/test_prefilter_batching.py @@ -0,0 +1,375 @@ +"""Tests for the single pipeline orchestrator (state-machine submit step).""" + +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from celery.exceptions import SoftTimeLimitExceeded + +from app.models.assessment import Stage, StageStatus +from app.services.assessment import tasks + + +@contextmanager +def _session_cm(session): + yield session + + +def _run(**kw): + base = { + "id": 5, + "assessment_id": 9, + "input": { + "prefilter_config": {"topic_relevance": {"columns": ["a"], "prompt": "p"}} + }, + "config_id": "c", + "config_version": 1, + "pipeline": None, + "stage": None, + "stage_status": None, + "status": "pending", + "stage_batches": None, + "total_items": 0, + } + base.update(kw) + return SimpleNamespace(**base) + + +class TestOrchestrate: + def test_inits_pipeline_and_submits_first_stage(self) -> None: + run = _run() + session = MagicMock() + session.get.return_value = run + with patch.object( + tasks, "Session", return_value=_session_cm(session) + ), patch.object(tasks, "flag_modified"), patch.object( + tasks, "_submit_stage" + ) as submit: + tasks._orchestrate(5, 1, 1) + assert run.stage == Stage.PRE_FILTER_TOPIC_RELEVANCE + assert run.stage_status == StageStatus.PENDING + submit.assert_called_once() + + def test_skips_when_not_pending(self) -> None: + run = _run( + pipeline={"stages": [{"stage": Stage.L2_ASSESSMENT, "order": 1}]}, + stage=Stage.L2_ASSESSMENT, + stage_status=StageStatus.PROCESSING, + ) + session = MagicMock() + session.get.return_value = run + with patch.object( + tasks, "Session", return_value=_session_cm(session) + ), patch.object(tasks, "_submit_stage") as submit: + tasks._orchestrate(5, 1, 1) + submit.assert_not_called() + + def test_terminal_stage_returns(self) -> None: + run = _run(stage=Stage.COMPLETED) + session = MagicMock() + session.get.return_value = run + with patch.object( + tasks, "Session", return_value=_session_cm(session) + ), patch.object(tasks, "_submit_stage") as submit: + tasks._orchestrate(5, 1, 1) + submit.assert_not_called() + + +class TestSubmitCurrentStage: + def _ctx(self, accepted): + return [ + patch.object( + tasks, + "_resolve_run_context", + return_value=(SimpleNamespace(), MagicMock(), SimpleNamespace(), None), + ), + patch.object(tasks, "_load_dataset_rows", return_value=[{"a": "1"}] * 3), + patch.object(tasks, "_accepted_indices", return_value=accepted), + patch.object(tasks, "recompute_assessment_status"), + ] + + def test_submits_prefilter_batch(self) -> None: + run = _run( + stage=Stage.PRE_FILTER_TOPIC_RELEVANCE, + stage_status=StageStatus.PENDING, + stage_batches={}, + ) + session = MagicMock() + batch_job = SimpleNamespace(id=7, total_items=3) + p = self._ctx([0, 1, 2]) + with p[0], p[1], p[2], p[3], patch.object(tasks, "flag_modified"), patch.object( + tasks, "build_prefilter_requests", return_value=[{"key": "tr_0"}] + ), patch.object(tasks, "submit_prefilter_batch", return_value=batch_job): + tasks._submit_stage(session, run, 1, 1) + assert run.stage_batches[Stage.PRE_FILTER_TOPIC_RELEVANCE] == 7 + assert run.stage_status == StageStatus.PROCESSING + + def test_zero_accepted_advances(self) -> None: + run = _run( + stage=Stage.L2_ASSESSMENT, + stage_status=StageStatus.PENDING, + stage_batches={}, + ) + session = MagicMock() + p = self._ctx([]) + with p[0], p[1], p[2], p[3], patch.object(tasks, "_persist_advance") as advance: + tasks._submit_stage(session, run, 1, 1) + advance.assert_called_once() + + +class TestAcceptedIndices: + def test_uses_persisted_indices_without_downloading(self) -> None: + """Stored accepted set is read directly — no gate batch re-download.""" + run = _run( + pipeline={ + "stages": [ + {"stage": Stage.PRE_FILTER_TOPIC_RELEVANCE, "order": 1}, + {"stage": Stage.L2_ASSESSMENT, "order": 2}, + ], + "accepted_indices": [0, 2, 5], + }, + stage=Stage.L2_ASSESSMENT, + ) + with patch.object(tasks, "load_raw_batch_results") as load: + result = tasks._accepted_indices( + MagicMock(), run, total_rows=10, project_id=1 + ) + assert result == [0, 2, 5] + load.assert_not_called() + + def test_persisted_indices_clamped_to_total_rows(self) -> None: + run = _run( + pipeline={"stages": [], "accepted_indices": [0, 3, 99]}, + stage=Stage.L2_ASSESSMENT, + ) + result = tasks._accepted_indices(MagicMock(), run, total_rows=4, project_id=1) + assert result == [0, 3] + + def test_falls_back_to_full_range_when_nothing_persisted(self) -> None: + run = _run( + pipeline={"stages": [{"stage": Stage.L2_ASSESSMENT, "order": 1}]}, + stage=Stage.L2_ASSESSMENT, + ) + result = tasks._accepted_indices(MagicMock(), run, total_rows=3, project_id=1) + assert result == [0, 1, 2] + + +class TestGuardEntrypoint: + def test_unexpected_exception_marks_failed_and_reraises(self) -> None: + with patch.object( + tasks, "_orchestrate", side_effect=RuntimeError("boom") + ), patch.object(tasks, "_mark_run_failed") as mark: + with pytest.raises(RuntimeError): + tasks.execute_assessment_pipeline(5, 1, 1) + mark.assert_called_once() + + def test_soft_timeout_marks_failed_and_reraises(self) -> None: + with patch.object( + tasks, "_orchestrate", side_effect=SoftTimeLimitExceeded() + ), patch.object(tasks, "_mark_run_failed") as mark: + with pytest.raises(SoftTimeLimitExceeded): + tasks.execute_assessment_pipeline(5, 1, 1) + mark.assert_called_once() + + +class TestMarkRunFailed: + def test_marks_non_terminal_run_failed(self) -> None: + run = _run(stage=Stage.L2_ASSESSMENT, stage_status=StageStatus.PROCESSING) + session = MagicMock() + session.get.return_value = run + with patch.object( + tasks, "Session", return_value=_session_cm(session) + ), patch.object(tasks, "update_assessment_run_status") as upd, patch.object( + tasks, "recompute_assessment_status" + ): + tasks._mark_run_failed(5, "dead") + assert run.stage_status == StageStatus.FAILED + upd.assert_called_once() + + def test_skips_terminal_run(self) -> None: + run = _run(stage=Stage.COMPLETED) + session = MagicMock() + session.get.return_value = run + with patch.object( + tasks, "Session", return_value=_session_cm(session) + ), patch.object(tasks, "update_assessment_run_status") as upd: + tasks._mark_run_failed(5, "dead") + upd.assert_not_called() + + +class TestDispatch: + def test_dispatch_enqueues_task(self) -> None: + with patch.object(tasks, "run_assessment_pipeline") as task: + tasks._dispatch(5, 1, 2) + task.delay.assert_called_once() + assert task.delay.call_args.kwargs["run_id"] == 5 + + +class TestResolveRunContext: + def test_success(self) -> None: + session = MagicMock() + run = _run() + session.get.return_value = SimpleNamespace(dataset_id=3) + with patch.object( + tasks, "get_assessment_dataset_by_id", return_value=MagicMock() + ), patch.object( + tasks, "resolve_evaluation_config", return_value=({"x": 1}, None) + ): + _a, _d, blob, err = tasks._resolve_run_context(session, run, 1, 1) + assert blob == {"x": 1} + assert err is None + + def test_missing_parent(self) -> None: + session = MagicMock() + session.get.return_value = None + _a, _d, blob, err = tasks._resolve_run_context(session, _run(), 1, 1) + assert blob is None + assert "not found" in err + + def test_config_error(self) -> None: + session = MagicMock() + session.get.return_value = SimpleNamespace(dataset_id=3) + with patch.object( + tasks, "get_assessment_dataset_by_id", return_value=MagicMock() + ), patch.object( + tasks, "resolve_evaluation_config", return_value=(None, "bad config") + ): + _a, _d, blob, err = tasks._resolve_run_context(session, _run(), 1, 1) + assert blob is None + assert "bad config" in err + + +class TestAcceptedIndicesFallback: + def test_recomputes_from_gate_batch(self) -> None: + run = _run( + pipeline={ + "stages": [ + {"stage": Stage.PRE_FILTER_TOPIC_RELEVANCE, "order": 1}, + {"stage": Stage.L2_ASSESSMENT, "order": 2}, + ] + }, + stage=Stage.L2_ASSESSMENT, + stage_batches={Stage.PRE_FILTER_TOPIC_RELEVANCE: 1}, + ) + with patch.object( + tasks, "get_batch_job", return_value=SimpleNamespace(provider="openai") + ), patch.object(tasks, "load_raw_batch_results", return_value=[]), patch.object( + tasks, "parse_assessment_output", return_value=[] + ), patch.dict( + tasks.STAGE_PARSERS, + { + Stage.PRE_FILTER_TOPIC_RELEVANCE: lambda outs: { + 0: {"verdict": True}, + 1: {"verdict": False}, + } + }, + ): + result = tasks._accepted_indices( + MagicMock(), run, total_rows=3, project_id=1 + ) + # Only row 0 passed the gate. + assert result == [0] + + +class TestSubmitStageBranches: + def test_config_error_fails_run(self) -> None: + run = _run(stage=Stage.L2_ASSESSMENT, stage_status=StageStatus.PENDING) + with patch.object( + tasks, "_resolve_run_context", return_value=(None, None, None, "boom") + ), patch.object(tasks, "update_assessment_run_status") as upd, patch.object( + tasks, "recompute_assessment_status" + ): + tasks._submit_stage(MagicMock(), run, 1, 1) + assert run.stage_status == StageStatus.FAILED + upd.assert_called_once() + + def test_empty_dataset_fails_run(self) -> None: + run = _run(stage=Stage.L2_ASSESSMENT, stage_status=StageStatus.PENDING) + with patch.object( + tasks, + "_resolve_run_context", + return_value=(SimpleNamespace(), MagicMock(), SimpleNamespace(), None), + ), patch.object(tasks, "_load_dataset_rows", return_value=[]), patch.object( + tasks, "update_assessment_run_status" + ) as upd, patch.object( + tasks, "recompute_assessment_status" + ): + tasks._submit_stage(MagicMock(), run, 1, 1) + assert run.stage_status == StageStatus.FAILED + upd.assert_called_once() + + def test_submits_l2_batch(self) -> None: + run = _run( + stage=Stage.L2_ASSESSMENT, + stage_status=StageStatus.PENDING, + stage_batches={}, + ) + batch_job = SimpleNamespace(id=8, total_items=2) + with patch.object( + tasks, + "_resolve_run_context", + return_value=(SimpleNamespace(), MagicMock(), SimpleNamespace(), None), + ), patch.object( + tasks, "_load_dataset_rows", return_value=[{"a": "1"}] * 3 + ), patch.object( + tasks, "_accepted_indices", return_value=[0, 1] + ), patch.object( + tasks, "flag_modified" + ), patch.object( + tasks, "submit_assessment_batch", return_value=batch_job + ), patch.object( + tasks, "recompute_assessment_status" + ): + tasks._submit_stage(MagicMock(), run, 1, 1) + assert run.total_items == 2 + assert run.stage_batches[Stage.L2_ASSESSMENT] == 8 + + def test_unknown_stage_raises(self) -> None: + run = _run(stage="BOGUS", stage_status=StageStatus.PENDING, stage_batches={}) + with patch.object( + tasks, + "_resolve_run_context", + return_value=(SimpleNamespace(), MagicMock(), SimpleNamespace(), None), + ), patch.object( + tasks, "_load_dataset_rows", return_value=[{"a": "1"}] + ), patch.object( + tasks, "_accepted_indices", return_value=[0] + ): + with pytest.raises(ValueError): + tasks._submit_stage(MagicMock(), run, 1, 1) + + +class TestPersistAdvance: + def test_dispatches_next_stage(self) -> None: + run = _run() + with patch.object( + tasks, "advance_or_finalize", return_value=Stage.L2_ASSESSMENT + ), patch.object(tasks, "recompute_assessment_status"), patch.object( + tasks, "_dispatch" + ) as dispatch: + tasks._persist_advance(MagicMock(), run, 1, 1) + dispatch.assert_called_once() + + def test_finalize_does_not_dispatch(self) -> None: + run = _run() + with patch.object( + tasks, "advance_or_finalize", return_value=None + ), patch.object(tasks, "recompute_assessment_status"), patch.object( + tasks, "_dispatch" + ) as dispatch: + tasks._persist_advance(MagicMock(), run, 1, 1) + dispatch.assert_not_called() + + def test_enqueue_failure_marks_failed(self) -> None: + run = _run(stage=Stage.L2_ASSESSMENT) + with patch.object( + tasks, "advance_or_finalize", return_value=Stage.L2_ASSESSMENT + ), patch.object(tasks, "recompute_assessment_status"), patch.object( + tasks, "_dispatch", side_effect=RuntimeError("broker down") + ), patch.object( + tasks, "update_assessment_run_status" + ) as upd: + tasks._persist_advance(MagicMock(), run, 1, 1) + assert run.stage_status == StageStatus.FAILED + upd.assert_called_once() diff --git a/backend/app/tests/assessment/test_processing.py b/backend/app/tests/assessment/test_processing.py index 958ab3019..84bace368 100644 --- a/backend/app/tests/assessment/test_processing.py +++ b/backend/app/tests/assessment/test_processing.py @@ -1,17 +1,20 @@ """Tests for assessment/processing.py pure functions.""" import json -from unittest.mock import AsyncMock, MagicMock, patch +from types import SimpleNamespace +from unittest.mock import MagicMock, patch import pytest +from app.crud.assessment import processing as processing_mod from app.crud.assessment.processing import ( _get_batch_provider, + _record_gate_stats, _sanitize_json_output, - check_and_process_assessment, parse_assessment_output, - poll_all_pending_assessments, + process_run_batches, ) +from app.models.assessment import Stage, StageStatus class TestSanitizeJsonOutput: @@ -210,10 +213,51 @@ def test_google_native_provider_accepted(self) -> None: assert results[0]["output"] == "out" +class TestRecordGateStats: + def _patches(self, parsed): + return [ + patch.object(processing_mod, "load_raw_batch_results", return_value=[]), + patch.object(processing_mod, "parse_assessment_output", return_value=[]), + patch.dict( + processing_mod.STAGE_PARSERS, + {Stage.PRE_FILTER_TOPIC_RELEVANCE: lambda _outputs: parsed}, + ), + patch.object(processing_mod, "update_assessment_run_prefilter_stats"), + patch.object(processing_mod, "flag_modified"), + ] + + def test_persists_accepted_indices_to_pipeline(self) -> None: + run = SimpleNamespace(id=1, assessment_id=2, pipeline={"stages": []}) + parsed = { + 0: {"verdict": True}, + 1: {"verdict": False}, + 2: {"verdict": True}, + } + p = self._patches(parsed) + with p[0], p[1], p[2], p[3], p[4]: + _record_gate_stats( + MagicMock(), run, Stage.PRE_FILTER_TOPIC_RELEVANCE, MagicMock(), 1 + ) + assert run.pipeline["accepted_indices"] == [0, 2] + + def test_intersects_with_prior_gate(self) -> None: + run = SimpleNamespace( + id=1, assessment_id=2, pipeline={"accepted_indices": [2, 3]} + ) + parsed = {2: {"verdict": True}, 3: {"verdict": False}, 4: {"verdict": True}} + p = self._patches(parsed) + with p[0], p[1], p[2], p[3], p[4]: + _record_gate_stats( + MagicMock(), run, Stage.PRE_FILTER_TOPIC_RELEVANCE, MagicMock(), 1 + ) + # 2 passes this gate and was in the prior accepted set; 4 wasn't; 3 rejected. + assert run.pipeline["accepted_indices"] == [2] + + class TestGetBatchProvider: def test_unsupported_provider_raises(self) -> None: session = MagicMock() - with pytest.raises(ValueError, match="Unsupported provider"): + with pytest.raises(ValueError, match="Unsupported batch provider"): _get_batch_provider( session=session, provider_name="anthropic", @@ -225,8 +269,8 @@ def test_openai_provider_returned(self) -> None: session = MagicMock() mock_client = MagicMock() with patch( - "app.crud.assessment.processing.get_openai_client", return_value=mock_client - ), patch("app.crud.assessment.processing.OpenAIBatchProvider") as mock_cls: + "app.services.assessment.stages.get_openai_client", return_value=mock_client + ), patch("app.services.assessment.stages.OpenAIBatchProvider") as mock_cls: _get_batch_provider( session=session, provider_name="openai", @@ -238,8 +282,8 @@ def test_openai_provider_returned(self) -> None: def test_google_provider_returned(self) -> None: session = MagicMock() mock_gemini = MagicMock() - with patch("app.crud.assessment.processing.GeminiClient") as mock_cls, patch( - "app.crud.assessment.processing.GeminiBatchProvider" + with patch("app.services.assessment.stages.GeminiClient") as mock_cls, patch( + "app.services.assessment.stages.GeminiBatchProvider" ) as mock_batch_cls: mock_cls.from_credentials.return_value = mock_gemini _get_batch_provider( @@ -251,186 +295,159 @@ def test_google_provider_returned(self) -> None: mock_batch_cls.assert_called_once_with(client=mock_gemini.client) -class TestPollAllPendingAssessments: - @pytest.mark.asyncio - async def test_delegates_to_cron(self) -> None: - session = MagicMock() - expected = {"processed": 2, "failed": 0} - with patch( - "app.crud.assessment.cron.poll_all_pending_assessment_evaluations", - new=AsyncMock(return_value=expected), - ): - result = await poll_all_pending_assessments(session=session) - assert result == expected - - -class TestCheckAndProcessAssessment: - def _make_run(self) -> MagicMock: - run = MagicMock() - run.id = 1 - run.batch_job_id = 99 - run.status = "processing" - run.assessment_id = 10 - run.organization_id = 1 - run.project_id = 1 - run.run_name = "exp" - return run +class TestProcessRunBatches: + def _parent(self): + return SimpleNamespace(organization_id=1, project_id=1, experiment_name="exp") + + def _run(self): + return SimpleNamespace( + id=1, + assessment_id=10, + status="processing", + stage=Stage.L2_ASSESSMENT, + stage_status=StageStatus.PROCESSING, + stage_batches={Stage.L2_ASSESSMENT: 5}, + ) @pytest.mark.asyncio - async def test_completed_with_no_output_file_and_failed_counts(self) -> None: + async def test_completes_stage_and_finalizes(self) -> None: session = MagicMock() - run = self._make_run() - batch_job = MagicMock() - batch_job.provider = "openai" - batch_job.provider_status = "completed" - batch_job.provider_output_file_id = None - batch_job.id = 99 + session.get.return_value = self._parent() + run = self._run() with patch( - "app.crud.assessment.processing.get_batch_job", return_value=batch_job + "app.crud.assessment.processing.get_batch_job", return_value=MagicMock() ), patch( "app.crud.assessment.processing._get_batch_provider", return_value=MagicMock(), ), patch( - "app.crud.assessment.processing.poll_batch_status", - return_value={ - "request_counts": {"failed": 3, "completed": 0, "total": 3}, - "error_file_id": "err-1", - }, - ), patch( - "app.crud.assessment.processing.update_assessment_run_status" + "app.crud.assessment.processing._poll_stage_outcome", + return_value="completed", ), patch( + "app.crud.assessment.processing.advance_or_finalize", return_value=None + ) as advance, patch( "app.crud.assessment.processing.recompute_assessment_status" ): - result = await check_and_process_assessment(run=run, session=session) + result = await process_run_batches(run=run, session=session) - assert result["action"] == "failed" - assert result["current_status"] == "failed" + advance.assert_called_once() + assert result["action"] == "processed" + assert run.stage_status == StageStatus.COMPLETED @pytest.mark.asyncio - async def test_completed_with_no_output_file_not_ready(self) -> None: + async def test_advances_and_dispatches_next_stage(self) -> None: session = MagicMock() - run = self._make_run() - batch_job = MagicMock() - batch_job.provider = "openai" - batch_job.provider_status = "completed" - batch_job.provider_output_file_id = None - batch_job.id = 99 + session.get.return_value = self._parent() + run = self._run() + run.stage = Stage.PRE_FILTER_TOPIC_RELEVANCE + run.stage_batches = {Stage.PRE_FILTER_TOPIC_RELEVANCE: 5} with patch( - "app.crud.assessment.processing.get_batch_job", return_value=batch_job + "app.crud.assessment.processing.get_batch_job", return_value=MagicMock() ), patch( "app.crud.assessment.processing._get_batch_provider", return_value=MagicMock(), ), patch( - "app.crud.assessment.processing.poll_batch_status", - return_value={"request_counts": {"failed": 0, "completed": 1, "total": 1}}, - ): - result = await check_and_process_assessment(run=run, session=session) + "app.crud.assessment.processing._poll_stage_outcome", + return_value="completed", + ), patch( + "app.crud.assessment.processing._record_gate_stats" + ) as gate_stats, patch( + "app.crud.assessment.processing.advance_or_finalize", + return_value=Stage.L2_ASSESSMENT, + ), patch( + "app.crud.assessment.processing.recompute_assessment_status" + ), patch( + "app.crud.assessment.processing.run_assessment_pipeline" + ) as dispatch: + result = await process_run_batches(run=run, session=session) - assert result["action"] == "no_change" + gate_stats.assert_called_once() # TR is a gate stage + dispatch.delay.assert_called_once() + assert result["action"] == "processed" @pytest.mark.asyncio - async def test_completed_with_output_file_processes_results(self) -> None: + async def test_no_change_while_in_progress(self) -> None: session = MagicMock() - run = self._make_run() - batch_job = MagicMock() - batch_job.provider = "openai" - batch_job.provider_status = "completed" - batch_job.provider_output_file_id = "file-1" - batch_job.id = 99 + session.get.return_value = self._parent() + run = self._run() with patch( - "app.crud.assessment.processing.get_batch_job", return_value=batch_job + "app.crud.assessment.processing.get_batch_job", return_value=MagicMock() ), patch( "app.crud.assessment.processing._get_batch_provider", return_value=MagicMock(), ), patch( - "app.crud.assessment.processing.poll_batch_status", - return_value={}, - ), patch( - "app.crud.assessment.processing.download_batch_results", - return_value=[{"custom_id": "row_0"}], - ), patch( - "app.crud.assessment.processing.upload_batch_results_to_object_store", - return_value="s3://results", - ), patch( - "app.crud.assessment.processing.parse_assessment_output", - return_value=[{"row_id": "row_0", "error": None}], - ), patch( - "app.crud.assessment.processing.update_assessment_run_status" - ), patch( - "app.crud.assessment.processing.recompute_assessment_status" + "app.crud.assessment.processing._poll_stage_outcome", + return_value="no_change", ): - result = await check_and_process_assessment(run=run, session=session) + result = await process_run_batches(run=run, session=session) - assert result["action"] == "processed" + assert result["action"] == "no_change" @pytest.mark.asyncio - async def test_terminal_provider_status_marks_failed(self) -> None: + async def test_failed_stage_fails_run(self) -> None: session = MagicMock() - run = self._make_run() - batch_job = MagicMock() - batch_job.provider = "openai" - batch_job.provider_status = "failed" - batch_job.error_message = "provider failed" + session.get.return_value = self._parent() + run = self._run() with patch( - "app.crud.assessment.processing.get_batch_job", return_value=batch_job + "app.crud.assessment.processing.get_batch_job", + return_value=MagicMock(error_message="boom"), ), patch( "app.crud.assessment.processing._get_batch_provider", return_value=MagicMock(), ), patch( - "app.crud.assessment.processing.poll_batch_status", return_value={} + "app.crud.assessment.processing._poll_stage_outcome", return_value="failed" ), patch( "app.crud.assessment.processing.update_assessment_run_status" ), patch( "app.crud.assessment.processing.recompute_assessment_status" ): - result = await check_and_process_assessment(run=run, session=session) + result = await process_run_batches(run=run, session=session) assert result["action"] == "failed" - assert result["provider_status"] == "failed" + # Failed stage preserved (so a resume knows where to restart); only status flips. + assert run.stage == Stage.L2_ASSESSMENT + assert run.stage_status == StageStatus.FAILED - @pytest.mark.asyncio - async def test_still_processing_returns_no_change(self) -> None: - session = MagicMock() - run = self._make_run() - batch_job = MagicMock() - batch_job.provider = "openai" - batch_job.provider_status = "in_progress" + +class TestPollStageOutcome: + def _job(self, **kw): + base = dict(provider_status="completed", provider_output_file_id=None) + base.update(kw) + return SimpleNamespace(**base) + + def test_all_failed_no_output_is_failed(self) -> None: + from app.crud.assessment.processing import _poll_stage_outcome with patch( - "app.crud.assessment.processing.get_batch_job", return_value=batch_job - ), patch( - "app.crud.assessment.processing._get_batch_provider", - return_value=MagicMock(), - ), patch( - "app.crud.assessment.processing.poll_batch_status", return_value={} + "app.crud.assessment.processing.poll_batch_status", + return_value={ + "request_counts": {"completed": 0, "failed": 3}, + "error_file_id": "err", + }, ): - result = await check_and_process_assessment(run=run, session=session) - - assert result["action"] == "no_change" + outcome = _poll_stage_outcome(MagicMock(), MagicMock(), self._job()) + assert outcome == "failed" - @pytest.mark.asyncio - async def test_exception_path_marks_failed(self) -> None: - session = MagicMock() - run = self._make_run() - run.batch_job_id = None + def test_no_output_not_ready_is_no_change(self) -> None: + from app.crud.assessment.processing import _poll_stage_outcome with patch( - "app.crud.assessment.processing.update_assessment_run_status" - ) as update_run, patch( - "app.crud.assessment.processing.recompute_assessment_status" + "app.crud.assessment.processing.poll_batch_status", + return_value={"request_counts": {"completed": 0, "failed": 0}}, ): - result = await check_and_process_assessment(run=run, session=session) + outcome = _poll_stage_outcome(MagicMock(), MagicMock(), self._job()) + assert outcome == "no_change" - assert result["action"] == "failed" - assert result["provider_status"] == "unknown" - assert result["error"] == "Assessment run 1 has no batch_job_id" - update_run.assert_called_once_with( - session=session, - run=run, - status="failed", - error_message="Assessment run 1 has no batch_job_id", - ) + def test_output_ready_is_completed(self) -> None: + from app.crud.assessment.processing import _poll_stage_outcome + + with patch( + "app.crud.assessment.processing.poll_batch_status", return_value={} + ), patch("app.crud.assessment.processing.process_completed_batch"): + outcome = _poll_stage_outcome( + MagicMock(), MagicMock(), self._job(provider_output_file_id="file_1") + ) + assert outcome == "completed" diff --git a/backend/app/tests/assessment/test_service.py b/backend/app/tests/assessment/test_service.py index b3654fa9b..e22c50e90 100644 --- a/backend/app/tests/assessment/test_service.py +++ b/backend/app/tests/assessment/test_service.py @@ -7,10 +7,16 @@ import pytest from fastapi import HTTPException -from app.models.assessment import AssessmentConfigRef, AssessmentCreate +from app.models.assessment import ( + AssessmentConfigRef, + AssessmentCreate, + Stage, + StageStatus, +) from app.models.config.config import ConfigTag from app.services.assessment.service import ( _build_retry_request, + resume_assessment_run, retry_assessment, retry_assessment_run, start_assessment, @@ -142,9 +148,6 @@ def test_google_provider_is_supported(self) -> None: config_blob = SimpleNamespace( completion=SimpleNamespace(provider="google", params={"model": "gemini"}) ) - batch_job = MagicMock() - batch_job.id = 101 - batch_job.total_items = 3 with ( patch( @@ -163,14 +166,7 @@ def test_google_provider_is_supported(self) -> None: "app.services.assessment.service.create_assessment_run", return_value=run, ), - patch( - "app.services.assessment.service.submit_assessment_batch", - return_value=batch_job, - ) as submit_batch, - patch( - "app.services.assessment.service.update_assessment_run_status", - return_value=run, - ), + patch("app.celery.tasks.job_execution.run_assessment_pipeline") as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -181,8 +177,10 @@ def test_google_provider_is_supported(self) -> None: project_id=1, ) + # Google is an accepted provider — no rejection, one Celery task dispatched. assert response.num_configs == 1 - assert submit_batch.call_args.kwargs["config_blob"] is config_blob + dispatch.delay.assert_called_once() + assert dispatch.delay.call_args.kwargs["run_id"] == 11 def test_defaults_missing_provider_to_openai(self) -> None: session = MagicMock() @@ -194,9 +192,6 @@ def test_defaults_missing_provider_to_openai(self) -> None: config_blob = SimpleNamespace( completion=SimpleNamespace(provider=None, params={"model": "gpt-4.1-mini"}) ) - batch_job = MagicMock() - batch_job.id = 101 - batch_job.total_items = 3 with ( patch( @@ -215,14 +210,7 @@ def test_defaults_missing_provider_to_openai(self) -> None: "app.services.assessment.service.create_assessment_run", return_value=run, ) as create_run, - patch( - "app.services.assessment.service.submit_assessment_batch", - return_value=batch_job, - ) as submit_batch, - patch( - "app.services.assessment.service.update_assessment_run_status", - return_value=run, - ), + patch("app.celery.tasks.job_execution.run_assessment_pipeline") as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -238,11 +226,7 @@ def test_defaults_missing_provider_to_openai(self) -> None: assert response.runs[0].run_id == 11 assessment_input = create_run.call_args.kwargs["assessment_input"] assert assessment_input["system_instruction"] == "Assess strictly" - assert ( - submit_batch.call_args.kwargs["assessment_input"]["system_instruction"] - == "Assess strictly" - ) - submit_batch.assert_called_once() + dispatch.delay.assert_called_once() def test_rejects_default_tagged_config(self) -> None: """Configs explicitly tagged 'default' must be rejected for assessment.""" @@ -278,14 +262,15 @@ def test_rejects_default_tagged_config(self) -> None: # Tag check must fire BEFORE config resolution. resolve.assert_not_called() - def test_batch_submission_failure_marks_run_failed(self) -> None: + def test_dispatches_one_celery_task_per_config(self) -> None: + """Batch submission moved to the Celery task; start_assessment only + creates runs and dispatches one task per resolved config.""" session = MagicMock() request = _make_request(UUID("00000000-0000-0000-0000-000000000001")) dataset = _make_dataset() assessment = MagicMock() assessment.id = 21 run = _make_run() - run.status = "failed" config_blob = SimpleNamespace( completion=SimpleNamespace( provider="openai", params={"model": "gpt-4.1-mini"} @@ -309,14 +294,7 @@ def test_batch_submission_failure_marks_run_failed(self) -> None: "app.services.assessment.service.create_assessment_run", return_value=run, ), - patch( - "app.services.assessment.service.submit_assessment_batch", - side_effect=RuntimeError("submit failed"), - ), - patch( - "app.services.assessment.service.update_assessment_run_status", - return_value=run, - ) as update_run, + patch("app.celery.tasks.job_execution.run_assessment_pipeline") as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -327,7 +305,10 @@ def test_batch_submission_failure_marks_run_failed(self) -> None: project_id=1, ) assert response.num_configs == 1 - assert update_run.called + dispatch.delay.assert_called_once() + assert dispatch.delay.call_args.kwargs["run_id"] == 11 + assert dispatch.delay.call_args.kwargs["organization_id"] == 1 + assert dispatch.delay.call_args.kwargs["project_id"] == 1 class TestRetryHelpers: @@ -403,3 +384,61 @@ def test_retry_assessment_wrappers(self) -> None: ): resp2 = retry_assessment_run(session, run, 1, 1) assert resp2.assessment_id == 1 + + +class TestResumeAssessmentRun: + def _failed_run(self, stage: str) -> MagicMock: + run = MagicMock() + run.id = 11 + run.assessment_id = 21 + run.config_id = UUID("00000000-0000-0000-0000-000000000001") + run.config_version = 1 + run.status = "failed" + run.stage = stage + run.stage_status = StageStatus.FAILED + run.pipeline = { + "stages": [ + {"stage": Stage.PRE_FILTER_TOPIC_RELEVANCE, "order": 1}, + {"stage": Stage.PRE_FILTER_DUPLICATE_DETECTION, "order": 2}, + {"stage": Stage.L2_ASSESSMENT, "order": 3}, + ] + } + run.assessment = SimpleNamespace(id=21, experiment_name="exp", dataset_id=7) + return run + + def test_rejects_non_failed_run(self) -> None: + run = self._failed_run(Stage.L2_ASSESSMENT) + run.stage_status = StageStatus.PROCESSING + with pytest.raises(HTTPException) as exc: + resume_assessment_run(MagicMock(), run, 1, 1) + assert exc.value.status_code == 400 + + def test_rejects_stage_not_in_pipeline(self) -> None: + run = self._failed_run(Stage.FAILED) + with pytest.raises(HTTPException) as exc: + resume_assessment_run(MagicMock(), run, 1, 1) + assert exc.value.status_code == 400 + + def test_resumes_in_place_from_failed_stage(self) -> None: + run = self._failed_run(Stage.L2_ASSESSMENT) + session = MagicMock() + + with ( + patch( + "app.services.assessment.service.get_assessment_dataset_by_id", + return_value=_make_dataset(), + ), + patch("app.services.assessment.service.recompute_assessment_status"), + patch("app.celery.tasks.job_execution.run_assessment_pipeline") as dispatch, + ): + resp = resume_assessment_run(session, run, 1, 1) + + # Same run, reset to PENDING at the same (failed) stage, re-dispatched. + assert run.stage == Stage.L2_ASSESSMENT + assert run.stage_status == StageStatus.PENDING + assert run.status == "processing" + assert run.error_message is None + dispatch.delay.assert_called_once() + assert dispatch.delay.call_args.kwargs["run_id"] == 11 + assert resp.assessment_id == 21 + assert resp.num_configs == 1 diff --git a/backend/app/tests/assessment/test_tasks_failure_guard.py b/backend/app/tests/assessment/test_tasks_failure_guard.py new file mode 100644 index 000000000..9e48486b3 --- /dev/null +++ b/backend/app/tests/assessment/test_tasks_failure_guard.py @@ -0,0 +1,82 @@ +"""Tests for the pipeline orchestrator failure guard (no dangling runs).""" + +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from celery.exceptions import SoftTimeLimitExceeded + +from app.models.assessment import Stage +from app.services.assessment import tasks + + +@contextmanager +def _session_cm(session): + yield session + + +def _patch_session(run): + session = MagicMock() + session.get.return_value = run + cm = patch.object(tasks, "Session", return_value=_session_cm(session)) + return cm, session + + +class TestMarkRunFailed: + def test_marks_non_terminal_run_failed(self) -> None: + run = SimpleNamespace( + stage=Stage.PRE_FILTER_TOPIC_RELEVANCE, + stage_status="PENDING", + assessment_id=7, + ) + cm, session = _patch_session(run) + with cm, patch.object( + tasks, "update_assessment_run_status" + ) as upd, patch.object(tasks, "recompute_assessment_status") as recompute: + tasks._mark_run_failed(11, "boom") + upd.assert_called_once() + assert upd.call_args.kwargs["status"] == "failed" + # Failed stage preserved for resume; only stage_status flips to FAILED. + assert run.stage == Stage.PRE_FILTER_TOPIC_RELEVANCE + assert run.stage_status == "FAILED" + recompute.assert_called_once_with(session=session, assessment_id=7) + + def test_skips_terminal_run(self) -> None: + run = SimpleNamespace(stage=Stage.COMPLETED, assessment_id=7) + cm, _ = _patch_session(run) + with cm, patch.object(tasks, "update_assessment_run_status") as upd: + tasks._mark_run_failed(11, "boom") + upd.assert_not_called() + + def test_missing_run_noop(self) -> None: + cm, _ = _patch_session(None) + with cm, patch.object(tasks, "update_assessment_run_status") as upd: + tasks._mark_run_failed(11, "boom") + upd.assert_not_called() + + +class TestExecutePipelineGuard: + def test_soft_timeout_marks_failed_and_reraises(self) -> None: + with patch.object( + tasks, "_orchestrate", side_effect=SoftTimeLimitExceeded() + ), patch.object(tasks, "_mark_run_failed") as mark: + with pytest.raises(SoftTimeLimitExceeded): + tasks.execute_assessment_pipeline(11, 1, 1) + mark.assert_called_once() + assert mark.call_args.args[0] == 11 + + def test_unexpected_exception_marks_failed_and_reraises(self) -> None: + with patch.object( + tasks, "_orchestrate", side_effect=RuntimeError("kaboom") + ), patch.object(tasks, "_mark_run_failed") as mark: + with pytest.raises(RuntimeError): + tasks.execute_assessment_pipeline(11, 1, 1) + mark.assert_called_once_with(11, "Assessment run failed unexpectedly.") + + def test_success_does_not_mark_failed(self) -> None: + with patch.object(tasks, "_orchestrate", return_value=None), patch.object( + tasks, "_mark_run_failed" + ) as mark: + tasks.execute_assessment_pipeline(11, 1, 1) + mark.assert_not_called() diff --git a/backend/app/tests/assessment/test_topic_relevance.py b/backend/app/tests/assessment/test_topic_relevance.py new file mode 100644 index 000000000..b06166982 --- /dev/null +++ b/backend/app/tests/assessment/test_topic_relevance.py @@ -0,0 +1,135 @@ +"""Tests for the topic-relevance per-record request builder and result parser.""" + +import json +from unittest.mock import patch + +from app.models.assessment import AssessmentAttachment +from app.services.assessment.prefilter import constants +from app.services.assessment.prefilter.topic_relevance import ( + build_topic_relevance_requests, + parse_topic_relevance_results, +) + + +def _gemini(): + return patch.object(constants, "ASSESSMENT_PREFILTER_PROVIDER", "google") + + +def _openai(): + return patch.object(constants, "ASSESSMENT_PREFILTER_PROVIDER", "openai") + + +class TestBuildRequestsOpenAI: + def test_openai_request_shape(self) -> None: + rows = [(0, {"Problem": "p0", "Docs": "https://x.com/a.png"})] + atts = [AssessmentAttachment(column="Docs", type="image", format="url")] + with _openai(): + lines = build_topic_relevance_requests(rows, ["Problem"], "rubric", atts) + line = lines[0] + assert line["custom_id"] == "tr_0" + assert line["url"] == "/v1/responses" + body = line["body"] + assert body["instructions"].startswith("rubric") + content = body["input"][0]["content"] + assert content[0] == {"type": "input_text", "text": "Problem:\np0"} + assert content[1]["type"] == "input_image" + assert body["text"]["format"]["type"] == "json_schema" + assert body["text"]["format"]["schema"]["additionalProperties"] is False + + +class TestBuildRequests: + def test_one_request_per_row_with_per_column_schema(self) -> None: + rows = [(0, {"Problem": "p0"}), (1, {"Problem": "p1"})] + with _gemini(): + lines = build_topic_relevance_requests(rows, ["Problem"], "rubric") + assert [ln["key"] for ln in lines] == ["tr_0", "tr_1"] + schema = lines[0]["request"]["generationConfig"]["responseSchema"] + # per-column boolean + decision/reasoning + assert schema["properties"]["Problem"]["type"] == "boolean" + assert set(schema["required"]) == {"decision", "reasoning", "Problem"} + assert "p0" in lines[0]["request"]["contents"][0]["parts"][0]["text"] + + def test_attachment_column_adds_part_and_schema_field(self) -> None: + rows = [ + (0, {"Problem": "p0", "Docs": "https://drive.google.com/file/d/A/view"}) + ] + atts = [AssessmentAttachment(column="Docs", type="image", format="url")] + with _gemini(): + lines = build_topic_relevance_requests(rows, ["Problem"], "rubric", atts) + schema = lines[0]["request"]["generationConfig"]["responseSchema"] + assert "Docs" in schema["properties"] # attachment column gets a verdict + parts = lines[0]["request"]["contents"][0]["parts"] + assert len(parts) >= 2 # text + at least one attachment part + + def test_empty_attachments_is_text_only(self) -> None: + with _gemini(): + lines = build_topic_relevance_requests( + [(0, {"Problem": "p"})], ["Problem"], "r" + ) + assert len(lines[0]["request"]["contents"][0]["parts"]) == 1 + + def test_blank_attachment_cell_is_skipped(self) -> None: + att = AssessmentAttachment(column="Docs", type="image", format="url") + with _gemini(): + lines = build_topic_relevance_requests( + [(0, {"Problem": "p", "Docs": " "})], ["Problem"], "r", [att] + ) + # Whitespace-only attachment cell -> only the text part survives. + assert len(lines[0]["request"]["contents"][0]["parts"]) == 1 + + +class TestParseResults: + def test_maps_decision_and_per_column_relevance(self) -> None: + outputs = [ + { + "row_id": "tr_0", + "output": json.dumps( + { + "decision": "ACCEPT", + "reasoning": "ok", + "Problem": True, + "Docs": False, + } + ), + "error": None, + }, + { + "row_id": "tr_1", + "output": json.dumps( + {"decision": "REJECT", "reasoning": "no", "Problem": False} + ), + "error": None, + }, + ] + parsed = parse_topic_relevance_results(outputs) + assert parsed[0]["verdict"] is True + assert parsed[0]["column_relevance"] == {"Problem": True, "Docs": False} + assert parsed[1]["verdict"] is False + assert parsed[1]["column_relevance"] == {"Problem": False} + + def test_unparseable_output_fails_open_accepted(self) -> None: + # A gate response we cannot parse must NOT silently drop the submission: + # it is accepted (verdict=True) so it still reaches L2 and is counted. + parsed = parse_topic_relevance_results( + [{"row_id": "tr_0", "output": "not json", "error": None}] + ) + assert parsed[0]["verdict"] is True + assert parsed[0]["decision"] == "" + assert parsed[0]["reasoning"] == "" + assert parsed[0]["column_relevance"] == {} + + def test_empty_output_fails_open_accepted(self) -> None: + parsed = parse_topic_relevance_results( + [{"row_id": "tr_0", "output": None, "error": "provider error"}] + ) + assert parsed[0]["verdict"] is True + assert parsed[0]["decision"] == "" + + def test_foreign_and_bad_index_keys_skipped(self) -> None: + parsed = parse_topic_relevance_results( + [ + {"row_id": "dup_0", "output": "{}", "error": None}, # not a tr key + {"row_id": "tr_x", "output": "{}", "error": None}, # bad index + ] + ) + assert parsed == {}