diff --git a/changelog.d/18968.feature b/changelog.d/18968.feature new file mode 100644 index 00000000000..30368b23fd9 --- /dev/null +++ b/changelog.d/18968.feature @@ -0,0 +1 @@ +Implement support for MSC4354: Sticky Events. diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 94e74df9d11..cb963c04ec5 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -135,6 +135,8 @@ experimental_features: msc4155_enabled: true # Thread Subscriptions msc4306_enabled: true + # Sticky Events + msc4354_enabled: true server_notices: system_mxid_localpart: _server diff --git a/scripts-dev/complement.sh b/scripts-dev/complement.sh index 2447e0dc7bb..cf7c1884abc 100755 --- a/scripts-dev/complement.sh +++ b/scripts-dev/complement.sh @@ -229,6 +229,7 @@ main() { ./tests/msc4140 ./tests/msc4155 ./tests/msc4306 + ./tests/msc4354 ) # Export the list of test packages as a space-separated environment variable, so other diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 79b2a0c528e..4fa3def624e 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -132,6 +132,7 @@ "has_known_state", "is_encrypted", ], + "sticky_events": ["soft_failed"], "thread_subscriptions": ["subscribed", "automatic"], "users": ["shadow_banned", "approved", "locked", "suspended"], "un_partial_stated_event_stream": ["rejection_status_changed"], diff --git a/synapse/api/constants.py b/synapse/api/constants.py index d41e44b1541..ea801f08f3b 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -24,7 +24,7 @@ """Contains constants from the specification.""" import enum -from typing import Final +from typing import Final, TypedDict # the max size of a (canonical-json-encoded) event MAX_PDU_SIZE = 65536 @@ -279,6 +279,8 @@ class EventUnsignedContentFields: # Requesting user's membership, per MSC4115 MEMBERSHIP: Final = "membership" + STICKY_TTL: Final = "msc4354_sticky_duration_ttl_ms" + class MTextFields: """Fields found inside m.text content blocks.""" @@ -364,3 +366,18 @@ class Direction(enum.Enum): class ProfileFields: DISPLAYNAME: Final = "displayname" AVATAR_URL: Final = "avatar_url" + + +class StickyEventField(TypedDict): + duration_ms: int + + +class StickyEvent: + QUERY_PARAM_NAME: Final = "org.matrix.msc4354.sticky_duration_ms" + FIELD_NAME: Final = "msc4354_sticky" + MAX_DURATION_MS: Final = 3600000 # 1 hour + """ + Maximum stickiness duration as specified in MSC4354. + Ensures that data in the /sync response can go down and not grow unbounded. + """ + MAX_EVENTS_IN_SYNC: Final = 100 diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 9939c0fe7d7..959fe39e70e 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -101,6 +101,7 @@ from synapse.storage.databases.main.sliding_sync import SlidingSyncStore from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.storage.databases.main.stats import StatsStore +from synapse.storage.databases.main.sticky_events import StickyEventsWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore from synapse.storage.databases.main.task_scheduler import TaskSchedulerWorkerStore @@ -136,6 +137,7 @@ class GenericWorkerStore( RoomWorkerStore, DirectoryWorkerStore, ThreadSubscriptionsWorkerStore, + StickyEventsWorkerStore, PushRulesWorkerStore, ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index dc5e096791a..a1a537545d7 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -597,5 +597,8 @@ def read_config( # (and MSC4308: Thread Subscriptions extension to Sliding Sync) self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False) + # MSC4354: Sticky Events + self.msc4354_enabled: bool = experimental.get("msc4354_enabled", False) + # MSC4380: Invite blocking self.msc4380_enabled: bool = experimental.get("msc4380_enabled", False) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index ec8ab9506bc..42f11b15d8f 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -127,7 +127,7 @@ class WriterLocations: """Specifies the instances that write various streams. Attributes: - events: The instances that write to the event and backfill streams. + events: The instances that write to the event, backfill and sticky events streams. typing: The instances that write to the typing stream. Currently can only be a single instance. to_device: The instances that write to the to_device stream. Currently diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 66f50115e32..4b354dd6fbd 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -66,6 +66,7 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( MutableStateMap, + StateKey, StateMap, StrCollection, UserID, @@ -1200,8 +1201,8 @@ def get_public_keys(invite_event: "EventBase") -> list[dict[str, Any]]: def auth_types_for_event( room_version: RoomVersion, event: Union["EventBase", "EventBuilder"] -) -> set[tuple[str, str]]: - """Given an event, return a list of (EventType, StateKey) that may be +) -> set[StateKey]: + """Given an event, return a list of (state event type, state key) that may be needed to auth the event. The returned list may be a superset of what would actually be required depending on the full state of the room. diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index c7eaf7eda2b..83916211af8 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -36,7 +36,12 @@ import attr from unpaddedbase64 import encode_base64 -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import ( + EventContentFields, + EventTypes, + RelationTypes, + StickyEvent, +) from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.synapse_rust.events import EventInternalMetadata from synapse.types import ( @@ -318,6 +323,23 @@ def freeze(self) -> None: # this will be a no-op if the event dict is already frozen. self._dict = freeze(self._dict) + def sticky_duration(self) -> int | None: + """ + Returns the effective sticky duration of this event, or None + if the event does not have a sticky duration. + (Sticky Events are a MSC4354 feature.) + + Clamps the sticky duration to the maximum allowed duration. + """ + sticky_obj = self.get_dict().get(StickyEvent.FIELD_NAME, None) + if type(sticky_obj) is not dict: + return None + sticky_duration_ms = sticky_obj.get("duration_ms", None) + # MSC: Valid values are the integer range 0-MAX_DURATION_MS + if type(sticky_duration_ms) is int and sticky_duration_ms >= 0: + return min(sticky_duration_ms, StickyEvent.MAX_DURATION_MS) + return None + def __str__(self) -> str: return self.__repr__() diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 6a2812109d0..1490aa57527 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -24,7 +24,7 @@ import attr from signedjson.types import SigningKey -from synapse.api.constants import MAX_DEPTH, EventTypes +from synapse.api.constants import MAX_DEPTH, EventTypes, StickyEvent, StickyEventField from synapse.api.room_versions import ( KNOWN_EVENT_FORMAT_VERSIONS, EventFormatVersions, @@ -89,6 +89,10 @@ class EventBuilder: content: JsonDict = attr.Factory(dict) unsigned: JsonDict = attr.Factory(dict) + sticky: StickyEventField | None = None + """ + Fields for MSC4354: Sticky Events + """ # These only exist on a subset of events, so they raise AttributeError if # someone tries to get them when they don't exist. @@ -269,6 +273,9 @@ async def build( if self._origin_server_ts is not None: event_dict["origin_server_ts"] = self._origin_server_ts + if self.sticky is not None: + event_dict[StickyEvent.FIELD_NAME] = self.sticky + return create_local_event_from_event_dict( clock=self._clock, hostname=self._hostname, @@ -318,6 +325,7 @@ def for_room_version( unsigned=key_values.get("unsigned", {}), redacts=key_values.get("redacts", None), origin_server_ts=key_values.get("origin_server_ts", None), + sticky=key_values.get(StickyEvent.FIELD_NAME, None), ) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 04ba5b86db0..26aac8793a6 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -194,6 +194,8 @@ async def _check_sigs_and_hash( # using the event in prev_events). redacted_event = prune_event(pdu) redacted_event.internal_metadata.soft_failed = True + # Mark this as spam so we don't re-evaluate soft-failure status. + redacted_event.internal_metadata.policy_server_spammy = True return redacted_event return pdu diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 4a6d155217c..2c801f3d2ff 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -212,6 +212,11 @@ def notify_new_events(self, max_token: RoomStreamToken) -> None: # This should never get called. raise NotImplementedError() + def notify_new_server_joined(self, server: str, room_id: str) -> None: + """As per FederationSender""" + # This should never get called. + raise NotImplementedError() + def build_and_send_edu( self, destination: str, diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index f7240c2f7f9..daeca71c7fb 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -177,6 +177,7 @@ from synapse.util.duration import Duration from synapse.util.metrics import Measure from synapse.util.retryutils import filter_destinations_by_retry_limiter +from synapse.visibility import filter_events_for_server if TYPE_CHECKING: from synapse.events.presence_router import PresenceRouter @@ -240,6 +241,13 @@ def notify_new_events(self, max_token: RoomStreamToken) -> None: """ raise NotImplementedError() + @abc.abstractmethod + def notify_new_server_joined(self, server: str, room_id: str) -> None: + """This gets called when we a new server has joined a room. We might + want to send out some events to this server. + """ + raise NotImplementedError() + @abc.abstractmethod async def send_read_receipt(self, receipt: ReadReceipt) -> None: """Send a RR to any other servers in the room @@ -502,6 +510,66 @@ def _get_per_destination_queue( self._per_destination_queues[destination] = queue return queue + def notify_new_server_joined(self, server: str, room_id: str) -> None: + # We currently only use this notification for MSC4354: Sticky Events. + if not self.hs.config.experimental.msc4354_enabled: + return + # fire off a processing loop in the background + self.hs.run_as_background_process( + "process_new_server_joined_over_federation", + self._process_new_server_joined_over_federation, + server, + room_id, + ) + + async def _process_new_server_joined_over_federation( + self, new_server: str, room_id: str + ) -> None: + sticky_event_ids = await self.store.get_sticky_event_ids_sent_by_self( + room_id, + from_stream_pos=0, + ) + sticky_events = await self.store.get_events_as_list(sticky_event_ids) + + # We must not send events that are outliers / lack a stream ordering, else we won't be able to + # satisfy /get_missing_events requests + sticky_events = [ + ev + for ev in sticky_events + if ev.internal_metadata.stream_ordering is not None + and not ev.internal_metadata.is_outlier() + ] + # order by stream ordering so we present things in the right timeline order on the receiver + sticky_events.sort( + key=lambda ev: ev.internal_metadata.stream_ordering + or 0, # not possible to be 0 + ) + + sticky_events = await filter_events_for_server( + self._storage_controllers, + new_server, + self.server_name, + sticky_events, + redact=False, + filter_out_erased_senders=True, + filter_out_remote_partial_state_events=True, + ) + + if not sticky_events: + return + + logger.info( + "sending %d sticky events to newly joined server %s in room %s", + len(sticky_events), + new_server, + room_id, + ) + # we don't track that we sent up to this stream position since it won't make any difference + # since notify_new_server_joined is only called initially. + await self._transaction_manager.send_new_transaction( + new_server, sticky_events, [] + ) + def notify_new_events(self, max_token: RoomStreamToken) -> None: """This gets called when we have some new events we might want to send out to other servers. diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index cdacf16d725..65ae1a56df3 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -105,6 +105,7 @@ def __init__( self._instance_name = hs.get_instance_name() self._federation_shard_config = hs.config.worker.federation_shard_config self._state = hs.get_state_handler() + self._msc4354_enabled = hs.config.experimental.msc4354_enabled self._should_send_on_this_instance = True if not self._federation_shard_config.should_handle( @@ -583,6 +584,33 @@ async def _catch_up_transmission_loop(self) -> None: # send. extrem_events = await self._store.get_events_as_list(extrems) + if self._msc4354_enabled: + # we also want to send sticky events that are still active in this room + sticky_event_ids = ( + await self._store.get_sticky_event_ids_sent_by_self( + pdu.room_id, + last_successful_stream_ordering, + ) + ) + # skip any that are actually the forward extremities we want to send anyway + sticky_events = await self._store.get_events_as_list( + [ + event_id + for event_id in sticky_event_ids + if event_id not in extrems + ] + ) + if sticky_events: + # *prepend* these to the extrem list, so they are processed first. + # This ensures they will show up before the forward extrem in stream order + extrem_events = sticky_events + extrem_events + logger.info( + "Sending %d missed sticky events to %s: %r", + len(sticky_events), + self._destination, + pdu.room_id, + ) + new_pdus = [] for p in extrem_events: # We pulled this from the DB, so it'll be non-null diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index cb0a4dd6b23..864dc02db35 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -17,7 +17,7 @@ from twisted.internet.interfaces import IDelayedCall -from synapse.api.constants import EventTypes +from synapse.api.constants import EventTypes, StickyEvent from synapse.api.errors import ShadowBanError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME @@ -333,6 +333,7 @@ async def add( origin_server_ts: int | None, content: JsonDict, delay: int, + sticky_duration_ms: int | None, ) -> str: """ Creates a new delayed event and schedules its delivery. @@ -346,7 +347,9 @@ async def add( If None, the timestamp will be the actual time when the event is sent. content: The content of the event to be sent. delay: How long (in milliseconds) to wait before automatically sending the event. - + sticky_duration_ms: If an MSC4354 sticky event: the sticky duration (in milliseconds). + The event will be attempted to be reliably delivered to clients and remote servers + during its sticky period. Returns: The ID of the added delayed event. Raises: @@ -382,6 +385,7 @@ async def add( origin_server_ts=origin_server_ts, content=content, delay=delay, + sticky_duration_ms=sticky_duration_ms, ) if self._repl_client is not None: @@ -570,7 +574,10 @@ async def _send_event( if event.state_key is not None: event_dict["state_key"] = event.state_key - + if event.sticky_duration_ms is not None: + event_dict[StickyEvent.FIELD_NAME] = { + "duration_ms": event.sticky_duration_ms, + } ( sent_event, _, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 7808f8928b0..f1c6039f425 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -61,6 +61,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase from synapse.events.validator import EventValidator from synapse.federation.federation_client import InvalidResponseError +from synapse.federation.federation_server import _INBOUND_EVENT_HANDLING_LOCK_NAME from synapse.handlers.pagination import PURGE_PAGINATION_LOCK_NAME from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import nested_logging_context @@ -68,6 +69,7 @@ from synapse.metrics import SERVER_NAME_LABEL from synapse.module_api import NOT_SPAM from synapse.storage.databases.main.events_worker import EventRedactBehaviour +from synapse.storage.databases.main.lock import Lock from synapse.storage.invite_rule import InviteRule from synapse.types import JsonDict, StrCollection, get_domain_from_id from synapse.types.state import StateFilter @@ -639,125 +641,158 @@ async def do_invite_join( except ValueError: pass + lock: Lock | None = None async with self._is_partial_state_room_linearizer.queue(room_id): - already_partial_state_room = await self.store.is_partial_state_room( - room_id - ) + try: + # MSC4354: Sticky Events causes existing servers in the room to send sticky events + # to the newly joined server as soon as they realise the new server is in the room. + # If they do this before we've persisted the /send_join response we will be unable to + # process those PDUs. Therefore, we take a lock out now for this room, and release it + # once we have processed the /send_join response, to buffer up these inbound messages. + # This may be useful to do even without MSC4354, but it's gated behind an + # experimental flag check to reduce the chance of this having unintended side-effects + # e.g accidental deadlocks. Once we're confident of this behaviour, we can probably + # drop the flag check. We take the lock AFTER we have been queued by the linearizer + # else we would just hold the lock for no reason whilst in the queue: we want to hold + # the lock for the smallest amount of time possible. + if self.config.experimental.msc4354_enabled: + lock = await self.store.try_acquire_lock( + _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id + ) + # Insert the room into the rooms table now so we can process potential incoming + # /send transactions enough to be able to insert into the federation staging + # area. We won't process the staging area until we release the lock above. + await self.store.upsert_room_on_join( + room_id=room_id, + room_version=room_version_obj, + state_events=None, + ) - ret = await self.federation_client.send_join( - host_list, - event, - room_version_obj, - # Perform a full join when we are already in the room and it is a - # full state room, since we are not allowed to persist a partial - # state join event in a full state room. In the future, we could - # optimize this by always performing a partial state join and - # computing the state ourselves or retrieving it from the remote - # homeserver if necessary. - # - # There's a race where we leave the room, then perform a full join - # anyway. This should end up being fast anyway, since we would - # already have the full room state and auth chain persisted. - partial_state=not is_host_joined or already_partial_state_room, - ) + already_partial_state_room = await self.store.is_partial_state_room( + room_id + ) - event = ret.event - origin = ret.origin - state = ret.state - auth_chain = ret.auth_chain - auth_chain.sort(key=lambda e: e.depth) + ret = await self.federation_client.send_join( + host_list, + event, + room_version_obj, + # Perform a full join when we are already in the room and it is a + # full state room, since we are not allowed to persist a partial + # state join event in a full state room. In the future, we could + # optimize this by always performing a partial state join and + # computing the state ourselves or retrieving it from the remote + # homeserver if necessary. + # + # There's a race where we leave the room, then perform a full join + # anyway. This should end up being fast anyway, since we would + # already have the full room state and auth chain persisted. + partial_state=not is_host_joined or already_partial_state_room, + ) - logger.debug("do_invite_join auth_chain: %s", auth_chain) - logger.debug("do_invite_join state: %s", state) + event = ret.event + origin = ret.origin + state = ret.state + auth_chain = ret.auth_chain + auth_chain.sort(key=lambda e: e.depth) - logger.debug("do_invite_join event: %s", event) + logger.debug("do_invite_join auth_chain: %s", auth_chain) + logger.debug("do_invite_join state: %s", state) - # if this is the first time we've joined this room, it's time to add - # a row to `rooms` with the correct room version. If there's already a - # row there, we should override it, since it may have been populated - # based on an invite request which lied about the room version. - # - # federation_client.send_join has already checked that the room - # version in the received create event is the same as room_version_obj, - # so we can rely on it now. - # - await self.store.upsert_room_on_join( - room_id=room_id, - room_version=room_version_obj, - state_events=state, - ) + logger.debug("do_invite_join event: %s", event) - if ret.partial_state and not already_partial_state_room: - # Mark the room as having partial state. - # The background process is responsible for unmarking this flag, - # even if the join fails. - # TODO(faster_joins): - # We may want to reset the partial state info if it's from an - # old, failed partial state join. - # https://github.com/matrix-org/synapse/issues/13000 - - # FIXME: Ideally, we would store the full stream token here - # not just the minimum stream ID, so that we can compute an - # accurate list of device changes when un-partial-ing the - # room. The only side effect of this is that we may send - # extra unecessary device list outbound pokes through - # federation, which is harmless. - device_lists_stream_id = self.store.get_device_stream_token().stream - - await self.store.store_partial_state_room( + # if this is the first time we've joined this room, it's time to add + # a row to `rooms` with the correct room version. If there's already a + # row there, we should override it, since it may have been populated + # based on an invite request which lied about the room version. + # + # federation_client.send_join has already checked that the room + # version in the received create event is the same as room_version_obj, + # so we can rely on it now. + # + await self.store.upsert_room_on_join( room_id=room_id, - servers=ret.servers_in_room, - device_lists_stream_id=device_lists_stream_id, - joined_via=origin, + room_version=room_version_obj, + state_events=state, ) - try: - max_stream_id = ( - await self._federation_event_handler.process_remote_join( - origin, - room_id, - auth_chain, - state, - event, - room_version_obj, - partial_state=ret.partial_state, - ) - ) - except PartialStateConflictError: - # This should be impossible, since we hold the lock on the room's - # partial statedness. - logger.error( - "Room %s was un-partial stated while processing remote join.", - room_id, - ) - raise - else: - # Record the join event id for future use (when we finish the full - # join). We have to do this after persisting the event to keep - # foreign key constraints intact. if ret.partial_state and not already_partial_state_room: + # Mark the room as having partial state. + # The background process is responsible for unmarking this flag, + # even if the join fails. # TODO(faster_joins): - # We may want to reset the partial state info if it's from - # an old, failed partial state join. + # We may want to reset the partial state info if it's from an + # old, failed partial state join. # https://github.com/matrix-org/synapse/issues/13000 - await self.store.write_partial_state_rooms_join_event_id( - room_id, event.event_id + + # FIXME: Ideally, we would store the full stream token here + # not just the minimum stream ID, so that we can compute an + # accurate list of device changes when un-partial-ing the + # room. The only side effect of this is that we may send + # extra unecessary device list outbound pokes through + # federation, which is harmless. + device_lists_stream_id = ( + self.store.get_device_stream_token().stream ) - finally: - # Always kick off the background process that asynchronously fetches - # state for the room. - # If the join failed, the background process is responsible for - # cleaning up — including unmarking the room as a partial state - # room. - if ret.partial_state: - # Kick off the process of asynchronously fetching the state for - # this room. - self._start_partial_state_room_sync( - initial_destination=origin, - other_destinations=ret.servers_in_room, + + await self.store.store_partial_state_room( room_id=room_id, + servers=ret.servers_in_room, + device_lists_stream_id=device_lists_stream_id, + joined_via=origin, ) + try: + max_stream_id = ( + await self._federation_event_handler.process_remote_join( + origin, + room_id, + auth_chain, + state, + event, + room_version_obj, + partial_state=ret.partial_state, + ) + ) + except PartialStateConflictError: + # This should be impossible, since we hold the lock on the room's + # partial statedness. + logger.error( + "Room %s was un-partial stated while processing remote join.", + room_id, + ) + raise + else: + # Record the join event id for future use (when we finish the full + # join). We have to do this after persisting the event to keep + # foreign key constraints intact. + if ret.partial_state and not already_partial_state_room: + # TODO(faster_joins): + # We may want to reset the partial state info if it's from + # an old, failed partial state join. + # https://github.com/matrix-org/synapse/issues/13000 + await self.store.write_partial_state_rooms_join_event_id( + room_id, event.event_id + ) + finally: + # Always kick off the background process that asynchronously fetches + # state for the room. + # If the join failed, the background process is responsible for + # cleaning up — including unmarking the room as a partial state + # room. + if ret.partial_state: + # Kick off the process of asynchronously fetching the state for + # this room. + self._start_partial_state_room_sync( + initial_destination=origin, + other_destinations=ret.servers_in_room, + room_id=room_id, + ) + finally: + # allow inbound events which happened during the join to be processed. + # Also ensures we release the lock on unexpected errors e.g db errors from + # upsert_room_on_join or network errors from send_join. + if lock: + await lock.release() # We wait here until this instance has seen the events come down # replication (if we're using replication) as the below uses caches. await self._replication.wait_for_stream_position( diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index d076bec51a9..8128a2e881b 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -51,6 +51,7 @@ concurrently_execute, gather_optional_coroutines, ) +from synapse.visibility import filter_events_for_client _ThreadSubscription: TypeAlias = ( SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription @@ -73,7 +74,10 @@ def __init__(self, hs: "HomeServer"): self.event_sources = hs.get_event_sources() self.device_handler = hs.get_device_handler() self.push_rules_handler = hs.get_push_rules_handler() + self.clock = hs.get_clock() + self._storage_controllers = hs.get_storage_controllers() self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled + self._enable_sticky_events = hs.config.experimental.msc4354_enabled @trace async def get_extensions_response( @@ -174,6 +178,19 @@ async def get_extensions_response( from_token=from_token, ) + sticky_events_coro = None + if ( + sync_config.extensions.sticky_events is not None + and self._enable_sticky_events + ): + sticky_events_coro = self.get_sticky_events_extension_response( + sync_config=sync_config, + sticky_events_request=sync_config.extensions.sticky_events, + actual_room_ids=actual_room_ids, + to_token=to_token, + from_token=from_token, + ) + ( to_device_response, e2ee_response, @@ -181,6 +198,7 @@ async def get_extensions_response( receipts_response, typing_response, thread_subs_response, + sticky_events_response, ) = await gather_optional_coroutines( to_device_coro, e2ee_coro, @@ -188,6 +206,7 @@ async def get_extensions_response( receipts_coro, typing_coro, thread_subs_coro, + sticky_events_coro, ) return SlidingSyncResult.Extensions( @@ -197,6 +216,7 @@ async def get_extensions_response( receipts=receipts_response, typing=typing_response, thread_subscriptions=thread_subs_response, + sticky_events=sticky_events_response, ) def find_relevant_room_ids_for_extension( @@ -967,3 +987,47 @@ async def get_thread_subscriptions_extension_response( unsubscribed=unsubscribed_threads, prev_batch=prev_batch, ) + + async def get_sticky_events_extension_response( + self, + sync_config: SlidingSyncConfig, + sticky_events_request: SlidingSyncConfig.Extensions.StickyEventsExtension, + actual_room_ids: set[str], + to_token: StreamToken, + from_token: SlidingSyncStreamToken | None, + ) -> SlidingSyncResult.Extensions.StickyEventsExtension | None: + if not sticky_events_request.enabled: + return None + now = self.clock.time_msec() + from_id = from_token.stream_token.sticky_events_key if from_token else 0 + _, room_to_event_ids = await self.store.get_sticky_events_in_rooms( + actual_room_ids, + from_id=from_id, + to_id=to_token.sticky_events_key, + now=now, + # We set no limit here because the client can control when they get sticky events. + # Furthermore, it doesn't seem possible to set a limit with the internal API shape + # as given, as we cannot manipulate the to_token.sticky_events_key sent to the client... + limit=None, + ) + all_sticky_event_ids = { + ev_id for evs in room_to_event_ids.values() for ev_id in evs + } + unfiltered_events = await self.store.get_events_as_list(all_sticky_event_ids) + filtered_events = await filter_events_for_client( + self._storage_controllers, + sync_config.user.to_string(), + unfiltered_events, + always_include_ids=frozenset(all_sticky_event_ids), + ) + filtered_event_map = {ev.event_id: ev for ev in filtered_events} + return SlidingSyncResult.Extensions.StickyEventsExtension( + room_id_to_sticky_events={ + room_id: { + filtered_event_map[event_id] + for event_id in sticky_event_ids + if event_id in filtered_event_map + } + for room_id, sticky_event_ids in room_to_event_ids.items() + } + ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 60d88274255..dacfab5d593 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -37,6 +37,7 @@ EventContentFields, EventTypes, Membership, + StickyEvent, ) from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState @@ -146,6 +147,7 @@ class JoinedSyncResult: state: StateMap[EventBase] ephemeral: list[JsonDict] account_data: list[JsonDict] + sticky: list[EventBase] unread_notifications: JsonDict unread_thread_notifications: JsonDict summary: JsonDict | None @@ -156,7 +158,11 @@ def __bool__(self) -> bool: to tell if room needs to be part of the sync result. """ return bool( - self.timeline or self.state or self.ephemeral or self.account_data + self.timeline + or self.state + or self.ephemeral + or self.account_data + or self.sticky # nb the notification count does not, er, count: if there's nothing # else in the result, we don't need to send it. ) @@ -596,6 +602,41 @@ async def ephemeral_by_room( return now_token, ephemeral_by_room + async def sticky_events_by_room( + self, + sync_result_builder: "SyncResultBuilder", + now_token: StreamToken, + since_token: StreamToken | None = None, + ) -> tuple[StreamToken, dict[str, set[str]]]: + """Get the sticky events for each room the user is in + Args: + sync_result_builder + now_token: Where the server is currently up to. + since_token: Where the server was when the client last synced. + Returns: + A tuple of the now StreamToken, updated to reflect the which sticky + events are included, and a dict mapping from room_id to a list of + sticky event IDs for that room. + """ + now = self.clock.time_msec() + with Measure( + self.clock, name="sticky_events_by_room", server_name=self.server_name + ): + from_id = since_token.sticky_events_key if since_token else 0 + + room_ids = sync_result_builder.joined_room_ids + + to_id, sticky_by_room = await self.store.get_sticky_events_in_rooms( + room_ids, + from_id=from_id, + to_id=now_token.sticky_events_key, + now=now, + limit=StickyEvent.MAX_EVENTS_IN_SYNC, + ) + now_token = now_token.copy_and_replace(StreamKeyType.STICKY_EVENTS, to_id) + + return now_token, sticky_by_room + async def _load_filtered_recents( self, room_id: str, @@ -2163,6 +2204,13 @@ async def _generate_sync_entry_for_rooms( ) sync_result_builder.now_token = now_token + sticky_by_room: dict[str, set[str]] = {} + if self.hs_config.experimental.msc4354_enabled: + now_token, sticky_by_room = await self.sticky_events_by_room( + sync_result_builder, now_token, since_token + ) + sync_result_builder.now_token = now_token + # 2. We check up front if anything has changed, if it hasn't then there is # no point in going further. if not sync_result_builder.full_state: @@ -2173,7 +2221,7 @@ async def _generate_sync_entry_for_rooms( tags_by_room = await self.store.get_updated_tags( user_id, since_token.account_data_key ) - if not tags_by_room: + if not tags_by_room and not sticky_by_room: logger.debug("no-oping sync") return set(), set() @@ -2193,7 +2241,6 @@ async def _generate_sync_entry_for_rooms( tags_by_room = await self.store.get_tags_for_user(user_id) log_kv({"rooms_changed": len(room_changes.room_entries)}) - room_entries = room_changes.room_entries invited = room_changes.invited knocked = room_changes.knocked @@ -2211,6 +2258,7 @@ async def handle_room_entries(room_entry: "RoomSyncResultBuilder") -> None: ephemeral=ephemeral_by_room.get(room_entry.room_id, []), tags=tags_by_room.get(room_entry.room_id), account_data=account_data_by_room.get(room_entry.room_id, {}), + sticky_event_ids=sticky_by_room.get(room_entry.room_id, set()), always_include=sync_result_builder.full_state, ) logger.debug("Generated room entry for %s", room_entry.room_id) @@ -2597,6 +2645,7 @@ async def _generate_room_entry( ephemeral: list[JsonDict], tags: Mapping[str, JsonMapping] | None, account_data: Mapping[str, JsonMapping], + sticky_event_ids: set[str], always_include: bool = False, ) -> None: """Populates the `joined` and `archived` section of `sync_result_builder` @@ -2626,6 +2675,7 @@ async def _generate_room_entry( tags: List of *all* tags for room, or None if there has been no change. account_data: List of new account data for room + sticky_event_ids: MSC4354 sticky events in the room, if any. always_include: Always include this room in the sync response, even if empty. """ @@ -2636,7 +2686,13 @@ async def _generate_room_entry( events = room_builder.events # We want to shortcut out as early as possible. - if not (always_include or account_data or ephemeral or full_state): + if not ( + always_include + or account_data + or ephemeral + or full_state + or sticky_event_ids + ): if events == [] and tags is None: return @@ -2728,6 +2784,7 @@ async def _generate_room_entry( or account_data_events or ephemeral or full_state + or sticky_event_ids ): return @@ -2774,6 +2831,22 @@ async def _generate_room_entry( if room_builder.rtype == "joined": unread_notifications: dict[str, int] = {} + sticky_events: list[EventBase] = [] + if sticky_event_ids: + # remove sticky events that are in the timeline, else we will needlessly duplicate + # events. This is particularly important given the risk of sticky events spam since + # anyone can send sticky events, so halving the bandwidth on average for each sticky + # event is helpful. + timeline = {ev.event_id for ev in batch.events} + sticky_event_ids = sticky_event_ids.difference(timeline) + if sticky_event_ids: + sticky_event_map = await self.store.get_events(sticky_event_ids) + sticky_events = await filter_events_for_client( + self._storage_controllers, + sync_result_builder.sync_config.user.to_string(), + list(sticky_event_map.values()), + always_include_ids=frozenset(sticky_event_ids), + ) room_sync = JoinedSyncResult( room_id=room_id, timeline=batch, @@ -2784,6 +2857,7 @@ async def _generate_room_entry( unread_thread_notifications={}, summary=summary, unread_count=0, + sticky=sticky_events, ) if room_sync or always_include: diff --git a/synapse/notifier.py b/synapse/notifier.py index d8d2db17f12..57b11e90214 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -526,6 +526,7 @@ def on_new_event( StreamKeyType.TYPING, StreamKeyType.UN_PARTIAL_STATED_ROOMS, StreamKeyType.THREAD_SUBSCRIPTIONS, + StreamKeyType.STICKY_EVENTS, ], new_token: int, users: Collection[str | UserID] | None = None, @@ -932,6 +933,11 @@ def notify_remote_server_up(self, server: str) -> None: # that any in flight requests can be immediately retried. self._federation_client.wake_destination(server) + def notify_new_server_joined(self, server: str, room_id: str) -> None: + # Inform the federation_sender that it may need to send events to the new server. + if self.federation_sender: + self.federation_sender.notify_new_server_joined(server, room_id) + def add_lock_released_callback( self, callback: Callable[[str, str, str], None] ) -> None: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index fdda932ead2..bc7e46d4c92 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -43,7 +43,10 @@ UnPartialStatedEventStream, UnPartialStatedRoomStream, ) -from synapse.replication.tcp.streams._base import ThreadSubscriptionsStream +from synapse.replication.tcp.streams._base import ( + StickyEventsStream, + ThreadSubscriptionsStream, +) from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamEventRow, @@ -262,6 +265,12 @@ async def on_rdata( token, users=[row.user_id for row in rows], ) + elif stream_name == StickyEventsStream.NAME: + self.notifier.on_new_event( + StreamKeyType.STICKY_EVENTS, + token, + rooms=[row.room_id for row in rows], + ) await self._presence_handler.process_replication_rows( stream_name, instance_name, token, rows diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index 79194f72750..c1bb3ba2c82 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -462,6 +462,32 @@ class RemoteServerUpCommand(_SimpleCommand): NAME = "REMOTE_SERVER_UP" +class NewServerJoinedCommand(Command): + """Sent when a worker has detected that a new remote server has joined a room. + + Format:: + + NEW_SERVER_JOINED + """ + + NAME = "NEW_SERVER_JOINED" + __slots__ = ["server", "room_id"] + + def __init__(self, server: str, room_id: str): + self.server = server + self.room_id = room_id + + @classmethod + def from_line( + cls: type["NewServerJoinedCommand"], line: str + ) -> "NewServerJoinedCommand": + server, room_id = line.split(" ") + return cls(server, room_id) + + def to_line(self) -> str: + return "%s %s" % (self.server, self.room_id) + + class LockReleasedCommand(Command): """Sent to inform other instances that a given lock has been dropped. @@ -517,6 +543,7 @@ class NewActiveTaskCommand(_SimpleCommand): FederationAckCommand, UserIpCommand, RemoteServerUpCommand, + NewServerJoinedCommand, ClearUserSyncsCommand, LockReleasedCommand, NewActiveTaskCommand, @@ -533,6 +560,7 @@ class NewActiveTaskCommand(_SimpleCommand): ErrorCommand.NAME, PingCommand.NAME, RemoteServerUpCommand.NAME, + NewServerJoinedCommand.NAME, LockReleasedCommand.NAME, ) @@ -547,6 +575,7 @@ class NewActiveTaskCommand(_SimpleCommand): UserIpCommand.NAME, ErrorCommand.NAME, RemoteServerUpCommand.NAME, + NewServerJoinedCommand.NAME, LockReleasedCommand.NAME, ) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 05370045e6b..ac5053647c9 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -40,6 +40,7 @@ FederationAckCommand, LockReleasedCommand, NewActiveTaskCommand, + NewServerJoinedCommand, PositionCommand, RdataCommand, RemoteServerUpCommand, @@ -66,6 +67,7 @@ ) from synapse.replication.tcp.streams._base import ( DeviceListsStream, + StickyEventsStream, ThreadSubscriptionsStream, ) from synapse.util.background_queue import BackgroundQueue @@ -216,6 +218,12 @@ def __init__(self, hs: "HomeServer"): continue + if isinstance(stream, StickyEventsStream): + if hs.get_instance_name() in hs.config.worker.writers.events: + self._streams_to_replicate.append(stream) + + continue + if isinstance(stream, DeviceListsStream): if hs.get_instance_name() in hs.config.worker.writers.device_lists: self._streams_to_replicate.append(stream) @@ -732,6 +740,12 @@ def on_REMOTE_SERVER_UP( """Called when get a new REMOTE_SERVER_UP command.""" self._notifier.notify_remote_server_up(cmd.data) + def on_NEW_SERVER_JOINED( + self, conn: IReplicationConnection, cmd: NewServerJoinedCommand + ) -> None: + """Called when get a new NEW_SERVER_JOINED command.""" + self._notifier.notify_new_server_joined(cmd.server, cmd.room_id) + def on_LOCK_RELEASED( self, conn: IReplicationConnection, cmd: LockReleasedCommand ) -> None: @@ -854,6 +868,9 @@ def send_user_ip( def send_remote_server_up(self, server: str) -> None: self.send_command(RemoteServerUpCommand(server)) + def send_new_server_joined(self, server: str, room_id: str) -> None: + self.send_command(NewServerJoinedCommand(server, room_id)) + def stream_update(self, stream_name: str, token: int | None, data: Any) -> None: """Called when a new update is available to stream to Redis subscribers. diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 87ac0a5ae17..067847617fa 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -40,6 +40,7 @@ PushersStream, PushRulesStream, ReceiptsStream, + StickyEventsStream, Stream, ThreadSubscriptionsStream, ToDeviceStream, @@ -68,6 +69,7 @@ ToDeviceStream, FederationStream, AccountDataStream, + StickyEventsStream, ThreadSubscriptionsStream, UnPartialStatedRoomStream, UnPartialStatedEventStream, @@ -90,6 +92,7 @@ "ToDeviceStream", "FederationStream", "AccountDataStream", + "StickyEventsStream", "ThreadSubscriptionsStream", "UnPartialStatedRoomStream", "UnPartialStatedEventStream", diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 4fb2aac2029..336b50160b5 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -763,3 +763,48 @@ async def _update_function( return [], to_token, False return rows, rows[-1][0], len(updates) == limit + + +@attr.s(slots=True, auto_attribs=True) +class StickyEventsStreamRow: + """Stream to inform workers about changes to sticky events.""" + + room_id: str + + event_id: str + """The sticky event ID""" + + +class StickyEventsStream(_StreamFromIdGen): + """A sticky event was changed.""" + + NAME = "sticky_events" + ROW_TYPE = StickyEventsStreamRow + + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastores().main + super().__init__( + hs.get_instance_name(), + self._update_function, + self.store._sticky_events_id_gen, + ) + + async def _update_function( + self, instance_name: str, from_token: int, to_token: int, limit: int + ) -> StreamUpdateResult: + updates = await self.store.get_updated_sticky_events( + from_id=from_token, to_id=to_token, limit=limit + ) + rows = [ + ( + stream_id, + # These are the args to `StickyEventsStreamRow` + (room_id, event_id), + ) + for stream_id, room_id, event_id, _ in updates + ] + + if not rows: + return [], to_token, False + + return rows, rows[-1][0], len(updates) == limit diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 5e7dcb01911..d09425a1869 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -34,7 +34,7 @@ from twisted.web.server import Request from synapse import event_auth -from synapse.api.constants import Direction, EventTypes, Membership +from synapse.api.constants import Direction, EventTypes, Membership, StickyEvent from synapse.api.errors import ( AuthError, Codes, @@ -210,6 +210,7 @@ def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self._max_event_delay_ms = hs.config.server.max_event_delay_ms self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker + self._msc4354_enabled = hs.config.experimental.msc4354_enabled def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/state/$eventtype @@ -331,6 +332,10 @@ async def on_PUT( if requester.app_service: origin_server_ts = parse_integer(request, "ts") + sticky_duration_ms: int | None = None + if self._msc4354_enabled: + sticky_duration_ms = parse_integer(request, StickyEvent.QUERY_PARAM_NAME) + delay = _parse_request_delay(request, self._max_event_delay_ms) if delay is not None: delay_id = await self.delayed_events_handler.add( @@ -341,6 +346,7 @@ async def on_PUT( origin_server_ts=origin_server_ts, content=content, delay=delay, + sticky_duration_ms=sticky_duration_ms, ) set_tag("delay_id", delay_id) @@ -368,6 +374,10 @@ async def on_PUT( "room_id": room_id, "sender": requester.user.to_string(), } + if sticky_duration_ms is not None: + event_dict[StickyEvent.FIELD_NAME] = { + "duration_ms": sticky_duration_ms, + } if state_key is not None: event_dict["state_key"] = state_key @@ -400,6 +410,7 @@ def __init__(self, hs: "HomeServer"): self.delayed_events_handler = hs.get_delayed_events_handler() self.auth = hs.get_auth() self._max_event_delay_ms = hs.config.server.max_event_delay_ms + self._msc4354_enabled = hs.config.experimental.msc4354_enabled def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/send/$event_type[/$txn_id] @@ -420,6 +431,10 @@ async def _do( if requester.app_service: origin_server_ts = parse_integer(request, "ts") + sticky_duration_ms: int | None = None + if self._msc4354_enabled: + sticky_duration_ms = parse_integer(request, StickyEvent.QUERY_PARAM_NAME) + delay = _parse_request_delay(request, self._max_event_delay_ms) if delay is not None: delay_id = await self.delayed_events_handler.add( @@ -430,6 +445,7 @@ async def _do( origin_server_ts=origin_server_ts, content=content, delay=delay, + sticky_duration_ms=sticky_duration_ms, ) set_tag("delay_id", delay_id) @@ -446,6 +462,11 @@ async def _do( if origin_server_ts is not None: event_dict["origin_server_ts"] = origin_server_ts + if sticky_duration_ms is not None: + event_dict[StickyEvent.FIELD_NAME] = { + "duration_ms": sticky_duration_ms, + } + try: ( event, diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 458bf08a19f..4cbd85779b2 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -617,6 +617,12 @@ async def encode_room( ephemeral_events = room.ephemeral result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications + if room.sticky: + # TODO Are we meant to peel out events from the timeline here? + serialized_sticky = await self._event_serializer.serialize_events( + room.sticky, time_now, config=serialize_options + ) + result["msc4354_sticky"] = {"events": serialized_sticky} if room.unread_thread_notifications: result["unread_thread_notifications"] = room.unread_thread_notifications if self._msc3773_enabled: @@ -646,6 +652,7 @@ class SlidingSyncRestServlet(RestServlet): - receipts (MSC3960) - account data (MSC3959) - thread subscriptions (MSC4308) + - sticky events (MSC4354) Request query parameters: timeout: How long to wait for new events in milliseconds. @@ -1089,8 +1096,36 @@ async def encode_extensions( _serialise_thread_subscriptions(extensions.thread_subscriptions) ) + if extensions.sticky_events: + serialized_extensions[ + "org.matrix.msc4354.sticky_events" + ] = await self._serialise_sticky_events(requester, extensions.sticky_events) + return serialized_extensions + async def _serialise_sticky_events( + self, + requester: Requester, + sticky_events: SlidingSyncResult.Extensions.StickyEventsExtension, + ) -> JsonDict: + time_now = self.clock.time_msec() + # Same as SSS timelines. + # + serialize_options = SerializeEventConfig( + event_format=format_event_for_client_v2_without_room_id, + requester=requester, + ) + return { + "rooms": { + room_id: { + "events": await self.event_serializer.serialize_events( + sticky_events, time_now, config=serialize_options + ) + } + for room_id, sticky_events in sticky_events.room_id_to_sticky_events.items() + }, + } + def _serialise_thread_subscriptions( thread_subscriptions: SlidingSyncResult.Extensions.ThreadSubscriptionsExtension, diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 75f27c98dea..89458495311 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -182,6 +182,8 @@ async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]: "org.matrix.msc4306": self.config.experimental.msc4306_enabled, # MSC4169: Backwards-compatible redaction sending using `/send` "com.beeper.msc4169": self.config.experimental.msc4169_enabled, + # MSC4354: Sticky events + "org.matrix.msc4354": self.config.experimental.msc4354_enabled, # MSC4380: Invite blocking "org.matrix.msc4380": self.config.experimental.msc4380_enabled, }, diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 2948227807f..3b133364779 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -658,6 +658,29 @@ async def _persist_event_batch( async with self._state_deletion_store.persisting_state_group_references( events_and_contexts ): + new_servers: set[str] | None = None + if self.hs.config.experimental.msc4354_enabled and state_delta_for_room: + # We specifically only consider events in `chunk` to reduce the risk of state rollbacks + # causing servers to appear to repeatedly rejoin rooms. This works because we only + # persist events once, whereas the state delta may unreliably flap between joined members + # on unrelated events. This means we may miss cases where the /first/ join event for a server + # is as a result of a state rollback and not as a result of a new join event. That is fine + # because the chance of that happening is vanishingly rare because the join event would need to be + # persisted without it affecting the current state (e.g there's a concurrent ban for that user) + # which is then revoked concurrently by a later event (e.g the user is unbanned). + # If state resolution were more reliable (in terms of state resets) then we could feasibly only + # consider the events in the state_delta_for_room, but we aren't there yet. + new_event_ids_in_current_state = set( + state_delta_for_room.to_insert.values() + ) + new_servers = await self._check_new_servers_joined( + room_id, + [ + ev + for (ev, _) in chunk + if ev.event_id in new_event_ids_in_current_state + ], + ) await self.persist_events_store._persist_events_and_state_updates( room_id, chunk, @@ -667,9 +690,71 @@ async def _persist_event_batch( inhibit_local_membership_updates=backfilled, new_event_links=new_event_links, ) + if new_servers: + # Notify other workers after the server has joined so they can take into account + # the latest events that are in `chunk`. + for server_name in new_servers: + self.hs.get_notifier().notify_new_server_joined( + server_name, room_id + ) + self.hs.get_replication_command_handler().send_new_server_joined( + server_name, room_id + ) return replaced_events + async def _check_new_servers_joined( + self, room_id: str, new_events_in_current_state: list[EventBase] + ) -> set[str] | None: + """Check if new servers have joined the given room. + + Assumes this function is called BEFORE the current_state_events table is updated. + + A new server is "joined" if this is the first join event seen from this domain. + + Args: + room_id: The room in question + new_events_in_current_state: A list of events that will become part of the current state, + but have not yet been persisted. + """ + # filter to only join events from other servers. We're obviously joined if we are getting full events + # so needn't consider ourselves. + join_events = [ + ev + for ev in new_events_in_current_state + if ev.type == EventTypes.Member + and ev.is_state() + and not self.is_mine_id(ev.state_key) + and ev.membership == Membership.JOIN + ] + if not join_events: + return None + + joining_domains = {get_domain_from_id(ev.state_key) for ev in join_events} + + # load all joined members from the current_state_events table as this table is fast and has what we want. + # This is the current state prior to applying the update. + joined_members: list[ + tuple[str] + ] = await self.main_store.db_pool.simple_select_list( + "current_state_events", + { + "room_id": room_id, + "type": EventTypes.Member, + "membership": Membership.JOIN, + }, + retcols=["state_key"], + desc="_check_new_servers_joined", + ) + joined_domains = { + get_domain_from_id(state_key) for (state_key,) in joined_members + } + + newly_joined_domains = joining_domains.difference(joined_domains) + if not newly_joined_domains: + return None + return newly_joined_domains + async def _calculate_new_forward_extremities_and_state_delta( self, room_id: str, ev_ctx_rm: list[EventPersistencePair] ) -> tuple[set[str] | None, DeltaState | None]: diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 12593094f18..9f8d4debbe0 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -34,6 +34,7 @@ ) from synapse.storage.databases.main.sliding_sync import SlidingSyncStore from synapse.storage.databases.main.stats import UserSortOrder +from synapse.storage.databases.main.sticky_events import StickyEventsWorkerStore from synapse.storage.databases.main.thread_subscriptions import ( ThreadSubscriptionsWorkerStore, ) @@ -144,6 +145,7 @@ class DataStore( TagsStore, AccountDataStore, ThreadSubscriptionsWorkerStore, + StickyEventsWorkerStore, PushRulesWorkerStore, StreamWorkerStore, OpenIdStore, diff --git a/synapse/storage/databases/main/delayed_events.py b/synapse/storage/databases/main/delayed_events.py index 55471505154..1727f589e2a 100644 --- a/synapse/storage/databases/main/delayed_events.py +++ b/synapse/storage/databases/main/delayed_events.py @@ -54,6 +54,7 @@ class EventDetails: origin_server_ts: Timestamp | None content: JsonDict device_id: DeviceID | None + sticky_duration_ms: int | None @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -122,6 +123,7 @@ async def add_delayed_event( origin_server_ts: int | None, content: JsonDict, delay: int, + sticky_duration_ms: int | None, ) -> tuple[DelayID, Timestamp]: """ Inserts a new delayed event in the DB. @@ -148,6 +150,7 @@ def add_delayed_event_txn(txn: LoggingTransaction) -> Timestamp: "state_key": state_key, "origin_server_ts": origin_server_ts, "content": json_encoder.encode(content), + "sticky_duration_ms": sticky_duration_ms, }, ) @@ -299,6 +302,7 @@ def process_timeout_delayed_events_txn( "send_ts", "content", "device_id", + "sticky_duration_ms", ) ) sql_update = "UPDATE delayed_events SET is_processed = TRUE" @@ -344,6 +348,7 @@ def process_timeout_delayed_events_txn( Timestamp(row[5] if row[5] is not None else row[6]), db_to_json(row[7]), DeviceID(row[8]) if row[8] is not None else None, + int(row[9]) if row[9] is not None else None, DelayID(row[0]), UserLocalpart(row[1]), ) @@ -392,6 +397,7 @@ def process_target_delayed_event_txn( origin_server_ts, content, device_id, + sticky_duration_ms, user_localpart """, (delay_id,), @@ -407,8 +413,9 @@ def process_target_delayed_event_txn( Timestamp(row[3]) if row[3] is not None else None, db_to_json(row[4]), DeviceID(row[5]) if row[5] is not None else None, + int(row[6]) if row[6] is not None else None, DelayID(delay_id), - UserLocalpart(row[6]), + UserLocalpart(row[7]), ) return event, self._get_next_delayed_event_send_ts_txn(txn) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 60fc884c3a6..b4de2e9bb36 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -264,6 +264,7 @@ def __init__( self.database_engine = db.engine self._clock = hs.get_clock() self._instance_name = hs.get_instance_name() + self._msc4354_enabled = hs.config.experimental.msc4354_enabled self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id @@ -383,6 +384,21 @@ async def _persist_events_and_state_updates( len(events_and_contexts) ) + # TODO: are we guaranteed to call the below code if we were to die now? + # On startup we will already think we have persisted the events? + + # This was originally in _persist_events_txn but it relies on non-txn functions like + # get_events_as_list and get_partial_filtered_current_state_ids to handle soft-failure + # re-evaluation, so it can't do that without leaking out the txn currently, hence it + # now just lives outside. + if self._msc4354_enabled: + # re-evaluate soft-failed sticky events. + await self.store.reevaluate_soft_failed_sticky_events( + room_id, + events_and_contexts, + state_delta_for_room, + ) + if not use_negative_stream_ordering: # we don't want to set the event_persisted_position to a negative # stream_ordering. @@ -1185,6 +1201,11 @@ def _persist_events_txn( sliding_sync_table_changes, ) + if self._msc4354_enabled: + self.store.insert_sticky_events_txn( + txn, [ev for ev, _ in events_and_contexts] + ) + # We only update the sliding sync tables for non-backfilled events. self._update_sliding_sync_tables_with_new_persisted_events_txn( txn, room_id, events_and_contexts @@ -2646,6 +2667,11 @@ def _update_outliers_txn( # event isn't an outlier any more. self._update_backward_extremeties(txn, [event]) + if self._msc4354_enabled and event.sticky_duration(): + # The de-outliered event is sticky. Update the sticky events table to ensure + # we delivery this down /sync. + self.store.insert_sticky_events_txn(txn, [event]) + return [ec for ec in events_and_contexts if ec[0] not in to_remove] def _store_event_txn( diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index ae6ee50dc24..e13f807148f 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -68,6 +68,10 @@ wrap_as_background_process, ) from synapse.replication.tcp.streams import BackfillStream, UnPartialStatedEventStream +from synapse.replication.tcp.streams._base import ( + StickyEventsStream, + StickyEventsStreamRow, +) from synapse.replication.tcp.streams.events import EventsStream from synapse.replication.tcp.streams.partial_state import UnPartialStatedEventStreamRow from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -459,6 +463,11 @@ def process_replication_rows( # If the partial-stated event became rejected or unrejected # when it wasn't before, we need to invalidate this cache. self._invalidate_local_get_event_cache(row.event_id) + elif stream_name == StickyEventsStream.NAME: + for row in rows: + assert isinstance(row, StickyEventsStreamRow) + # In case soft-failure status changed, invalidate the cache. + self._invalidate_local_get_event_cache(row.event_id) super().process_replication_rows(stream_name, instance_name, token, rows) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 633df077367..a68d3cb4825 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -2454,7 +2454,10 @@ def __init__( self._instance_name = hs.get_instance_name() async def upsert_room_on_join( - self, room_id: str, room_version: RoomVersion, state_events: list[EventBase] + self, + room_id: str, + room_version: RoomVersion, + state_events: list[EventBase] | None, ) -> None: """Ensure that the room is stored in the table @@ -2466,36 +2469,46 @@ async def upsert_room_on_join( # mark the room as having an auth chain cover index. has_auth_chain_index = await self.has_auth_chain_index(room_id) - create_event = None - for e in state_events: - if (e.type, e.state_key) == (EventTypes.Create, ""): - create_event = e - break + # We may want to insert a row into the rooms table BEFORE having the state events in the + # room, in order to correctly handle the race condition where the /send_join is processed + # remotely which causes remote servers to send us events before we've processed the /send_join + # response. Therefore, we allow state_events (and thus the creator column) to be optional. + # When we get the /send_join response, we'll patch this up. + room_creator: str | None = None + if state_events: + create_event = None + for e in state_events: + if (e.type, e.state_key) == (EventTypes.Create, ""): + create_event = e + break + + if create_event is None: + # If the state doesn't have a create event then the room is + # invalid, and it would fail auth checks anyway. + raise StoreError(400, "No create event in state") - if create_event is None: - # If the state doesn't have a create event then the room is - # invalid, and it would fail auth checks anyway. - raise StoreError(400, "No create event in state") + # Before MSC2175, the room creator was a separate field. + if not room_version.implicit_room_creator: + room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) - # Before MSC2175, the room creator was a separate field. - if not room_version.implicit_room_creator: - room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) + if not isinstance(room_creator, str): + # If the create event does not have a creator then the room is + # invalid, and it would fail auth checks anyway. + raise StoreError(400, "No creator defined on the create event") + else: + room_creator = create_event.sender - if not isinstance(room_creator, str): - # If the create event does not have a creator then the room is - # invalid, and it would fail auth checks anyway. - raise StoreError(400, "No creator defined on the create event") - else: - room_creator = create_event.sender + update_with = {"room_version": room_version.identifier} + if room_creator: + update_with["creator"] = room_creator await self.db_pool.simple_upsert( desc="upsert_room_on_join", table="rooms", keyvalues={"room_id": room_id}, - values={"room_version": room_version.identifier}, + values=update_with, insertion_values={ "is_public": False, - "creator": room_creator, "has_auth_chain_index": has_auth_chain_index, }, ) diff --git a/synapse/storage/databases/main/sticky_events.py b/synapse/storage/databases/main/sticky_events.py new file mode 100644 index 00000000000..c1b3a14e3ec --- /dev/null +++ b/synapse/storage/databases/main/sticky_events.py @@ -0,0 +1,623 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +import logging +from itertools import chain +from typing import ( + TYPE_CHECKING, + Collection, + cast, +) + +from twisted.internet.defer import Deferred + +from synapse import event_auth +from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError +from synapse.events import EventBase +from synapse.events.snapshot import EventPersistencePair +from synapse.replication.tcp.streams._base import StickyEventsStream +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.databases.main.events import DeltaState +from synapse.storage.databases.main.state import StateGroupWorkerStore +from synapse.storage.engines import PostgresEngine +from synapse.storage.engines.sqlite import Sqlite3Engine +from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.types import StateKey +from synapse.types.state import StateFilter +from synapse.util.duration import Duration +from synapse.util.stringutils import shortstr + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +# Remove entries from the sticky_events table at this frequency. +# Note: this does NOT mean we don't honour shorter expiration timeouts. +# Consumers call 'get_sticky_events_in_rooms' which has `WHERE expires_at > ?` +# to filter out expired sticky events that have yet to be deleted. +DELETE_EXPIRED_STICKY_EVENTS_INTERVAL = Duration(hours=1) + + +class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._can_write_to_sticky_events = ( + self._instance_name in hs.config.worker.writers.events + ) + + # Technically this means we will cleanup N times, once per event persister, maybe put on master? + if self._can_write_to_sticky_events: + self.clock.looping_call( + self._run_background_cleanup, DELETE_EXPIRED_STICKY_EVENTS_INTERVAL + ) + + self._sticky_events_id_gen: MultiWriterIdGenerator = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="sticky_events", + server_name=self.server_name, + instance_name=self._instance_name, + tables=[ + ("sticky_events", "instance_name", "stream_id"), + ], + sequence_name="sticky_events_sequence", + writers=hs.config.worker.writers.events, + ) + + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == StickyEventsStream.NAME: + self._sticky_events_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + + def get_max_sticky_events_stream_id(self) -> int: + """Get the current maximum stream_id for thread subscriptions. + + Returns: + The maximum stream_id + """ + return self._sticky_events_id_gen.get_current_token() + + def get_sticky_events_stream_id_generator(self) -> MultiWriterIdGenerator: + return self._sticky_events_id_gen + + async def get_sticky_events_in_rooms( + self, + room_ids: Collection[str], + *, + from_id: int, + to_id: int, + now: int, + limit: int | None, + ) -> tuple[int, dict[str, set[str]]]: + """ + Fetch all the sticky events in the given rooms, from the given sticky stream ID. + + Args: + room_ids: The room IDs to return sticky events in. + from_id: The sticky stream ID that sticky events should be returned from (exclusive). + to_id: The sticky stream ID that sticky events should end at (inclusive). + now: The current time in unix millis, used for skipping expired events. + limit: Max sticky events to return, or None to apply no limit. + Returns: + to_id, map[room_id, event_ids] + """ + sticky_events_rows = await self.db_pool.runInteraction( + "get_sticky_events_in_rooms", + self._get_sticky_events_in_rooms_txn, + room_ids, + from_id, + to_id, + now, + limit, + ) + + if not sticky_events_rows: + return from_id, {} + + # Get stream_id of the last row, which is the highest + new_to_id, _, _ = sticky_events_rows[-1] + + # room ID -> event IDs + room_to_events: dict[str, set[str]] = {} + for _, room_id, event_id in sticky_events_rows: + events = room_to_events.setdefault(room_id, set()) + events.add(event_id) + + return (new_to_id, room_to_events) + + def _get_sticky_events_in_rooms_txn( + self, + txn: LoggingTransaction, + room_ids: Collection[str], + from_id: int, + to_id: int, + now: int, + limit: int | None, + ) -> list[tuple[int, str, str]]: + if len(room_ids) == 0: + return [] + room_id_in_list_clause, room_id_in_list_values = make_in_list_sql_clause( + txn.database_engine, "room_id", room_ids + ) + limit_clause = "" + limit_params: tuple[int, ...] = () + if limit is not None: + limit_clause = "LIMIT ?" + limit_params = (limit,) + txn.execute( + f""" + SELECT stream_id, room_id, event_id + FROM sticky_events + WHERE + NOT soft_failed + AND expires_at > ? + AND stream_id > ? + AND stream_id <= ? + AND {room_id_in_list_clause} + ORDER BY stream_id ASC + {limit_clause} + """, + (now, from_id, to_id, *room_id_in_list_values, *limit_params), + ) + return cast(list[tuple[int, str, str]], txn.fetchall()) + + async def get_updated_sticky_events( + self, from_id: int, to_id: int, limit: int + ) -> list[tuple[int, str, str, bool]]: + """Get updates to sticky events between two stream IDs. + + Args: + from_id: The starting stream ID (exclusive) + to_id: The ending stream ID (inclusive) + limit: The maximum number of rows to return + + Returns: + list of (stream_id, room_id, event_id, soft_failed) tuples + """ + return await self.db_pool.runInteraction( + "get_updated_sticky_events", + self._get_updated_sticky_events_txn, + from_id, + to_id, + limit, + ) + + def _get_updated_sticky_events_txn( + self, txn: LoggingTransaction, from_id: int, to_id: int, limit: int + ) -> list[tuple[int, str, str, bool]]: + txn.execute( + """ + SELECT stream_id, room_id, event_id, soft_failed + FROM sticky_events + WHERE ? < stream_id AND stream_id <= ? + LIMIT ? + """, + (from_id, to_id, limit), + ) + return cast(list[tuple[int, str, str, bool]], txn.fetchall()) + + async def get_sticky_event_ids_sent_by_self( + self, room_id: str, from_stream_pos: int + ) -> list[str]: + """Get unexpired sticky event IDs which have been sent by users on this homeserver. + + Used when sending sticky events eagerly to newly joined servers, or when catching up over federation. + + Args: + room_id: The room to fetch sticky events in. + from_stream_pos: The stream position to return events from. May be 0 for newly joined servers. + Exclusive. + Returns: + A list of event IDs, which may be empty. + """ + return await self.db_pool.runInteraction( + "get_sticky_event_ids_sent_by_self", + self._get_sticky_event_ids_sent_by_self_txn, + room_id, + from_stream_pos, + ) + + def _get_sticky_event_ids_sent_by_self_txn( + self, txn: LoggingTransaction, room_id: str, from_stream_pos: int + ) -> list[str]: + now_ms = self.clock.time_msec() + sender_is_mine_like = "%:" + self.hs.hostname + txn.execute( + """ + SELECT event_id + FROM sticky_events + INNER JOIN events USING (event_id) + WHERE + NOT soft_failed + AND expires_at > ? + AND sticky_events.room_id = ? + AND sticky_events.sender LIKE ? + AND events.stream_ordering > ? + """, + (now_ms, room_id, from_stream_pos, sender_is_mine_like), + ) + return [cast(str, event_id) for (event_id,) in txn] + + async def reevaluate_soft_failed_sticky_events( + self, + room_id: str, + events_and_contexts: list[EventPersistencePair], + state_delta_for_room: DeltaState | None, + ) -> None: + """Re-evaluate soft failed events in the room provided. + + Args: + room_id: The room that all of the events belong to + events_and_contexts: The events just persisted. These are not eligible for re-evaluation. + state_delta_for_room: The changes to the current state, used to detect if we need to + re-evaluate soft-failed sticky events. + """ + assert self._can_write_to_sticky_events + + # fetch soft failed sticky events to recheck + event_ids_to_check = await self._get_soft_failed_sticky_events_to_recheck( + room_id, state_delta_for_room + ) + # filter out soft-failed events in events_and_contexts as we just inserted them, so the + # soft failure status won't have changed for them. + persisting_event_ids = {ev.event_id for ev, _ in events_and_contexts} + event_ids_to_check = [ + item for item in event_ids_to_check if item not in persisting_event_ids + ] + if event_ids_to_check: + logger.info( + "_get_soft_failed_sticky_events_to_recheck => %s", event_ids_to_check + ) + # recheck them and update any that now pass soft-fail checks. + await self._recheck_soft_failed_events(room_id, event_ids_to_check) + + def insert_sticky_events_txn( + self, + txn: LoggingTransaction, + events: list[EventBase], + ) -> None: + now_ms = self.clock.time_msec() + # event, expires_at + sticky_events: list[tuple[EventBase, int]] = [] + for ev in events: + # MSC: Note: policy servers and other similar antispam techniques still apply to these events. + if ev.internal_metadata.policy_server_spammy: + continue + # We shouldn't be passed rejected events, but if we do, we filter them out too. + if ev.rejected_reason is not None: + continue + # We can't persist outlier sticky events as we don't know the room state at that event + if ev.internal_metadata.is_outlier(): + continue + sticky_duration = ev.sticky_duration() + if sticky_duration is None: + continue + # Calculate the end time as start_time + effecitve sticky duration + expires_at = min(ev.origin_server_ts, now_ms) + sticky_duration + # Filter out already expired sticky events + if expires_at > now_ms: + sticky_events.append((ev, expires_at)) + if len(sticky_events) == 0: + return + + logger.info( + "inserting %d sticky events in room %s", + len(sticky_events), + sticky_events[0][0].room_id, + ) + + # Generate stream_ids in one go + sticky_events_with_ids = zip( + sticky_events, + self._sticky_events_id_gen.get_next_mult_txn(txn, len(sticky_events)), + strict=True, + ) + + self.db_pool.simple_insert_many_txn( + txn, + "sticky_events", + keys=( + "instance_name", + "stream_id", + "room_id", + "event_id", + "sender", + "expires_at", + "soft_failed", + ), + values=[ + ( + self._instance_name, + stream_id, + ev.room_id, + ev.event_id, + ev.sender, + expires_at, + ev.internal_metadata.is_soft_failed(), + ) + for (ev, expires_at), stream_id in sticky_events_with_ids + ], + ) + + async def _get_soft_failed_sticky_events_to_recheck( + self, + room_id: str, + state_delta_for_room: DeltaState | None, + ) -> list[str]: + """Fetch soft-failed sticky events which should be rechecked against the current state. + + Soft-failed events are not rejected, so they pass auth at the state before + the event and at the auth_events in the event. Instead, soft-failed events failed auth at + the *current* state of the room. We only need to recheck soft failure if we have a reason to + believe the event may pass that check now. + + Note that we don't bother rechecking accepted events that may now be soft-failed, because + by that point it's too late as we've already sent the event to clients. + + Returns: + A list of event IDs to recheck + """ + + if state_delta_for_room is None: + # No change to current state => no way soft failure status could be different. + return [] + + # any change to critical auth events may change soft failure status. This means any changes + # to join rules, power levels or member events. If the state has changed but it isn't one + # of those events, we don't need to recheck. + critical_auth_types = ( + EventTypes.JoinRules, + EventTypes.PowerLevels, + EventTypes.Member, + ) + critical_auth_types_changed = { + typ + for typ, _ in chain( + state_delta_for_room.to_insert, state_delta_for_room.to_delete + ) + if typ in critical_auth_types + } + if len(critical_auth_types_changed) == 0: + # No change to critical auth events => no way soft failure status could be different. + return [] + + if critical_auth_types_changed == {EventTypes.Member}: + # the final case we want to catch is when unprivileged users join/leave rooms. These users cause + # changes in the critical auth types (the member event) but ultimately have no effect on soft + # failure status for anyone but that user themselves. + # + # Grab the set of senders that have been modified and see if any of them sent a soft-failed + # sticky event. If they did, then we need to re-evaluate. If they didn't, then we don't need to. + new_membership_changes = { + membership_user_id + for event_type, membership_user_id in chain( + state_delta_for_room.to_insert, state_delta_for_room.to_delete + ) + if event_type == EventTypes.Member + } + + # pull out sticky events that were sent in this room + # by those whose membership just changed + events_to_recheck: list[ + tuple[str] + ] = await self.db_pool.simple_select_many_batch( + table="sticky_events", + column="sender", + iterable=new_membership_changes, + keyvalues={ + "room_id": room_id, + "soft_failed": True, + }, + retcols=("event_id",), + desc="_get_soft_failed_sticky_events_to_recheck_members", + ) + return [event_id for (event_id,) in events_to_recheck] + + # otherwise one of the following must be true: + # - there was a change in PL or join rules + # - there was a change in the membership of a sender of a soft-failed sticky event. + # In both of these cases we want to re-evaluate soft failure status. + # + # NB: event auth checks are NOT recursive. We don't need to specifically handle the case where + # an admin user's membership changes which causes a PL event to be allowed, as when the PL event + # gets allowed we will re-evaluate anyway. E.g: + # + # PL(send_event=0, sender=Admin) #1 + # ^ ^_____________________ + # | | + # . PL(send_event=50, sender=Mod) #2 sticky event (sender=User) #3 + # + # In this scenario, the sticky event is soft-failed due to the Mod updating the PL event to + # set send_event=50, which User does not have. If we learn of an event which makes Mod's PL + # event invalid (say, Mod was banned by Admin concurrently to Mod setting the PL event), then + # the act of seeing the ban event will cause the old PL event to be in the state delta, meaning + # we will re-evaluate the sticky event due to the PL changing. We don't need to specially handle + # this case. + return await self.db_pool.simple_select_onecol( + table="sticky_events", + keyvalues={ + "room_id": room_id, + "soft_failed": True, + }, + retcol="event_id", + desc="_get_soft_failed_sticky_events_to_recheck", + ) + + async def _recheck_soft_failed_events( + self, + room_id: str, + soft_failed_event_ids: Collection[str], + ) -> None: + """ + Recheck authorised but soft-failed events. The provided event IDs must have already passed + all auth checks (so the event isn't rejected) except for soft-failure checks. + + Args: + txn: The SQL transaction + room_id: The room the event IDs are in. + soft_failed_event_ids: The soft-failed events to re-evaluate. + """ + # Load all the soft-failed events to recheck + soft_failed_event_map = await self.get_events( + soft_failed_event_ids, allow_rejected=False + ) + # What (state event type, state key) tuples are needed as auth events for the + # soft-failed events we are reconsidering? + # e.g. [('m.room.member', '@user:example.org'), ('m.room.power_levels', ''), ...] + needed_state_tuples_for_auth: set[StateKey] = set() + for soft_failed_event in soft_failed_event_map.values(): + needed_state_tuples_for_auth.update( + event_auth.auth_types_for_event( + soft_failed_event.room_version, soft_failed_event + ) + ) + + # We know the events are otherwise authorised, so we only need to load the needed tuples from + # the current state to check if the events pass auth. + current_auth_state_map = await self.get_partial_filtered_current_state_ids( + room_id, StateFilter.from_types(needed_state_tuples_for_auth) + ) + current_auth_state_event_ids: list[str] = list(current_auth_state_map.values()) + current_auth_events = await self.get_events_as_list( + current_auth_state_event_ids + ) + passing_event_ids: set[str] = set() + for soft_failed_event in soft_failed_event_map.values(): + if soft_failed_event.internal_metadata.policy_server_spammy: + # don't re-evaluate spam. + continue + try: + # We don't need to check_state_independent_auth_rules as that doesn't depend on room state, + # so if it passed once it'll pass again. + event_auth.check_state_dependent_auth_rules( + soft_failed_event, current_auth_events + ) + passing_event_ids.add(soft_failed_event.event_id) + except AuthError: + pass + + if not passing_event_ids: + return + + logger.info( + "%s soft-failed events now pass current state checks in room %s : %s", + len(passing_event_ids), + room_id, + shortstr(passing_event_ids), + ) + # Update the DB with the new soft-failure status + await self.db_pool.runInteraction( + "_recheck_soft_failed_events", + self._update_soft_failure_status_txn, + passing_event_ids, + ) + + def _update_soft_failure_status_txn( + self, txn: LoggingTransaction, passing_event_ids: set[str] + ) -> None: + # Update the sticky events table so we notify downstream of the change in soft-failure status + new_stream_ids: list[tuple[str, int]] = list( + zip( + passing_event_ids, + self._sticky_events_id_gen.get_next_mult_txn( + txn, len(passing_event_ids) + ), + strict=True, + ) + ) + + self.db_pool.simple_update_many_txn( + txn, + table="sticky_events", + key_names=("event_id",), + key_values=[(event_id,) for event_id, _stream_id in new_stream_ids], + value_names=( + "stream_id", + "soft_failed", + ), + value_values=[ + (stream_id, False) for _event_id, stream_id in new_stream_ids + ], + ) + + # Also update the internal metadata on the event itself, so when we filter_events_for_client + # we don't filter them out. It's a bit sad internal_metadata is TEXT and not JSONB... + event_id_in_list_clause, event_id_in_list_args = make_in_list_sql_clause( + txn.database_engine, + "event_id", + passing_event_ids, + ) + if isinstance(txn.database_engine, PostgresEngine): + txn.execute( + f""" + UPDATE event_json + SET internal_metadata = ( + jsonb_set(internal_metadata::jsonb, '{{soft_failed}}', 'false'::jsonb) + )::text + WHERE {event_id_in_list_clause} + """, + event_id_in_list_args, + ) + else: + assert isinstance(txn.database_engine, Sqlite3Engine) + txn.execute( + f""" + UPDATE event_json + SET internal_metadata = json_set(internal_metadata, '$.soft_failed', json('false')) + WHERE {event_id_in_list_clause} + """, + event_id_in_list_args, + ) + # finally, invalidate caches + for event_id in passing_event_ids: + self.invalidate_get_event_cache_after_txn(txn, event_id) + + async def _delete_expired_sticky_events(self) -> None: + logger.info("delete_expired_sticky_events") + await self.db_pool.runInteraction( + "_delete_expired_sticky_events", + self._delete_expired_sticky_events_txn, + self.clock.time_msec(), + ) + + def _delete_expired_sticky_events_txn( + self, txn: LoggingTransaction, now: int + ) -> None: + txn.execute( + """ + DELETE FROM sticky_events WHERE expires_at < ? + """, + (now,), + ) + + def _run_background_cleanup(self) -> Deferred: + return self.hs.run_as_background_process( + "delete_expired_sticky_events", + self._delete_expired_sticky_events, + ) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index 2fdd27d3da1..a3d75002b99 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -381,7 +381,7 @@ async def get_catch_up_room_event_ids( ) -> list[str]: """ Returns at most 50 event IDs and their corresponding stream_orderings - that correspond to the oldest events that have not yet been sent to + that correspond to the newest events that have not yet been sent to the destination. Args: diff --git a/synapse/storage/schema/main/delta/93/01_sticky_events.sql b/synapse/storage/schema/main/delta/93/01_sticky_events.sql new file mode 100644 index 00000000000..c62b6f61ce5 --- /dev/null +++ b/synapse/storage/schema/main/delta/93/01_sticky_events.sql @@ -0,0 +1,28 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +CREATE TABLE sticky_events ( + stream_id INTEGER NOT NULL PRIMARY KEY, + instance_name TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + sender TEXT NOT NULL, + expires_at BIGINT NOT NULL, + soft_failed BOOLEAN NOT NULL +); + +-- for pulling out soft failed events by room +CREATE INDEX sticky_events_room_idx ON sticky_events (room_id, soft_failed); + +-- A optional int for combining sticky events with delayed events. Used at send time. +ALTER TABLE delayed_events ADD COLUMN sticky_duration_ms BIGINT; \ No newline at end of file diff --git a/synapse/storage/schema/main/delta/93/01_sticky_events_seq.sql.postgres b/synapse/storage/schema/main/delta/93/01_sticky_events_seq.sql.postgres new file mode 100644 index 00000000000..9ba72856bc9 --- /dev/null +++ b/synapse/storage/schema/main/delta/93/01_sticky_events_seq.sql.postgres @@ -0,0 +1,18 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +CREATE SEQUENCE sticky_events_sequence; +-- Synapse streams start at 2, because the default position is 1 +-- so any item inserted at position 1 is ignored. +-- We have to use nextval not START WITH 2, see https://github.com/element-hq/synapse/issues/18712 +SELECT nextval('sticky_events_sequence'); diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 143f659499b..d2720fb9592 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -84,6 +84,7 @@ def get_current_token(self) -> StreamToken: self._instance_name ) thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id() + sticky_events_key = self.store.get_max_sticky_events_stream_id() token = StreamToken( room_key=self.sources.room.get_current_key(), @@ -98,6 +99,7 @@ def get_current_token(self) -> StreamToken: groups_key=0, un_partial_stated_rooms_key=un_partial_stated_rooms_key, thread_subscriptions_key=thread_subscriptions_key, + sticky_events_key=sticky_events_key, ) return token @@ -125,6 +127,7 @@ async def bound_future_token(self, token: StreamToken) -> StreamToken: StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(), StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(), StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(), + StreamKeyType.STICKY_EVENTS: self.store.get_sticky_events_stream_id_generator(), } for _, key in StreamKeyType.__members__.items(): diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 16892b37c0b..b9e9c21741e 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -1006,6 +1006,7 @@ class StreamKeyType(Enum): DEVICE_LIST = "device_list_key" UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key" THREAD_SUBSCRIPTIONS = "thread_subscriptions_key" + STICKY_EVENTS = "sticky_events_key" @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -1027,6 +1028,7 @@ class StreamToken: 9. `groups_key`: `1` (note that this key is now unused) 10. `un_partial_stated_rooms_key`: `379` 11. `thread_subscriptions_key`: 4242 + 12. `sticky_events_key`: 4141 You can see how many of these keys correspond to the various fields in a "/sync" response: @@ -1086,6 +1088,7 @@ class StreamToken: groups_key: int un_partial_stated_rooms_key: int thread_subscriptions_key: int + sticky_events_key: int _SEPARATOR = "_" START: ClassVar["StreamToken"] @@ -1114,6 +1117,7 @@ async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": groups_key, un_partial_stated_rooms_key, thread_subscriptions_key, + sticky_events_key, ) = keys return cls( @@ -1130,6 +1134,7 @@ async def from_string(cls, store: "DataStore", string: str) -> "StreamToken": groups_key=int(groups_key), un_partial_stated_rooms_key=int(un_partial_stated_rooms_key), thread_subscriptions_key=int(thread_subscriptions_key), + sticky_events_key=int(sticky_events_key), ) except CancelledError: raise @@ -1153,6 +1158,7 @@ async def to_string(self, store: "DataStore") -> str: str(self.groups_key), str(self.un_partial_stated_rooms_key), str(self.thread_subscriptions_key), + str(self.sticky_events_key), ] ) @@ -1218,6 +1224,7 @@ def get_field( StreamKeyType.TYPING, StreamKeyType.UN_PARTIAL_STATED_ROOMS, StreamKeyType.THREAD_SUBSCRIPTIONS, + StreamKeyType.STICKY_EVENTS, ], ) -> int: ... @@ -1274,7 +1281,7 @@ def __str__(self) -> str: f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, " f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, " f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key}," - f"thread_subscriptions: {self.thread_subscriptions_key})" + f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key})" ) @@ -1290,6 +1297,7 @@ def __str__(self) -> str: groups_key=0, un_partial_stated_rooms_key=0, thread_subscriptions_key=0, + sticky_events_key=0, ) diff --git a/synapse/types/handlers/sliding_sync.py b/synapse/types/handlers/sliding_sync.py index 03b3bcb3caf..ce2fcb44480 100644 --- a/synapse/types/handlers/sliding_sync.py +++ b/synapse/types/handlers/sliding_sync.py @@ -21,6 +21,7 @@ AbstractSet, Any, Callable, + Collection, Final, Generic, Mapping, @@ -388,12 +389,26 @@ def __bool__(self) -> bool: or bool(self.prev_batch) ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class StickyEventsExtension: + """The Sticky Events extension (MSC4354) + + Attributes: + room_id_to_sticky_events: map (room_id -> [unexpired_sticky_events]) + """ + + room_id_to_sticky_events: Mapping[str, Collection[EventBase]] + + def __bool__(self) -> bool: + return bool(self.room_id_to_sticky_events) + to_device: ToDeviceExtension | None = None e2ee: E2eeExtension | None = None account_data: AccountDataExtension | None = None receipts: ReceiptsExtension | None = None typing: TypingExtension | None = None thread_subscriptions: ThreadSubscriptionsExtension | None = None + sticky_events: StickyEventsExtension | None = None def __bool__(self) -> bool: return bool( @@ -403,6 +418,7 @@ def __bool__(self) -> bool: or self.receipts or self.typing or self.thread_subscriptions + or self.sticky_events ) next_pos: SlidingSyncStreamToken diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index 49782b52348..a3d47270315 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -383,6 +383,15 @@ class ThreadSubscriptionsExtension(RequestBodyModel): enabled: StrictBool | None = False limit: StrictInt = 100 + class StickyEventsExtension(RequestBodyModel): + """The Sticky Events extension (MSC4354) + + Attributes: + enabled + """ + + enabled: StrictBool | None = False + to_device: ToDeviceExtension | None = None e2ee: E2eeExtension | None = None account_data: AccountDataExtension | None = None @@ -391,6 +400,9 @@ class ThreadSubscriptionsExtension(RequestBodyModel): thread_subscriptions: ThreadSubscriptionsExtension | None = Field( None, alias="io.element.msc4308.thread_subscriptions" ) + sticky_events: StickyEventsExtension | None = Field( + None, alias="org.matrix.msc4354.sticky_events" + ) conn_id: StrictStr | None = None lists: ( diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 818f8b1a69b..9136856e1e8 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -341,6 +341,7 @@ async def yieldable_gather_results_delaying_cancellation( T4 = TypeVar("T4") T5 = TypeVar("T5") T6 = TypeVar("T6") +T7 = TypeVar("T7") @overload @@ -470,6 +471,30 @@ async def gather_optional_coroutines( ) -> tuple[T1 | None, T2 | None, T3 | None, T4 | None, T5 | None, T6 | None]: ... +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + tuple[ + Coroutine[Any, Any, T1] | None, + Coroutine[Any, Any, T2] | None, + Coroutine[Any, Any, T3] | None, + Coroutine[Any, Any, T4] | None, + Coroutine[Any, Any, T5] | None, + Coroutine[Any, Any, T6] | None, + Coroutine[Any, Any, T7] | None, + ] + ], +) -> tuple[ + T1 | None, + T2 | None, + T3 | None, + T4 | None, + T5 | None, + T6 | None, + T7 | None, +]: ... + + async def gather_optional_coroutines( *coroutines: Unpack[tuple[Coroutine[Any, Any, T1] | None, ...]], ) -> tuple[T1 | None, ...]: diff --git a/synapse/visibility.py b/synapse/visibility.py index bfa0db5670d..74ea79e7df1 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -237,6 +237,15 @@ def allowed(event: EventBase) -> EventBase | None: # to the cache! cloned = clone_event(filtered) cloned.unsigned[EventUnsignedContentFields.MEMBERSHIP] = user_membership + if storage.main.config.experimental.msc4354_enabled: + sticky_duration = cloned.sticky_duration() + if sticky_duration: + now = storage.main.clock.time_msec() + expires_at = min(cloned.origin_server_ts, now) + sticky_duration + if expires_at > now: + cloned.unsigned[EventUnsignedContentFields.STICKY_TTL] = ( + expires_at - now + ) return cloned diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index fd1ef043bb8..9cc6e3ac61a 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -19,6 +19,7 @@ from synapse.util.clock import Clock from synapse.util.retryutils import NotRetryingDestination +from tests import unittest from tests.test_utils import event_injection from tests.unittest import FederatingHomeserverTestCase @@ -452,6 +453,58 @@ def wake_destination_track(destination: str) -> None: # has been successfully sent. self.assertCountEqual(woken, set(server_names[:-1])) + @unittest.override_config({"experimental_features": {"msc4354_enabled": True}}) + def test_sends_sticky_events(self) -> None: + """Test that we send sticky events in addition to the latest event in the room when catching up.""" + per_dest_queue, sent_pdus = self.make_fake_destination_queue() + + # Make a room with a local user, and two servers. One will go offline + # and one will send some events. + self.register_user("u1", "you the one") + u1_token = self.login("u1", "you the one") + room_1 = self.helper.create_room_as("u1", tok=u1_token) + + self.get_success( + event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join") + ) + event_1 = self.get_success( + event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join") + ) + + # now we send a sticky event that we expect to be bundled with the fwd extrem event + sticky_event_id = self.helper.send_sticky_event( + room_1, "m.room.sticky", duration_ms=60_000, tok=u1_token + )["event_id"] + # ..and other uninteresting events + self.helper.send(room_1, "you hear me!!", tok=u1_token) + + # Now simulate us receiving an event from the still online remote. + fwd_extrem_event = self.get_success( + event_injection.inject_event( + self.hs, + type=EventTypes.Message, + sender="@user:host3", + room_id=room_1, + content={"msgtype": "m.text", "body": "Hello"}, + ) + ) + + assert event_1.internal_metadata.stream_ordering is not None + self.get_success( + self.hs.get_datastores().main.set_destination_last_successful_stream_ordering( + "host2", event_1.internal_metadata.stream_ordering + ) + ) + + self.get_success(per_dest_queue._catch_up_transmission_loop()) + + # We expect the sticky event and the fwd extrem to be sent + self.assertEqual(len(sent_pdus), 2) + # We expect the sticky event to appear before the fwd extrem + self.assertEqual(sent_pdus[0].event_id, sticky_event_id) + self.assertEqual(sent_pdus[1].event_id, fwd_extrem_event.event_id) + self.assertFalse(per_dest_queue._catching_up) + def test_not_latest_event(self) -> None: """Test that we send the latest event in the room even if its not ours.""" diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 1c340efa0cd..2e38a239fb7 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -2545,7 +2545,7 @@ def test_timestamp_to_event(self) -> None: def test_topo_token_is_accepted(self) -> None: """Test Topo Token is accepted.""" - token = "t1-0_0_0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), @@ -2559,7 +2559,7 @@ def test_topo_token_is_accepted(self) -> None: def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: """Test that stream token is accepted for forward pagination.""" - token = "s0_0_0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token), diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 926560afd6b..f85c9939ce4 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2245,7 +2245,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.room_id = self.helper.create_room_as(self.user_id) def test_topo_token_is_accepted(self) -> None: - token = "t1-0_0_0_0_0_0_0_0_0_0_0" + token = "t1-0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) @@ -2256,7 +2256,7 @@ def test_topo_token_is_accepted(self) -> None: self.assertTrue("end" in channel.json_body) def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None: - token = "s0_0_0_0_0_0_0_0_0_0_0" + token = "s0_0_0_0_0_0_0_0_0_0_0_0" channel = self.make_request( "GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token) ) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 613c317b8a6..326502ba495 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -453,6 +453,40 @@ def send_event( return channel.json_body + def send_sticky_event( + self, + room_id: str, + type: str, + *, + duration_ms: int, + content: dict | None = None, + txn_id: str | None = None, + tok: str | None = None, + expect_code: int = HTTPStatus.OK, + custom_headers: Iterable[tuple[AnyStr, AnyStr]] | None = None, + ) -> JsonDict: + if txn_id is None: + txn_id = f"m{time.time()}" + + path = f"/_matrix/client/r0/rooms/{room_id}/send/{type}/{txn_id}?org.matrix.msc4354.sticky_duration_ms={duration_ms}" + if tok: + path = path + f"&access_token={tok}" + + channel = make_request( + self.reactor, + self.site, + "PUT", + path, + content or {}, + custom_headers=custom_headers, + ) + + assert channel.code == expect_code, ( + f"Expected: {expect_code}, got: {channel.code}, resp: {channel.result['body']!r}" + ) + + return channel.json_body + def get_event( self, room_id: str,