diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index efe5e83d..689ce0f1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,7 +8,7 @@ ci: - eslint repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -28,11 +28,11 @@ repos: types_or: [javascript, jsx, ts, tsx, json, scss, sass, css, yaml, markdown] - repo: https://github.com/scop/pre-commit-shfmt - rev: v3.10.0-1 + rev: v3.13.1-1 hooks: - id: shfmt - repo: https://github.com/adrienverge/yamllint.git - rev: v1.35.1 + rev: v1.38.0 hooks: - id: yamllint args: [--format, parsable, -d, relaxed] @@ -65,7 +65,7 @@ repos: - "config/keycloak/realms/ol-local-realm.json" additional_dependencies: ["gibberish-detector"] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.7.2" + rev: "v0.15.10" hooks: - id: ruff-format - id: ruff @@ -84,7 +84,7 @@ repos: additional_dependencies: - eslint@8 - repo: https://github.com/shellcheck-py/shellcheck-py - rev: v0.10.0.1 + rev: v0.11.0.1 hooks: - id: shellcheck args: ["--severity=warning"] diff --git a/ai_chatbots/api.py b/ai_chatbots/api.py index 83bccbc8..378d1d0d 100644 --- a/ai_chatbots/api.py +++ b/ai_chatbots/api.py @@ -2,7 +2,7 @@ import json import logging -from typing import Any, Optional, Union +from typing import Any, Union from uuid import uuid4 from channels.db import database_sync_to_async @@ -96,14 +96,14 @@ def serialize_tool_calls(tool_calls: list[dict]) -> list[dict]: @database_sync_to_async -def query_tutorbot_output(thread_id: str) -> Optional[TutorBotOutput]: +def query_tutorbot_output(thread_id: str) -> TutorBotOutput | None: """Return the latest TutorBotOutput for a given thread_id""" return TutorBotOutput.objects.filter(thread_id=thread_id).last() @database_sync_to_async def create_tutorbot_output_and_checkpoints( - thread_id: str, chat_json: Union[str, dict], edx_module_id: Optional[str] + thread_id: str, chat_json: Union[str, dict], edx_module_id: str | None ) -> tuple[TutorBotOutput, list[DjangoCheckpoint]]: """Atomically create both TutorBotOutput and DjangoCheckpoint objects""" with transaction.atomic(): @@ -133,7 +133,7 @@ def _should_create_checkpoint(msg: dict) -> bool: def _identify_new_messages( - filtered_messages: list[dict], previous_chat_json: Optional[Union[str, dict]] + filtered_messages: list[dict], previous_chat_json: Union[str, dict] | None ) -> list[dict]: """Identify which messages are new by comparing with previous chat data.""" if not previous_chat_json: @@ -222,7 +222,7 @@ def _create_checkpoint_metadata( def create_tutor_checkpoints( thread_id: str, chat_json: Union[str, dict], - previous_chat_json: Optional[Union[str, dict]] = None, + previous_chat_json: Union[str, dict] | None = None, ) -> list[DjangoCheckpoint]: """Create DjangoCheckpoint records from tutor chat data (synchronous)""" # Get the associated session diff --git a/ai_chatbots/chatbots.py b/ai_chatbots/chatbots.py index b4a75816..7e3b167c 100644 --- a/ai_chatbots/chatbots.py +++ b/ai_chatbots/chatbots.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncGenerator from operator import add -from typing import Annotated, Any, Optional +from typing import Annotated, Any from uuid import uuid4 import posthog @@ -73,10 +73,10 @@ def __init__( # noqa: PLR0913 checkpointer: BaseCheckpointSaver, *, name: str = "MIT Open Learning Chatbot", - model: Optional[str] = None, - temperature: Optional[float] = None, - instructions: Optional[str] = None, - thread_id: Optional[str] = None, + model: str | None = None, + temperature: float | None = None, + instructions: str | None = None, + thread_id: str | None = None, ): """Initialize the AI chat agent service""" self.bot_name = name @@ -224,7 +224,7 @@ async def validate_and_clean_checkpoint(self) -> None: except Exception: log.exception("Error while cleaning checkpoint") - async def _get_latest_checkpoint_id(self) -> Optional[str]: + async def _get_latest_checkpoint_id(self) -> str | None: """Get the most recent assistant response checkpoint""" checkpoint = ( await DjangoCheckpoint.objects.prefetch_related("session", "session__user") @@ -237,7 +237,7 @@ async def _get_latest_checkpoint_id(self) -> Optional[str]: return checkpoint.id if checkpoint else None async def set_callbacks( - self, properties: Optional[dict] = None + self, properties: dict | None = None ) -> list[CallbackHandler]: """Set callbacks for the agent LLM""" if settings.POSTHOG_PROJECT_API_KEY and settings.POSTHOG_API_HOST: @@ -282,7 +282,7 @@ async def get_completion( self, message: str, *, - extra_state: Optional[dict[str, Any]] = None, + extra_state: dict[str, Any] | None = None, debug: bool = settings.AI_DEBUG, ) -> AsyncGenerator[str, None]: """ @@ -424,13 +424,13 @@ class ResourceRecommendationBot(TruncatingChatbot): def __init__( # noqa: PLR0913 self, user_id: str, - checkpointer: Optional[BaseCheckpointSaver] = None, + checkpointer: BaseCheckpointSaver | None = None, *, name: str = "MIT Open Learning Chatbot", - model: Optional[str] = None, - temperature: Optional[float] = None, - instructions: Optional[str] = None, - thread_id: Optional[str] = None, + model: str | None = None, + temperature: float | None = None, + instructions: str | None = None, + thread_id: str | None = None, ): """Initialize the AI search agent service""" super().__init__( @@ -466,7 +466,7 @@ class SyllabusAgentState(SummaryState): related_courses: Annotated[list[str], add] # str representation of a boolean value, because the # langgraph JsonPlusSerializer can't handle booleans - exclude_canvas: Annotated[Optional[list[str]], add] + exclude_canvas: Annotated[list[str] | None, add] class SyllabusBot(TruncatingChatbot): @@ -483,11 +483,11 @@ def __init__( # noqa: PLR0913 checkpointer: BaseCheckpointSaver, *, name: str = "MIT Open Learning Syllabus Chatbot", - model: Optional[str] = None, - temperature: Optional[float] = None, - instructions: Optional[str] = None, - thread_id: Optional[str] = None, - enable_related_courses: Optional[bool] = False, + model: str | None = None, + temperature: float | None = None, + instructions: str | None = None, + thread_id: str | None = None, + enable_related_courses: bool | None = False, ): self.enable_related_courses = enable_related_courses super().__init__( @@ -546,16 +546,16 @@ class TutorBot(BaseChatbot): def __init__( # noqa: PLR0913 self, user_id: str, - checkpointer: Optional[BaseCheckpointSaver] = BaseCheckpointSaver, + checkpointer: BaseCheckpointSaver | None = BaseCheckpointSaver, *, name: str = "MIT Open Learning Tutor Chatbot", - model: Optional[str] = None, - temperature: Optional[float] = None, - thread_id: Optional[str] = None, - block_siblings: Optional[list[str]] = None, - edx_module_id: Optional[str] = None, - run_readable_id: Optional[str] = None, - problem_set_title: Optional[str] = None, + model: str | None = None, + temperature: float | None = None, + thread_id: str | None = None, + block_siblings: list[str] | None = None, + edx_module_id: str | None = None, + run_readable_id: str | None = None, + problem_set_title: str | None = None, ): super().__init__( user_id, @@ -600,7 +600,7 @@ async def get_completion( self, message: str, *, - extra_state: Optional[dict[str, Any]] = None, # noqa: ARG002 + extra_state: dict[str, Any] | None = None, # noqa: ARG002 debug: bool = settings.AI_DEBUG, ) -> AsyncGenerator[str, None]: """Call message_tutor with the user query and return the response""" @@ -810,10 +810,10 @@ def __init__( # noqa: PLR0913 checkpointer: BaseCheckpointSaver, *, name: str = "MIT Open Learning VideoGPT Chatbot", - model: Optional[str] = None, - temperature: Optional[float] = None, - instructions: Optional[str] = None, - thread_id: Optional[str] = None, + model: str | None = None, + temperature: float | None = None, + instructions: str | None = None, + thread_id: str | None = None, ): super().__init__( user_id, diff --git a/ai_chatbots/chatbots_test.py b/ai_chatbots/chatbots_test.py index e1abab84..3b65f5f8 100644 --- a/ai_chatbots/chatbots_test.py +++ b/ai_chatbots/chatbots_test.py @@ -975,9 +975,11 @@ async def test_tutor_get_completion(posthog_settings, mocker, variant): assert "Let's start by thinking about the problem. " in results checkpoint = await database_sync_to_async( - lambda: DjangoCheckpoint.objects.select_related("session") - .filter(thread_id=thread_id) - .last() + lambda: ( + DjangoCheckpoint.objects.select_related("session") + .filter(thread_id=thread_id) + .last() + ) )() history = await database_sync_to_async( lambda: TutorBotOutput.objects.filter(thread_id=thread_id).last() diff --git a/ai_chatbots/checkpointers.py b/ai_chatbots/checkpointers.py index c080df88..5cec1be8 100644 --- a/ai_chatbots/checkpointers.py +++ b/ai_chatbots/checkpointers.py @@ -4,7 +4,6 @@ from collections.abc import AsyncGenerator from typing import ( Any, - Optional, ) from django.conf import settings @@ -106,8 +105,8 @@ def _load_writes( def _parse_checkpoint_data( serde: JsonPlusSerializer, data: DjangoCheckpoint, - pending_writes: Optional[list[PendingWrite]] = None, -) -> Optional[CheckpointTuple]: + pending_writes: list[PendingWrite] | None = None, +) -> CheckpointTuple | None: """ Parse checkpoint data retrieved from the database. """ @@ -163,9 +162,9 @@ async def create_with_session( # noqa: PLR0913 thread_id: str, message: str, agent: str, - user: Optional[USER_MODEL] = None, - dj_session_key: Optional[str] = "", - object_id: Optional[str] = "", + user: USER_MODEL | None = None, + dj_session_key: str | None = "", + object_id: str | None = "", ): """ Initialize the DjangoSaver and create a UserChatSession if applicable. @@ -317,7 +316,7 @@ async def aput_writes( }, ) - async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None: """Get a checkpoint tuple from the database asynchronously. This method retrieves a checkpoint tuple from the database based on the @@ -362,11 +361,11 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: async def alist( self, - config: Optional[RunnableConfig], + config: RunnableConfig | None, *, - filter: Optional[dict[str, Any]] = None, # noqa: ARG002, A002 - before: Optional[RunnableConfig] = None, - limit: Optional[int] = None, + filter: dict[str, Any] | None = None, # noqa: ARG002, A002 + before: RunnableConfig | None = None, + limit: int | None = None, ) -> AsyncGenerator[CheckpointTuple, None]: """List checkpoints from the database asynchronously. diff --git a/ai_chatbots/constants.py b/ai_chatbots/constants.py index 404564e2..9a3f796d 100644 --- a/ai_chatbots/constants.py +++ b/ai_chatbots/constants.py @@ -2,7 +2,6 @@ import dataclasses import datetime -from typing import Optional from named_enum import ExtendedEnum @@ -57,7 +56,7 @@ class ChatbotCookie: name: str value: str path: str = "/" - max_age: Optional[datetime.datetime] = None + max_age: datetime.datetime | None = None def __str__(self) -> str: """ diff --git a/ai_chatbots/consumers.py b/ai_chatbots/consumers.py index a1b696c7..13dc62e6 100644 --- a/ai_chatbots/consumers.py +++ b/ai_chatbots/consumers.py @@ -2,7 +2,6 @@ import logging from abc import ABC, abstractmethod from http.cookies import SimpleCookie -from typing import Optional from uuid import uuid4 import litellm @@ -104,9 +103,9 @@ async def assign_thread_cookies( self, user: User, *, - clear_history: Optional[bool] = False, - thread_id: Optional[str] = None, - object_id: Optional[str] = None, + clear_history: bool | None = False, + thread_id: str | None = None, + object_id: str | None = None, ) -> tuple[str, list[str]]: """ Extract and update separate cookie values for logged in vs anonymous users. @@ -220,7 +219,7 @@ async def assign_thread_cookies( return current_thread_id, cookies async def prepare_response( - self, serializer: ChatRequestSerializer, object_id_field: Optional[str] = None + self, serializer: ChatRequestSerializer, object_id_field: str | None = None ) -> tuple[str, list[str]]: """Prepare consumer for the API response""" if object_id_field: @@ -249,9 +248,9 @@ def process_extra_state(self, data: dict) -> dict: # noqa: ARG002 async def start_response( self, - thread_id: Optional[str] = None, - status: Optional[int] = HTTP_200_OK, - cookies: Optional[list[str]] = None, + thread_id: str | None = None, + status: int | None = HTTP_200_OK, + cookies: list[str] | None = None, ): headers = ( [ @@ -507,7 +506,7 @@ def process_extra_state(self, data: dict) -> dict: def prepare_response( self, serializer: SyllabusChatRequestSerializer, - object_id_field: Optional[str] = None, + object_id_field: str | None = None, ) -> tuple[str, list[str]]: """Set the course id as the default object id field""" object_id_field = object_id_field or "course_id" @@ -620,7 +619,7 @@ def create_chatbot( def prepare_response( self, serializer: TutorChatRequestSerializer, - object_id_field: Optional[str] = None, + object_id_field: str | None = None, ) -> tuple[str, list[str]]: """Set the edx_module_id as the default object id field""" object_id_field = object_id_field or "edx_module_id" @@ -673,7 +672,7 @@ def create_chatbot( def prepare_response( self, serializer: TutorChatRequestSerializer, - object_id_field: Optional[str] = None, + object_id_field: str | None = None, ) -> tuple[str, list[str]]: """Set the edx_module_id as the default object id field""" object_id_field = "object_id" @@ -752,7 +751,7 @@ def process_extra_state(self, data: dict) -> dict: def prepare_response( self, serializer: VideoGPTRequestSerializer, - object_id_field: Optional[str] = None, + object_id_field: str | None = None, ) -> tuple[str, list[str]]: """Set the problem code as the default object id field""" object_id_field = object_id_field or "transcript_asset_id" diff --git a/ai_chatbots/consumers_test.py b/ai_chatbots/consumers_test.py index be00a5e3..7516b166 100644 --- a/ai_chatbots/consumers_test.py +++ b/ai_chatbots/consumers_test.py @@ -1017,7 +1017,7 @@ async def test_anonymous_user_login_session_association( # noqa: PLR0913 for i in range(3): anon_consumer = anonymous_consumer_setup(test_session_key) payload = { - "message": f"Anonymous question {i+1}", + "message": f"Anonymous question {i + 1}", "course_id": "MITx+6.00.1x", } await anon_consumer.handle(json.dumps(payload)) diff --git a/ai_chatbots/evaluation/base.py b/ai_chatbots/evaluation/base.py index 5cce8b79..42979085 100644 --- a/ai_chatbots/evaluation/base.py +++ b/ai_chatbots/evaluation/base.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from datetime import UTC, datetime from pathlib import Path -from typing import Any, Optional +from typing import Any from asgiref.sync import sync_to_async from deepeval.test_case import LLMTestCase, ToolCall @@ -21,9 +21,9 @@ class TestCaseSpec: """Specification for a test case.""" question: str - expected_output: Optional[str] = None - expected_tools: Optional[list[str]] = None - metadata: Optional[dict[str, Any]] = None + expected_output: str | None = None + expected_tools: list[str] | None = None + metadata: dict[str, Any] | None = None def to_dict(self) -> dict[str, Any]: """Convert to dictionary representation.""" @@ -43,7 +43,7 @@ class EvaluationConfig: evaluation_model: str metrics: list[Any] metric_thresholds: dict[str, float] - confident_api_key: Optional[str] = None + confident_api_key: str | None = None class BaseBotEvaluator(ABC): @@ -54,9 +54,9 @@ def __init__( bot_class, bot_name: str, *, - data_file: Optional[str] = None, + data_file: str | None = None, stdout=None, - error_log_file: Optional[str] = None, + error_log_file: str | None = None, ): self.bot_class = bot_class self.bot_name = bot_name @@ -67,13 +67,13 @@ def __init__( def _log_error( self, error_msg: str, - test_case: Optional[TestCaseSpec] = None, - exception: Optional[Exception] = None, + test_case: TestCaseSpec | None = None, + exception: Exception | None = None, ): """Log error to file with timestamp and details.""" try: timestamp = datetime.now(tz=UTC).isoformat() - log_entry = f"\n{'='*80}\n" + log_entry = f"\n{'=' * 80}\n" log_entry += f"[{timestamp}] Error in {self.bot_name}\n" if test_case: @@ -86,7 +86,7 @@ def _log_error( if exception: log_entry += f"\nFull traceback:\n{traceback.format_exc()}\n" - log_entry += f"{'='*80}\n" + log_entry += f"{'=' * 80}\n" # Append to log file log_path = Path(self.error_log_file) @@ -107,7 +107,7 @@ def validate_test_case(self, test_case: TestCaseSpec) -> bool: @abstractmethod def create_bot_instance( - self, model: str, test_case: TestCaseSpec, instructions: Optional[str] = None + self, model: str, test_case: TestCaseSpec, instructions: str | None = None ): """Create a bot instance configured for the given test case.""" @@ -232,7 +232,7 @@ async def evaluate_model( self, model: str, test_cases: list[TestCaseSpec], - instructions: Optional[str] = None, + instructions: str | None = None, prompt_label: str = "default", max_concurrent: int = 10, ) -> list[LLMTestCase]: @@ -254,7 +254,7 @@ async def evaluate_model( # Create a semaphore to limit concurrency semaphore = asyncio.Semaphore(max_concurrent) - async def process_test_case(test_case: TestCaseSpec) -> Optional[LLMTestCase]: + async def process_test_case(test_case: TestCaseSpec) -> LLMTestCase | None: """Process a single test case with semaphore limiting.""" async with semaphore: try: diff --git a/ai_chatbots/evaluation/base_test.py b/ai_chatbots/evaluation/base_test.py index 7793d5e9..1500e43f 100644 --- a/ai_chatbots/evaluation/base_test.py +++ b/ai_chatbots/evaluation/base_test.py @@ -1,6 +1,5 @@ """Tests for evaluation.base module.""" -from typing import Optional from unittest.mock import Mock import pytest @@ -94,7 +93,7 @@ class ConcreteBotEvaluator(BaseBotEvaluator): """Concrete implementation for testing abstract methods.""" def __init__( - self, bot_class, bot_name: str, *, data_file: Optional[str] = None, mocker=None + self, bot_class, bot_name: str, *, data_file: str | None = None, mocker=None ): """Initialize with optional mocker for testing.""" super().__init__(bot_class, bot_name, data_file=data_file) @@ -114,7 +113,7 @@ def validate_test_case(self, test_case): return test_case.question is not None def create_bot_instance( - self, model: str, test_case: TestCaseSpec, instructions: Optional[str] = None + self, model: str, test_case: TestCaseSpec, instructions: str | None = None ): """Create mock bot instance.""" _ = model, test_case, instructions # Unused parameters for testing diff --git a/ai_chatbots/evaluation/evaluators.py b/ai_chatbots/evaluation/evaluators.py index e9e1eba3..3119be29 100644 --- a/ai_chatbots/evaluation/evaluators.py +++ b/ai_chatbots/evaluation/evaluators.py @@ -1,6 +1,6 @@ """Bot-specific evaluators for RAG evaluation.""" -from typing import Any, Optional +from typing import Any from langchain_core.messages import AIMessage @@ -37,7 +37,7 @@ def validate_test_case(self, test_case: TestCaseSpec) -> bool: return "search_url" in extra_state def create_bot_instance( - self, model: str, test_case: TestCaseSpec, instructions: Optional[str] = None + self, model: str, test_case: TestCaseSpec, instructions: str | None = None ): """Create recommendation bot instance.""" metadata = test_case.metadata or {} @@ -95,7 +95,7 @@ def validate_test_case(self, test_case: TestCaseSpec) -> bool: return "course_id" in extra_state def create_bot_instance( - self, model: str, test_case: TestCaseSpec, instructions: Optional[str] = None + self, model: str, test_case: TestCaseSpec, instructions: str | None = None ): """Create syllabus bot instance.""" metadata = test_case.metadata or {} @@ -153,7 +153,7 @@ def validate_test_case(self, test_case: TestCaseSpec) -> bool: return "transcript_asset_id" in extra_state def create_bot_instance( - self, model: str, test_case: TestCaseSpec, instructions: Optional[str] = None + self, model: str, test_case: TestCaseSpec, instructions: str | None = None ): """Create video GPT bot instance.""" metadata = test_case.metadata or {} @@ -214,7 +214,7 @@ def create_bot_instance( self, model: str, test_case: TestCaseSpec, - instructions: Optional[str] = None, # noqa: ARG002 + instructions: str | None = None, # noqa: ARG002 ): """Create tutor bot instance.""" metadata = test_case.metadata or {} diff --git a/ai_chatbots/evaluation/orchestrator.py b/ai_chatbots/evaluation/orchestrator.py index 43537b14..31930d40 100644 --- a/ai_chatbots/evaluation/orchestrator.py +++ b/ai_chatbots/evaluation/orchestrator.py @@ -4,7 +4,6 @@ from datetime import UTC, datetime from json import JSONDecodeError from pathlib import Path -from typing import Optional import deepeval from deepeval.evaluate import AsyncConfig, ErrorConfig @@ -46,7 +45,7 @@ def create_evaluation_config( # noqa: PLR0913 self, models: list[str], evaluation_model: str, - metric_thresholds: Optional[dict[str, float]] = None, + metric_thresholds: dict[str, float] | None = None, timeout_seconds: int = 360, max_retries: int = 3, retry_delay: float = 5.0, @@ -183,9 +182,9 @@ async def _collect_and_evaluate_bot( # noqa: C901, PLR0912, PLR0913, PLR0915 bot_name: str, config: EvaluationConfig, *, - data_file: Optional[str], - use_prompts: Optional[bool], - prompts_data: Optional[dict], + data_file: str | None, + use_prompts: bool | None, + prompts_data: dict | None, max_concurrent: int, batch_size: int, error_log_file: str, @@ -396,10 +395,10 @@ def _initialize_error_log(self, error_log_file: str): log_path = Path(error_log_file) header = f""" -{'#'*80} +{"#" * 80} RAG EVALUATION ERROR LOG Started: {timestamp} -{'#'*80} +{"#" * 80} """ with log_path.open("w", encoding="utf-8") as f: @@ -413,13 +412,13 @@ async def run_evaluation( # noqa: PLR0913 self, config: EvaluationConfig, *, - bot_names: Optional[list[str]] = None, - data_file: Optional[str] = None, - use_prompts: Optional[bool] = True, - prompts_file: Optional[str] = None, - max_concurrent: Optional[int] = 10, - batch_size: Optional[int] = 0, - error_log_file: Optional[str] = None, + bot_names: list[str] | None = None, + data_file: str | None = None, + use_prompts: bool | None = True, + prompts_file: str | None = None, + max_concurrent: int | None = 10, + batch_size: int | None = 0, + error_log_file: str | None = None, ) -> EvaluationResult: """Run evaluation across specified bots and models. @@ -454,9 +453,9 @@ async def run_evaluation( # noqa: PLR0913 # Process each bot one at a time, accumulating results all_bot_results = [] for bot_name in bot_names: - self.stdout.write(f"\n{'='*60}") + self.stdout.write(f"\n{'=' * 60}") self.stdout.write(f"Processing bot: {bot_name}") - self.stdout.write(f"{'='*60}") + self.stdout.write(f"{'=' * 60}") bot_result = await self._collect_and_evaluate_bot( bot_name, @@ -474,9 +473,9 @@ async def run_evaluation( # noqa: PLR0913 results = self._merge_evaluation_results(all_bot_results) # Log final summary - self.stdout.write(f"\n{'='*60}") + self.stdout.write(f"\n{'=' * 60}") self.stdout.write("EVALUATION COMPLETE - FINAL SUMMARY") - self.stdout.write(f"{'='*60}") + self.stdout.write(f"{'=' * 60}") self.stdout.write( f"\nTotal test results across all bots: {len(results.test_results)}" ) diff --git a/ai_chatbots/evaluation/reporting.py b/ai_chatbots/evaluation/reporting.py index 7ae00696..addc3643 100644 --- a/ai_chatbots/evaluation/reporting.py +++ b/ai_chatbots/evaluation/reporting.py @@ -2,7 +2,7 @@ from datetime import UTC, datetime from pathlib import Path -from typing import Optional, TextIO +from typing import TextIO import pandas as pd from deepeval.evaluate.types import EvaluationResult @@ -16,10 +16,10 @@ class DualOutputWrapper: """Wrapper that writes to both stdout and a file simultaneously.""" - def __init__(self, stdout: OutputWrapper, file_path: Optional[str] = None): + def __init__(self, stdout: OutputWrapper, file_path: str | None = None): self.stdout = stdout self.file_path = file_path - self.file: Optional[TextIO] = None + self.file: TextIO | None = None if file_path: file_path_obj = Path(file_path) @@ -454,8 +454,7 @@ def composite_leaderboard( self.stdout.write("\nRanking by Model + Bot:") for rank, (_, row) in enumerate(composite.iterrows(), 1): self.stdout.write( - f" {rank}. {row['model']} / {row['bot']}: " - f"{row['composite_score']:.3f}" + f" {rank}. {row['model']} / {row['bot']}: {row['composite_score']:.3f}" ) # Per-model aggregate diff --git a/ai_chatbots/management/commands/backpopulate_tutor_checkpoints.py b/ai_chatbots/management/commands/backpopulate_tutor_checkpoints.py index 4cbbf087..d1384373 100644 --- a/ai_chatbots/management/commands/backpopulate_tutor_checkpoints.py +++ b/ai_chatbots/management/commands/backpopulate_tutor_checkpoints.py @@ -25,7 +25,7 @@ def add_message_ids(thread_id: str, output_id: int, messages: list[dict]) -> lis # Handle both dict and LangChain message objects message_id = str( uuid5( - UUID(thread_id), f'{output_id}_{message["type"]}_{message["content"]}' + UUID(thread_id), f"{output_id}_{message['type']}_{message['content']}" ) ) if isinstance(message, dict): diff --git a/ai_chatbots/management/commands/rag_evaluation.py b/ai_chatbots/management/commands/rag_evaluation.py index e4101769..a201e18c 100644 --- a/ai_chatbots/management/commands/rag_evaluation.py +++ b/ai_chatbots/management/commands/rag_evaluation.py @@ -94,8 +94,7 @@ def add_arguments(self, parser): dest="error_log_file", required=False, help=( - "Specify a file to save error logs " - "(default: rag_evaluation_errors.log)" + "Specify a file to save error logs (default: rag_evaluation_errors.log)" ), default="rag_evaluation_errors.log", ) diff --git a/ai_chatbots/migrations/0007_convert_tutorbot_to_checkpoints.py b/ai_chatbots/migrations/0007_convert_tutorbot_to_checkpoints.py index 18675bf3..f5ef8df8 100644 --- a/ai_chatbots/migrations/0007_convert_tutorbot_to_checkpoints.py +++ b/ai_chatbots/migrations/0007_convert_tutorbot_to_checkpoints.py @@ -22,7 +22,7 @@ def add_message_ids(thread_id: str, output_id: int, messages: list[dict]) -> lis # Handle both dict and LangChain message objects message_id = str( uuid5( - UUID(thread_id), f'{output_id}_{message["type"]}_{message["content"]}' + UUID(thread_id), f"{output_id}_{message['type']}_{message['content']}" ) ) if isinstance(message, dict): diff --git a/ai_chatbots/posthog.py b/ai_chatbots/posthog.py index e693e226..cc65d9fe 100644 --- a/ai_chatbots/posthog.py +++ b/ai_chatbots/posthog.py @@ -1,7 +1,7 @@ """PostHog serialization and callback handler for AI chatbots.""" import logging -from typing import Any, Optional +from typing import Any from uuid import UUID import litellm @@ -95,7 +95,7 @@ def on_chat_model_start( messages: list[list[BaseMessage]], *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ): """Format messages and estimate input tokens""" @@ -142,7 +142,7 @@ def on_llm_end( response: LLMResult, *, run_id: UUID, - parent_run_id: Optional[UUID] = None, + parent_run_id: UUID | None = None, **kwargs: Any, ): # Calculate output tokens using LiteLLM's token_counter @@ -193,7 +193,7 @@ def on_llm_end( ) def _pop_run_and_capture_trace_or_span( - self, run_id: UUID, parent_run_id: Optional[UUID], outputs: Any + self, run_id: UUID, parent_run_id: UUID | None, outputs: Any ): """Override to serialize outputs before passing to parent.""" serialized_outputs = serialize_for_posthog(outputs) diff --git a/ai_chatbots/tasks_test.py b/ai_chatbots/tasks_test.py index e1b91d5f..f4c2a410 100644 --- a/ai_chatbots/tasks_test.py +++ b/ai_chatbots/tasks_test.py @@ -21,8 +21,9 @@ def test_delete_stale_sessions(): """Test delete_stale_sessions""" with freeze_time( - lambda: now_in_utc() - - timedelta(days=settings.AI_CHATBOTS_SESSION_EXPIRY_DAYS + 1) + lambda: ( + now_in_utc() - timedelta(days=settings.AI_CHATBOTS_SESSION_EXPIRY_DAYS + 1) + ) ): expired_chats = UserChatSessionFactory.create_batch( 4, user=None, dj_session_key=uuid4().hex diff --git a/ai_chatbots/tools.py b/ai_chatbots/tools.py index 409e00e5..442668b5 100644 --- a/ai_chatbots/tools.py +++ b/ai_chatbots/tools.py @@ -2,7 +2,7 @@ import json import logging -from typing import Annotated, Optional +from typing import Annotated import pydantic from asgiref.sync import sync_to_async @@ -68,7 +68,7 @@ class SearchToolSchema(pydantic.BaseModel): description="The agent state, including the search url to use" ) - resource_type: Optional[list[enum_zip("resource_type", LearningResourceType)]] = ( + resource_type: list[enum_zip("resource_type", LearningResourceType)] | None = ( Field( default=None, description=( @@ -85,7 +85,7 @@ class SearchToolSchema(pydantic.BaseModel): ), ) ) - free: Optional[bool] = Field( + free: bool | None = Field( default=None, description=( """ @@ -96,7 +96,7 @@ class SearchToolSchema(pydantic.BaseModel): """ ), ) - certification: Optional[bool] = Field( + certification: bool | None = Field( default=None, description=( """ @@ -107,7 +107,7 @@ class SearchToolSchema(pydantic.BaseModel): """ ), ) - offered_by: Optional[list[enum_zip("resource_type", OfferedBy)]] = Field( + offered_by: list[enum_zip("resource_type", OfferedBy)] | None = Field( default=None, description=""" If a user asks for resources "offered by" or "from" an institution, @@ -129,7 +129,7 @@ class SearchToolSchema(pydantic.BaseModel): @tool(args_schema=SearchToolSchema) async def search_courses( - q: str, state: Optional[Annotated[dict, InjectedState]], **kwargs + q: str, state: Annotated[dict, InjectedState] | None, **kwargs ) -> str: """ Query the MIT API for learning resources, and @@ -203,7 +203,7 @@ class SearchContentFilesToolSchema(pydantic.BaseModel): description=("Query to find requested information about a learning resource.") ) - readable_id: Optional[str] = Field( + readable_id: str | None = Field( description=("The readable_id of the learning resource."), default=None, ) diff --git a/conftest.py b/conftest.py index 64a74173..a0a2d4b2 100644 --- a/conftest.py +++ b/conftest.py @@ -11,7 +11,7 @@ @pytest.fixture(autouse=True) -def prevent_requests(mocker, request): # noqa: PT004 +def prevent_requests(mocker, request): """Patch requests to error on request by default""" if "mocked_responses" in request.fixturenames: return diff --git a/fixtures/common.py b/fixtures/common.py index 8a581ea2..f81822ac 100644 --- a/fixtures/common.py +++ b/fixtures/common.py @@ -14,13 +14,13 @@ @pytest.fixture(autouse=True) -def silence_factory_logging(): # noqa: PT004 +def silence_factory_logging(): """Only show factory errors""" logging.getLogger("factory").setLevel(logging.ERROR) @pytest.fixture(autouse=True) -def warnings_as_errors(): # noqa: PT004 +def warnings_as_errors(): """ Convert warnings to errors. This should only affect unit tests, letting pylint and other plugins raise DeprecationWarnings without erroring. @@ -54,7 +54,7 @@ def warnings_as_errors(): # noqa: PT004 @pytest.fixture -def randomness(): # noqa: PT004 +def randomness(): """Ensure a fixed seed for factoryboy""" factory.fuzzy.reseed_random("happy little clouds") diff --git a/main/features.py b/main/features.py index 94ad14ff..832264b1 100644 --- a/main/features.py +++ b/main/features.py @@ -5,7 +5,6 @@ import logging from enum import StrEnum from functools import wraps -from typing import Optional import posthog from django.conf import settings @@ -76,7 +75,7 @@ def generate_cache_key(key: str, unique_id: str, person_properties: dict) -> str ) -def get_all_feature_flags(opt_unique_id: Optional[str] = None): +def get_all_feature_flags(opt_unique_id: str | None = None): """ Get the set of all feature flags """ @@ -98,8 +97,8 @@ def get_all_feature_flags(opt_unique_id: Optional[str] = None): def is_enabled( name: str, - default: Optional[bool] = None, - opt_unique_id: Optional[str] = None, + default: bool | None = None, + opt_unique_id: str | None = None, ) -> bool: """ Return True if the feature flag is enabled @@ -149,7 +148,7 @@ def is_enabled( ) -def if_feature_enabled(name: str, default: Optional[bool] = None): +def if_feature_enabled(name: str, default: bool | None = None): """ Wrapper that results in a no-op if the given feature isn't enabled, and otherwise runs the wrapped function as normal. diff --git a/main/utils_test.py b/main/utils_test.py index 5d80cc9c..0b81e4c3 100644 --- a/main/utils_test.py +++ b/main/utils_test.py @@ -35,7 +35,7 @@ def test_is_near_now(): def test_chunks(): """ - test for chunks + Test for chunks """ input_list = list(range(113)) output_list = [] @@ -56,7 +56,7 @@ def test_chunks(): def test_chunks_iterable(): """ - test that chunks works on non-list iterables too + Test that chunks works on non-list iterables too """ count = 113 input_range = range(count)