diff --git a/.gitignore b/.gitignore index 7c18b313..fca55899 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,7 @@ client !client/ vendor/ v3/ - +.idea/ # Test binary, build with `go test -c` *.test diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..86102eb2 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,12 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Run simple server", + "type": "go", + "request": "launch", + "mode": "auto", + "program": "${workspaceFolder}/examples/observe/server" + } + ] +} \ No newline at end of file diff --git a/udp/client/conn.go b/udp/client/conn.go index 6e1703c4..694f8f2c 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -206,6 +206,9 @@ type Conn struct { msgID atomic.Uint32 blockwiseSZX blockwise.SZX + localAddr atomic.Pointer[net.IP] + interfaceIndex atomic.Int64 + /* An outstanding interaction is either a CON for which an ACK has not yet been received but is still expected (message layer) or a request @@ -550,6 +553,8 @@ func (cc *Conn) writeMessageAsync(req *pool.Message) error { func (cc *Conn) writeMessage(req *pool.Message) error { req.UpsertType(message.Confirmable) req.UpsertMessageID(cc.GetMessageID()) + cc.upsertControlInformation(req) + if req.Type() != message.Confirmable { return cc.writeMessageAsync(req) } @@ -755,6 +760,7 @@ func (cc *Conn) processResponse(reqType message.Type, reqMessageID int32, w *res w.Message().SetType(message.Acknowledgement) w.Message().SetMessageID(reqMessageID) w.Message().SetToken(nil) + err := cc.addResponseToCache(w.Message()) if err != nil { return fmt.Errorf("cannot cache response: %w", err) @@ -824,7 +830,8 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con }() resp := cc.AcquireMessage(cc.Context()) resp.SetToken(req.Token()) - ifIndex := req.ControlMessage().GetIfIndex() + cc.setControlInformation(req.ControlMessage()) + w := responsewriter.New(resp, cc, req.Options()...) defer func() { cc.ReleaseMessage(w.Message()) @@ -839,7 +846,7 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con // nothing to send return } - upsertInterfaceToMessage(w.Message(), ifIndex) + cc.upsertControlInformation(w.Message()) errW := cc.writeMessageAsync(w.Message()) if errW != nil { cc.closeConnection() @@ -851,13 +858,37 @@ func (cc *Conn) handlePong(w *responsewriter.ResponseWriter[*Conn], r *pool.Mess cc.sendPong(w, r) } -func upsertInterfaceToMessage(m *pool.Message, ifIndex int) { - if ifIndex >= 1 { - cm := coapNet.ControlMessage{ - IfIndex: ifIndex, - } - m.UpsertControlMessage(&cm) +func (cc *Conn) setControlInformation(cm *coapNet.ControlMessage) { + if cm == nil { + cc.interfaceIndex.Store(0) + cc.localAddr.Store(nil) + return + } + + cc.interfaceIndex.Store(int64(cm.GetIfIndex())) + if len(cm.Dst) == 0 || cm.Dst.IsMulticast() { + cc.localAddr.Store(nil) + return } + + dst := make(net.IP, len(cm.Dst)) + copy(dst, cm.Dst) + cc.localAddr.Store(&dst) +} + +func (cc *Conn) upsertControlInformation(msg *pool.Message) { + ifIndex := int(cc.interfaceIndex.Load()) + localAddrPtr := cc.localAddr.Load() + if ifIndex < 1 && localAddrPtr == nil { + return + } + + var localAddr net.IP + if localAddrPtr != nil { + localAddr = *localAddrPtr + } + + msg.UpsertControlMessage(&coapNet.ControlMessage{IfIndex: ifIndex, Src: localAddr}) } func (cc *Conn) handleSpecialMessages(r *pool.Message) bool { @@ -872,7 +903,6 @@ func (cc *Conn) handleSpecialMessages(r *pool.Message) bool { elem.ReleaseMessage(cc) resp := cc.AcquireMessage(cc.Context()) resp.SetToken(r.Token()) - upsertInterfaceToMessage(resp, r.ControlMessage().GetIfIndex()) w := responsewriter.New(resp, cc, r.Options()...) defer func() { cc.ReleaseMessage(w.Message()) diff --git a/udp/client/controlmessage_test.go b/udp/client/controlmessage_test.go new file mode 100644 index 00000000..3f0d0745 --- /dev/null +++ b/udp/client/controlmessage_test.go @@ -0,0 +1,49 @@ +package client + +import ( + "net" + "testing" + + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/stretchr/testify/require" +) + +func TestSetControlInformationNil(t *testing.T) { + var cc Conn + addr := net.ParseIP("192.0.2.1").To4() + cc.localAddr.Store(&addr) + cc.interfaceIndex.Store(7) + + cc.setControlInformation(nil) + + require.Equal(t, int64(0), cc.interfaceIndex.Load()) + require.Nil(t, cc.localAddr.Load()) +} + +func TestSetControlInformationUnicastCopiesAddress(t *testing.T) { + var cc Conn + dst := net.ParseIP("2001:db8::1") + cm := &coapNet.ControlMessage{Dst: append(net.IP(nil), dst...), IfIndex: 9} + + cc.setControlInformation(cm) + + require.Equal(t, int64(9), cc.interfaceIndex.Load()) + stored := cc.localAddr.Load() + require.NotNil(t, stored) + expected := append(net.IP(nil), dst...) + require.True(t, stored.Equal(expected)) + + cm.Dst[0] ^= 0xff + require.True(t, stored.Equal(expected)) +} + +func TestSetControlInformationMulticastClearsAddress(t *testing.T) { + var cc Conn + addr := net.ParseIP("192.0.2.1").To4() + cc.localAddr.Store(&addr) + + cc.setControlInformation(&coapNet.ControlMessage{Dst: net.ParseIP("ff02::1"), IfIndex: 11}) + + require.Equal(t, int64(11), cc.interfaceIndex.Load()) + require.Nil(t, cc.localAddr.Load()) +} diff --git a/udp/server/server.go b/udp/server/server.go index 1ec8100a..f1871cd1 100644 --- a/udp/server/server.go +++ b/udp/server/server.go @@ -159,7 +159,19 @@ func (s *Server) Serve(l *coapNet.UDPConn) error { } } buf = buf[:n] - cc, err := s.getConn(l, raddr, true) + + // UDPConn.LocalAddr() only takes into account the address it is bound to. + // In the case of a wildcard address, the actual destination address is in the control message. + // On server-initiated exchanges, listener's LocalAddr can be used as the client has no assumptions of the source. + laddr, err := s.getListenerLocalAddr(l) + if err != nil { + return err + } + if cm != nil && len(cm.Dst) > 0 && !cm.Dst.IsMulticast() { + laddr.IP = cm.Dst + } + + cc, err := s.getConn(l, raddr, laddr, true) if err != nil { s.cfg.Errors(fmt.Errorf("%v: cannot get client connection: %w", raddr, err)) continue @@ -178,6 +190,15 @@ func (s *Server) getListener() *coapNet.UDPConn { return s.listen } +func (s *Server) getListenerLocalAddr(l *coapNet.UDPConn) (*net.UDPAddr, error) { + localAddr, ok := l.LocalAddr().(*net.UDPAddr) + if !ok || localAddr == nil { + return nil, fmt.Errorf("unexpected listener local addr type: %T", l.LocalAddr()) + } + laddrVal := *localAddr + return &laddrVal, nil +} + // Stop stops server without wait of ends Serve function. func (s *Server) Stop() { s.cancel() @@ -254,10 +275,21 @@ func getClose(cc *client.Conn) func() { return closeFn } -func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) (cc *client.Conn, created bool) { +func getConnKey(raddr *net.UDPAddr, laddr *net.UDPAddr) string { + normalizedLocalAddr := *laddr + if len(normalizedLocalAddr.IP) > 0 && normalizedLocalAddr.IP.IsMulticast() { + // Multicast destination address does not identify a unique server-side source address. + // Normalize it to avoid creating one conn key per multicast group. + normalizedLocalAddr.IP = nil + normalizedLocalAddr.Zone = "" + } + return raddr.String() + "-" + normalizedLocalAddr.String() +} + +func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr, laddr *net.UDPAddr) (cc *client.Conn, created bool) { s.connsMutex.Lock() defer s.connsMutex.Unlock() - key := raddr.String() + key := getConnKey(raddr, laddr) cc = s.conns[key] if cc != nil { @@ -345,8 +377,19 @@ func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) ( return cc, true } -func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, firstTime bool) (*client.Conn, error) { - cc, created := s.getOrCreateConn(l, raddr) +func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, laddr *net.UDPAddr, firstTime bool) (*client.Conn, error) { + if raddr == nil { + return nil, errors.New("invalid remote address") + } + if laddr == nil { + var err error + laddr, err = s.getListenerLocalAddr(l) + if err != nil { + return nil, err + } + } + + cc, created := s.getOrCreateConn(l, raddr, laddr) if created { if s.cfg.OnNewConn != nil { s.cfg.OnNewConn(cc) @@ -367,18 +410,30 @@ func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, firstTime bool) closeFn() } if firstTime { - return s.getConn(l, raddr, false) + return s.getConn(l, raddr, laddr, false) } return nil, errors.New("connection is closed") } return cc, nil } -func (s *Server) NewConn(addr *net.UDPAddr) (*client.Conn, error) { +// NewConn creates or gets a connection for the provided remote address. +// +// Optional laddr may be used to pin a concrete local address when the listener is bound to a wildcard address. +// If laddr is omitted or nil, listener's local address is used. +func (s *Server) NewConn(addr *net.UDPAddr, laddr ...*net.UDPAddr) (*client.Conn, error) { + if len(laddr) > 1 { + return nil, fmt.Errorf("invalid number of local addresses: %d", len(laddr)) + } + var localAddr *net.UDPAddr + if len(laddr) == 1 { + localAddr = laddr[0] + } + l := s.getListener() if l == nil { // server is not started/stopped return nil, errors.New("server is not running") } - return s.getConn(l, addr, true) + return s.getConn(l, addr, localAddr, true) } diff --git a/udp/server/server_key_test.go b/udp/server/server_key_test.go new file mode 100644 index 00000000..1565aa93 --- /dev/null +++ b/udp/server/server_key_test.go @@ -0,0 +1,26 @@ +package server + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetConnKeyIgnoresMulticastLocalAddress(t *testing.T) { + raddr := &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 56830} + mcastV6 := &net.UDPAddr{IP: net.ParseIP("ff02::fd"), Port: 5683} + mcastV4 := &net.UDPAddr{IP: net.ParseIP("224.0.1.187"), Port: 5683} + normalized := &net.UDPAddr{Port: 5683} + + require.Equal(t, getConnKey(raddr, mcastV6), getConnKey(raddr, normalized)) + require.Equal(t, getConnKey(raddr, mcastV4), getConnKey(raddr, normalized)) +} + +func TestGetConnKeyKeepsUnicastLocalAddressDistinct(t *testing.T) { + raddr := &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 56830} + laddrA := &net.UDPAddr{IP: net.ParseIP("2001:db8::10"), Port: 5683} + laddrB := &net.UDPAddr{IP: net.ParseIP("2001:db8::11"), Port: 5683} + + require.NotEqual(t, getConnKey(raddr, laddrA), getConnKey(raddr, laddrB)) +} diff --git a/udp/server_test.go b/udp/server_test.go index 119adf80..9c6799a3 100644 --- a/udp/server_test.go +++ b/udp/server_test.go @@ -467,6 +467,9 @@ func TestServerNewClient(t *testing.T) { time.Sleep(time.Second) + _, err = s1.NewConn(nil) + require.ErrorContains(t, err, "invalid remote address") + cc, err := s1.NewConn(peer) require.NoError(t, err) @@ -478,7 +481,7 @@ func TestServerNewClient(t *testing.T) { require.NoError(t, err) // repeat ping - new client should be created - cc, err = s1.NewConn(peer) + cc, err = s1.NewConn(peer, nil) require.NoError(t, err) err = cc.Ping(ctx) require.NoError(t, err) @@ -626,7 +629,7 @@ func TestServerReconnectNewClient(t *testing.T) { } // new client - cc, err = s1.NewConn(peer) + cc, err = s1.NewConn(peer, nil) require.NoError(t, err) ctx, cancel = context.WithTimeout(context.Background(), time.Second*1) defer cancel()