|
12 | 12 | # <https://www.gnu.org/licenses/agpl-3.0.html>. |
13 | 13 | # |
14 | 14 | # |
| 15 | +from http import HTTPStatus |
15 | 16 | from typing import Literal |
16 | 17 |
|
17 | 18 | from twisted.internet.testing import MemoryReactor |
18 | 19 |
|
19 | | -from synapse.api.constants import EventContentFields, EventTypes |
| 20 | +from synapse.api.constants import ( |
| 21 | + EventContentFields, |
| 22 | + EventTypes, |
| 23 | + Membership, |
| 24 | +) |
| 25 | +from synapse.api.room_versions import RoomVersions |
20 | 26 | from synapse.config.server import DEFAULT_ROOM_VERSION |
| 27 | +from synapse.events import make_event_from_dict |
| 28 | +from synapse.module_api import EventBase |
21 | 29 | from synapse.rest import admin, login, room, room_upgrade_rest_servlet |
22 | 30 | from synapse.server import HomeServer |
23 | 31 | from synapse.types import Codes, JsonDict |
24 | 32 | from synapse.util.clock import Clock |
25 | 33 |
|
| 34 | +from tests import unittest |
26 | 35 | from tests.server import FakeChannel |
27 | 36 | from tests.unittest import HomeserverTestCase |
28 | 37 |
|
29 | 38 |
|
30 | 39 | class SpamCheckerTestCase(HomeserverTestCase): |
| 40 | + """Tests for the spam checker module API.""" |
| 41 | + |
31 | 42 | servlets = [ |
32 | 43 | room.register_servlets, |
33 | 44 | admin.register_servlets, |
@@ -284,3 +295,178 @@ async def user_may_send_state_event( |
284 | 295 |
|
285 | 296 | self.assertEqual(channel.code, 403) |
286 | 297 | self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) |
| 298 | + |
| 299 | + |
| 300 | +class FederatedEventSpamCheckMetadataTestCase(unittest.FederatingHomeserverTestCase): |
| 301 | + servlets = [ |
| 302 | + admin.register_servlets, |
| 303 | + login.register_servlets, |
| 304 | + room.register_servlets, |
| 305 | + ] |
| 306 | + |
| 307 | + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: |
| 308 | + super().prepare(reactor, clock, hs) |
| 309 | + self._module_api = hs.get_module_api() |
| 310 | + self._store = hs.get_datastores().main |
| 311 | + self._storage_controllers = hs.get_storage_controllers() |
| 312 | + self._federation_event_handler = hs.get_federation_event_handler() |
| 313 | + self._federation_server = hs.get_federation_server() |
| 314 | + self._state_handler = hs.get_state_handler() |
| 315 | + self._persistence_controller = hs.get_storage_controllers().persistence |
| 316 | + |
| 317 | + # Create a room |
| 318 | + user1_id = self.register_user("user1", "pass") |
| 319 | + user1_tok = self.login(user1_id, "pass") |
| 320 | + self.room_id = self.helper.create_room_as( |
| 321 | + user1_id, |
| 322 | + tok=user1_tok, |
| 323 | + is_public=True, |
| 324 | + room_version=RoomVersions.V10.identifier, |
| 325 | + ) |
| 326 | + |
| 327 | + # Prepare a join for the 'remote' user |
| 328 | + state_map = self.get_success( |
| 329 | + self._storage_controllers.state.get_current_state(self.room_id) |
| 330 | + ) |
| 331 | + forward_extremity_event_ids = self.get_success( |
| 332 | + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) |
| 333 | + ) |
| 334 | + self.remote_user_id = f"@remoteuser:{self.OTHER_SERVER_NAME}" |
| 335 | + self.remote_user_join_event = make_event_from_dict( |
| 336 | + self.add_hashes_and_signatures_from_other_server( |
| 337 | + { |
| 338 | + "room_id": self.room_id, |
| 339 | + "sender": self.remote_user_id, |
| 340 | + "state_key": self.remote_user_id, |
| 341 | + "depth": 1000, |
| 342 | + "origin_server_ts": 1, |
| 343 | + "type": EventTypes.Member, |
| 344 | + "content": {"membership": Membership.JOIN}, |
| 345 | + "auth_events": [ |
| 346 | + state_map[(EventTypes.Create, "")].event_id, |
| 347 | + state_map[(EventTypes.JoinRules, "")].event_id, |
| 348 | + ], |
| 349 | + "prev_events": list(forward_extremity_event_ids), |
| 350 | + } |
| 351 | + ), |
| 352 | + room_version=RoomVersions.V10, |
| 353 | + ) |
| 354 | + |
| 355 | + # Send the join |
| 356 | + self.get_success( |
| 357 | + self._federation_event_handler.on_receive_pdu( |
| 358 | + self.OTHER_SERVER_NAME, self.remote_user_join_event |
| 359 | + ) |
| 360 | + ) |
| 361 | + |
| 362 | + # Check the join made it to the 'local' view of the room |
| 363 | + self.helper.get_event( |
| 364 | + room_id=self.room_id, |
| 365 | + event_id=self.remote_user_join_event.event_id, |
| 366 | + tok=user1_tok, |
| 367 | + expect_code=HTTPStatus.OK, |
| 368 | + ) |
| 369 | + |
| 370 | + def test_federated_events_with_spam_checker_metadata(self) -> None: |
| 371 | + """ |
| 372 | + Simulates receiving spammy and non-spammy events over federation, |
| 373 | + then checks their `spam_checker_spammy` flag is set properly. |
| 374 | + """ |
| 375 | + |
| 376 | + async def check_event_for_spam(event: EventBase) -> Literal["NOT_SPAM"] | Codes: |
| 377 | + if event.type == EventTypes.Message: |
| 378 | + if "ham" not in event.content["body"]: |
| 379 | + return Codes.FORBIDDEN |
| 380 | + return "NOT_SPAM" |
| 381 | + |
| 382 | + # Register a spam checker callback that only allows messages with 'ham' |
| 383 | + self._module_api.register_spam_checker_callbacks( |
| 384 | + check_event_for_spam=check_event_for_spam |
| 385 | + ) |
| 386 | + |
| 387 | + # Prepare a spammy and a non-spammy event. |
| 388 | + forward_extremity_event_ids = self.get_success( |
| 389 | + self._store.get_latest_event_ids_in_room(self.room_id) |
| 390 | + ) |
| 391 | + state_map = self.get_success( |
| 392 | + self._storage_controllers.state.get_current_state(self.room_id) |
| 393 | + ) |
| 394 | + spammy_event = make_event_from_dict( |
| 395 | + self.add_hashes_and_signatures_from_other_server( |
| 396 | + { |
| 397 | + "room_id": self.room_id, |
| 398 | + "sender": self.remote_user_id, |
| 399 | + "depth": 2000, |
| 400 | + "origin_server_ts": 2, |
| 401 | + "type": EventTypes.Message, |
| 402 | + "content": {"body": "this is spam", "msgtype": "m.text"}, |
| 403 | + "auth_events": [ |
| 404 | + state_map[(EventTypes.Create, "")].event_id, |
| 405 | + state_map[(EventTypes.JoinRules, "")].event_id, |
| 406 | + state_map[(EventTypes.Member, self.remote_user_id)].event_id, |
| 407 | + ], |
| 408 | + "prev_events": list(forward_extremity_event_ids), |
| 409 | + } |
| 410 | + ), |
| 411 | + room_version=RoomVersions.V10, |
| 412 | + ) |
| 413 | + non_spammy_event = make_event_from_dict( |
| 414 | + self.add_hashes_and_signatures_from_other_server( |
| 415 | + { |
| 416 | + "room_id": self.room_id, |
| 417 | + "sender": self.remote_user_id, |
| 418 | + "depth": 2000, |
| 419 | + "origin_server_ts": 2, |
| 420 | + "type": EventTypes.Message, |
| 421 | + "content": {"body": "delicious ham", "msgtype": "m.text"}, |
| 422 | + "auth_events": [ |
| 423 | + state_map[(EventTypes.Create, "")].event_id, |
| 424 | + state_map[(EventTypes.JoinRules, "")].event_id, |
| 425 | + state_map[(EventTypes.Member, self.remote_user_id)].event_id, |
| 426 | + ], |
| 427 | + "prev_events": list(forward_extremity_event_ids), |
| 428 | + } |
| 429 | + ), |
| 430 | + room_version=RoomVersions.V10, |
| 431 | + ) |
| 432 | + |
| 433 | + # Receive these events over federation |
| 434 | + # We need to let the federation server have them because it will |
| 435 | + # invoke `_check_sigs_and_hash` which invokes the spam checker. |
| 436 | + self.get_success( |
| 437 | + self._federation_server._handle_received_pdu( |
| 438 | + self.OTHER_SERVER_NAME, spammy_event |
| 439 | + ) |
| 440 | + ) |
| 441 | + self.get_success( |
| 442 | + self._federation_server._handle_received_pdu( |
| 443 | + self.OTHER_SERVER_NAME, non_spammy_event |
| 444 | + ) |
| 445 | + ) |
| 446 | + |
| 447 | + # Retrieve the events from the database |
| 448 | + retrieved_spammy_event = self.get_success( |
| 449 | + self._store.get_event(spammy_event.event_id, allow_rejected=True) |
| 450 | + ) |
| 451 | + retrieved_non_spammy_event = self.get_success( |
| 452 | + self._store.get_event(non_spammy_event.event_id, allow_rejected=True) |
| 453 | + ) |
| 454 | + |
| 455 | + # Assert the spammy flags (and soft-failed flags, for good measure) are set properly |
| 456 | + self.assertTrue( |
| 457 | + retrieved_spammy_event.internal_metadata.spam_checker_spammy, |
| 458 | + "Spammy inbound event should be marked as spam_checker_spammy!", |
| 459 | + ) |
| 460 | + self.assertTrue( |
| 461 | + retrieved_spammy_event.internal_metadata.is_soft_failed(), |
| 462 | + "Spammy inbound event should be soft-failed.", |
| 463 | + ) |
| 464 | + |
| 465 | + self.assertFalse( |
| 466 | + retrieved_non_spammy_event.internal_metadata.spam_checker_spammy, |
| 467 | + "Non-spammy inbound event should not be marked as spam_checker_spammy!", |
| 468 | + ) |
| 469 | + self.assertFalse( |
| 470 | + retrieved_non_spammy_event.internal_metadata.is_soft_failed(), |
| 471 | + "Non-spammy inbound event should not be soft-failed.", |
| 472 | + ) |
0 commit comments