From 6106187e5cf36fa273aa146e18495ee3cfab40c0 Mon Sep 17 00:00:00 2001 From: Ronnie Dutta <61982285+MetRonnie@users.noreply.github.com> Date: Thu, 12 Mar 2026 12:46:14 +0000 Subject: [PATCH 1/3] Switch to graphql-transport-ws subprotocol --- cylc/uiserver/app.py | 3 -- cylc/uiserver/graphql/tornado_ws.py | 58 ++++++++++++----------------- cylc/uiserver/handlers.py | 36 +++++++++--------- 3 files changed, 42 insertions(+), 55 deletions(-) diff --git a/cylc/uiserver/app.py b/cylc/uiserver/app.py index 9a562e97..0b3c81df 100644 --- a/cylc/uiserver/app.py +++ b/cylc/uiserver/app.py @@ -467,8 +467,6 @@ def __init__(self, *args, **kwargs): self.log, self.max_threads, ) - # sub_status dictionary storing status of subscriptions - self.sub_statuses = {} self.resolvers = Resolvers( self, self.data_store_mgr, @@ -575,7 +573,6 @@ def initialize_handlers(self): { 'sub_server': self.subscription_server, 'resolvers': self.resolvers, - 'sub_statuses': self.sub_statuses } ), ( diff --git a/cylc/uiserver/graphql/tornado_ws.py b/cylc/uiserver/graphql/tornado_ws.py index 1ed40d4f..e6acd182 100644 --- a/cylc/uiserver/graphql/tornado_ws.py +++ b/cylc/uiserver/graphql/tornado_ws.py @@ -52,18 +52,14 @@ NO_MSG_DELAY = 1.0 -GRAPHQL_WS = "graphql-ws" -WS_PROTOCOL = GRAPHQL_WS -GQL_CONNECTION_INIT = "connection_init" # Client -> Server -GQL_CONNECTION_ACK = "connection_ack" # Server -> Client -GQL_CONNECTION_ERROR = "connection_error" # Server -> Client -GQL_CONNECTION_TERMINATE = "connection_terminate" # Client -> Server -GQL_CONNECTION_KEEP_ALIVE = "ka" # Server -> Client -GQL_START = "start" # Client -> Server -GQL_DATA = "data" # Server -> Client -GQL_ERROR = "error" # Server -> Client -GQL_COMPLETE = "complete" # Server -> Client -GQL_STOP = "stop" # Client -> Server +# See https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md +WS_PROTOCOL = 'graphql-transport-ws' +GQL_CONNECTION_INIT = 'connection_init' # Client -> Server +GQL_CONNECTION_ACK = 'connection_ack' # Server -> Client +GQL_SUBSCRIBE = 'subscribe' # Client -> Server +GQL_NEXT = 'next' # Server -> Client +GQL_ERROR = 'error' # Server -> Client +GQL_COMPLETE = 'complete' # Bidrectional REQ_HEADER_INFO = { 'Host', @@ -189,41 +185,33 @@ async def _process_message(self, connection_context, parsed_message): connection_context, op_id, payload ) - elif op_type == GQL_CONNECTION_TERMINATE: - return self.on_connection_terminate(connection_context, op_id) - - elif op_type == GQL_START: + elif op_type == GQL_SUBSCRIBE: if not isinstance(payload, dict): raise AssertionError("The payload must be a dict") params = self.get_graphql_params(connection_context, payload) return await self.on_start(connection_context, op_id, params) - elif op_type == GQL_STOP: + elif op_type == GQL_COMPLETE: return await self.on_stop(connection_context, op_id) else: - return await self.send_error( - connection_context, - op_id, - Exception("Invalid message type: {}.".format(op_type)), + connection_context.ws.close( + 4400, f"Invalid message type: {op_type}" ) async def on_connection_init(self, connection_context, op_id, payload): try: await self.on_connect(connection_context, payload) await self.send_message( - connection_context, op_type=GQL_CONNECTION_ACK) + connection_context, op_type=GQL_CONNECTION_ACK + ) except Exception as e: - await self.send_error( - connection_context, op_id, e, GQL_CONNECTION_ERROR) + await self.send_error(connection_context, op_id, e, GQL_ERROR) await connection_context.ws.close(1011) async def on_connect(self, connection_context, payload): pass - def on_connection_terminate(self, connection_context, op_id): - return connection_context.ws.close(1011) - def get_graphql_params(self, connection_context, payload): # Create a new context object for each subscription, # which allows it to carry a unique subscription id. @@ -262,6 +250,7 @@ async def on_open(self, connection_context): async def on_stop(self, connection_context, op_id): return await connection_context.unsubscribe(op_id) + connection_context.request_context['sub_statuses'][op_id] = 'stop' async def on_close(self, connection_context): return await connection_context.unsubscribe_all() @@ -287,6 +276,8 @@ async def on_start(self, connection_context, op_id, params): # with this id. await connection_context.unsubscribe(op_id) + connection_context.request_context['sub_statuses'][op_id] = 'start' + params['kwargs']['root_value'] = op_id execution_result = await self.execute(params) iterator = None @@ -370,7 +361,7 @@ async def send_execution_result( result = execution_result.formatted return await self.send_message( - connection_context, op_id, GQL_DATA, result + connection_context, op_id, GQL_NEXT, result ) async def on_operation_complete(self, connection_context, op_id): @@ -384,13 +375,9 @@ async def send_error( if error_type is None: error_type = GQL_ERROR - if error_type not in {GQL_CONNECTION_ERROR, GQL_ERROR}: - raise AssertionError( - "error_type should be one of the allowed error messages" - " GQL_CONNECTION_ERROR or GQL_ERROR" - ) + assert error_type == GQL_ERROR, "error_type should be GQL_ERROR" - error_payload = {"message": str(error)} + error_payload = [{"message": str(error)}] with suppress(WebSocketClosedError): return await self.send_message( @@ -405,6 +392,7 @@ async def on_message(self, connection_context, message): else: parsed_message = message except Exception as e: - return await self.send_error(connection_context, None, e) + connection_context.ws.close(4400, str(e)) + return None return self.process_message(connection_context, parsed_message) diff --git a/cylc/uiserver/handlers.py b/cylc/uiserver/handlers.py index 4238a198..1177b00d 100644 --- a/cylc/uiserver/handlers.py +++ b/cylc/uiserver/handlers.py @@ -19,7 +19,13 @@ import json import os import re -from typing import TYPE_CHECKING, Callable, Dict, Awaitable, Optional +from typing import ( + TYPE_CHECKING, + Awaitable, + Callable, + Literal, + Optional, +) from cylc.flow import __version__ as cylc_flow_version from jupyter_server.base.handlers import JupyterHandler @@ -30,18 +36,22 @@ from tornado.ioloop import IOLoop from cylc.uiserver import __version__ -from cylc.uiserver.authorise import Authorization, AuthorizationMiddleware +from cylc.uiserver.authorise import ( + Authorization, + AuthorizationMiddleware, +) from cylc.uiserver.graphql import authenticated as websockets_authenticated from cylc.uiserver.graphql.tornado import TornadoGraphQLHandler -from cylc.uiserver.graphql.tornado_ws import GRAPHQL_WS +from cylc.uiserver.graphql.tornado_ws import WS_PROTOCOL from cylc.uiserver.utils import is_bearer_token_authenticated if TYPE_CHECKING: - from cylc.uiserver.resolvers import Resolvers - from cylc.uiserver.graphql.tornado_ws import TornadoSubscriptionServer from jupyter_server.auth.identity import User as JPSUser + from cylc.uiserver.graphql.tornado_ws import TornadoSubscriptionServer + from cylc.uiserver.resolvers import Resolvers + ME = getpass.getuser() RE_SLASH = re.compile(r'\/+') @@ -369,14 +379,15 @@ async def run(self, *args, **kwargs): class SubscriptionHandler(CylcAppHandler, websocket.WebSocketHandler): """Endpoint for performing GraphQL subscriptions.""" # No authorization decorators here, auth handled in AuthorizationMiddleware - def initialize(self, sub_server, resolvers, sub_statuses=None): + def initialize(self, sub_server, resolvers): self.queue: Queue = Queue(100) self.subscription_server: TornadoSubscriptionServer = sub_server self.resolvers: Resolvers = resolvers - self.sub_statuses: Dict = sub_statuses + # sub_status dictionary storing status of subscriptions + self.sub_statuses: dict[str, Literal["start", "stop"]] = {} def select_subprotocol(self, subprotocols): - return GRAPHQL_WS + return WS_PROTOCOL @websockets_authenticated def get(self, *args, **kwargs): @@ -392,15 +403,6 @@ def open(self, *args, **kwargs): # noqa: A003 ) async def on_message(self, message): - try: - message_dict = json.loads(message) - op_id = message_dict.get("id", None) - if (message_dict['type'] == 'start'): - self.sub_statuses[op_id] = 'start' - if (message_dict['type'] == 'stop'): - self.sub_statuses[op_id] = 'stop' - except (KeyError, ValueError): - pass await self.queue.put(message) async def recv(self): From 52ea134d900c93d405626e52521fc6b3d347c16a Mon Sep 17 00:00:00 2001 From: Ronnie Dutta <61982285+MetRonnie@users.noreply.github.com> Date: Wed, 18 Mar 2026 14:29:27 +0000 Subject: [PATCH 2/3] Tidy and add type annotations --- cylc/uiserver/authorise.py | 2 +- cylc/uiserver/graphql/tornado_ws.py | 267 +++++++++++++++------------- 2 files changed, 148 insertions(+), 121 deletions(-) diff --git a/cylc/uiserver/authorise.py b/cylc/uiserver/authorise.py index c047b967..57256819 100644 --- a/cylc/uiserver/authorise.py +++ b/cylc/uiserver/authorise.py @@ -504,7 +504,7 @@ class AuthorizationMiddleware: """ - auth = None + auth: Authorization def resolve(self, next_, root, info, **args): current_user = info.context["current_user"] diff --git a/cylc/uiserver/graphql/tornado_ws.py b/cylc/uiserver/graphql/tornado_ws.py index e6acd182..7f655874 100644 --- a/cylc/uiserver/graphql/tornado_ws.py +++ b/cylc/uiserver/graphql/tornado_ws.py @@ -28,7 +28,13 @@ from contextlib import suppress from inspect import isawaitable import json -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + Literal, + Type, +) from graphql import ( ExecutionResult, @@ -47,19 +53,35 @@ if TYPE_CHECKING: + from collections.abc import AsyncIterator + + from graphene import Schema + from graphql import ExecutionContext + from tornado.httputil import HTTPServerRequest + + from cylc.uiserver.authorise import Authorization from cylc.uiserver.handlers import SubscriptionHandler + SendOperationType = Literal[ + "connection_ack", "ping", "pong", "next", "error", "complete" + ] NO_MSG_DELAY = 1.0 # See https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md WS_PROTOCOL = 'graphql-transport-ws' -GQL_CONNECTION_INIT = 'connection_init' # Client -> Server -GQL_CONNECTION_ACK = 'connection_ack' # Server -> Client -GQL_SUBSCRIBE = 'subscribe' # Client -> Server -GQL_NEXT = 'next' # Server -> Client -GQL_ERROR = 'error' # Server -> Client -GQL_COMPLETE = 'complete' # Bidrectional +GQL_CONNECTION_INIT: Literal["connection_init"] = ( + 'connection_init' # Client -> Server +) +GQL_CONNECTION_ACK: Literal["connection_ack"] = ( + 'connection_ack' # Server -> Client +) +GQL_PING: Literal["ping"] = 'ping' # Bidirectional +GQL_PONG: Literal["pong"] = 'pong' # Bidirectional +GQL_SUBSCRIBE: Literal["subscribe"] = 'subscribe' # Client -> Server +GQL_NEXT: Literal["next"] = 'next' # Server -> Client +GQL_ERROR: Literal["error"] = 'error' # Server -> Client +GQL_COMPLETE: Literal["complete"] = 'complete' # Bidrectional REQ_HEADER_INFO = { 'Host', @@ -72,40 +94,33 @@ class TornadoConnectionContext: - - def __init__(self, ws, request_context=None): - self.ws: SubscriptionHandler = ws - self.operations = {} + def __init__(self, ws: 'SubscriptionHandler', request_context: dict): + self.ws = ws + self.operations: dict[str, 'AsyncIterator'] = {} self.request_context = request_context self.pending_tasks: set[asyncio.Task] = set() - def has_operation(self, op_id): + def has_operation(self, op_id: str): return op_id in self.operations - def register_operation(self, op_id, async_iterator): + def register_operation(self, op_id: str, async_iterator: 'AsyncIterator'): self.operations[op_id] = async_iterator - def get_operation(self, op_id): + def get_operation(self, op_id: str): return self.operations[op_id] - def remove_operation(self, op_id): - try: - return self.operations.pop(op_id) - except KeyError: - return - - async def receive(self): - return self.ws.recv_nowait() + def remove_operation(self, op_id: str): + return self.operations.pop(op_id, None) @property - def closed(self): + def closed(self) -> bool: return self.ws.close_code is not None def remember_task(self, task: asyncio.Task) -> None: self.pending_tasks.add(task) task.add_done_callback(self.pending_tasks.discard) - async def unsubscribe(self, op_id): + async def unsubscribe(self, op_id: str) -> None: async_iterator = self._unsubscribe(op_id) if ( getattr(async_iterator, "future", None) @@ -113,7 +128,7 @@ async def unsubscribe(self, op_id): ): await async_iterator.future - def _unsubscribe(self, op_id): + def _unsubscribe(self, op_id: str): async_iterator = self.remove_operation(op_id) if hasattr(async_iterator, "dispose"): async_iterator.dispose() @@ -136,14 +151,12 @@ class TornadoSubscriptionServer: def __init__( self, - schema, - loop=None, - middleware=None, - execution_context_class=None, - auth=None + schema: 'Schema', + middleware: Iterable[Type], + execution_context_class: 'Type[ExecutionContext]', + auth: 'Authorization', ): self.schema = schema - self.loop = loop self.middleware = middleware self.execution_context_class = execution_context_class self.auth = auth @@ -167,52 +180,67 @@ async def execute(self, params): **params['kwargs'] ) - def process_message(self, connection_context, parsed_message): - task = asyncio.ensure_future( - self._process_message(connection_context, parsed_message), - loop=self.loop + def process_message( + self, + connection_context: TornadoConnectionContext, + parsed_message: dict, + ) -> asyncio.Task[None]: + """Process a message from the client""" + task = asyncio.create_task( + self._process_message(connection_context, parsed_message) ) connection_context.remember_task(task) return task - async def _process_message(self, connection_context, parsed_message): + async def _process_message( + self, + connection_context: TornadoConnectionContext, + parsed_message: dict, + ) -> None: op_id = parsed_message.get("id") op_type = parsed_message.get("type") payload = parsed_message.get("payload") if op_type == GQL_CONNECTION_INIT: - return await self.on_connection_init( - connection_context, op_id, payload - ) + return await self.on_connection_init(connection_context) elif op_type == GQL_SUBSCRIBE: if not isinstance(payload, dict): raise AssertionError("The payload must be a dict") + assert op_id, "The message must have an operation ID" params = self.get_graphql_params(connection_context, payload) - return await self.on_start(connection_context, op_id, params) + return await self.on_subscribe(connection_context, op_id, params) elif op_type == GQL_COMPLETE: - return await self.on_stop(connection_context, op_id) + assert op_id, "The message must have an operation ID" + await connection_context.unsubscribe(op_id) + connection_context.request_context['sub_statuses'][op_id] = 'stop' + return + + elif op_type == GQL_PING: + return await self.send_message(connection_context, GQL_PONG) + + elif op_type == GQL_PONG: + return else: connection_context.ws.close( 4400, f"Invalid message type: {op_type}" ) - async def on_connection_init(self, connection_context, op_id, payload): + async def on_connection_init( + self, connection_context: TornadoConnectionContext + ) -> None: try: - await self.on_connect(connection_context, payload) await self.send_message( connection_context, op_type=GQL_CONNECTION_ACK ) except Exception as e: - await self.send_error(connection_context, op_id, e, GQL_ERROR) - await connection_context.ws.close(1011) + connection_context.ws.close(1011, str(e)) - async def on_connect(self, connection_context, payload): - pass - - def get_graphql_params(self, connection_context, payload): + def get_graphql_params( + self, connection_context: TornadoConnectionContext, payload: dict + ): # Create a new context object for each subscription, # which allows it to carry a unique subscription id. params = { @@ -225,14 +253,11 @@ def get_graphql_params(self, connection_context, payload): } # If middleware get instantiated here (optional), they will # be local/private to each subscription. - if self.middleware is not None: - middleware = list( - instantiate_middleware(self.middleware) - ) - else: - middleware = self.middleware - for mw in self.middleware: - if mw == AuthorizationMiddleware: + middleware = list( + instantiate_middleware(self.middleware) + ) + for mw in middleware: + if isinstance(mw, AuthorizationMiddleware): mw.auth = self.auth return { 'query': payload.get("query"), @@ -245,33 +270,36 @@ def get_graphql_params(self, connection_context, payload): ) } - async def on_open(self, connection_context): - pass - - async def on_stop(self, connection_context, op_id): - return await connection_context.unsubscribe(op_id) - connection_context.request_context['sub_statuses'][op_id] = 'stop' - - async def on_close(self, connection_context): - return await connection_context.unsubscribe_all() - - async def handle(self, ws, request_context=None): + async def handle(self, ws: 'SubscriptionHandler', request_context: dict): await asyncio.shield(self._handle(ws, request_context)) - async def _handle(self, ws, request_context=None): + async def _handle(self, ws: 'SubscriptionHandler', request_context: dict): connection_context = TornadoConnectionContext(ws, request_context) - await self.on_open(connection_context) while not connection_context.closed: try: - message = await connection_context.receive() + message = connection_context.ws.recv_nowait() except QueueEmpty: await asyncio.sleep(NO_MSG_DELAY) else: - await self.on_message(connection_context, message) + self.on_message(connection_context, message) - await self.on_close(connection_context) + await connection_context.unsubscribe_all() - async def on_start(self, connection_context, op_id, params): + async def on_subscribe( + self, + connection_context: TornadoConnectionContext, + op_id: str, + params: dict, + ) -> None: + """Run when the client starts a subscription. + + Execute the GraphQL subscription and send to the client each resulting + delta in turn. + + This will not return until the subscription ends (e.g. workflow stops), + at which point it will send a complete message to the client and + clean up. + """ # Attempt to unsubscribe first in case we already have a subscription # with this id. await connection_context.unsubscribe(op_id) @@ -298,26 +326,40 @@ async def on_start(self, connection_context, op_id, params): except (GeneratorExit, asyncio.CancelledError): raise except Exception as e: - await self.send_error(connection_context, op_id, e) + with suppress(WebSocketClosedError): + await self.send_message( + connection_context, + GQL_ERROR, + op_id, + payload=[{"message": str(e)}], + ) finally: if iterator: await iterator.aclose() + + # Complete the subscription from the server side: with suppress(WebSocketClosedError): - await self.send_message(connection_context, op_id, GQL_COMPLETE) + await self.send_message(connection_context, GQL_COMPLETE, op_id) await connection_context.unsubscribe(op_id) - await self.on_operation_complete(connection_context, op_id) + connection_context.request_context['sub_statuses'].pop(op_id, None) async def send_message( - self, connection_context, op_id=None, op_type=None, payload=None - ): - message = self.build_message(op_id, op_type, payload) + self, + connection_context: TornadoConnectionContext, + op_type: 'SendOperationType', + op_id: str | None = None, + payload=None, + ) -> None: + message = self.build_message(op_type, op_id, payload) try: return await connection_context.ws.write_message(message) except WebSocketClosedError: resolvers = connection_context.request_context.get('resolvers') if resolvers is not None: - request = connection_context.request_context.get('request') - headers = {} + request: HTTPServerRequest = ( + connection_context.request_context['request'] + ) + headers: dict[str, Any] = {} headers.update(getattr(request, 'headers', {})) resolvers.log.warning( '[GraphQL WS] Websocket closed on send' @@ -337,23 +379,28 @@ async def send_message( # Raise exception, in order to exit the on_start subscription loop. raise - def build_message(self, _id, op_type, payload): - message = {} - if _id is not None: - message["id"] = _id - if op_type is not None: - message["type"] = op_type + @staticmethod + def build_message( + op_type: 'SendOperationType', op_id: str | None, payload + ): + assert op_type, "Message must have a type" + message: dict[str, Any] = {"type": op_type} + if op_id is not None: + message["id"] = op_id if payload is not None: message["payload"] = payload - if not message: - raise AssertionError("You need to send at least one thing") return message async def send_execution_result( - self, connection_context, op_id, execution_result): + self, + connection_context: TornadoConnectionContext, + op_id: str, + execution_result: ExecutionResult, + ): # Resolve any pending promises if is_awaitable(execution_result.data): - await execution_result.data + # TODO: never hit? + await execution_result.data # type: ignore[misc] if execution_result.data and 'logs' not in execution_result.data: request_context = connection_context.request_context await request_context['resolvers'].flow_delta_processed( @@ -361,38 +408,18 @@ async def send_execution_result( result = execution_result.formatted return await self.send_message( - connection_context, op_id, GQL_NEXT, result + connection_context, GQL_NEXT, op_id, result ) - async def on_operation_complete(self, connection_context, op_id): - # remove the subscription from the sub_statuses dict - with suppress(KeyError): - connection_context.request_context['sub_statuses'].pop(op_id) - - async def send_error( - self, connection_context, op_id, error, error_type=None - ): - if error_type is None: - error_type = GQL_ERROR - - assert error_type == GQL_ERROR, "error_type should be GQL_ERROR" - - error_payload = [{"message": str(error)}] - - with suppress(WebSocketClosedError): - return await self.send_message( - connection_context, op_id, error_type, error_payload) - - async def on_message(self, connection_context, message): + def on_message( + self, connection_context: TornadoConnectionContext, message + ) -> None: try: - if not isinstance(message, dict): - parsed_message = json.loads(message) - if not isinstance(parsed_message, dict): - raise AssertionError("Payload must be an object.") - else: - parsed_message = message + parsed_message = json.loads(message) + if not isinstance(parsed_message, dict): + raise AssertionError("Message must be an object") except Exception as e: connection_context.ws.close(4400, str(e)) return None - return self.process_message(connection_context, parsed_message) + self.process_message(connection_context, parsed_message) From feb91ec2307ee1a78f9e7a623885613e0a403627 Mon Sep 17 00:00:00 2001 From: Ronnie Dutta <61982285+MetRonnie@users.noreply.github.com> Date: Mon, 23 Mar 2026 17:32:32 +0000 Subject: [PATCH 3/3] Tidy using `Enum` and `match...case` --- cylc/uiserver/graphql/tornado_ws.py | 109 ++++++++++++++++------------ 1 file changed, 63 insertions(+), 46 deletions(-) diff --git a/cylc/uiserver/graphql/tornado_ws.py b/cylc/uiserver/graphql/tornado_ws.py index 7f655874..0e7446d9 100644 --- a/cylc/uiserver/graphql/tornado_ws.py +++ b/cylc/uiserver/graphql/tornado_ws.py @@ -26,6 +26,10 @@ import asyncio from asyncio.queues import QueueEmpty from contextlib import suppress +from enum import ( + StrEnum, + auto, +) from inspect import isawaitable import json from typing import ( @@ -62,26 +66,35 @@ from cylc.uiserver.authorise import Authorization from cylc.uiserver.handlers import SubscriptionHandler - SendOperationType = Literal[ - "connection_ack", "ping", "pong", "next", "error", "complete" - ] NO_MSG_DELAY = 1.0 -# See https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md WS_PROTOCOL = 'graphql-transport-ws' -GQL_CONNECTION_INIT: Literal["connection_init"] = ( - 'connection_init' # Client -> Server -) -GQL_CONNECTION_ACK: Literal["connection_ack"] = ( - 'connection_ack' # Server -> Client -) -GQL_PING: Literal["ping"] = 'ping' # Bidirectional -GQL_PONG: Literal["pong"] = 'pong' # Bidirectional -GQL_SUBSCRIBE: Literal["subscribe"] = 'subscribe' # Client -> Server -GQL_NEXT: Literal["next"] = 'next' # Server -> Client -GQL_ERROR: Literal["error"] = 'error' # Server -> Client -GQL_COMPLETE: Literal["complete"] = 'complete' # Bidrectional + + +class OperationType(StrEnum): + """graphql-transport-ws message types. + + See https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md + """ + CONNECTION_INIT = auto() # Client -> Server + CONNECTION_ACK = auto() # Server -> Client + PING = auto() # Bidirectional + PONG = auto() # Bidirectional + SUBSCRIBE = auto() # Client -> Server + NEXT = auto() # Server -> Client + ERROR = auto() # Server -> Client + COMPLETE = auto() # Bidrectional + + +SendOperationType = Literal[ + OperationType.CONNECTION_ACK, + OperationType.PING, + OperationType.PONG, + OperationType.NEXT, + OperationType.ERROR, + OperationType.COMPLETE, +] REQ_HEADER_INFO = { 'Host', @@ -201,39 +214,41 @@ async def _process_message( op_type = parsed_message.get("type") payload = parsed_message.get("payload") - if op_type == GQL_CONNECTION_INIT: - return await self.on_connection_init(connection_context) - - elif op_type == GQL_SUBSCRIBE: - if not isinstance(payload, dict): - raise AssertionError("The payload must be a dict") - assert op_id, "The message must have an operation ID" - params = self.get_graphql_params(connection_context, payload) - return await self.on_subscribe(connection_context, op_id, params) - - elif op_type == GQL_COMPLETE: - assert op_id, "The message must have an operation ID" - await connection_context.unsubscribe(op_id) - connection_context.request_context['sub_statuses'][op_id] = 'stop' - return + match op_type: + case OperationType.CONNECTION_INIT.value: + await self.on_connection_init(connection_context) + + case OperationType.SUBSCRIBE.value: + if not isinstance(payload, dict): + raise AssertionError("The payload must be a dict") + assert op_id, "The message must have an operation ID" + params = self.get_graphql_params(connection_context, payload) + await self.on_subscribe(connection_context, op_id, params) + + case OperationType.COMPLETE.value: + assert op_id, "The message must have an operation ID" + await connection_context.unsubscribe(op_id) + connection_context.request_context['sub_statuses'][op_id] = ( + 'stop' + ) - elif op_type == GQL_PING: - return await self.send_message(connection_context, GQL_PONG) + case OperationType.PING.value: + await self.send_message(connection_context, OperationType.PONG) - elif op_type == GQL_PONG: - return + case OperationType.PONG.value: + pass - else: - connection_context.ws.close( - 4400, f"Invalid message type: {op_type}" - ) + case _: + connection_context.ws.close( + 4400, f"Invalid message type: {op_type}" + ) async def on_connection_init( self, connection_context: TornadoConnectionContext ) -> None: try: await self.send_message( - connection_context, op_type=GQL_CONNECTION_ACK + connection_context, op_type=OperationType.CONNECTION_ACK ) except Exception as e: connection_context.ws.close(1011, str(e)) @@ -329,7 +344,7 @@ async def on_subscribe( with suppress(WebSocketClosedError): await self.send_message( connection_context, - GQL_ERROR, + OperationType.ERROR, op_id, payload=[{"message": str(e)}], ) @@ -339,14 +354,16 @@ async def on_subscribe( # Complete the subscription from the server side: with suppress(WebSocketClosedError): - await self.send_message(connection_context, GQL_COMPLETE, op_id) + await self.send_message( + connection_context, OperationType.COMPLETE, op_id + ) await connection_context.unsubscribe(op_id) connection_context.request_context['sub_statuses'].pop(op_id, None) async def send_message( self, connection_context: TornadoConnectionContext, - op_type: 'SendOperationType', + op_type: SendOperationType, op_id: str | None = None, payload=None, ) -> None: @@ -381,10 +398,10 @@ async def send_message( @staticmethod def build_message( - op_type: 'SendOperationType', op_id: str | None, payload + op_type: SendOperationType, op_id: str | None, payload ): assert op_type, "Message must have a type" - message: dict[str, Any] = {"type": op_type} + message: dict[str, Any] = {"type": str(op_type)} if op_id is not None: message["id"] = op_id if payload is not None: @@ -408,7 +425,7 @@ async def send_execution_result( result = execution_result.formatted return await self.send_message( - connection_context, GQL_NEXT, op_id, result + connection_context, OperationType.NEXT, op_id, result ) def on_message(