-
-
Notifications
You must be signed in to change notification settings - Fork 100
Expand file tree
/
Copy pathwebsocket.py
More file actions
148 lines (117 loc) · 5.17 KB
/
websocket.py
File metadata and controls
148 lines (117 loc) · 5.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import asyncio
import typing as t
import nacl.encoding
import nacl.utils
from asgiref.sync import sync_to_async
from django.db.models import QuerySet
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status
from redis import asyncio as aioredis
from redis.exceptions import ConnectionError
from etebase_server.django import models
from etebase_server.django.utils import CallbackContext, get_user_queryset
from etebase_server.myauth.models import UserType, get_typed_user_model
from ..dependencies import get_collection_queryset, get_item_queryset
from ..exceptions import NotSupported
from ..msgpack import MsgpackRoute, msgpack_decode, msgpack_encode
from ..redis import redisw
from ..utils import BaseModel, permission_responses
User = get_typed_user_model()
websocket_router = APIRouter(route_class=MsgpackRoute, responses=permission_responses)
CollectionQuerySet = QuerySet[models.Collection]
TICKET_VALIDITY_SECONDS = 10
class TicketRequest(BaseModel):
collection: str
class TicketOut(BaseModel):
ticket: str
class TicketInner(BaseModel):
user: int
req: TicketRequest
async def get_ticket(
ticket_request: TicketRequest,
user: UserType,
):
"""Get an authentication ticket that can be used with the websocket endpoint for authentication"""
if not redisw.is_active:
raise NotSupported(detail="This end-point requires Redis to be configured")
uid = nacl.encoding.URLSafeBase64Encoder.encode(nacl.utils.random(32))
ticket_model = TicketInner(user=user.id, req=ticket_request)
ticket_raw = msgpack_encode(ticket_model.dict())
await redisw.redis.set(uid, ticket_raw, ex=TICKET_VALIDITY_SECONDS * 1000)
return TicketOut(ticket=uid)
async def load_websocket_ticket(websocket: WebSocket, ticket: str) -> t.Optional[TicketInner]:
content = await redisw.redis.get(ticket)
if content is None:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return None
await redisw.redis.delete(ticket)
return TicketInner(**msgpack_decode(content))
def get_websocket_user(websocket: WebSocket, ticket_model: t.Optional[TicketInner] = Depends(load_websocket_ticket)):
if ticket_model is None:
return None
user_queryset = get_user_queryset(User.objects.all(), CallbackContext(websocket.path_params))
return user_queryset.get(id=ticket_model.user)
@websocket_router.websocket("/{ticket}/")
async def websocket_endpoint(
websocket: WebSocket,
stoken: t.Optional[str] = None,
user: t.Optional[UserType] = Depends(get_websocket_user),
ticket_model: TicketInner = Depends(load_websocket_ticket),
):
if user is None:
return
await websocket.accept()
await redis_connector(websocket, ticket_model, user, stoken)
async def send_item_updates(
websocket: WebSocket,
collection: models.Collection,
user: UserType,
stoken: t.Optional[str],
):
from .collection import item_list_common
done = False
while not done:
queryset = await sync_to_async(get_item_queryset)(collection)
response = await sync_to_async(item_list_common)(queryset, user, stoken, limit=50, prefetch="auto")
done = response.done
if len(response.data) > 0:
await websocket.send_bytes(msgpack_encode(response.dict()))
async def redis_connector(websocket: WebSocket, ticket_model: TicketInner, user: UserType, stoken: t.Optional[str]):
async def producer_handler(r: aioredis.Redis, ws: WebSocket):
pubsub = r.pubsub()
channel_name = f"col.{ticket_model.req.collection}"
await pubsub.subscribe(channel_name)
# Send missing items if we are not up to date
queryset: QuerySet[models.Collection] = get_collection_queryset(user)
collection: t.Optional[models.Collection] = await sync_to_async(
queryset.filter(uid=ticket_model.req.collection).first
)()
if collection is None:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await send_item_updates(websocket, collection, user, stoken)
async def handle_message():
msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=20)
if msg and msg['type'] == 'message':
message = msg['data']
await ws.send_bytes(message)
try:
while True:
# We wait on the websocket so we fail if web sockets fail or get data
receive = asyncio.create_task(websocket.receive())
handle = asyncio.create_task(handle_message())
done, pending = await asyncio.wait(
{receive, handle},
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
if receive in done:
# Web socket should never receive any data
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
except ConnectionError:
await websocket.close(code=status.WS_1012_SERVICE_RESTART)
except WebSocketDisconnect:
pass
redis = redisw.redis
await producer_handler(redis, websocket)