diff --git a/changelog.d/19424.feature b/changelog.d/19424.feature new file mode 100644 index 00000000000..8f241a87b56 --- /dev/null +++ b/changelog.d/19424.feature @@ -0,0 +1 @@ +Add experimental support for [MSC4242](https://github.com/matrix-org/matrix-spec-proposals/pull/4242): State DAGs. Excludes federation support. \ No newline at end of file diff --git a/contrib/grafana/synapse.json b/contrib/grafana/synapse.json index ceacc10369f..a67d1aba0b7 100644 --- a/contrib/grafana/synapse.json +++ b/contrib/grafana/synapse.json @@ -6809,6 +6809,155 @@ ], "title": "Stale extremity dropping", "type": "timeseries" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "description": "For a given percentage P, the number X where P% of events were persisted to rooms with X state DAG forward extremities or fewer.", + "fieldConfig": { + "defaults": { + "links": [] + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 50 + }, + "id": 181, + "options": { + "alertThreshold": true + }, + "pluginVersion": "9.2.2", + "targets": [ + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "expr": "histogram_quantile(0.5, rate(synapse_storage_msc4242_state_dag_forward_extremities_persisted_bucket{server_name=\"$server_name\"}[$bucket_size]) and on (instance, job, index) (synapse_storage_events_persisted_events_total > 0))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "50%", + "refId": "A" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "expr": "histogram_quantile(0.75, rate(synapse_storage_msc4242_state_dag_forward_extremities_persisted_bucket{server_name=\"$server_name\"}[$bucket_size]) and on (instance, job, index) (synapse_storage_events_persisted_events_total > 0))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "75%", + "refId": "B" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "expr": "histogram_quantile(0.90, rate(synapse_storage_msc4242_state_dag_forward_extremities_persisted_bucket{server_name=\"$server_name\"}[$bucket_size]) and on (instance, job, index) (synapse_storage_events_persisted_events_total > 0))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "90%", + "refId": "C" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "expr": "histogram_quantile(0.99, rate(synapse_storage_msc4242_state_dag_forward_extremities_persisted_bucket{server_name=\"$server_name\"}[$bucket_size]) and on (instance, job, index) (synapse_storage_events_persisted_events_total > 0))", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "99%", + "refId": "D" + } + ], + "title": "Events persisted, by number of state DAG forward extremities in room (quantiles)", + "type": "timeseries" + }, + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "description": "Colour reflects the number of events persisted to rooms with the given number of state DAG forward extremities, or fewer.", + "fieldConfig": { + "defaults": { + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 50 + }, + "id": 127, + "options": { + "calculate": false, + "calculation": {}, + "cellGap": 1, + "cellValues": {}, + "color": { + "exponent": 0.5, + "fill": "#5794F2", + "min": 0, + "mode": "opacity", + "reverse": false, + "scale": "exponential", + "scheme": "Oranges", + "steps": 128 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "filterValues": { + "le": 1e-9 + }, + "legend": { + "show": true + }, + "rowsFrame": { + "layout": "auto" + }, + "showValue": "never", + "tooltip": { + "show": true, + "yHistogram": true + }, + "yAxis": { + "axisPlacement": "left", + "decimals": 0, + "reverse": false, + "unit": "short" + } + }, + "pluginVersion": "9.2.2", + "targets": [ + { + "datasource": { + "uid": "${DS_PROMETHEUS}" + }, + "expr": "rate(synapse_storage_msc4242_state_dag_forward_extremities_persisted_bucket{server_name=\"$server_name\"}[$bucket_size]) and on (instance, job, index) (synapse_storage_events_persisted_events_total > 0)", + "format": "heatmap", + "intervalFactor": 1, + "legendFormat": "{{le}}", + "refId": "A" + } + ], + "title": "Events persisted, by number of state DAG forward extremities in room (heatmap)", + "type": "heatmap" } ], "title": "Extremities", @@ -7711,4 +7860,4 @@ "uid": "000000012", "version": 1, "weekStart": "" -} +} \ No newline at end of file diff --git a/rust/src/events/internal_metadata.rs b/rust/src/events/internal_metadata.rs index 21d3b8c4358..6fd3d06b00b 100644 --- a/rust/src/events/internal_metadata.rs +++ b/rust/src/events/internal_metadata.rs @@ -65,6 +65,7 @@ enum EventInternalMetadataData { DelayId(Box), TokenId(i64), DeviceId(Box), + CalculatedAuthEventIDs(Vec), // MSC4242: State DAGs } impl EventInternalMetadataData { @@ -140,6 +141,10 @@ impl EventInternalMetadataData { pyo3::intern!(py, "device_id"), o.into_pyobject(py).unwrap_infallible().into_any(), ), + EventInternalMetadataData::CalculatedAuthEventIDs(o) => ( + pyo3::intern!(py, "calculated_auth_event_ids"), + o.into_pyobject(py).unwrap().into_any(), + ), } } @@ -218,6 +223,11 @@ impl EventInternalMetadataData { .map(String::into_boxed_str) .with_context(|| format!("'{key_str}' has invalid type"))?, ), + "calculated_auth_event_ids" => EventInternalMetadataData::CalculatedAuthEventIDs( + value + .extract() + .with_context(|| format!("'{key_str}' has invalid type"))?, + ), _ => return Ok(None), }; @@ -395,6 +405,10 @@ impl EventInternalMetadataInner { get_property_opt!(self, DelayId).map(|s| s.deref()) } + pub fn get_calculated_auth_event_ids(&self) -> Option<&Vec> { + get_property_opt!(self, CalculatedAuthEventIDs) + } + pub fn get_token_id(&self) -> Option { get_property_opt!(self, TokenId).copied() } @@ -456,6 +470,10 @@ impl EventInternalMetadataInner { pub fn set_device_id(&mut self, obj: String) { set_property!(self, DeviceId, obj.into_boxed_str()); } + + pub fn set_calculated_auth_event_ids(&mut self, obj: Vec) { + set_property!(self, CalculatedAuthEventIDs, obj); + } } #[pyclass(frozen)] @@ -722,6 +740,21 @@ impl EventInternalMetadata { Ok(()) } + /// The calculated auth event IDs, if it was set when the event was created. + #[getter] + fn get_calculated_auth_event_ids(&self) -> PyResult> { + let guard = self.read_inner()?; + attr_err( + guard.get_calculated_auth_event_ids().cloned(), + "calculated_auth_event_ids", + ) + } + #[setter] + fn set_calculated_auth_event_ids(&self, obj: Vec) -> PyResult<()> { + self.write_inner()?.set_calculated_auth_event_ids(obj); + Ok(()) + } + /// The delay ID, set only if the event was a delayed event. #[getter] fn get_delay_id(&self) -> PyResult { diff --git a/rust/src/room_versions.rs b/rust/src/room_versions.rs index fbcc32516ab..dbc962174dd 100644 --- a/rust/src/room_versions.rs +++ b/rust/src/room_versions.rs @@ -47,6 +47,9 @@ impl EventFormatVersions { /// MSC4291 room IDs as hashes: introduced for room HydraV11 #[classattr] const ROOM_V11_HYDRA_PLUS: i32 = 4; + /// MSC4242 state DAGs: adds prev_state_events, removes auth_events + #[classattr] + const ROOM_VMSC4242: i32 = 5; } /// Enum to identify the state resolution algorithms. @@ -146,6 +149,14 @@ pub struct RoomVersion { /// /// In these room versions, we are stricter with event size validation. pub strict_event_byte_limits_room_versions: bool, + /// MSC4242: State DAGs. Creates events with prev_state_events instead of auth_events and derives + /// state from it. Events are always processed in causal order without any gaps in the DAG + /// (prev_state_events are always known), guaranteeing that processed events have a path to the + /// create event. This is an emergent property of state DAGs as asserting that there is a path + /// to the create event every time we insert an event would be prohibitively expensive. + /// This is similar to how doubly-linked lists can potentially not refer to previous items correctly + /// without verifying the list's integrity, but doing it on every insert is too expensive. + pub msc4242_state_dags: bool, } const ROOM_VERSION_V1: RoomVersion = RoomVersion { @@ -170,6 +181,7 @@ const ROOM_VERSION_V1: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V2: RoomVersion = RoomVersion { @@ -194,6 +206,7 @@ const ROOM_VERSION_V2: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V3: RoomVersion = RoomVersion { @@ -218,6 +231,7 @@ const ROOM_VERSION_V3: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V4: RoomVersion = RoomVersion { @@ -242,6 +256,7 @@ const ROOM_VERSION_V4: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V5: RoomVersion = RoomVersion { @@ -266,6 +281,7 @@ const ROOM_VERSION_V5: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V6: RoomVersion = RoomVersion { @@ -290,6 +306,7 @@ const ROOM_VERSION_V6: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V7: RoomVersion = RoomVersion { @@ -314,6 +331,7 @@ const ROOM_VERSION_V7: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V8: RoomVersion = RoomVersion { @@ -338,6 +356,7 @@ const ROOM_VERSION_V8: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V9: RoomVersion = RoomVersion { @@ -362,6 +381,7 @@ const ROOM_VERSION_V9: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V10: RoomVersion = RoomVersion { @@ -386,6 +406,7 @@ const ROOM_VERSION_V10: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; /// MSC3389 (Redaction changes for events with a relation) based on room version "10". @@ -411,6 +432,7 @@ const ROOM_VERSION_MSC3389V10: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: true, + msc4242_state_dags: false, }; /// MSC1767 (Extensible Events) based on room version "10". @@ -436,6 +458,7 @@ const ROOM_VERSION_MSC1767V10: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; /// MSC3757 (Restricting who can overwrite a state event) based on room version "10". @@ -461,6 +484,7 @@ const ROOM_VERSION_MSC3757V10: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: false, + msc4242_state_dags: false, }; const ROOM_VERSION_V11: RoomVersion = RoomVersion { @@ -485,6 +509,7 @@ const ROOM_VERSION_V11: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: true, // Changed from v10 + msc4242_state_dags: false, }; /// MSC3757 (Restricting who can overwrite a state event) based on room version "11". @@ -510,6 +535,7 @@ const ROOM_VERSION_MSC3757V11: RoomVersion = RoomVersion { msc4289_creator_power_enabled: false, msc4291_room_ids_as_hashes: false, strict_event_byte_limits_room_versions: true, + msc4242_state_dags: false, }; const ROOM_VERSION_HYDRA_V11: RoomVersion = RoomVersion { @@ -534,6 +560,7 @@ const ROOM_VERSION_HYDRA_V11: RoomVersion = RoomVersion { msc4289_creator_power_enabled: true, // Changed from v11 msc4291_room_ids_as_hashes: true, // Changed from v11 strict_event_byte_limits_room_versions: true, + msc4242_state_dags: false, }; const ROOM_VERSION_V12: RoomVersion = RoomVersion { @@ -558,6 +585,32 @@ const ROOM_VERSION_V12: RoomVersion = RoomVersion { msc4289_creator_power_enabled: true, // Changed from v11 msc4291_room_ids_as_hashes: true, // Changed from v11 strict_event_byte_limits_room_versions: true, + msc4242_state_dags: false, +}; + +const ROOM_VERSION_MSC4242V12: RoomVersion = RoomVersion { + identifier: "org.matrix.msc4242.12", + disposition: RoomDisposition::UNSTABLE, + event_format: EventFormatVersions::ROOM_VMSC4242, + state_res: StateResolutionVersions::V2_1, + enforce_key_validity: true, + special_case_aliases_auth: false, + strict_canonicaljson: true, + limit_notifications_power_levels: true, + implicit_room_creator: true, + updated_redaction_rules: true, + restricted_join_rule: true, + restricted_join_rule_fix: true, + knock_join_rule: true, + msc3389_relation_redactions: false, + knock_restricted_join_rule: true, + enforce_int_power_levels: true, + msc3931_push_features: &[], + msc3757_enabled: false, + msc4289_creator_power_enabled: true, + msc4291_room_ids_as_hashes: true, + strict_event_byte_limits_room_versions: true, + msc4242_state_dags: true, }; /// Helper class for managing the known room versions, and providing dict-like @@ -800,6 +853,10 @@ impl RoomVersions { fn V12(py: Python<'_>) -> PyResult> { ROOM_VERSION_V12.into_py_any(py) } + #[classattr] + fn MSC4242v12(py: Python<'_>) -> PyResult> { + ROOM_VERSION_MSC4242V12.into_py_any(py) + } } /// Called when registering modules with python. @@ -814,11 +871,12 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> child_module.add_class::()?; // Build KNOWN_EVENT_FORMAT_VERSIONS as a frozenset - let known_ef: [i32; 4] = [ + let known_ef: [i32; 5] = [ EventFormatVersions::ROOM_V1_V2, EventFormatVersions::ROOM_V3, EventFormatVersions::ROOM_V4_PLUS, EventFormatVersions::ROOM_V11_HYDRA_PLUS, + EventFormatVersions::ROOM_VMSC4242, ]; let known_event_format_versions = PyFrozenSet::new(py, known_ef)?; child_module.add("KNOWN_EVENT_FORMAT_VERSIONS", known_event_format_versions)?; diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 702c7e32468..f1a7771568c 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -479,6 +479,12 @@ def read_config( # Enable room version (and thus applicable push rules from MSC3931/3932) KNOWN_ROOM_VERSIONS.add_room_version(RoomVersions.MSC1767v10) + # MSC4242: State DAGs + self.msc4242_enabled: bool = experimental.get("msc4242_enabled", False) + if self.msc4242_enabled: + # Enable the room version + KNOWN_ROOM_VERSIONS.add_room_version(RoomVersions.MSC4242v12) + # MSC3391: Removing account data. self.msc3391_enabled = experimental.get("msc3391_enabled", False) diff --git a/synapse/event_auth.py b/synapse/event_auth.py index bf239e660de..ca528ae2358 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -61,7 +61,7 @@ EventFormatVersions, RoomVersion, ) -from synapse.events import is_creator +from synapse.events import FrozenEventVMSC4242, is_creator from synapse.state import CREATE_KEY from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( @@ -186,6 +186,70 @@ async def check_state_independent_auth_rules( # 1.5 Otherwise, allow return + # State DAGs 2. Considering the event's prev_state_events: + if event.room_version.msc4242_state_dags: + prev_state_events_ids = set(cast(FrozenEventVMSC4242, event).prev_state_events) + # Fetch all of the `prev_state_events` + prev_state_events = {} + # Try to load the `prev_state_events` from `batched_auth_events` initially as + # that can save us a database hit. + if batched_auth_events is not None: + prev_state_events = { + event_id: value + for event_id in prev_state_events_ids + if (value := batched_auth_events.get(event_id)) is not None + } + # Fetch the rest of the `prev_state_events` + missing_prev_state_events_ids = prev_state_events_ids - set( + prev_state_events.keys() + ) + fetched_prev_state_events = await store.get_events( + missing_prev_state_events_ids, + redact_behaviour=EventRedactBehaviour.as_is, + allow_rejected=True, + ) + prev_state_events.update(fetched_prev_state_events) + if len(prev_state_events) != len(prev_state_events_ids): + # we should have all the `prev_state_events` by now, so if we do not, that suggests + # a Synapse programming error + known_prev_state_event_ids = set(prev_state_events) + raise AssertionError( + f"Event {event.event_id} has unknown prev_state_events " + + f"({len(prev_state_events)}/{len(prev_state_events_ids)} known)" + + f"{prev_state_events_ids - known_prev_state_event_ids} missing " + + f"out of {prev_state_events_ids}" + ) + for prev_state_event in prev_state_events.values(): + # 2.1 If there are entries which do not belong in the same room, reject. + if prev_state_event.room_id != event.room_id: + raise AuthError( + 403, + "During auth for event %s in room %s, found event %s in prev_state_events " + "which belongs to a different room %s" + % ( + event.event_id, + event.room_id, + prev_state_event.event_id, + prev_state_event.room_id, + ), + ) + # 2.2 If there are entries which do not have a state_key, reject. + if not prev_state_event.is_state(): + raise AuthError( + 403, + f"During auth for event {event.event_id} in room {event.room_id}, event has a " + + f"prev_state_event which is not state: {prev_state_event.event_id}", + ) + # 2.3 If there are entries which were themselves rejected under the checks performed on + # receipt of a PDU, reject. + if prev_state_event.rejected_reason is not None: + raise AuthError( + 403, + f"During auth for event {event.event_id} in room {event.room_id}, event has a " + + f"prev_state_event which is rejected ({prev_state_event.rejected_reason}): " + + f"{prev_state_event.event_id}", + ) + # 2. Reject if event has auth_events that: ... auth_events: ChainMap[str, EventBase] = ChainMap() if batched_auth_events: @@ -450,6 +514,12 @@ def _check_create(event: "EventBase") -> None: if event.prev_event_ids(): raise AuthError(403, "Create event has prev events") + # State DAGs 1.2 If it has any prev_state_events, reject. + if event.room_version.msc4242_state_dags: + assert isinstance(event, FrozenEventVMSC4242) + if len(event.prev_state_events) > 0: + raise AuthError(403, "Create event has prev state events") + if event.room_version.msc4291_room_ids_as_hashes: # 1.2 If the create event has a room_id, reject if "room_id" in event: diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index f48d5c4f1d7..f4a5624d1a2 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -44,10 +44,7 @@ ) from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.synapse_rust.events import EventInternalMetadata -from synapse.types import ( - JsonDict, - StrCollection, -) +from synapse.types import JsonDict, StateKey, StrCollection from synapse.util.caches import intern_dict from synapse.util.duration import Duration from synapse.util.frozenutils import freeze @@ -575,9 +572,60 @@ def auth_event_ids(self) -> StrCollection: return [*self._dict["auth_events"], create_event_id] +class FrozenEventVMSC4242(FrozenEventV4): + """FrozenEventVMSC4242, which differs from FrozenEventV4 only in the addition of prev_state_events""" + + format_version = EventFormatVersions.ROOM_VMSC4242 + prev_state_events: DictProperty[list[str]] = DictProperty("prev_state_events") + + def __init__( + self, + event_dict: JsonDict, + room_version: RoomVersion, + internal_metadata_dict: JsonDict | None = None, + rejected_reason: str | None = None, + ): + # Similar to how we assert event_id isn't in V2+ events, we do the same with auth_events. + # We don't expect `auth_events` in the wire format because we calculate it from prev_state_events. + assert "auth_events" not in event_dict + super().__init__( + event_dict=event_dict, + room_version=room_version, + internal_metadata_dict=internal_metadata_dict, + rejected_reason=rejected_reason, + ) + + def auth_event_ids(self) -> StrCollection: + """Returns the list of _calculated_ auth event IDs. + + Returns: + The list of event IDs of this event's auth events + """ + # Catches cases where we accidentally call auth_event_ids() prior to calculating what they + # actually are. The exception being the m.room.create event which has no auth events. + if self.type != EventTypes.Create: + assert len(self.internal_metadata.calculated_auth_event_ids) > 0 + return self.internal_metadata.calculated_auth_event_ids + + def __repr__(self) -> str: + rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else "" + + return ( + f"<{self.__class__.__name__} " + f"{rejection}" + f"event_id={self.event_id}, " + f"type={self.get('type')}, " + f"state_key={self.get('state_key')}, " + f"prev_events={self.get('prev_events')}, " + f"prev_state_events={self.get('prev_state_events')}, " + f"outlier={self.internal_metadata.is_outlier()}" + ">" + ) + + def _event_type_from_format_version( format_version: int, -) -> type[FrozenEvent | FrozenEventV2 | FrozenEventV3]: +) -> type[FrozenEvent | FrozenEventV2 | FrozenEventV3 | FrozenEventVMSC4242]: """Returns the python type to use to construct an Event object for the given event format version. @@ -594,6 +642,8 @@ def _event_type_from_format_version( return FrozenEventV2 elif format_version == EventFormatVersions.ROOM_V4_PLUS: return FrozenEventV3 + elif format_version == EventFormatVersions.ROOM_VMSC4242: + return FrozenEventVMSC4242 elif format_version == EventFormatVersions.ROOM_V11_HYDRA_PLUS: return FrozenEventV4 else: @@ -655,6 +705,24 @@ def relation_from_event(event: EventBase) -> _EventRelation | None: return _EventRelation(parent_id, rel_type, aggregation_key) +def event_exists_in_state_dag( + event: Union["EventBase", "EventBuilder", "EventMetadata", "StateKey"], +) -> bool: + """Given an event, returns true if this event should form part of the state DAG. + Only valid for room versions which use a state DAG (MSC4242).""" + state_key = None + if isinstance(event, EventMetadata): + state_key = event.state_key + elif isinstance(event, tuple): # StateKey + # can't use StateKey else you get: + # "Subscripted generics cannot be used with class and instance checks" + state_key = event[1] + else: + state_key = event.state_key if event.is_state() else None + + return state_key is not None + + def is_creator(create: EventBase, user_id: str) -> bool: """ Return true if the provided user ID is the room creator. @@ -689,3 +757,13 @@ class StrippedStateEvent: state_key: str sender: str content: dict[str, Any] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventMetadata: + """Returned by `get_metadata_for_events`""" + + room_id: str + event_type: str + state_key: str | None + rejection_reason: str | None diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 2cd1bf6106f..78eb98e1e59 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -132,6 +132,7 @@ async def build( prev_event_ids: list[str], auth_event_ids: list[str] | None, depth: int | None = None, + prev_state_events: list[str] | None = None, ) -> EventBase: """Transform into a fully signed and hashed event @@ -143,10 +144,51 @@ async def build( depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. - + prev_state_events: The event IDs to use as prev_state_events. + Only applicable on MSC4242 state DAG rooms. If this is supplied, auth_event_ids + must not be specified unless this event is part of a batch such that the builder + will be unable to compute the auth_event_ids due to the events not being persisted + yet. Returns: The signed and hashed event. """ + # If the caller specifies this, make sure the room version supports it. + if prev_state_events: + assert self.room_version.msc4242_state_dags + if self.room_version.msc4242_state_dags: + assert prev_state_events is not None + if self.room_id: + state_ids = await self._state.compute_state_after_events( + self.room_id, + prev_state_events, + state_filter=StateFilter.from_types( + auth_types_for_event(self.room_version, self) + ), + await_full_state=False, + ) + # When we create rooms we only insert the create+member events, and batch the rest. + # Therefore, we may not have state_ids from compute_state_after_events as the + # prev_state_events are unknown. If this happens, the caller provides the auth events + # to use instead. + calculated_auth_event_ids: list[ + str + ] = [] # assume it's the create event which has [] + if len(state_ids) == 0 and len(prev_state_events) > 0: + # it's a batched event, so we should have been provided the auth_events + assert auth_event_ids and len(auth_event_ids) > 0 + calculated_auth_event_ids = auth_event_ids + else: + calculated_auth_event_ids = ( + self._event_auth_handler.compute_auth_events(self, state_ids) + ) + else: + # event is a state DAG event and is the create event (room_id is not provided), + # therefore there are no auth_events. + calculated_auth_event_ids = [] + assert self.type == EventTypes.Create and self.state_key == "" + self.internal_metadata.calculated_auth_event_ids = calculated_auth_event_ids + auth_event_ids = calculated_auth_event_ids + # Create events always have empty auth_events. if self.type == EventTypes.Create and self.is_state() and self.state_key == "": auth_event_ids = [] @@ -155,6 +197,8 @@ async def build( if auth_event_ids is None: # Every non-create event must have a room ID assert self.room_id is not None + # this block must not be hit for MSC4242 rooms as it resolves state with prev_events + assert not self.room_version.msc4242_state_dags state_ids = await self._state.compute_state_after_events( self.room_id, prev_event_ids, @@ -231,7 +275,6 @@ async def build( # rejected by other servers (and so that they can be persisted in # the db) depth = min(depth, MAX_DEPTH) - event_dict: dict[str, Any] = { "auth_events": auth_events, "prev_events": prev_events, @@ -241,8 +284,6 @@ async def build( "unsigned": self.unsigned, "depth": depth, } - if self.room_id is not None: - event_dict["room_id"] = self.room_id if self.room_version.msc4291_room_ids_as_hashes: # In MSC4291: the create event has no room ID as the create event ID /is/ the room ID. @@ -262,6 +303,14 @@ async def build( auth_event_ids.remove(create_event_id) event_dict["auth_events"] = auth_event_ids + if self.room_version.msc4242_state_dags: + # Auth events are removed entirely on state DAG rooms + event_dict.pop("auth_events") + assert prev_state_events is not None + event_dict["prev_state_events"] = prev_state_events + if self.room_id is not None: + event_dict["room_id"] = self.room_id + if self.is_state(): event_dict["state_key"] = self._state_key diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ff0476f5fbb..f038fb5578d 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -156,6 +156,10 @@ def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDic # Earlier room versions from had additional allowed keys. if not room_version.updated_redaction_rules: allowed_keys.extend(["prev_state", "membership", "origin"]) + # Custom room versions add new allowed keys and remove others + if room_version.msc4242_state_dags: + allowed_keys.extend(["prev_state_events"]) + allowed_keys.remove("auth_events") event_type = event_dict["type"] diff --git a/synapse/events/validator.py b/synapse/events/validator.py index b27f8a942af..ff22b2287f2 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -63,14 +63,17 @@ def validate_new(self, event: EventBase, config: HomeServerConfig) -> None: if event.format_version == EventFormatVersions.ROOM_V1_V2: EventID.from_string(event.event_id) - required = [ + required = { "auth_events", "content", "hashes", "prev_events", "sender", "type", - ] + } + if event.room_version.msc4242_state_dags: + required.remove("auth_events") + required.add("prev_state_events") for k in required: if k not in event: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 55151ca549c..78a1900c731 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -1108,6 +1108,11 @@ async def send_join( SynapseError: if the chosen remote server returns a 300/400 code, or no servers successfully handle the request. """ + # See related restriction in /createRoom requests in handlers/room.py + if room_version.msc4242_state_dags: + raise UnsupportedRoomVersionError( + "Homeserver does not support this room version over federation" + ) async def send_request(destination: str) -> SendJoinResult: response = await self._do_send_join( diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index d2c1f98d7c4..51a752472f1 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -32,7 +32,7 @@ from synapse.api.constants import Direction, EventTypes, Membership from synapse.api.errors import SynapseError -from synapse.events import EventBase +from synapse.events import EventBase, FrozenEventVMSC4242 from synapse.events.utils import FilteredEvent from synapse.types import ( JsonMapping, @@ -494,9 +494,16 @@ async def _redact_all_events( event_dict["redacts"] = event.event_id try: + prev_state_events = None + if room_version.msc4242_state_dags: + assert isinstance(event, FrozenEventVMSC4242) + prev_state_events = event.prev_state_events + assert prev_state_events is not None, ( + "Parent event of redaction has no `prev_state_events` which should be impossible as `prev_state_events` is a required field in MSC4242 rooms" + ) # set the prev event to the offending message to allow for redactions # to be processed in the case where the user has been kicked/banned before - # redactions are requested + # redactions are requested. ( redaction, _, @@ -505,6 +512,7 @@ async def _redact_all_events( event_dict, prev_event_ids=[event.event_id], ratelimit=False, + prev_state_events=prev_state_events, ) except Exception as ex: logger.info( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0aa0a16127f..4032c7eca97 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -53,7 +53,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.urls import ConsentURIBuilder from synapse.event_auth import validate_event_for_room_version -from synapse.events import EventBase, relation_from_event +from synapse.events import EventBase, FrozenEventVMSC4242, relation_from_event from synapse.events.builder import EventBuilder from synapse.events.snapshot import ( EventContext, @@ -589,6 +589,7 @@ async def create_event( state_map: StateMap[str] | None = None, for_batch: bool = False, current_state_group: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[EventBase, UnpersistedEventContextBase]: """ @@ -644,6 +645,10 @@ async def create_event( current_state_group: the current state group, used only for creating events for batch persisting + prev_state_events: + The state event IDs which represent the current forward extremities of the state DAG. + Only applicable on room versions which use a state DAG (MSC4242). + delay_id: The delay ID of this event, if it was a delayed event. Raises: @@ -748,6 +753,7 @@ async def create_event( state_map=state_map, for_batch=for_batch, current_state_group=current_state_group, + prev_state_events=prev_state_events, ) # In an ideal world we wouldn't need the second part of this condition. However, @@ -976,6 +982,7 @@ async def create_and_send_nonmember_event( ignore_shadow_ban: bool = False, outlier: bool = False, depth: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[EventBase, int]: """ @@ -1005,6 +1012,9 @@ async def create_and_send_nonmember_event( depth: Override the depth used to order the event in the DAG. Should normally be set to None, which will cause the depth to be calculated based on the prev_events. + prev_state_events: + The state event IDs which represent the current forward extremities of the state DAG. + Only applicable on room versions which use a state DAG (MSC4242). delay_id: The delay ID of this event, if it was a delayed event. Returns: @@ -1102,6 +1112,7 @@ async def create_and_send_nonmember_event( ignore_shadow_ban=ignore_shadow_ban, outlier=outlier, depth=depth, + prev_state_events=prev_state_events, delay_id=delay_id, ) @@ -1116,6 +1127,7 @@ async def _create_and_send_nonmember_event_locked( ignore_shadow_ban: bool = False, outlier: bool = False, depth: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[EventBase, int]: room_id = event_dict["room_id"] @@ -1145,6 +1157,7 @@ async def _create_and_send_nonmember_event_locked( state_event_ids=state_event_ids, outlier=outlier, depth=depth, + prev_state_events=prev_state_events, delay_id=delay_id, ) context = await unpersisted_context.persist(event) @@ -1240,6 +1253,7 @@ async def create_new_client_event( state_map: StateMap[str] | None = None, for_batch: bool = False, current_state_group: int | None = None, + prev_state_events: list[str] | None = None, ) -> tuple[EventBase, UnpersistedEventContextBase]: """Create a new event for a local client. If bool for_batch is true, will create an event using the prev_event_ids, and will create an event context for @@ -1281,9 +1295,30 @@ async def create_new_client_event( current_state_group: the current state group, used only for creating events for batch persisting + prev_state_events: + The state event IDs which represent the current forward extremities of the state DAG. + Only applicable on room versions which use a state DAG (MSC4242). + If unset, populates them from the current state dag forward extremities. + Returns: Tuple of created event, UnpersistedEventContext """ + if builder.room_version.msc4242_state_dags: + assert auth_event_ids is None + # (kegan) I can't find any call-site which uses this. We can't risk letting in + # untrusted input, so for now assert that we aren't told about any state. + assert state_event_ids is None + + if builder.room_id: + if prev_state_events is None: + prev_state_events = list( + await self.store.get_state_dag_extremities(builder.room_id) + ) + else: + # create event doesn't need prev_state_events to be fetched, but it must be non-None. + assert builder.type == EventTypes.Create and builder.state_key == "" + prev_state_events = [] + # Strip down the state_event_ids to only what we need to auth the event. # For example, we don't need extra m.room.member that don't match event.sender if state_event_ids is not None: @@ -1357,7 +1392,10 @@ async def create_new_client_event( assert state_map is not None auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map) event = await builder.build( - prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth + prev_event_ids=prev_event_ids, + auth_event_ids=auth_ids, + depth=depth, + prev_state_events=prev_state_events, ) context: UnpersistedEventContextBase = ( @@ -1374,6 +1412,7 @@ async def create_new_client_event( prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, depth=depth, + prev_state_events=prev_state_events, ) # Pass on the outlier property from the builder to the event @@ -1563,6 +1602,20 @@ async def handle_new_client_event( auth_event = event_id_to_event.get(event_id) if auth_event: batched_auth_events[event_id] = auth_event + if event.room_version.msc4242_state_dags: + assert isinstance(event, FrozenEventVMSC4242) + # State DAG rooms will check that the prev_state_events are not rejected. + # To do that, we need to make sure we pass in the prev_state_events as + # batched_auth_events, else we will fail the event due to the + # prev_state_events not existing in the database. + for prev_state_event_id in event.prev_state_events: + prev_state_event = event_id_to_event.get( + prev_state_event_id + ) + if prev_state_event: + batched_auth_events[prev_state_event_id] = ( + prev_state_event + ) await self._event_auth_handler.check_auth_rules_from_context( event, batched_auth_events ) @@ -1817,7 +1870,10 @@ async def cache_joined_hosts_for_events( # set for a while, so that the expiry time is reset. state_entry = await self.state.resolve_state_groups_for_events( - event.room_id, event_ids=event.prev_event_ids() + event.room_id, + event_ids=event.prev_state_events + if isinstance(event, FrozenEventVMSC4242) + else event.prev_event_ids(), ) if state_entry.state_group: @@ -2360,9 +2416,16 @@ async def _rebuild_event_after_third_party_rules( # case. prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) + prev_state_events = None + if original_event.room_version.msc4242_state_dags: + prev_state_events = list( + await self.store.get_state_dag_extremities(builder.room_id) + ) + event = await builder.build( prev_event_ids=prev_event_ids, auth_event_ids=None, + prev_state_events=prev_state_events, ) # we rebuild the event context, to be on the safe side. If nothing else, diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 9074d7916b6..f110be0a2fc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -65,7 +65,7 @@ from synapse.api.ratelimiting import Ratelimiter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version -from synapse.events import EventBase +from synapse.events import EventBase, event_exists_in_state_dag from synapse.events.snapshot import UnpersistedEventContext from synapse.events.utils import FilteredEvent, copy_and_fixup_power_levels_contents from synapse.handlers.relations import BundledAggregations @@ -1237,6 +1237,10 @@ async def create_room( creation_content = config.get("creation_content", {}) # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier + # We do not currently support federating state DAG rooms. + # See related restriction in /send_join requests in federation_client.py. + if room_version.msc4242_state_dags: + creation_content[EventContentFields.FEDERATE] = False # trusted private chats have the invited users marked as additional creators if ( @@ -1486,6 +1490,11 @@ async def _send_events_for_new_room( # the most recently created event prev_event: list[str] = [] + # This should be the most recently created state event as we create each event + prev_state_events: list[str] | None = ( + [] if room_version.msc4242_state_dags else None + ) + # a map of event types, state keys -> event_ids. We collect these mappings this as events are # created (but not persisted to the db) to determine state for future created events # (as this info can't be pulled from the db) @@ -1512,6 +1521,7 @@ async def create_event( """ nonlocal depth nonlocal prev_event + nonlocal prev_state_events # Create the event dictionary. event_dict = {"type": etype, "content": content} @@ -1525,6 +1535,7 @@ async def create_event( creator, event_dict, prev_event_ids=prev_event, + prev_state_events=prev_state_events, depth=depth, # Take a copy to ensure each event gets a unique copy of # state_map since it is modified below. @@ -1535,7 +1546,8 @@ async def create_event( depth += 1 prev_event = [new_event.event_id] state_map[(new_event.type, new_event.state_key)] = new_event.event_id - + if room_version.msc4242_state_dags and event_exists_in_state_dag(new_event): + prev_state_events = [new_event.event_id] return new_event, new_unpersisted_context preset_config, config = self._room_preset_config(room_config) @@ -1568,6 +1580,8 @@ async def create_event( ignore_shadow_ban=True, ) last_sent_event_id = ev.event_id + if room_version.msc4242_state_dags: + prev_state_events = [ev.event_id] member_event_id, _ = await self.room_member_handler.update_membership( creator, @@ -1579,8 +1593,11 @@ async def create_event( new_room=True, prev_event_ids=[last_sent_event_id], depth=depth, + prev_state_events=prev_state_events, ) prev_event = [member_event_id] + if room_version.msc4242_state_dags: + prev_state_events = [member_event_id] # update the depth and state map here as the membership event has been created # through a different code path diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index b2e678e90e9..236c8ca03c7 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -36,9 +36,11 @@ from synapse.api.errors import ( AuthError, Codes, + NotFoundError, PartialStateConflictError, ShadowBanError, SynapseError, + UnsupportedRoomVersionError, ) from synapse.api.ratelimiting import Ratelimiter from synapse.event_auth import get_named_level, get_power_level_event @@ -408,6 +410,7 @@ async def _local_membership_update( require_consent: bool = True, outlier: bool = False, origin_server_ts: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[str, int]: """ @@ -494,6 +497,7 @@ async def _local_membership_update( depth=depth, require_consent=require_consent, outlier=outlier, + prev_state_events=prev_state_events, delay_id=delay_id, ) context = await unpersisted_context.persist(event) @@ -590,6 +594,7 @@ async def update_membership( state_event_ids: list[str] | None = None, depth: int | None = None, origin_server_ts: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[str, int]: """Update a user's membership in a room. @@ -684,6 +689,7 @@ async def update_membership( state_event_ids=state_event_ids, depth=depth, origin_server_ts=origin_server_ts, + prev_state_events=prev_state_events, delay_id=delay_id, ) @@ -707,6 +713,7 @@ async def update_membership_locked( state_event_ids: list[str] | None = None, depth: int | None = None, origin_server_ts: int | None = None, + prev_state_events: list[str] | None = None, delay_id: str | None = None, ) -> tuple[str, int]: """Helper for update_membership. @@ -951,10 +958,21 @@ async def update_membership_locked( require_consent=require_consent, outlier=outlier, origin_server_ts=origin_server_ts, + prev_state_events=prev_state_events, delay_id=delay_id, ) - latest_event_ids = await self.store.get_prev_events_for_room(room_id) + is_state_dags = False + try: + room_version = await self.store.get_room_version(room_id) + is_state_dags = room_version.msc4242_state_dags + except (NotFoundError, UnsupportedRoomVersionError): + pass + + if is_state_dags: + latest_event_ids = list(await self.store.get_state_dag_extremities(room_id)) + else: + latest_event_ids = await self.store.get_prev_events_for_room(room_id) is_partial_state_room = await self.store.is_partial_state_room(room_id) partial_state_before_join = await self.state_handler.compute_state_after_events( @@ -1165,6 +1183,8 @@ async def update_membership_locked( # see: https://github.com/matrix-org/synapse/issues/7139 if len(latest_event_ids) == 0: latest_event_ids = [invite.event_id] + if invite.room_version.msc4242_state_dags: + prev_state_events = [invite.event_id] # or perhaps this is a remote room that a local user has knocked on elif current_membership_type == Membership.KNOCK: @@ -1210,6 +1230,7 @@ async def update_membership_locked( require_consent=require_consent, outlier=outlier, origin_server_ts=origin_server_ts, + prev_state_events=prev_state_events, delay_id=delay_id, ) @@ -2108,10 +2129,21 @@ async def _generate_local_out_of_band_leave( # # the prev_events consist solely of the previous membership event. prev_event_ids = [previous_membership_event.event_id] - auth_event_ids = ( - list(previous_membership_event.auth_event_ids()) + prev_event_ids - ) + auth_event_ids = None + # Authorise the leave by referencing the previous membership + prev_state_event_ids = None + if previous_membership_event.room_version.msc4242_state_dags: + prev_state_event_ids = [ + previous_membership_event.event_id, + ] + else: + auth_event_ids = ( + list(previous_membership_event.auth_event_ids()) + prev_event_ids + ) + # State DAG rooms should not have auth events specified + # Normal rooms should not have prev state event IDs specified + assert not (prev_state_event_ids is not None and auth_event_ids is not None) # Try several times, it could fail with PartialStateConflictError # in handle_new_client_event, cf comment in except block. max_retries = 5 @@ -2127,6 +2159,7 @@ async def _generate_local_out_of_band_leave( prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids, outlier=True, + prev_state_events=prev_state_event_ids, ) context = await unpersisted_context.persist(event) event.internal_metadata.out_of_band_membership = True diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index a92233c863e..2f0e3f2c3e5 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -37,7 +37,7 @@ from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions -from synapse.events import EventBase +from synapse.events import EventBase, FrozenEventVMSC4242 from synapse.events.snapshot import ( EventContext, UnpersistedEventContext, @@ -239,31 +239,6 @@ async def compute_state_after_events( ) return await ret.get_state(self._state_storage_controller, state_filter) - async def get_current_user_ids_in_room( - self, room_id: str, latest_event_ids: StrCollection - ) -> set[str]: - """ - Get the users IDs who are currently in a room. - - Note: This is much slower than using the equivalent method - `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`, - so this should only be used when wanting the users at a particular point - in the room. - - Args: - room_id: The ID of the room. - latest_event_ids: Precomputed list of latest event IDs. Will be computed if None. - Returns: - Set of user IDs in the room. - """ - - assert latest_event_ids is not None - - logger.debug("calling resolve_state_groups from get_current_user_ids_in_room") - entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - state = await entry.get_state(self._state_storage_controller, StateFilter.all()) - return await self.store.get_joined_user_ids_from_state(room_id, state) - async def get_hosts_in_room_at_events( self, room_id: str, event_ids: StrCollection ) -> frozenset[str]: @@ -303,7 +278,8 @@ async def calculate_context_info( membership events. `False` if `state_ids_before_event` is the full state. `None` when `state_ids_before_event` is not provided. In this case, the - flag will be calculated based on `event`'s prev events. + flag will be calculated based on `event`'s `prev_events` or `prev_state_events` + for state DAG rooms. state_group_before_event: the current state group at the time of event, if known Returns: @@ -337,7 +313,11 @@ async def calculate_context_info( # (This is slightly racy - the prev-events might get fixed up before we use # their states - but I don't think that really matters; it just means we # might redundantly recalculate the state for this event later.) - prev_event_ids = event.prev_event_ids() + prev_event_ids = frozenset( + event.prev_state_events + if isinstance(event, FrozenEventVMSC4242) + else event.prev_event_ids() + ) incomplete_prev_events = await self.store.get_partial_state_events( prev_event_ids ) @@ -355,7 +335,7 @@ async def calculate_context_info( entry = await self.resolve_state_groups_for_events( event.room_id, - event.prev_event_ids(), + prev_event_ids, await_full_state=False, ) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 2948227807f..7cc6a39639a 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -35,6 +35,7 @@ Generic, Iterable, TypeVar, + cast, ) import attr @@ -43,7 +44,9 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.events import EventBase +from synapse.api.errors import SynapseError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.events import EventBase, FrozenEventVMSC4242, event_exists_in_state_dag from synapse.events.snapshot import EventContext, EventPersistencePair from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable @@ -68,6 +71,7 @@ from synapse.types.state import StateFilter from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results from synapse.util.metrics import Measure +from synapse.util.stringutils import shortstr if TYPE_CHECKING: from synapse.server import HomeServer @@ -111,6 +115,14 @@ buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), ) +# The number of forward extremities for each new event. +msc4242_state_dag_forward_extremities_counter = Histogram( + "synapse_storage_msc4242_state_dag_forward_extremities_persisted", + "Number of forward extremities for each new event in the state DAG", + labelnames=[SERVER_NAME_LABEL], + buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), +) + state_resolutions_during_persistence = Counter( "synapse_storage_events_state_resolutions_during_persistence", "Number of times we had to do state res to calculate new current state", @@ -529,7 +541,15 @@ async def _calculate_current_state(self, room_id: str) -> StateMap[str]: Returns: map from (type, state_key) to event id for the current state in the room """ - latest_event_ids = await self.main_store.get_latest_event_ids_in_room(room_id) + room_version = await self.main_store.get_room_version_id(room_id) + room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + if room_version_obj.msc4242_state_dags: + latest_event_ids = await self.main_store.get_state_dag_extremities(room_id) + else: + latest_event_ids = await self.main_store.get_latest_event_ids_in_room( + room_id + ) + state_groups = set( ( await self.main_store._get_state_group_for_events(latest_event_ids) @@ -551,7 +571,6 @@ async def _calculate_current_state(self, room_id: str) -> StateMap[str]: # Avoid a circular import. from synapse.state import StateResolutionStore - room_version = await self.main_store.get_room_version_id(room_id) res = await self._state_resolution_handler.resolve_state_groups( room_id, room_version, @@ -615,28 +634,52 @@ async def _persist_event_batch( for x in range(0, len(events_and_contexts), 100) ] + # Get the room version for the first event. This room version is the same for all events + # as events_and_contexts is all for one room. + assert len(events_and_contexts) > 0 + room_version = events_and_contexts[0][0].room_version + for chunk in chunks: # We can't easily parallelize these since different chunks # might contain the same event. :( new_forward_extremities = None state_delta_for_room = None + new_state_dag_extrems = None if not backfilled: - with Measure( - self._clock, - name="_calculate_state_and_extrem", - server_name=self.server_name, - ): - # Work out the new "current state" for the room. - # We do this by working out what the new extremities are and then - # calculating the state from that. - ( - new_forward_extremities, - state_delta_for_room, - ) = await self._calculate_new_forward_extremities_and_state_delta( - room_id, chunk - ) + if room_version.msc4242_state_dags: + with Measure( + self._clock, + name="_process_state_dag_forward_extremities_and_state_delta", + server_name=self.server_name, + ): + assert all( + isinstance(ev, FrozenEventVMSC4242) for ev, _ in chunk + ) + ( + new_forward_extremities, # for prev_events + state_delta_for_room, # for state groups + new_state_dag_extrems, # for prev_state_events + ) = await self._process_state_dag_forward_extremities_and_state_delta( + room_id, + cast(list[tuple[FrozenEventVMSC4242, EventContext]], chunk), + ) + else: + with Measure( + self._clock, + name="_calculate_state_and_extrem", + server_name=self.server_name, + ): + # Work out the new "current state" for the room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + ( + new_forward_extremities, + state_delta_for_room, + ) = await self._calculate_new_forward_extremities_and_state_delta( + room_id, chunk + ) with Measure( self._clock, @@ -666,6 +709,7 @@ async def _persist_event_batch( use_negative_stream_ordering=backfilled, inhibit_local_membership_updates=backfilled, new_event_links=new_event_links, + new_state_dag_forward_extremities=new_state_dag_extrems, ) return replaced_events @@ -793,6 +837,216 @@ async def _calculate_new_forward_extremities_and_state_delta( return (new_forward_extremities, delta) + async def _process_state_dag_forward_extremities_and_state_delta( + self, + room_id: str, + event_contexts: list[tuple[FrozenEventVMSC4242, EventContext]], + ) -> tuple[set[str] | None, DeltaState | None, set[str] | None]: + """Process the forwards extremities for state DAG rooms. + Returns: + - the new room dag extremities which should be written when these events are persisted. + - the state delta for the room, if applicable. + - the new state dag extremities which should be written when these events are persisted. + + NB: this does not write them because if it did, new events may see them _before_ the events + get persisted, causing failures in retrieving state groups. + """ + # Update forward extremities + # ...for the state DAG + existing_state_dag_fwd_extrems = ( + await self.main_store.get_state_dag_extremities(room_id) + ) + new_state_dag_fwd_extrems = await self._calculate_new_state_dag_extremities( + room_id, + existing_state_dag_fwd_extrems, + event_contexts, + ) + # ...and the room DAG + existing_room_dag_fwd_extrems = ( + await self.main_store.get_latest_event_ids_in_room(room_id) + ) + new_room_dag_fwd_extrems = await self._calculate_new_extremities( + room_id, + cast(list[EventPersistencePair], event_contexts), + existing_room_dag_fwd_extrems, + ) + assert new_room_dag_fwd_extrems, ( + f"No room dag forward extremities left in room {room_id}!" + ) + + # See if we need to calculate a state delta + if new_state_dag_fwd_extrems == existing_state_dag_fwd_extrems: + # No change in state extremities, so no new state to calculate + return new_room_dag_fwd_extrems, None, new_state_dag_fwd_extrems + + with Measure( + self._clock, + name="persist_events.state_dag.get_new_state_after_events", + server_name=self.server_name, + ): + (current_state, delta_ids, _) = await self._get_new_state_after_events( + room_id, + cast(list[EventPersistencePair], event_contexts), + existing_state_dag_fwd_extrems, + new_state_dag_fwd_extrems, + # do not prune forward extremities in the state DAG + # else we lose eventual delivery + should_prune=False, + ) + + # Following logic cargoculted from _calculate_new_forward_extremities_and_state_delta + # If either are not None then there has been a change, + # and we need to work out the delta (or use that + # given) + delta = None + if delta_ids is not None: + # If there is a delta we know that we've + # only added or replaced state, never + # removed keys entirely. + delta = DeltaState([], delta_ids) + elif current_state is not None: + with Measure( + self._clock, + name="persist_events.calculate_state_delta", + server_name=self.server_name, + ): + delta = await self._calculate_state_delta(room_id, current_state) + + if delta: + # If we have a change of state then lets check + # whether we're actually still a member of the room, + # or if our last user left. If we're no longer in + # the room then we delete the current state and + # extremities. + is_still_joined = await self._is_server_still_joined( + room_id, + cast(list[EventPersistencePair], event_contexts), + delta, + ) + if not is_still_joined: + logger.info("Server no longer in room %s", room_id) + delta.no_longer_in_room = True + + return new_room_dag_fwd_extrems, delta, new_state_dag_fwd_extrems + + async def _calculate_new_state_dag_extremities( + self, + room_id: str, + existing_fwd_extrems: frozenset[str], + event_contexts: list[tuple[FrozenEventVMSC4242, EventContext]], + ) -> set[str]: + """Calculate the new state dag forward extremities. Modifies existing_fwd_extrems. + + Assumes that event_contexts are only state events which should be in the state DAG. + + Raises: + SynapseError: if the new events include unknown prev_state_events + AssertionError: if there are no state DAG forward extremities remaining in the room + """ + # Events are always processed in causal order without any gaps in the DAG + # (prev_state_events are always known), guaranteeing that processed events have a path to the + # create event. This is an emergent property of state DAGs as asserting that there is a path + # to the create event every time we insert an event would be prohibitively expensive. + # This is similar to how doubly-linked lists can potentially not refer to previous items correctly + # without verifying the list's integrity, but doing it on every insert is too expensive. + + # filter out events which don't belong in the state dag. + new_state_events_contexts = [ + (e, ctx) for e, ctx in event_contexts if event_exists_in_state_dag(e) + ] + if len(new_state_events_contexts) == 0: + # if there are no state events being persisted, then the fwd extremities of the state dag + # do not change. + return set(existing_fwd_extrems) + + # This logic is very similar to _calculate_new_extremities with a few key differences: + # - We do not "Remove any events which are prev_events of any existing events." because the + # state DAG mandates that events are processed in causal order, so there MUST NOT be any + # existing, processed events which have the to-be-persisted events as prev_state_events. + # - We don't care if they are an "outlier" in the main room dag, so long as they AREN'T + # an outlier on the state dag, which this function checks, so we don't check outlier-ness. + # - We allow *soft-failed* events to become forward extremities, as per the MSC. We do not + # allow *rejected* events to become forward extremities though. + + rejected_events = [ev for ev, ctx in new_state_events_contexts if ctx.rejected] + new_state_events = [ + ev for ev, ctx in new_state_events_contexts if not ctx.rejected + ] + # We want to check that we are not missing any prev_state_events. + # To do this, we include rejected events in this check because other events may point to them. + # If we didn't include them, we might incorrectly say we are missing events when we are not. + all_new_state_events = set(rejected_events + new_state_events) + + # First, verify that we know all prev_state_events. If we fail this check then we don't have + # a complete DAG and that is bad, so bail out. + + # Start with them all missing. + missing_prev_state_events = { + e_id for event in all_new_state_events for e_id in event.prev_state_events + } + + # remove prev events which appear in all_events + missing_prev_state_events.difference_update( + event.event_id for event in all_new_state_events + ) + # the rest of these events should be present in the DB. Some of them may be forward extremities, + # some may not be, that's ok. + seen_events = await self.main_store.have_seen_events( + room_id, + missing_prev_state_events, + ) + missing_prev_state_events.difference_update(seen_events) + + if len(missing_prev_state_events) > 0: + logger.error( + "_calculate_new_state_dag_extremities: missing the following prev_state_events in room %s : %s", + room_id, + missing_prev_state_events, + ) + logger.error( + "_calculate_new_state_dag_extremities: was handling %s", + shortstr([ev.event_id for ev in all_new_state_events]), + ) + raise SynapseError( + code=500, + msg=f"missing {len(missing_prev_state_events)} prev_state_events in room {room_id}", + ) + + # Now calculate the forward extremities. + + # start with the existing forward extremities + result = set(existing_fwd_extrems) + + # add all the new events to the list + result.update(event.event_id for event in new_state_events) + + # Now remove all events which are prev_state_events of any of the new events + result.difference_update( + e_id for event in new_state_events for e_id in event.prev_state_events + ) + + # Finally handle the case where the new events have rejected/soft-failed `prev_state_events`. + # If they do we need to remove them and their `prev_state_events`, + # otherwise we end up with dangling extremities. + # Specifically, this handles the case where (F=fwd extrem, SF=soft-failed, N=new event) + # F <-- SF <-- SF <-- N + # where we want to remove F as a forward extremity and replace with N. + existing_prevs = await self.persist_events_store._get_prevs_before_rejected( + (e_id for event in new_state_events for e_id in event.prev_state_events), + include_soft_failed=False, + ) + result.difference_update(existing_prevs) + + # We only update metrics for events that change forward extremities + if result != existing_fwd_extrems: + msc4242_state_dag_forward_extremities_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).observe(len(result)) + + # There should always be at least one forward extremity. + assert result, f"No state dag forward extremities left in room {room_id}!" + return result + async def _calculate_new_extremities( self, room_id: str, @@ -859,6 +1113,7 @@ async def _get_new_state_after_events( events_context: list[EventPersistencePair], old_latest_event_ids: AbstractSet[str], new_latest_event_ids: set[str], + should_prune: bool = True, ) -> tuple[StateMap[str] | None, StateMap[str] | None, set[str]]: """Calculate the current state dict after adding some new events to a room @@ -873,9 +1128,15 @@ async def _get_new_state_after_events( old_latest_event_ids: the old forward extremities for the room. - new_latest_event_ids : + new_latest_event_ids: the new forward extremities for the room. + should_prune: + if true, attempt to prune the forward extremities. + Pruning means we will not communicate some new events to other servers, + which can compromise eventual delivery, so graphs which are fully synchronised + e.g. state DAGs should not prune. + Returns: Returns a tuple of two state maps and a set of new forward extremities. @@ -1015,7 +1276,7 @@ async def _get_new_state_after_events( # If the returned state matches the state group of one of the new # forward extremities then we check if we are able to prune some state # extremities. - if res.state_group and res.state_group in new_state_groups: + if should_prune and res.state_group and res.state_group in new_state_groups: new_latest_event_ids = await self._prune_extremities( room_id, new_latest_event_ids, diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index cc7083b605d..415926eb0a3 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1493,6 +1493,15 @@ async def get_latest_event_ids_in_room(self, room_id: str) -> frozenset[str]: ) return frozenset(event_ids) + async def get_state_dag_extremities(self, room_id: str) -> frozenset[str]: + event_ids = await self.db_pool.simple_select_onecol( + table="msc4242_state_dag_forward_extremities", + keyvalues={"room_id": room_id}, + retcol="event_id", + desc="get_state_dag_extremities", + ) + return frozenset(event_ids) + async def get_min_depth(self, room_id: str) -> int | None: """For the given room, get the minimum depth we have seen for it.""" return await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 6d3bc15777b..12c918eca64 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -48,7 +48,9 @@ from synapse.api.room_versions import RoomVersions from synapse.events import ( EventBase, + FrozenEventVMSC4242, StrippedStateEvent, + event_exists_in_state_dag, is_creator, relation_from_event, ) @@ -295,6 +297,7 @@ async def _persist_events_and_state_updates( new_event_links: dict[str, NewEventChainLinks], use_negative_stream_ordering: bool = False, inhibit_local_membership_updates: bool = False, + new_state_dag_forward_extremities: set[str] | None = None, ) -> None: """Persist a set of events alongside updates to the current state and forward extremities tables. @@ -315,6 +318,8 @@ async def _persist_events_and_state_updates( from being updated by these events. This should be set to True for backfilled events because backfilled events in the past do not affect the current local state. + new_state_dag_forward_extremities: A set of event IDs that are the new forward + extremities for the state DAG for this room. MSC4242 only. Returns: Resolves when the events have been persisted @@ -379,6 +384,7 @@ async def _persist_events_and_state_updates( new_forward_extremities=new_forward_extremities, new_event_links=new_event_links, sliding_sync_table_changes=sliding_sync_table_changes, + new_state_dag_forward_extremities=new_state_dag_forward_extremities, ) persist_event_counter.labels(**{SERVER_NAME_LABEL: self.server_name}).inc( len(events_and_contexts) @@ -962,8 +968,10 @@ def _get_events_which_are_prevs_txn( return results - async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> set[str]: - """Get soft-failed ancestors to remove from the extremities. + async def _get_prevs_before_rejected( + self, event_ids: Iterable[str], include_soft_failed: bool = True + ) -> set[str]: + """Get soft-failed/rejected ancestors to remove from the extremities. Given a set of events, find all those that have been soft-failed or rejected. Returns those soft failed/rejected events and their prev @@ -976,7 +984,8 @@ async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> set[str] Args: event_ids: Events to find prev events for. Note that these must have already been persisted. - + include_soft_failed: Soft-failed events are included in the search. If false, only + rejected events are included. Returns: The previous events. """ @@ -1016,7 +1025,7 @@ def _get_prevs_before_rejected_txn( continue soft_failed = db_to_json(metadata).get("soft_failed") - if soft_failed or rejected: + if (include_soft_failed and soft_failed) or rejected: to_recursively_check.append(prev_event_id) existing_prevs.add(prev_event_id) @@ -1038,6 +1047,7 @@ def _persist_events_txn( new_forward_extremities: set[str] | None, new_event_links: dict[str, NewEventChainLinks], sliding_sync_table_changes: SlidingSyncTableChanges | None, + new_state_dag_forward_extremities: set[str] | None = None, ) -> None: """Insert some number of room events into the necessary database tables. @@ -1146,6 +1156,11 @@ def _persist_events_txn( max_stream_order=max_stream_order, ) + if new_state_dag_forward_extremities: + self._set_state_dag_extremities_txn( + txn, room_id, new_state_dag_forward_extremities + ) + self._persist_transaction_ids_txn(txn, events_and_contexts) # Insert into event_to_state_groups. @@ -2475,6 +2490,29 @@ def _update_forward_extremities_txn( ], ) + def _set_state_dag_extremities_txn( + self, txn: LoggingTransaction, room_id: str, new_extrems: Collection[str] + ) -> None: + self.db_pool.simple_delete_txn( + txn, + table="msc4242_state_dag_forward_extremities", + keyvalues={ + "room_id": room_id, + }, + ) + self.db_pool.simple_insert_many_txn( + txn, + table="msc4242_state_dag_forward_extremities", + keys=("room_id", "event_id"), + values=[ + ( + room_id, + event_id, + ) + for event_id in new_extrems + ], + ) + @classmethod def _filter_events_and_contexts_for_duplicates( cls, events_and_contexts: list[EventPersistencePair] @@ -2859,6 +2897,12 @@ def _update_metadata_tables_txn( self._handle_event_relations(txn, event) + if event.room_version.msc4242_state_dags and event_exists_in_state_dag( + event + ): + assert isinstance(event, FrozenEventVMSC4242) + self._store_state_dag_edges(txn, event) + # Store the labels for this event. labels = event.content.get(EventContentFields.LABELS) if labels: @@ -2935,6 +2979,36 @@ def local_prefill() -> None: txn.async_call_after(external_prefill) txn.call_after(local_prefill) + def _store_state_dag_edges( + self, txn: LoggingTransaction, event: FrozenEventVMSC4242 + ) -> None: + # the create event has no edge but we still need to persist it as get_state_dag just + # yanks all rows in this table. It's a bit gross to store NULL as the prev_state_event_id + # though. + if len(event.prev_state_events) == 0 and event.type == EventTypes.Create: + self.db_pool.simple_insert_txn( + txn, + table="msc4242_state_dag_edges", + values={ + "room_id": event.room_id, + "event_id": event.event_id, + "prev_state_event_id": None, + }, + ) + return + assert len(event.prev_state_events) > 0 + self.db_pool.simple_upsert_many_txn( + txn, + table="msc4242_state_dag_edges", + key_names=["room_id", "event_id", "prev_state_event_id"], + key_values=[ + (event.room_id, event.event_id, prev_state_event) + for prev_state_event in event.prev_state_events + ], + value_names=(), + value_values=(), + ) + def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: assert event.redacts is not None self.db_pool.simple_upsert_txn( @@ -3456,7 +3530,13 @@ def _store_event_state_mappings_txn( """ state_groups = {} for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): + # state dag rooms allow outliers to have state, as `/get_missing_events` state dag events are nominally + # outliers (not present in the timeline) but do need state persisted so we can calculate + # what the auth_events are for the event. + if ( + not event.room_version.msc4242_state_dags + and event.internal_metadata.is_outlier() + ): # double-check that we don't have any events that claim to be outliers # *and* have partial state (which is meaningless: we should have no # state at all for an outlier) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index d55ea5cf7d4..fe8079c2010 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -71,6 +71,10 @@ # so must be deleted first. "sliding_sync_joined_rooms", "sliding_sync_membership_snapshots", + # Note: msc4242_state_dag_forward_extremities/edges have a foreign key to the `events` table + # so must be deleted first. + "msc4242_state_dag_forward_extremities", + "msc4242_state_dag_edges", "events", "federation_inbound_events_staging", "receipts_graph", diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index cfde107b486..87523e6f180 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -38,7 +38,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion -from synapse.events import EventBase +from synapse.events import EventBase, EventMetadata from synapse.events.snapshot import EventContext from synapse.logging.opentracing import trace from synapse.replication.tcp.streams import UnPartialStatedEventStream @@ -78,16 +78,6 @@ class Sentinel: ROOM_UNKNOWN_SENTINEL = Sentinel() -@attr.s(slots=True, frozen=True, auto_attribs=True) -class EventMetadata: - """Returned by `get_metadata_for_events`""" - - room_id: str - event_type: str - state_key: str | None - rejection_reason: str | None - - def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: v = KNOWN_ROOM_VERSIONS.get(room_version_id) if not v: diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index e3095a9d0d0..1afc6d0b2a6 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -174,6 +174,7 @@ Changes in SCHEMA_VERSION = 94 - Add `recheck` column (boolean, default true) to the `redactions` table. + - MSC4242: Add state DAG tables. """ diff --git a/synapse/storage/schema/main/delta/94/03_state_dag_fwd_extrems.sql b/synapse/storage/schema/main/delta/94/03_state_dag_fwd_extrems.sql new file mode 100644 index 00000000000..bc5c738ba53 --- /dev/null +++ b/synapse/storage/schema/main/delta/94/03_state_dag_fwd_extrems.sql @@ -0,0 +1,38 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2026 Element Creations, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +CREATE TABLE IF NOT EXISTS msc4242_state_dag_forward_extremities( + -- we always expect the room to exist. If it gets removed, delete fwd extremities. + room_id TEXT NOT NULL REFERENCES rooms(room_id) ON DELETE CASCADE, + event_id TEXT NOT NULL REFERENCES events(event_id) ON DELETE CASCADE, + -- it doesn't make sense to reference the same event multiple times, and this uniqueness + -- index is also used to delete events once they are no longer forward extremities. + UNIQUE (event_id) +); +-- When creating events, we want to select all forward extremities for a room which this index helps with. +CREATE INDEX msc4242_state_dag_room ON msc4242_state_dag_forward_extremities(room_id); + + +CREATE TABLE IF NOT EXISTS msc4242_state_dag_edges( + -- Deleting the room deletes the state DAG. + room_id TEXT NOT NULL REFERENCES rooms(room_id) ON DELETE CASCADE, + -- the event IDs being referenced must exist (hence REFERENCES) and we do not want to accidentally delete + -- the event and create a hole in the state DAG. It is not possible for a state + -- DAG room to function with an holey DAG, so these events _cannot_ be purged. To purge them, the + -- entire room would need to be deleted. + event_id TEXT NOT NULL REFERENCES events(event_id), + -- one of the `prev_state_events` for this event ID. We must have it since we must have the entire state DAG. + -- can be NULL for the create event. + prev_state_event_id TEXT REFERENCES events(event_id) +); +CREATE UNIQUE INDEX msc4242_state_dag_edges_key ON msc4242_state_dag_edges(room_id, event_id, prev_state_event_id); diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi index 29432bdd560..fe0ca044208 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi @@ -60,6 +60,9 @@ class EventInternalMetadata: device_id: str """The device ID of the user who sent this event, if any.""" + # MSC4242 state dags + calculated_auth_event_ids: list[str] + def get_dict(self) -> JsonDict: ... def is_outlier(self) -> bool: ... def copy(self) -> "EventInternalMetadata": ... diff --git a/synapse/synapse_rust/room_versions.pyi b/synapse/synapse_rust/room_versions.pyi index 909e3a1c26f..9bbb538f185 100644 --- a/synapse/synapse_rust/room_versions.pyi +++ b/synapse/synapse_rust/room_versions.pyi @@ -31,6 +31,8 @@ class EventFormatVersions: """MSC1884-style format: introduced for room v4""" ROOM_V11_HYDRA_PLUS: int """MSC4291 room IDs as hashes: introduced for room HydraV11""" + ROOM_VMSC4242: int + """MSC4242 state DAGs: adds prev_state_events, removes auth_events""" KNOWN_EVENT_FORMAT_VERSIONS: frozenset[int] @@ -113,6 +115,14 @@ class RoomVersion: rather than in codepoints. If true, this room version uses stricter event size validation.""" + msc4242_state_dags: bool + """MSC4242: State DAGs. Creates events with prev_state_events instead of auth_events and derives + state from it. Events are always processed in causal order without any gaps in the DAG + (prev_state_events are always known), guaranteeing that processed events have a path to the + create event. This is an emergent property of state DAGs as asserting that there is a path + to the create event every time we insert an event would be prohibitively expensive. + This is similar to how doubly-linked lists can potentially not refer to previous items correctly + without verifying the list's integrity, but doing it on every insert is too expensive.""" class RoomVersions: V1: RoomVersion @@ -132,6 +142,7 @@ class RoomVersions: MSC3757v11: RoomVersion HydraV11: RoomVersion V12: RoomVersion + MSC4242v12: RoomVersion class KnownRoomVersionsMapping(Mapping[str, RoomVersion]): def add_room_version(self, room_version: RoomVersion) -> None: ... diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 275e5dfa1d8..a40e0b06807 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -432,7 +432,14 @@ def _test_get_extremities_common(self, room_version: str) -> None: self.assertEqual(channel.json_body["error"], "Server is banned from room") self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") - @parameterized.expand([(k,) for k in KNOWN_ROOM_VERSIONS.keys()]) + # FIXME: Exclude MSC4242 room versions whilst it lacks federation support + @parameterized.expand( + [ + (k,) + for k in KNOWN_ROOM_VERSIONS.keys() + if k != RoomVersions.MSC4242v12.identifier + ] + ) @override_config( {"use_frozen_dicts": True, "experimental_features": {"msc4370_enabled": True}} ) @@ -440,7 +447,14 @@ def test_get_extremities_with_frozen_dicts(self, room_version: str) -> None: """Test GET /extremities with USE_FROZEN_DICTS=True""" self._test_get_extremities_common(room_version) - @parameterized.expand([(k,) for k in KNOWN_ROOM_VERSIONS.keys()]) + # FIXME: Exclude MSC4242 room versions whilst it lacks federation support + @parameterized.expand( + [ + (k,) + for k in KNOWN_ROOM_VERSIONS.keys() + if k != RoomVersions.MSC4242v12.identifier + ] + ) @override_config( {"use_frozen_dicts": False, "experimental_features": {"msc4370_enabled": True}} ) @@ -573,12 +587,18 @@ def _test_send_join_common(self, room_version: str) -> None: @override_config({"use_frozen_dicts": True}) def test_send_join_with_frozen_dicts(self, room_version: str) -> None: """Test send_join with USE_FROZEN_DICTS=True""" + if room_version == RoomVersions.MSC4242v12.identifier: + # TODO: This room version doesn't work over federation in this PR. + return self._test_send_join_common(room_version) @parameterized.expand([(k,) for k in KNOWN_ROOM_VERSIONS.keys()]) @override_config({"use_frozen_dicts": False}) def test_send_join_without_frozen_dicts(self, room_version: str) -> None: """Test send_join with USE_FROZEN_DICTS=False""" + if room_version == RoomVersions.MSC4242v12.identifier: + # TODO: This room version doesn't work over federation in this PR. + return self._test_send_join_common(room_version) def test_send_join_partial_state(self) -> None: diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 3d856b93462..1aaa86e2e87 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -1121,7 +1121,9 @@ async def get_event( return {"pdus": [missing_event.get_pdu_json()]} async def get_room_state_ids( - destination: str, room_id: str, event_id: str + destination: str, + room_id: str, + event_id: str, ) -> JsonDict: self.assertEqual(destination, self.OTHER_SERVER_NAME) self.assertEqual(event_id, missing_event.event_id) @@ -1131,7 +1133,10 @@ async def get_room_state_ids( } async def get_room_state( - room_version: RoomVersion, destination: str, room_id: str, event_id: str + room_version: RoomVersion, + destination: str, + room_id: str, + event_id: str, ) -> StateRequestResponse: self.assertEqual(destination, self.OTHER_SERVER_NAME) self.assertEqual(event_id, missing_event.event_id) diff --git a/tests/storage/test_msc4242_state_dag.py b/tests/storage/test_msc4242_state_dag.py new file mode 100644 index 00000000000..8775e5c8eb7 --- /dev/null +++ b/tests/storage/test_msc4242_state_dag.py @@ -0,0 +1,371 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2026 Element Creations, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . + +from typing import Iterable +from unittest.mock import Mock + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventTypes +from synapse.api.errors import SynapseError +from synapse.api.room_versions import RoomVersions +from synapse.events import FrozenEventVMSC4242, make_event_from_dict +from synapse.events.snapshot import EventContext +from synapse.rest.client import room +from synapse.server import HomeServer +from synapse.util.clock import Clock + +from tests.unittest import HomeserverTestCase, override_config + + +class MSC4242StateDagsTests(HomeserverTestCase): + user_id = "@user1:server" + servlets = [room.register_servlets] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + hs = self.setup_test_homeserver("server") + return hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.room_id = self.helper.create_room_as( + self.user_id, + room_version=RoomVersions.MSC4242v12.identifier, + ) + + self.store = hs.get_datastores().main + self._storage_controllers = self.hs.get_storage_controllers() + + def _get_prev_state_events(self, event_id: str) -> list[str]: + ev = self.helper.get_event(self.room_id, event_id) + prev_state_events: list[str] | None = ev.get("prev_state_events", None) + assert prev_state_events is not None + return prev_state_events + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_forward_extremities_are_calculated(self) -> None: + """ + Check that forward extremities are set as prev_state_events and that they don't change + for non-state events. + """ + # they don't change for messages + first_event_id = self.helper.send(self.room_id, body="test1")["event_id"] + first_prev_state_events = self._get_prev_state_events(first_event_id) + assert len(first_prev_state_events) == 1 + second_id = self.helper.send(self.room_id, body="test2")["event_id"] + second_prev_state_events = self._get_prev_state_events(second_id) + assert len(second_prev_state_events) == 1 + self.assertIncludes( + set(first_prev_state_events), set(second_prev_state_events), exact=True + ) + + # send an auth event, which should change the prev_state_events on *subsequent* events + join_rule_state_event_id = self.helper.send_state( + self.room_id, + EventTypes.JoinRules, + { + "join_rule": "knock", + }, + tok="nope", + )["event_id"] + join_rule_prev_state_event_ids = self._get_prev_state_events( + join_rule_state_event_id + ) + self.assertIncludes( + set(second_prev_state_events), + set(join_rule_prev_state_event_ids), + exact=True, + ) + + # prev_state_events should always point to the join rule now + third_event_id = self.helper.send(self.room_id, body="test3")["event_id"] + third_prev_state_events = self._get_prev_state_events(third_event_id) + self.assertIncludes( + set(third_prev_state_events), {join_rule_state_event_id}, exact=True + ) + # and non-auth state should also update prev_state_events + name_state_event_id = self.helper.send_state( + self.room_id, + EventTypes.Name, + { + "name": "State DAGs!", + }, + tok="nope", + )["event_id"] + name_prev_state_event_ids = self._get_prev_state_events(name_state_event_id) + self.assertIncludes( + set(name_prev_state_event_ids), {join_rule_state_event_id}, exact=True + ) + fourth_event_id = self.helper.send(self.room_id, body="test4")["event_id"] + fourth_prev_state_events = self._get_prev_state_events(fourth_event_id) + self.assertIncludes( + set(fourth_prev_state_events), {name_state_event_id}, exact=True + ) + + +class MSC4242EventPersistenceStateDagsStoreTestCase(HomeserverTestCase): + servlets = [ + room.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + persistence = hs.get_storage_controllers().persistence + assert persistence is not None + self.persistence = persistence + self.room_id = "!foo:bar" + self.seen_event_ids: set[str] = set() + self.persistence.main_store = Mock(spec=["have_seen_events"]) + self.persistence.main_store.have_seen_events.side_effect = ( + self._have_seen_events + ) + self.rejected_event_ids_and_their_prevs: set[str] = set() + self.persistence.persist_events_store = Mock( + spec=["_get_prevs_before_rejected"] + ) + self.persistence.persist_events_store._get_prevs_before_rejected.side_effect = ( + self._get_prevs_before_rejected + ) + + async def _have_seen_events( + self, room_id: str, event_ids: Iterable[str] + ) -> set[str]: + unknown_events = set(event_ids) + return self.seen_event_ids.intersection(unknown_events) + + async def _get_prevs_before_rejected( + self, event_ids: Iterable[str], include_soft_failed: bool = True + ) -> set[str]: + return self.rejected_event_ids_and_their_prevs + + def _make_event( + self, + id: str, + prev_state_events: list[str], + rejected: bool = False, + ) -> tuple[FrozenEventVMSC4242, EventContext]: + ev = make_event_from_dict( + { + "prev_state_events": prev_state_events, + "content": { + "membership": "join", + }, + "sender": "@unimportant:info", + "state_key": "@unimportant:info", + "type": "m.room.member", + "room_id": self.room_id, + }, + room_version=RoomVersions.MSC4242v12, + ) + assert isinstance(ev, FrozenEventVMSC4242) + ev._event_id = id + ctx = Mock() + ctx.rejected = rejected + return ev, ctx + + def _test( + self, + current_fwds: list[str], + new_events: list[tuple[FrozenEventVMSC4242, EventContext]], + want_new_extrems: set[str], + want_raises: bool = False, + ) -> None: + """ + Tests the logic of _calculate_new_state_dag_extremities. + + Tests that the new extremities calculated as a result of processing current_fwds and new_events + matches want_new_extrems or raises if want_raises is True. + """ + coroutine = self.persistence._calculate_new_state_dag_extremities( + self.room_id, + frozenset(current_fwds), + new_events, + ) + if want_raises: + f = self.get_failure(coroutine, SynapseError) + assert f is not None + return + + new_extrems = set(self.get_success(coroutine)) + self.assertIncludes( + new_extrems, + set(want_new_extrems), + exact=True, + message=f"want_new_extrems={want_new_extrems} got={new_extrems}", + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_simple(self) -> None: + # Simple linear chain + self._test( + current_fwds=[], + new_events=[ + self._make_event("$1", []), + self._make_event("$2", ["$1"]), + self._make_event("$3", ["$2"]), + self._make_event("$4", ["$3"]), + ], + want_new_extrems={"$4"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_fork(self) -> None: + # Simple fork so we end up with two forward extrems + self._test( + current_fwds=[], + new_events=[ + self._make_event("$1", []), + self._make_event("$2", ["$1"]), + self._make_event("$3", ["$2"]), + self._make_event("$4", ["$2"]), + ], + want_new_extrems={"$3", "$4"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_merge(self) -> None: + # Simple fork so we end up with two forward extrems + self._test( + current_fwds=[], + new_events=[ + self._make_event("$1", []), + self._make_event("$2", ["$1"]), + self._make_event("$3", ["$1"]), + self._make_event("$4", ["$2", "$3"]), + ], + want_new_extrems={"$4"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_fork_on_existing(self) -> None: + # Fork where we are adding to older events + self.seen_event_ids = {"$1", "$2", "$3"} + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$3"]), # append to the forward extrem + self._make_event("$5", ["$1"]), # append to the root + ], + want_new_extrems={"$4", "$5"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_merge_on_existing(self) -> None: + # Merge where we are merging to older events + self.seen_event_ids = {"$1", "$2", "$3"} + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$3", "$2"]), + ], + want_new_extrems={"$4"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_merge_on_not_current(self) -> None: + # Merge where we are merging to older events + self.seen_event_ids = {"$1", "$2", "$3"} + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$1", "$2"]), + ], + want_new_extrems={"$3", "$4"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_append_with_rejected(self) -> None: + # rejected events cannot be forward extremities + self.seen_event_ids = {"$1", "$2", "$3"} + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$3"], rejected=True), + ], + want_new_extrems={"$3"}, + ) + + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$3"], rejected=True), + self._make_event("$5", ["$4"], rejected=True), + ], + want_new_extrems={"$3"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_append_with_rejected_in_chain( + self, + ) -> None: + # rejected events cannot be forward extremities, but events that come after them can. + # this shouldn't cause multiple forward extremities. + self.seen_event_ids = {"$1", "$2", "$3"} + self.rejected_event_ids_and_their_prevs = {"$4", "$3"} + self._test( + current_fwds=["$3"], + new_events=[ + self._make_event("$4", ["$3"], rejected=True), + self._make_event("$5", ["$4"]), + ], + want_new_extrems={"$5"}, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_missing_prevs_raises(self) -> None: + self._test( + current_fwds=[], + new_events=[ + self._make_event("$1", []), + self._make_event("$2", ["$1"]), + self._make_event("$3", ["$unknown"]), + self._make_event("$4", ["$3"]), + ], + want_new_extrems={"$4"}, + want_raises=True, + ) + + @override_config({"experimental_features": {"msc4242_enabled": True}}) + def test_calculate_new_state_dag_extremities_complex(self) -> None: + """ + 1 + | \ + 2 4 + | + 3 + + Exists already, then becomes... + + 1______ + | \\ | + 2 4 5R + | | | + 3--7 6R + | \\ / \ + 10R 8 9 + + """ + # Merge where we are merging to older events + self.seen_event_ids = {"$1", "$2", "$3", "$4"} + self.rejected_event_ids_and_their_prevs = {"$1", "$5", "$6", "$3", "$10"} + self._test( + current_fwds=["$3", "$4"], + new_events=[ + self._make_event("$5", ["$1"], rejected=True), + self._make_event("$6", ["$5"], rejected=True), + self._make_event("$7", ["$4", "$3"]), + self._make_event("$8", ["$6", "$7"]), + self._make_event("$9", ["$6"]), + self._make_event("$10", ["$3"], rejected=True), + ], + want_new_extrems={"$8", "$9"}, + ) diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 67a5c31c448..c3462457060 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -232,6 +232,7 @@ async def build( prev_event_ids: list[str], auth_event_ids: list[str] | None, depth: int | None = None, + prev_state_events: list[str] | None = None, ) -> EventBase: built_event = await self._base_builder.build( prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index 934a2fd3071..9258f0d4dc1 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -20,15 +20,16 @@ # import unittest +from collections import namedtuple from typing import Any, Collection, Iterable from parameterized import parameterized from synapse import event_auth -from synapse.api.constants import EventContentFields +from synapse.api.constants import EventContentFields, RejectedReason from synapse.api.errors import AuthError, SynapseError from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions -from synapse.events import EventBase, make_event_from_dict +from synapse.events import EventBase, event_exists_in_state_dag, make_event_from_dict from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import JsonDict, get_domain_from_id @@ -374,6 +375,195 @@ def test_msc2432_alias_event(self) -> None: auth_events, ) + def test_msc4242_state_dag_rules(self) -> None: + """Tests additional rules in place for state DAG rooms. + + 1. m.room.create => if it has any prev_state_events, reject. + 2. Considering the event's prev_state_events: + i. If there are entries which do not belong in the same room, reject. + ii. If there are entries which do not have a state_key, reject. + iii. If there are entries which were themselves rejected under the checks performed on receipt of a PDU, reject. + """ + creator = "@creator:example.com" + room_version = RoomVersions.MSC4242v12 + + create_event = make_event_from_dict( + { + "type": "m.room.create", + "sender": creator, + "state_key": "", + "content": {"creator": creator}, + "prev_events": [], + "prev_state_events": [], + }, + room_version, + ) + create_event_2 = make_event_from_dict( + { + "type": "m.room.create", + "sender": creator, + "state_key": "", + "content": {"creator": creator, "another": "room"}, + "prev_events": [], + "prev_state_events": [], + }, + room_version, + ) + room_id = create_event.room_id + another_room_id = create_event_2.room_id + join_event = make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.member", + "sender": creator, + "state_key": creator, + "content": {"membership": "join"}, + "prev_events": [create_event.event_id], + "prev_state_events": [create_event.event_id], + }, + room_version, + {"calculated_auth_event_ids": [create_event.event_id]}, + ) + event_in_another_room = make_event_from_dict( + { + "room_id": another_room_id, + "type": "m.room.join_rules", + "sender": creator, + "state_key": "", + "content": {"join_rule": "public"}, + "prev_events": [join_event.event_id], + "prev_state_events": [join_event.event_id], + }, + room_version, + {"calculated_auth_event_ids": [create_event.event_id, join_event.event_id]}, + ) + msg_event = make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.message", + "sender": creator, + "content": {"msgtype": "m.text", "body": "I am a message"}, + "prev_events": [join_event.event_id], + "prev_state_events": [join_event.event_id], + }, + room_version, + {"calculated_auth_event_ids": [create_event.event_id, join_event.event_id]}, + ) + rejected_event = make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.name", + "sender": creator, + "state_key": "", + "content": {"name": "REJECTED"}, + "prev_events": [join_event.event_id], + "prev_state_events": [join_event.event_id], + }, + room_version, + {"calculated_auth_event_ids": [create_event.event_id, join_event.event_id]}, + rejected_reason=RejectedReason.AUTH_ERROR, + ) + RejectingTestCase = namedtuple( + "RejectingTestCase", "name events_in_store test_event" + ) + rejecting_test_cases = [ + RejectingTestCase( + name="create event has prev_state_events", + events_in_store=[], + test_event=make_event_from_dict( + { + "type": "m.room.create", + "sender": creator, + "state_key": "", + "content": {"creator": creator}, + "prev_events": [], + "prev_state_events": [create_event.event_id], + }, + room_version, + {}, + ), + ), + RejectingTestCase( + name="prev_state_event belongs in a different room", + events_in_store=[create_event, join_event, event_in_another_room], + test_event=make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.name", + "sender": creator, + "state_key": "", + "content": {"name": "prev_state_event is in another room"}, + "prev_events": [join_event.event_id], + "prev_state_events": [event_in_another_room.event_id], + }, + room_version, + { + "calculated_auth_event_ids": [ + create_event.event_id, + join_event.event_id, + ] + }, + ), + ), + RejectingTestCase( + name="prev_state_event is a message event", + events_in_store=[create_event, join_event, msg_event], + test_event=make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.name", + "sender": creator, + "state_key": "", + "content": {"name": "prev state event is a message"}, + "prev_events": [msg_event.event_id], + "prev_state_events": [msg_event.event_id], + }, + room_version, + { + "calculated_auth_event_ids": [ + create_event.event_id, + join_event.event_id, + ] + }, + ), + ), + RejectingTestCase( + name="prev_state_event was rejected", + events_in_store=[create_event, join_event, rejected_event], + test_event=make_event_from_dict( + { + "room_id": room_id, + "type": "m.room.name", + "sender": creator, + "state_key": "", + "content": {"name": "prev state event was rejected"}, + "prev_events": [join_event.event_id], + "prev_state_events": [rejected_event.event_id], + }, + room_version, + { + "calculated_auth_event_ids": [ + create_event.event_id, + join_event.event_id, + ] + }, + ), + ), + ] + + for test_case in rejecting_test_cases: + event_store = _StubEventSourceStore() + event_store.add_events(test_case.events_in_store) + + with self.assertRaises( + AuthError, msg=f"test case {test_case.name} was not rejected" + ): + get_awaitable_result( + event_auth.check_state_independent_auth_rules( + event_store, test_case.test_event + ) + ) + @parameterized.expand([(RoomVersions.V1, True), (RoomVersions.V6, False)]) def test_notifications( self, room_version: RoomVersion, allow_modification: bool @@ -769,6 +959,105 @@ def create_event(pl_event_content: dict[str, Any]) -> EventBase: with self.assertRaises(SynapseError): event_auth._check_power_levels(event.room_version, event, {}) + def test_event_exists_in_state_dag(self) -> None: + events_that_exist_in_state_dag = [ + { + "type": "m.room.create", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.join_rules", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.power_levels", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.server_acl", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.member", + "state_key": "@alice:somewhere", + "content": {}, + }, + { + "type": "m.room.third_party_invite", + "state_key": "flibble", + "content": {}, + }, + { + "type": "m.room.create", + "state_key": " ", + "content": {}, + }, + { + "type": "m.room.join_rules", + "state_key": " ", + "content": {}, + }, + { + "type": "m.room.power_levels", + "state_key": " ", + "content": {}, + }, + { + "type": "m.room.name", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.member", + "state_key": "", + "content": {}, + }, + { + "type": "m.room.member", + "state_key": "hello_world", + "content": {}, + }, + ] + events_that_dont_exist_in_state_dag = [ + { + "type": "m.room.message", + "content": {}, + }, + { + "type": "m.room.create", + "content": {}, + }, + { + "type": "m.room.join_rules", + "content": {}, + }, + { + "type": "m.room.power_levels", + "content": {}, + }, + ] + + def check_events(events: list[dict], should_exist: bool) -> None: + for ev in events: + base = { + "room_id": TEST_ROOM_ID, + "sender": "@test:test.com", + "signatures": {"test.com": {"ed25519:0": "some9signature"}}, + } + base.update(ev) + event = make_event_from_dict(base, RoomVersions.V10) + got = event_exists_in_state_dag(event) + self.assertEqual( + got, should_exist, f"{ev} should_exist={should_exist} but got {got}" + ) + + check_events(events_that_exist_in_state_dag, should_exist=True) + check_events(events_that_dont_exist_in_state_dag, should_exist=False) + # helpers for making events TEST_DOMAIN = "example.com"