Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ci:
- eslint
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
Expand All @@ -28,11 +28,11 @@ repos:
types_or:
[javascript, jsx, ts, tsx, json, scss, sass, css, yaml, markdown]
- repo: https://github.com/scop/pre-commit-shfmt
rev: v3.10.0-1
rev: v3.13.1-1
hooks:
- id: shfmt
- repo: https://github.com/adrienverge/yamllint.git
rev: v1.35.1
rev: v1.38.0
hooks:
- id: yamllint
args: [--format, parsable, -d, relaxed]
Expand Down Expand Up @@ -65,7 +65,7 @@ repos:
- "config/keycloak/realms/ol-local-realm.json"
additional_dependencies: ["gibberish-detector"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.7.2"
rev: "v0.15.10"
hooks:
- id: ruff-format
- id: ruff
Expand All @@ -84,7 +84,7 @@ repos:
additional_dependencies:
- eslint@8
- repo: https://github.com/shellcheck-py/shellcheck-py
rev: v0.10.0.1
rev: v0.11.0.1
hooks:
- id: shellcheck
args: ["--severity=warning"]
Expand Down
10 changes: 5 additions & 5 deletions ai_chatbots/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import logging
from typing import Any, Optional, Union
from typing import Any, Union
from uuid import uuid4

from channels.db import database_sync_to_async
Expand Down Expand Up @@ -96,14 +96,14 @@ def serialize_tool_calls(tool_calls: list[dict]) -> list[dict]:


@database_sync_to_async
def query_tutorbot_output(thread_id: str) -> Optional[TutorBotOutput]:
def query_tutorbot_output(thread_id: str) -> TutorBotOutput | None:
"""Return the latest TutorBotOutput for a given thread_id"""
return TutorBotOutput.objects.filter(thread_id=thread_id).last()


@database_sync_to_async
def create_tutorbot_output_and_checkpoints(
thread_id: str, chat_json: Union[str, dict], edx_module_id: Optional[str]
thread_id: str, chat_json: Union[str, dict], edx_module_id: str | None
) -> tuple[TutorBotOutput, list[DjangoCheckpoint]]:
"""Atomically create both TutorBotOutput and DjangoCheckpoint objects"""
with transaction.atomic():
Expand Down Expand Up @@ -133,7 +133,7 @@ def _should_create_checkpoint(msg: dict) -> bool:


def _identify_new_messages(
filtered_messages: list[dict], previous_chat_json: Optional[Union[str, dict]]
filtered_messages: list[dict], previous_chat_json: Union[str, dict] | None
) -> list[dict]:
"""Identify which messages are new by comparing with previous chat data."""
if not previous_chat_json:
Expand Down Expand Up @@ -222,7 +222,7 @@ def _create_checkpoint_metadata(
def create_tutor_checkpoints(
thread_id: str,
chat_json: Union[str, dict],
previous_chat_json: Optional[Union[str, dict]] = None,
previous_chat_json: Union[str, dict] | None = None,
) -> list[DjangoCheckpoint]:
"""Create DjangoCheckpoint records from tutor chat data (synchronous)"""
# Get the associated session
Expand Down
64 changes: 32 additions & 32 deletions ai_chatbots/chatbots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from operator import add
from typing import Annotated, Any, Optional
from typing import Annotated, Any
from uuid import uuid4

import posthog
Expand Down Expand Up @@ -73,10 +73,10 @@ def __init__( # noqa: PLR0913
checkpointer: BaseCheckpointSaver,
*,
name: str = "MIT Open Learning Chatbot",
model: Optional[str] = None,
temperature: Optional[float] = None,
instructions: Optional[str] = None,
thread_id: Optional[str] = None,
model: str | None = None,
temperature: float | None = None,
instructions: str | None = None,
thread_id: str | None = None,
):
"""Initialize the AI chat agent service"""
self.bot_name = name
Expand Down Expand Up @@ -224,7 +224,7 @@ async def validate_and_clean_checkpoint(self) -> None:
except Exception:
log.exception("Error while cleaning checkpoint")

async def _get_latest_checkpoint_id(self) -> Optional[str]:
async def _get_latest_checkpoint_id(self) -> str | None:
"""Get the most recent assistant response checkpoint"""
checkpoint = (
await DjangoCheckpoint.objects.prefetch_related("session", "session__user")
Expand All @@ -237,7 +237,7 @@ async def _get_latest_checkpoint_id(self) -> Optional[str]:
return checkpoint.id if checkpoint else None

async def set_callbacks(
self, properties: Optional[dict] = None
self, properties: dict | None = None
) -> list[CallbackHandler]:
"""Set callbacks for the agent LLM"""
if settings.POSTHOG_PROJECT_API_KEY and settings.POSTHOG_API_HOST:
Expand Down Expand Up @@ -282,7 +282,7 @@ async def get_completion(
self,
message: str,
*,
extra_state: Optional[dict[str, Any]] = None,
extra_state: dict[str, Any] | None = None,
debug: bool = settings.AI_DEBUG,
) -> AsyncGenerator[str, None]:
"""
Expand Down Expand Up @@ -424,13 +424,13 @@ class ResourceRecommendationBot(TruncatingChatbot):
def __init__( # noqa: PLR0913
self,
user_id: str,
checkpointer: Optional[BaseCheckpointSaver] = None,
checkpointer: BaseCheckpointSaver | None = None,
*,
name: str = "MIT Open Learning Chatbot",
model: Optional[str] = None,
temperature: Optional[float] = None,
instructions: Optional[str] = None,
thread_id: Optional[str] = None,
model: str | None = None,
temperature: float | None = None,
instructions: str | None = None,
thread_id: str | None = None,
):
"""Initialize the AI search agent service"""
super().__init__(
Expand Down Expand Up @@ -466,7 +466,7 @@ class SyllabusAgentState(SummaryState):
related_courses: Annotated[list[str], add]
# str representation of a boolean value, because the
# langgraph JsonPlusSerializer can't handle booleans
exclude_canvas: Annotated[Optional[list[str]], add]
exclude_canvas: Annotated[list[str] | None, add]


class SyllabusBot(TruncatingChatbot):
Expand All @@ -483,11 +483,11 @@ def __init__( # noqa: PLR0913
checkpointer: BaseCheckpointSaver,
*,
name: str = "MIT Open Learning Syllabus Chatbot",
model: Optional[str] = None,
temperature: Optional[float] = None,
instructions: Optional[str] = None,
thread_id: Optional[str] = None,
enable_related_courses: Optional[bool] = False,
model: str | None = None,
temperature: float | None = None,
instructions: str | None = None,
thread_id: str | None = None,
enable_related_courses: bool | None = False,
):
self.enable_related_courses = enable_related_courses
super().__init__(
Expand Down Expand Up @@ -546,16 +546,16 @@ class TutorBot(BaseChatbot):
def __init__( # noqa: PLR0913
self,
user_id: str,
checkpointer: Optional[BaseCheckpointSaver] = BaseCheckpointSaver,
checkpointer: BaseCheckpointSaver | None = BaseCheckpointSaver,
*,
name: str = "MIT Open Learning Tutor Chatbot",
model: Optional[str] = None,
temperature: Optional[float] = None,
thread_id: Optional[str] = None,
block_siblings: Optional[list[str]] = None,
edx_module_id: Optional[str] = None,
run_readable_id: Optional[str] = None,
problem_set_title: Optional[str] = None,
model: str | None = None,
temperature: float | None = None,
thread_id: str | None = None,
block_siblings: list[str] | None = None,
edx_module_id: str | None = None,
run_readable_id: str | None = None,
problem_set_title: str | None = None,
):
super().__init__(
user_id,
Expand Down Expand Up @@ -600,7 +600,7 @@ async def get_completion(
self,
message: str,
*,
extra_state: Optional[dict[str, Any]] = None, # noqa: ARG002
extra_state: dict[str, Any] | None = None, # noqa: ARG002
debug: bool = settings.AI_DEBUG,
) -> AsyncGenerator[str, None]:
"""Call message_tutor with the user query and return the response"""
Expand Down Expand Up @@ -810,10 +810,10 @@ def __init__( # noqa: PLR0913
checkpointer: BaseCheckpointSaver,
*,
name: str = "MIT Open Learning VideoGPT Chatbot",
model: Optional[str] = None,
temperature: Optional[float] = None,
instructions: Optional[str] = None,
thread_id: Optional[str] = None,
model: str | None = None,
temperature: float | None = None,
instructions: str | None = None,
thread_id: str | None = None,
):
super().__init__(
user_id,
Expand Down
8 changes: 5 additions & 3 deletions ai_chatbots/chatbots_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,9 +975,11 @@ async def test_tutor_get_completion(posthog_settings, mocker, variant):
assert "Let's start by thinking about the problem. " in results

checkpoint = await database_sync_to_async(
lambda: DjangoCheckpoint.objects.select_related("session")
.filter(thread_id=thread_id)
.last()
lambda: (
DjangoCheckpoint.objects.select_related("session")
.filter(thread_id=thread_id)
.last()
)
)()
history = await database_sync_to_async(
lambda: TutorBotOutput.objects.filter(thread_id=thread_id).last()
Expand Down
21 changes: 10 additions & 11 deletions ai_chatbots/checkpointers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from collections.abc import AsyncGenerator
from typing import (
Any,
Optional,
)

from django.conf import settings
Expand Down Expand Up @@ -106,8 +105,8 @@ def _load_writes(
def _parse_checkpoint_data(
serde: JsonPlusSerializer,
data: DjangoCheckpoint,
pending_writes: Optional[list[PendingWrite]] = None,
) -> Optional[CheckpointTuple]:
pending_writes: list[PendingWrite] | None = None,
) -> CheckpointTuple | None:
"""
Parse checkpoint data retrieved from the database.
"""
Expand Down Expand Up @@ -163,9 +162,9 @@ async def create_with_session( # noqa: PLR0913
thread_id: str,
message: str,
agent: str,
user: Optional[USER_MODEL] = None,
dj_session_key: Optional[str] = "",
object_id: Optional[str] = "",
user: USER_MODEL | None = None,
dj_session_key: str | None = "",
object_id: str | None = "",
):
"""
Initialize the DjangoSaver and create a UserChatSession if applicable.
Expand Down Expand Up @@ -317,7 +316,7 @@ async def aput_writes(
},
)

async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""Get a checkpoint tuple from the database asynchronously.

This method retrieves a checkpoint tuple from the database based on the
Expand Down Expand Up @@ -362,11 +361,11 @@ async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]:

async def alist(
self,
config: Optional[RunnableConfig],
config: RunnableConfig | None,
*,
filter: Optional[dict[str, Any]] = None, # noqa: ARG002, A002
before: Optional[RunnableConfig] = None,
limit: Optional[int] = None,
filter: dict[str, Any] | None = None, # noqa: ARG002, A002
before: RunnableConfig | None = None,
limit: int | None = None,
) -> AsyncGenerator[CheckpointTuple, None]:
"""List checkpoints from the database asynchronously.

Expand Down
3 changes: 1 addition & 2 deletions ai_chatbots/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import dataclasses
import datetime
from typing import Optional

from named_enum import ExtendedEnum

Expand Down Expand Up @@ -57,7 +56,7 @@ class ChatbotCookie:
name: str
value: str
path: str = "/"
max_age: Optional[datetime.datetime] = None
max_age: datetime.datetime | None = None

def __str__(self) -> str:
"""
Expand Down
23 changes: 11 additions & 12 deletions ai_chatbots/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
from abc import ABC, abstractmethod
from http.cookies import SimpleCookie
from typing import Optional
from uuid import uuid4

import litellm
Expand Down Expand Up @@ -104,9 +103,9 @@ async def assign_thread_cookies(
self,
user: User,
*,
clear_history: Optional[bool] = False,
thread_id: Optional[str] = None,
object_id: Optional[str] = None,
clear_history: bool | None = False,
thread_id: str | None = None,
object_id: str | None = None,
) -> tuple[str, list[str]]:
"""
Extract and update separate cookie values for logged in vs anonymous users.
Expand Down Expand Up @@ -220,7 +219,7 @@ async def assign_thread_cookies(
return current_thread_id, cookies

async def prepare_response(
self, serializer: ChatRequestSerializer, object_id_field: Optional[str] = None
self, serializer: ChatRequestSerializer, object_id_field: str | None = None
) -> tuple[str, list[str]]:
"""Prepare consumer for the API response"""
if object_id_field:
Expand Down Expand Up @@ -249,9 +248,9 @@ def process_extra_state(self, data: dict) -> dict: # noqa: ARG002

async def start_response(
self,
thread_id: Optional[str] = None,
status: Optional[int] = HTTP_200_OK,
cookies: Optional[list[str]] = None,
thread_id: str | None = None,
status: int | None = HTTP_200_OK,
cookies: list[str] | None = None,
):
headers = (
[
Expand Down Expand Up @@ -507,7 +506,7 @@ def process_extra_state(self, data: dict) -> dict:
def prepare_response(
self,
serializer: SyllabusChatRequestSerializer,
object_id_field: Optional[str] = None,
object_id_field: str | None = None,
) -> tuple[str, list[str]]:
"""Set the course id as the default object id field"""
object_id_field = object_id_field or "course_id"
Expand Down Expand Up @@ -620,7 +619,7 @@ def create_chatbot(
def prepare_response(
self,
serializer: TutorChatRequestSerializer,
object_id_field: Optional[str] = None,
object_id_field: str | None = None,
) -> tuple[str, list[str]]:
"""Set the edx_module_id as the default object id field"""
object_id_field = object_id_field or "edx_module_id"
Expand Down Expand Up @@ -673,7 +672,7 @@ def create_chatbot(
def prepare_response(
self,
serializer: TutorChatRequestSerializer,
object_id_field: Optional[str] = None,
object_id_field: str | None = None,
) -> tuple[str, list[str]]:
"""Set the edx_module_id as the default object id field"""
object_id_field = "object_id"
Expand Down Expand Up @@ -752,7 +751,7 @@ def process_extra_state(self, data: dict) -> dict:
def prepare_response(
self,
serializer: VideoGPTRequestSerializer,
object_id_field: Optional[str] = None,
object_id_field: str | None = None,
) -> tuple[str, list[str]]:
"""Set the problem code as the default object id field"""
object_id_field = object_id_field or "transcript_asset_id"
Expand Down
Loading
Loading