diff --git a/federation/handle.go b/federation/handle.go index 772e3d44..a87d9e1e 100644 --- a/federation/handle.go +++ b/federation/handle.go @@ -59,29 +59,22 @@ func MakeJoinRequestsHandler(s *Server, w http.ResponseWriter, req *http.Request // or dealing with HTTP responses itself. func MakeRespMakeJoin(s *Server, room *ServerRoom, userID string) (resp fclient.RespMakeJoin, err error) { // Generate a join event - proto := gomatrixserverlib.ProtoEvent{ - SenderID: userID, - RoomID: room.RoomID, - Type: "m.room.member", - StateKey: &userID, - PrevEvents: []string{room.Timeline[len(room.Timeline)-1].EventID()}, - Depth: room.Timeline[len(room.Timeline)-1].Depth() + 1, - } - err = proto.SetContent(map[string]interface{}{"membership": spec.Join}) - if err != nil { - err = fmt.Errorf("make_join cannot set membership content: %w", err) - return - } - stateNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(&proto) + proto, err := room.ProtoEventCreator(Event{ + Type: "m.room.member", + StateKey: &userID, + Content: map[string]interface{}{ + "membership": spec.Join, + }, + Sender: userID, + }) if err != nil { - err = fmt.Errorf("make_join cannot calculate auth_events: %w", err) + err = fmt.Errorf("make_join cannot set create proto event: %w", err) return } - proto.AuthEvents = room.AuthEvents(stateNeeded) resp = fclient.RespMakeJoin{ RoomVersion: room.Version, - JoinEvent: proto, + JoinEvent: *proto, } return } @@ -91,29 +84,22 @@ func MakeRespMakeJoin(s *Server, room *ServerRoom, userID string) (resp fclient. // or dealing with HTTP responses itself. func MakeRespMakeKnock(s *Server, room *ServerRoom, userID string) (resp fclient.RespMakeKnock, err error) { // Generate a knock event - proto := gomatrixserverlib.ProtoEvent{ - SenderID: userID, - RoomID: room.RoomID, - Type: "m.room.member", - StateKey: &userID, - PrevEvents: []string{room.Timeline[len(room.Timeline)-1].EventID()}, - Depth: room.Timeline[len(room.Timeline)-1].Depth() + 1, - } - err = proto.SetContent(map[string]interface{}{"membership": spec.Join}) - if err != nil { - err = fmt.Errorf("make_knock cannot set membership content: %w", err) - return - } - stateNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(&proto) + proto, err := room.ProtoEventCreator(Event{ + Type: "m.room.member", + StateKey: &userID, + Content: map[string]interface{}{ + "membership": spec.Join, // XXX this feels wrong? + }, + Sender: userID, + }) if err != nil { - err = fmt.Errorf("make_knock cannot calculate auth_events: %w", err) + err = fmt.Errorf("make_knock cannot set create proto event: %w", err) return } - proto.AuthEvents = room.AuthEvents(stateNeeded) resp = fclient.RespMakeKnock{ RoomVersion: room.Version, - KnockEvent: proto, + KnockEvent: *proto, } return } @@ -173,40 +159,8 @@ func SendJoinRequestsHandler(s *Server, w http.ResponseWriter, req *http.Request return } - // build the state list *before* we insert the new event - var stateEvents []gomatrixserverlib.PDU - room.StateMutex.RLock() - for _, ev := range room.State { - // filter out non-critical memberships if this is a partial-state join - if expectPartialState { - if ev.Type() == "m.room.member" && ev.StateKey() != event.StateKey() { - continue - } - } - stateEvents = append(stateEvents, ev) - } - room.StateMutex.RUnlock() - - authEvents := room.AuthChainForEvents(stateEvents) - - // get servers in room *before* the join event - serversInRoom := []string{s.serverName} - if !omitServersInRoom { - serversInRoom = room.ServersInRoom() - } - - // insert the join event into the room state - room.AddEvent(event) - log.Printf("Received send-join of event %s", event.EventID()) - - // return state and auth chain - b, err := json.Marshal(fclient.RespSendJoin{ - Origin: spec.ServerName(s.serverName), - AuthEvents: gomatrixserverlib.NewEventJSONsFromEvents(authEvents), - StateEvents: gomatrixserverlib.NewEventJSONsFromEvents(stateEvents), - MembersOmitted: expectPartialState, - ServersInRoom: serversInRoom, - }) + resp := room.GenerateSendJoinResponse(s, event, expectPartialState, omitServersInRoom) + b, err := json.Marshal(resp) if err != nil { w.WriteHeader(500) w.Write([]byte("complement: HandleMakeSendJoinRequests send_join cannot marshal RespSendJoin: " + err.Error())) @@ -410,7 +364,7 @@ func HandleEventAuthRequests() func(*Server) { authEvents := room.AuthChainForEvents([]gomatrixserverlib.PDU{event}) resp := fclient.RespEventAuth{ - gomatrixserverlib.NewEventJSONsFromEvents(authEvents), + AuthEvents: gomatrixserverlib.NewEventJSONsFromEvents(authEvents), } respJSON, err := json.Marshal(resp) if err != nil { @@ -590,8 +544,8 @@ func HandleTransactionRequests(pduCallback func(gomatrixserverlib.PDU), eduCallb verImpl, err := gomatrixserverlib.GetRoomVersion(room.Version) if err != nil { log.Printf( - "complement: Transaction '%s': Failed to get room version '%s': %s", - transaction.TransactionID, event.EventID(), err.Error(), + "complement: Transaction '%s': Failed to get room version: %s", + transaction.TransactionID, err.Error(), ) continue } diff --git a/federation/server.go b/federation/server.go index a5b6d46a..eb35e5d2 100644 --- a/federation/server.go +++ b/federation/server.go @@ -303,58 +303,15 @@ func (s *Server) DoFederationRequest( // It does not insert this event into the room however. See ServerRoom.AddEvent for that. func (s *Server) MustCreateEvent(t ct.TestLike, room *ServerRoom, ev Event) gomatrixserverlib.PDU { t.Helper() - content, err := json.Marshal(ev.Content) + proto, err := room.ProtoEventCreator(ev) if err != nil { - ct.Fatalf(t, "MustCreateEvent: failed to marshal event content %s - %+v", err, ev.Content) + ct.Fatalf(t, "MustCreateEvent: failed to create proto event: %v", err) } - var unsigned []byte - if ev.Unsigned != nil { - unsigned, err = json.Marshal(ev.Unsigned) - if err != nil { - ct.Fatalf(t, "MustCreateEvent: failed to marshal event unsigned: %s - %+v", err, ev.Unsigned) - } - } - - var prevEvents interface{} - if ev.PrevEvents != nil { - // We deliberately want to set the prev events. - prevEvents = ev.PrevEvents - } else { - // No other prev events were supplied so we'll just - // use the forward extremities of the room, which is - // the usual behaviour. - prevEvents = room.ForwardExtremities - } - proto := gomatrixserverlib.ProtoEvent{ - SenderID: ev.Sender, - Depth: int64(room.Depth + 1), // depth starts at 1 - Type: ev.Type, - StateKey: ev.StateKey, - Content: content, - RoomID: room.RoomID, - PrevEvents: prevEvents, - Unsigned: unsigned, - AuthEvents: ev.AuthEvents, - Redacts: ev.Redacts, - } - if proto.AuthEvents == nil { - var stateNeeded gomatrixserverlib.StateNeeded - stateNeeded, err = gomatrixserverlib.StateNeededForProtoEvent(&proto) - if err != nil { - ct.Fatalf(t, "MustCreateEvent: failed to work out auth_events : %s", err) - } - proto.AuthEvents = room.AuthEvents(stateNeeded) - } - verImpl, err := gomatrixserverlib.GetRoomVersion(room.Version) + pdu, err := room.EventCreator(s, proto) if err != nil { - ct.Fatalf(t, "MustCreateEvent: invalid room version: %s", err) + ct.Fatalf(t, "MustCreateEvent: failed to create PDU: %v", err) } - eb := verImpl.NewEventBuilderFromProtoEvent(&proto) - signedEvent, err := eb.Build(time.Now(), spec.ServerName(s.serverName), s.KeyID, s.Priv) - if err != nil { - ct.Fatalf(t, "MustCreateEvent: failed to sign event: %s", err) - } - return signedEvent + return pdu } // MustJoinRoom will make the server send a make_join and a send_join to join a room @@ -424,12 +381,8 @@ func (s *Server) MustJoinRoom(t ct.TestLike, deployment FederationDeployment, re if err != nil { ct.Fatalf(t, "MustJoinRoom: send_join failed: %v", err) } - stateEvents := sendJoinResp.StateEvents.UntrustedEvents(roomVer) room := NewServerRoom(roomVer, roomID) - for _, ev := range stateEvents { - room.ReplaceCurrentState(ev) - } - room.AddEvent(joinEvent) + room.PopulateFromSendJoinResponse(joinEvent, sendJoinResp) s.rooms[roomID] = room t.Logf("Server.MustJoinRoom joined room ID %s", roomID) diff --git a/federation/server_room.go b/federation/server_room.go index 22f6cd1d..9beb8414 100644 --- a/federation/server_room.go +++ b/federation/server_room.go @@ -3,9 +3,13 @@ package federation import ( "encoding/json" "fmt" + "log" "sync" + "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/gomatrixserverlib/fclient" + "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/complement/b" "github.com/matrix-org/complement/ct" @@ -33,6 +37,14 @@ type Event struct { // EXPERIMENTAL // ServerRoom represents a room on this test federation server type ServerRoom struct { + // Functions to map Complement events into actual PDUs. + // Most tests don't care about this and can use the default functions, + // but if your MSC or tests fiddle with the raw JSON in some way then these + // function need to be replaced. By replacing these functions, helper functions + // which indirectly create events (e.g joins and leaves) will automatically use + // them and will hence work with your custom code. + ServerRoomImpl + Version gomatrixserverlib.RoomVersion RoomID string State map[string]gomatrixserverlib.PDU @@ -47,7 +59,7 @@ type ServerRoom struct { // NewServerRoom creates an empty room structure with no events func NewServerRoom(roomVer gomatrixserverlib.RoomVersion, roomId string) *ServerRoom { - return &ServerRoom{ + room := &ServerRoom{ RoomID: roomId, Version: roomVer, State: make(map[string]gomatrixserverlib.PDU), @@ -55,6 +67,8 @@ func NewServerRoom(roomVer gomatrixserverlib.RoomVersion, roomId string) *Server waiters: make(map[string][]*helpers.Waiter), waitersMu: &sync.Mutex{}, } + room.ServerRoomImpl = &ServerRoomImplDefault{Room: room} + return room } // AddEvent adds a new event to the timeline, updating current state if it is a state event. @@ -335,3 +349,154 @@ func (r *ServerRoom) EventIDsOrReferences(events []gomatrixserverlib.PDU) (refs } return } + +type ServerRoomImpl interface { + // ProtoEventCreator converts a Complement Event into a gomatrixserverlib proto event, ready to be signed. + // This function is used in /make_x endpoints to create proto events to return to other servers. + // This function is one of two used when creating events, the other being EventCreator. + ProtoEventCreator(ev Event) (*gomatrixserverlib.ProtoEvent, error) + // EventCreator converts a proto event into a signed PDU. + EventCreator(s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) + // PopulateFromSendJoinResponse should replace the state of this ServerRoom with the information contained + // in RespSendJoin and the join event. + PopulateFromSendJoinResponse(joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) + // GenerateSendJoinResponse generates a /send_join response to send back to a server. + GenerateSendJoinResponse(s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin +} + +type ServerRoomImplCustom struct { + ServerRoomImplDefault + ProtoEventCreatorFn func(def ServerRoomImpl, ev Event) (*gomatrixserverlib.ProtoEvent, error) + EventCreatorFn func(def ServerRoomImpl, s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) + PopulateFromSendJoinResponseFn func(def ServerRoomImpl, joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) + GenerateSendJoinResponseFn func(def ServerRoomImpl, s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin +} + +func (i *ServerRoomImplCustom) ProtoEventCreator(ev Event) (*gomatrixserverlib.ProtoEvent, error) { + if i.ProtoEventCreatorFn != nil { + return i.ProtoEventCreatorFn(&i.ServerRoomImplDefault, ev) + } + return i.ServerRoomImplDefault.ProtoEventCreator(ev) +} + +func (i *ServerRoomImplCustom) EventCreator(s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) { + if i.EventCreatorFn != nil { + return i.EventCreatorFn(&i.ServerRoomImplDefault, s, proto) + } + return i.ServerRoomImplDefault.EventCreator(s, proto) +} + +func (i *ServerRoomImplCustom) PopulateFromSendJoinResponse(joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) { + if i.PopulateFromSendJoinResponseFn != nil { + i.PopulateFromSendJoinResponseFn(&i.ServerRoomImplDefault, joinEvent, resp) + return + } + i.ServerRoomImplDefault.PopulateFromSendJoinResponse(joinEvent, resp) +} + +func (i *ServerRoomImplCustom) GenerateSendJoinResponse(s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin { + if i.GenerateSendJoinResponseFn != nil { + return i.GenerateSendJoinResponseFn(&i.ServerRoomImplDefault, s, joinEvent, expectPartialState, omitServersInRoom) + } + return i.ServerRoomImplDefault.GenerateSendJoinResponse(s, joinEvent, expectPartialState, omitServersInRoom) +} + +type ServerRoomImplDefault struct { + Room *ServerRoom +} + +func (i *ServerRoomImplDefault) ProtoEventCreator(ev Event) (*gomatrixserverlib.ProtoEvent, error) { + var prevEvents interface{} + if ev.PrevEvents != nil { + // We deliberately want to set the prev events. + prevEvents = ev.PrevEvents + } else { + // No other prev events were supplied so we'll just + // use the forward extremities of the room, which is + // the usual behaviour. + prevEvents = i.Room.ForwardExtremities + } + proto := gomatrixserverlib.ProtoEvent{ + SenderID: ev.Sender, + Depth: int64(i.Room.Depth + 1), // depth starts at 1 + Type: ev.Type, + StateKey: ev.StateKey, + RoomID: i.Room.RoomID, + PrevEvents: prevEvents, + AuthEvents: ev.AuthEvents, + Redacts: ev.Redacts, + } + if err := proto.SetContent(ev.Content); err != nil { + return nil, fmt.Errorf("EventCreator: failed to marshal event content: %s - %+v", err, ev.Content) + } + if err := proto.SetUnsigned(ev.Content); err != nil { + return nil, fmt.Errorf("EventCreator: failed to marshal event unsigned: %s - %+v", err, ev.Unsigned) + } + if proto.AuthEvents == nil { + var stateNeeded gomatrixserverlib.StateNeeded + stateNeeded, err := gomatrixserverlib.StateNeededForProtoEvent(&proto) + if err != nil { + return nil, fmt.Errorf("EventCreator: failed to work out auth_events : %s", err) + } + proto.AuthEvents = i.Room.AuthEvents(stateNeeded) + } + return &proto, nil +} + +func (i *ServerRoomImplDefault) EventCreator(s *Server, proto *gomatrixserverlib.ProtoEvent) (gomatrixserverlib.PDU, error) { + verImpl, err := gomatrixserverlib.GetRoomVersion(i.Room.Version) + if err != nil { + return nil, fmt.Errorf("EventCreator: invalid room version: %s", err) + } + eb := verImpl.NewEventBuilderFromProtoEvent(proto) + signedEvent, err := eb.Build(time.Now(), spec.ServerName(s.serverName), s.KeyID, s.Priv) + if err != nil { + return nil, fmt.Errorf("EventCreator: failed to sign event: %s", err) + } + return signedEvent, nil +} + +func (i *ServerRoomImplDefault) PopulateFromSendJoinResponse(joinEvent gomatrixserverlib.PDU, resp fclient.RespSendJoin) { + stateEvents := resp.StateEvents.UntrustedEvents(i.Room.Version) + for _, ev := range stateEvents { + i.Room.ReplaceCurrentState(ev) + } + i.Room.AddEvent(joinEvent) +} + +func (i *ServerRoomImplDefault) GenerateSendJoinResponse(s *Server, joinEvent gomatrixserverlib.PDU, expectPartialState, omitServersInRoom bool) fclient.RespSendJoin { + // build the state list *before* we insert the new event + var stateEvents []gomatrixserverlib.PDU + i.Room.StateMutex.RLock() + for _, ev := range i.Room.State { + // filter out non-critical memberships if this is a partial-state join + if expectPartialState { + if ev.Type() == "m.room.member" && ev.StateKey() != joinEvent.StateKey() { + continue + } + } + stateEvents = append(stateEvents, ev) + } + i.Room.StateMutex.RUnlock() + + authEvents := i.Room.AuthChainForEvents(stateEvents) + + // get servers in room *before* the join event + serversInRoom := []string{s.serverName} + if !omitServersInRoom { + serversInRoom = i.Room.ServersInRoom() + } + + // insert the join event into the room state + i.Room.AddEvent(joinEvent) + log.Printf("Received send-join of event %s", joinEvent.EventID()) + + // return state and auth chain + return fclient.RespSendJoin{ + Origin: spec.ServerName(s.serverName), + AuthEvents: gomatrixserverlib.NewEventJSONsFromEvents(authEvents), + StateEvents: gomatrixserverlib.NewEventJSONsFromEvents(stateEvents), + MembersOmitted: expectPartialState, + ServersInRoom: serversInRoom, + } +}