From 3785a053958b5c6af85c26a18e59f745fadf15fe Mon Sep 17 00:00:00 2001 From: Theodor Rauch Date: Mon, 20 Oct 2025 12:33:35 +0200 Subject: [PATCH 01/16] fix: on binding the server to [::] the response happened to come from different addresses than sent to. --- net/connUDP.go | 7 +++++++ udp/client/conn.go | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/net/connUDP.go b/net/connUDP.go index 4d2b9521..60eca43a 100644 --- a/net/connUDP.go +++ b/net/connUDP.go @@ -57,6 +57,13 @@ func (c *ControlMessage) GetIfIndex() int { return c.IfIndex } +func (c *ControlMessage) FlipSrcDst() { + if c == nil { + return + } + c.Dst, c.Src = c.Src, c.Dst +} + type packetConn interface { SetWriteDeadline(t time.Time) error WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error) diff --git a/udp/client/conn.go b/udp/client/conn.go index 6e1703c4..cecb7590 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -824,6 +824,10 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con }() resp := cc.AcquireMessage(cc.Context()) resp.SetToken(req.Token()) + + resp.SetControlMessage(req.ControlMessage()) + resp.ControlMessage().FlipSrcDst() + ifIndex := req.ControlMessage().GetIfIndex() w := responsewriter.New(resp, cc, req.Options()...) defer func() { From f8782d9ba474bdffdc9e377ba4d454cafe3598cd Mon Sep 17 00:00:00 2001 From: Theodor Rauch Date: Mon, 2 Mar 2026 14:14:48 +0100 Subject: [PATCH 02/16] fix: add .idea to gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 7c18b313..fca55899 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,7 @@ client !client/ vendor/ v3/ - +.idea/ # Test binary, build with `go test -c` *.test From b24bb254c6ceda5078422225bf273b9cecdf41a4 Mon Sep 17 00:00:00 2001 From: Theodor Rauch Date: Mon, 9 Mar 2026 09:54:29 +0100 Subject: [PATCH 03/16] feat: store source addr in conn --- net/connUDP.go | 7 - udp/client/conn.go | 31 +- udp/server/server.go | 776 ++++++++++++++++++++++--------------------- 3 files changed, 412 insertions(+), 402 deletions(-) diff --git a/net/connUDP.go b/net/connUDP.go index 60eca43a..4d2b9521 100644 --- a/net/connUDP.go +++ b/net/connUDP.go @@ -57,13 +57,6 @@ func (c *ControlMessage) GetIfIndex() int { return c.IfIndex } -func (c *ControlMessage) FlipSrcDst() { - if c == nil { - return - } - c.Dst, c.Src = c.Src, c.Dst -} - type packetConn interface { SetWriteDeadline(t time.Time) error WriteTo(b []byte, cm *ControlMessage, dst net.Addr) (n int, err error) diff --git a/udp/client/conn.go b/udp/client/conn.go index cecb7590..0cd228f3 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -206,6 +206,9 @@ type Conn struct { msgID atomic.Uint32 blockwiseSZX blockwise.SZX + localAddr atomic.Pointer[net.IP] + interfaceIndex atomic.Int64 + /* An outstanding interaction is either a CON for which an ACK has not yet been received but is still expected (message layer) or a request @@ -550,6 +553,8 @@ func (cc *Conn) writeMessageAsync(req *pool.Message) error { func (cc *Conn) writeMessage(req *pool.Message) error { req.UpsertType(message.Confirmable) req.UpsertMessageID(cc.GetMessageID()) + cc.upsertControlInformation(req) + if req.Type() != message.Confirmable { return cc.writeMessageAsync(req) } @@ -755,6 +760,7 @@ func (cc *Conn) processResponse(reqType message.Type, reqMessageID int32, w *res w.Message().SetType(message.Acknowledgement) w.Message().SetMessageID(reqMessageID) w.Message().SetToken(nil) + err := cc.addResponseToCache(w.Message()) if err != nil { return fmt.Errorf("cannot cache response: %w", err) @@ -825,10 +831,9 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con resp := cc.AcquireMessage(cc.Context()) resp.SetToken(req.Token()) - resp.SetControlMessage(req.ControlMessage()) - resp.ControlMessage().FlipSrcDst() + cc.interfaceIndex.Store(int64(req.ControlMessage().GetIfIndex())) + cc.localAddr.Store(&req.ControlMessage().Src) - ifIndex := req.ControlMessage().GetIfIndex() w := responsewriter.New(resp, cc, req.Options()...) defer func() { cc.ReleaseMessage(w.Message()) @@ -843,7 +848,7 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con // nothing to send return } - upsertInterfaceToMessage(w.Message(), ifIndex) + cc.upsertControlInformation(w.Message()) errW := cc.writeMessageAsync(w.Message()) if errW != nil { cc.closeConnection() @@ -855,13 +860,18 @@ func (cc *Conn) handlePong(w *responsewriter.ResponseWriter[*Conn], r *pool.Mess cc.sendPong(w, r) } -func upsertInterfaceToMessage(m *pool.Message, ifIndex int) { - if ifIndex >= 1 { - cm := coapNet.ControlMessage{ - IfIndex: ifIndex, - } - m.UpsertControlMessage(&cm) +func (cc *Conn) upsertControlInformation(msg *pool.Message) { + + cm := coapNet.ControlMessage{} + if cc.interfaceIndex.Load() >= 0 { + cm.IfIndex = int(cc.interfaceIndex.Load()) + } + + if cc.localAddr.Load() != nil { + cm.Src = *cc.localAddr.Load() } + msg.UpsertControlMessage(&cm) + } func (cc *Conn) handleSpecialMessages(r *pool.Message) bool { @@ -876,7 +886,6 @@ func (cc *Conn) handleSpecialMessages(r *pool.Message) bool { elem.ReleaseMessage(cc) resp := cc.AcquireMessage(cc.Context()) resp.SetToken(r.Token()) - upsertInterfaceToMessage(resp, r.ControlMessage().GetIfIndex()) w := responsewriter.New(resp, cc, r.Options()...) defer func() { cc.ReleaseMessage(w.Message()) diff --git a/udp/server/server.go b/udp/server/server.go index 1ec8100a..abb5d11c 100644 --- a/udp/server/server.go +++ b/udp/server/server.go @@ -1,384 +1,392 @@ -package server - -import ( - "context" - "errors" - "fmt" - "net" - "sync" - "time" - - "github.com/plgd-dev/go-coap/v3/message" - "github.com/plgd-dev/go-coap/v3/message/pool" - coapNet "github.com/plgd-dev/go-coap/v3/net" - "github.com/plgd-dev/go-coap/v3/net/blockwise" - "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" - "github.com/plgd-dev/go-coap/v3/net/responsewriter" - coapSync "github.com/plgd-dev/go-coap/v3/pkg/sync" - "github.com/plgd-dev/go-coap/v3/udp/client" -) - -type Server struct { - doneCtx context.Context - ctx context.Context - multicastRequests *client.RequestsMap - multicastHandler *coapSync.Map[uint64, HandlerFunc] - serverStartedChan chan struct{} - doneCancel context.CancelFunc - cancel context.CancelFunc - - connsMutex sync.Mutex - conns map[string]*client.Conn - - listenMutex sync.Mutex - listen *coapNet.UDPConn - - cfg *Config -} - -// A Option sets options such as credentials, codec and keepalive parameters, etc. -type Option interface { - UDPServerApply(cfg *Config) -} - -func New(opt ...Option) *Server { - cfg := DefaultConfig - for _, o := range opt { - o.UDPServerApply(&cfg) - } - - if cfg.Errors == nil { - cfg.Errors = func(error) { - // default no-op - } - } - - if cfg.GetMID == nil { - cfg.GetMID = message.GetMID - } - - if cfg.GetToken == nil { - cfg.GetToken = message.GetToken - } - - if cfg.CreateInactivityMonitor == nil { - cfg.CreateInactivityMonitor = func() client.InactivityMonitor { - return inactivity.NewNilMonitor[*client.Conn]() - } - } - if cfg.MessagePool == nil { - cfg.MessagePool = pool.New(0, 0) - } - - ctx, cancel := context.WithCancel(cfg.Ctx) - serverStartedChan := make(chan struct{}) - - doneCtx, doneCancel := context.WithCancel(context.Background()) - errorsFunc := cfg.Errors - cfg.Errors = func(err error) { - if coapNet.IsCancelOrCloseError(err) { - // this error was produced by cancellation context or closing connection. - return - } - errorsFunc(fmt.Errorf("udp: %w", err)) - } - return &Server{ - ctx: ctx, - cancel: cancel, - multicastHandler: coapSync.NewMap[uint64, HandlerFunc](), - multicastRequests: coapSync.NewMap[uint64, *pool.Message](), - serverStartedChan: serverStartedChan, - doneCtx: doneCtx, - doneCancel: doneCancel, - conns: make(map[string]*client.Conn), - - cfg: &cfg, - } -} - -func (s *Server) checkAndSetListener(l *coapNet.UDPConn) error { - s.listenMutex.Lock() - defer s.listenMutex.Unlock() - if s.listen != nil { - return fmt.Errorf("server already serve: %v", s.listen.LocalAddr().String()) - } - s.listen = l - close(s.serverStartedChan) - return nil -} - -func (s *Server) closeConnection(cc *client.Conn) { - if err := cc.Close(); err != nil { - s.cfg.Errors(fmt.Errorf("cannot close connection: %w", err)) - } -} - -func (s *Server) Serve(l *coapNet.UDPConn) error { - if s.cfg.BlockwiseSZX > blockwise.SZX1024 { - return errors.New("invalid blockwiseSZX") - } - - err := s.checkAndSetListener(l) - if err != nil { - return err - } - - defer func() { - s.closeSessions() - s.doneCancel() - s.listenMutex.Lock() - defer s.listenMutex.Unlock() - s.listen = nil - s.serverStartedChan = make(chan struct{}, 1) - }() - - m := make([]byte, s.cfg.MaxMessageSize) - var wg sync.WaitGroup - - s.cfg.PeriodicRunner(func(now time.Time) bool { - s.handleInactivityMonitors(now) - return s.ctx.Err() == nil - }) - - for { - buf := m - var raddr *net.UDPAddr - var cm *coapNet.ControlMessage - n, err := l.ReadWithOptions(buf, coapNet.WithContext(s.ctx), coapNet.WithGetControlMessage(&cm), coapNet.WithGetRemoteAddr(&raddr)) - if err != nil { - wg.Wait() - - select { - case <-s.ctx.Done(): - return nil - default: - if coapNet.IsCancelOrCloseError(err) { - return nil - } - return err - } - } - buf = buf[:n] - cc, err := s.getConn(l, raddr, true) - if err != nil { - s.cfg.Errors(fmt.Errorf("%v: cannot get client connection: %w", raddr, err)) - continue - } - err = cc.Process(cm, buf) - if err != nil { - s.closeConnection(cc) - s.cfg.Errors(fmt.Errorf("%v: cannot process packet: %w", cc.RemoteAddr(), err)) - } - } -} - -func (s *Server) getListener() *coapNet.UDPConn { - s.listenMutex.Lock() - defer s.listenMutex.Unlock() - return s.listen -} - -// Stop stops server without wait of ends Serve function. -func (s *Server) Stop() { - s.cancel() - l := s.getListener() - if l != nil { - if errC := l.Close(); errC != nil { - s.cfg.Errors(fmt.Errorf("cannot close listener: %w", errC)) - } - } - s.closeSessions() -} - -func (s *Server) closeSessions() { - s.connsMutex.Lock() - conns := s.conns - s.conns = make(map[string]*client.Conn) - s.connsMutex.Unlock() - for _, cc := range conns { - s.closeConnection(cc) - if closeFn := getClose(cc); closeFn != nil { - closeFn() - } - } -} - -func (s *Server) conn() *coapNet.UDPConn { - s.listenMutex.Lock() - serverStartedChan := s.serverStartedChan - s.listenMutex.Unlock() - select { - case <-serverStartedChan: - case <-s.ctx.Done(): - } - s.listenMutex.Lock() - defer s.listenMutex.Unlock() - return s.listen -} - -const closeKey = "gocoapCloseConnection" - -func (s *Server) getConns() []*client.Conn { - s.connsMutex.Lock() - defer s.connsMutex.Unlock() - conns := make([]*client.Conn, 0, 32) - for _, c := range s.conns { - conns = append(conns, c) - } - return conns -} - -func (s *Server) handleInactivityMonitors(now time.Time) { - for _, cc := range s.getConns() { - select { - case <-cc.Context().Done(): - if closeFn := getClose(cc); closeFn != nil { - closeFn() - } - continue - default: - cc.CheckExpirations(now) - } - } -} - -func getClose(cc *client.Conn) func() { - v := cc.Context().Value(closeKey) - if v == nil { - return nil - } - closeFn, ok := v.(func()) - if !ok { - panic(fmt.Errorf("invalid type(%T) of context value for key %s", v, closeKey)) - } - return closeFn -} - -func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) (cc *client.Conn, created bool) { - s.connsMutex.Lock() - defer s.connsMutex.Unlock() - key := raddr.String() - cc = s.conns[key] - - if cc != nil { - return cc, false - } - - createBlockWise := func(*client.Conn) *blockwise.BlockWise[*client.Conn] { - return nil - } - if s.cfg.BlockwiseEnable { - createBlockWise = func(cc *client.Conn) *blockwise.BlockWise[*client.Conn] { - v := cc - return blockwise.New( - v, - s.cfg.BlockwiseTransferTimeout, - s.cfg.Errors, - func(token message.Token) (*pool.Message, bool) { - msg, ok := v.GetObservationRequest(token) - if ok { - return msg, ok - } - return s.multicastRequests.LoadWithFunc(token.Hash(), func(m *pool.Message) *pool.Message { - msg := v.AcquireMessage(m.Context()) - msg.ResetOptionsTo(m.Options()) - msg.SetCode(m.Code()) - msg.SetToken(m.Token()) - msg.SetMessageID(m.MessageID()) - return msg - }) - }) - } - } - session := NewSession( - s.ctx, - s.doneCtx, - udpConn, - raddr, - s.cfg.MaxMessageSize, - s.cfg.MTU, - false, - ) - monitor := s.cfg.CreateInactivityMonitor() - cfg := client.DefaultConfig - cfg.TransmissionNStart = s.cfg.TransmissionNStart - cfg.TransmissionAcknowledgeTimeout = s.cfg.TransmissionAcknowledgeTimeout - cfg.TransmissionMaxRetransmit = s.cfg.TransmissionMaxRetransmit - cfg.Handler = func(w *responsewriter.ResponseWriter[*client.Conn], r *pool.Message) { - h, ok := s.multicastHandler.Load(r.Token().Hash()) - if ok { - h(w, r) - return - } - s.cfg.Handler(w, r) - } - cfg.BlockwiseSZX = s.cfg.BlockwiseSZX - cfg.Errors = s.cfg.Errors - cfg.GetMID = s.cfg.GetMID - cfg.GetToken = s.cfg.GetToken - cfg.MessagePool = s.cfg.MessagePool - cfg.ProcessReceivedMessage = s.cfg.ProcessReceivedMessage - cfg.ReceivedMessageQueueSize = s.cfg.ReceivedMessageQueueSize - - requestMonitor := s.cfg.RequestMonitor - cc = client.NewConnWithOpts( - session, - &cfg, - client.WithInactivityMonitor(monitor), - client.WithRequestMonitor(requestMonitor), - client.WithBlockWise(createBlockWise), - ) - cc.SetContextValue(closeKey, func() { - if err := session.Close(); err != nil { - s.cfg.Errors(fmt.Errorf("cannot close session: %w", err)) - } - session.shutdown() - }) - cc.AddOnClose(func() { - s.connsMutex.Lock() - defer s.connsMutex.Unlock() - if cc == s.conns[key] { - delete(s.conns, key) - } - }) - s.conns[key] = cc - return cc, true -} - -func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, firstTime bool) (*client.Conn, error) { - cc, created := s.getOrCreateConn(l, raddr) - if created { - if s.cfg.OnNewConn != nil { - s.cfg.OnNewConn(cc) - } - } else { - // check if client is not expired now + 10ms - if so, close it - // 10ms - The expected maximum time taken by cc.CheckExpirations and cc.InactivityMonitor().Notify() - cc.CheckExpirations(time.Now().Add(10 * time.Millisecond)) - if cc.Context().Err() == nil { - // if client is not closed, extend expiration time - cc.InactivityMonitor().Notify() - } - } - - if cc.Context().Err() != nil { - // connection is closed so we need to create new one - if closeFn := getClose(cc); closeFn != nil { - closeFn() - } - if firstTime { - return s.getConn(l, raddr, false) - } - return nil, errors.New("connection is closed") - } - return cc, nil -} - -func (s *Server) NewConn(addr *net.UDPAddr) (*client.Conn, error) { - l := s.getListener() - if l == nil { - // server is not started/stopped - return nil, errors.New("server is not running") - } - return s.getConn(l, addr, true) -} +package server + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/net/blockwise" + "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + coapSync "github.com/plgd-dev/go-coap/v3/pkg/sync" + "github.com/plgd-dev/go-coap/v3/udp/client" +) + +type Server struct { + doneCtx context.Context + ctx context.Context + multicastRequests *client.RequestsMap + multicastHandler *coapSync.Map[uint64, HandlerFunc] + serverStartedChan chan struct{} + doneCancel context.CancelFunc + cancel context.CancelFunc + + connsMutex sync.Mutex + conns map[string]*client.Conn + + listenMutex sync.Mutex + listen *coapNet.UDPConn + + cfg *Config +} + +// A Option sets options such as credentials, codec and keepalive parameters, etc. +type Option interface { + UDPServerApply(cfg *Config) +} + +func New(opt ...Option) *Server { + cfg := DefaultConfig + for _, o := range opt { + o.UDPServerApply(&cfg) + } + + if cfg.Errors == nil { + cfg.Errors = func(error) { + // default no-op + } + } + + if cfg.GetMID == nil { + cfg.GetMID = message.GetMID + } + + if cfg.GetToken == nil { + cfg.GetToken = message.GetToken + } + + if cfg.CreateInactivityMonitor == nil { + cfg.CreateInactivityMonitor = func() client.InactivityMonitor { + return inactivity.NewNilMonitor[*client.Conn]() + } + } + if cfg.MessagePool == nil { + cfg.MessagePool = pool.New(0, 0) + } + + ctx, cancel := context.WithCancel(cfg.Ctx) + serverStartedChan := make(chan struct{}) + + doneCtx, doneCancel := context.WithCancel(context.Background()) + errorsFunc := cfg.Errors + cfg.Errors = func(err error) { + if coapNet.IsCancelOrCloseError(err) { + // this error was produced by cancellation context or closing connection. + return + } + errorsFunc(fmt.Errorf("udp: %w", err)) + } + return &Server{ + ctx: ctx, + cancel: cancel, + multicastHandler: coapSync.NewMap[uint64, HandlerFunc](), + multicastRequests: coapSync.NewMap[uint64, *pool.Message](), + serverStartedChan: serverStartedChan, + doneCtx: doneCtx, + doneCancel: doneCancel, + conns: make(map[string]*client.Conn), + + cfg: &cfg, + } +} + +func (s *Server) checkAndSetListener(l *coapNet.UDPConn) error { + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + if s.listen != nil { + return fmt.Errorf("server already serve: %v", s.listen.LocalAddr().String()) + } + s.listen = l + close(s.serverStartedChan) + return nil +} + +func (s *Server) closeConnection(cc *client.Conn) { + if err := cc.Close(); err != nil { + s.cfg.Errors(fmt.Errorf("cannot close connection: %w", err)) + } +} + +func (s *Server) Serve(l *coapNet.UDPConn) error { + if s.cfg.BlockwiseSZX > blockwise.SZX1024 { + return errors.New("invalid blockwiseSZX") + } + + err := s.checkAndSetListener(l) + if err != nil { + return err + } + + defer func() { + s.closeSessions() + s.doneCancel() + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + s.listen = nil + s.serverStartedChan = make(chan struct{}, 1) + }() + + m := make([]byte, s.cfg.MaxMessageSize) + var wg sync.WaitGroup + + s.cfg.PeriodicRunner(func(now time.Time) bool { + s.handleInactivityMonitors(now) + return s.ctx.Err() == nil + }) + + for { + buf := m + var raddr *net.UDPAddr + var cm *coapNet.ControlMessage + n, err := l.ReadWithOptions(buf, coapNet.WithContext(s.ctx), coapNet.WithGetControlMessage(&cm), coapNet.WithGetRemoteAddr(&raddr)) + if err != nil { + wg.Wait() + + select { + case <-s.ctx.Done(): + return nil + default: + if coapNet.IsCancelOrCloseError(err) { + return nil + } + return err + } + } + buf = buf[:n] + + // UDPConn.LocalAddr() only takes into account the address it is bound to. + // In the case of a wildcard address, the actual destination address is in the control message. + laddr := l.LocalAddr().(*net.UDPAddr) + if cm != nil && cm.Dst != nil { + laddr.IP = cm.Dst + } + + cc, err := s.getConn(l, raddr, laddr, true) + if err != nil { + s.cfg.Errors(fmt.Errorf("%v: cannot get client connection: %w", raddr, err)) + continue + } + err = cc.Process(cm, buf) + if err != nil { + s.closeConnection(cc) + s.cfg.Errors(fmt.Errorf("%v: cannot process packet: %w", cc.RemoteAddr(), err)) + } + } +} + +func (s *Server) getListener() *coapNet.UDPConn { + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + return s.listen +} + +// Stop stops server without wait of ends Serve function. +func (s *Server) Stop() { + s.cancel() + l := s.getListener() + if l != nil { + if errC := l.Close(); errC != nil { + s.cfg.Errors(fmt.Errorf("cannot close listener: %w", errC)) + } + } + s.closeSessions() +} + +func (s *Server) closeSessions() { + s.connsMutex.Lock() + conns := s.conns + s.conns = make(map[string]*client.Conn) + s.connsMutex.Unlock() + for _, cc := range conns { + s.closeConnection(cc) + if closeFn := getClose(cc); closeFn != nil { + closeFn() + } + } +} + +func (s *Server) conn() *coapNet.UDPConn { + s.listenMutex.Lock() + serverStartedChan := s.serverStartedChan + s.listenMutex.Unlock() + select { + case <-serverStartedChan: + case <-s.ctx.Done(): + } + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + return s.listen +} + +const closeKey = "gocoapCloseConnection" + +func (s *Server) getConns() []*client.Conn { + s.connsMutex.Lock() + defer s.connsMutex.Unlock() + conns := make([]*client.Conn, 0, 32) + for _, c := range s.conns { + conns = append(conns, c) + } + return conns +} + +func (s *Server) handleInactivityMonitors(now time.Time) { + for _, cc := range s.getConns() { + select { + case <-cc.Context().Done(): + if closeFn := getClose(cc); closeFn != nil { + closeFn() + } + continue + default: + cc.CheckExpirations(now) + } + } +} + +func getClose(cc *client.Conn) func() { + v := cc.Context().Value(closeKey) + if v == nil { + return nil + } + closeFn, ok := v.(func()) + if !ok { + panic(fmt.Errorf("invalid type(%T) of context value for key %s", v, closeKey)) + } + return closeFn +} + +func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr, laddr *net.UDPAddr) (cc *client.Conn, created bool) { + s.connsMutex.Lock() + defer s.connsMutex.Unlock() + key := raddr.String() + "-" + laddr.String() + cc = s.conns[key] + + if cc != nil { + return cc, false + } + + createBlockWise := func(*client.Conn) *blockwise.BlockWise[*client.Conn] { + return nil + } + if s.cfg.BlockwiseEnable { + createBlockWise = func(cc *client.Conn) *blockwise.BlockWise[*client.Conn] { + v := cc + return blockwise.New( + v, + s.cfg.BlockwiseTransferTimeout, + s.cfg.Errors, + func(token message.Token) (*pool.Message, bool) { + msg, ok := v.GetObservationRequest(token) + if ok { + return msg, ok + } + return s.multicastRequests.LoadWithFunc(token.Hash(), func(m *pool.Message) *pool.Message { + msg := v.AcquireMessage(m.Context()) + msg.ResetOptionsTo(m.Options()) + msg.SetCode(m.Code()) + msg.SetToken(m.Token()) + msg.SetMessageID(m.MessageID()) + return msg + }) + }) + } + } + session := NewSession( + s.ctx, + s.doneCtx, + udpConn, + raddr, + s.cfg.MaxMessageSize, + s.cfg.MTU, + false, + ) + monitor := s.cfg.CreateInactivityMonitor() + cfg := client.DefaultConfig + cfg.TransmissionNStart = s.cfg.TransmissionNStart + cfg.TransmissionAcknowledgeTimeout = s.cfg.TransmissionAcknowledgeTimeout + cfg.TransmissionMaxRetransmit = s.cfg.TransmissionMaxRetransmit + cfg.Handler = func(w *responsewriter.ResponseWriter[*client.Conn], r *pool.Message) { + h, ok := s.multicastHandler.Load(r.Token().Hash()) + if ok { + h(w, r) + return + } + s.cfg.Handler(w, r) + } + cfg.BlockwiseSZX = s.cfg.BlockwiseSZX + cfg.Errors = s.cfg.Errors + cfg.GetMID = s.cfg.GetMID + cfg.GetToken = s.cfg.GetToken + cfg.MessagePool = s.cfg.MessagePool + cfg.ProcessReceivedMessage = s.cfg.ProcessReceivedMessage + cfg.ReceivedMessageQueueSize = s.cfg.ReceivedMessageQueueSize + + requestMonitor := s.cfg.RequestMonitor + cc = client.NewConnWithOpts( + session, + &cfg, + client.WithInactivityMonitor(monitor), + client.WithRequestMonitor(requestMonitor), + client.WithBlockWise(createBlockWise), + ) + cc.SetContextValue(closeKey, func() { + if err := session.Close(); err != nil { + s.cfg.Errors(fmt.Errorf("cannot close session: %w", err)) + } + session.shutdown() + }) + cc.AddOnClose(func() { + s.connsMutex.Lock() + defer s.connsMutex.Unlock() + if cc == s.conns[key] { + delete(s.conns, key) + } + }) + s.conns[key] = cc + return cc, true +} + +func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, laddr *net.UDPAddr, firstTime bool) (*client.Conn, error) { + cc, created := s.getOrCreateConn(l, raddr, laddr) + if created { + if s.cfg.OnNewConn != nil { + s.cfg.OnNewConn(cc) + } + } else { + // check if client is not expired now + 10ms - if so, close it + // 10ms - The expected maximum time taken by cc.CheckExpirations and cc.InactivityMonitor().Notify() + cc.CheckExpirations(time.Now().Add(10 * time.Millisecond)) + if cc.Context().Err() == nil { + // if client is not closed, extend expiration time + cc.InactivityMonitor().Notify() + } + } + + if cc.Context().Err() != nil { + // connection is closed so we need to create new one + if closeFn := getClose(cc); closeFn != nil { + closeFn() + } + if firstTime { + return s.getConn(l, raddr, laddr, false) + } + return nil, errors.New("connection is closed") + } + return cc, nil +} + +func (s *Server) NewConn(addr *net.UDPAddr, laddr *net.UDPAddr) (*client.Conn, error) { + l := s.getListener() + if l == nil { + // server is not started/stopped + return nil, errors.New("server is not running") + } + return s.getConn(l, addr, laddr, true) +} From 94fce92c30421992dd00bf1c588e910c4df0ac16 Mon Sep 17 00:00:00 2001 From: Theodor Rauch Date: Mon, 9 Mar 2026 14:06:24 +0100 Subject: [PATCH 04/16] fix: did not swap --- .gitignore | 1 + udp/client/conn.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index fca55899..ea05bcb7 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ client vendor/ v3/ .idea/ +.vscode/ # Test binary, build with `go test -c` *.test diff --git a/udp/client/conn.go b/udp/client/conn.go index 0cd228f3..d5ccac45 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -832,7 +832,7 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con resp.SetToken(req.Token()) cc.interfaceIndex.Store(int64(req.ControlMessage().GetIfIndex())) - cc.localAddr.Store(&req.ControlMessage().Src) + cc.localAddr.Store(&req.ControlMessage().Dst) w := responsewriter.New(resp, cc, req.Options()...) defer func() { From b187e369dce5b61569362a83dd076f7ea6dccec3 Mon Sep 17 00:00:00 2001 From: Theodor Rauch Date: Mon, 9 Mar 2026 14:18:07 +0100 Subject: [PATCH 05/16] fix: remove vscode folder from gitignore --- .gitignore | 1 - .vscode/launch.json | 12 ++++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 .vscode/launch.json diff --git a/.gitignore b/.gitignore index ea05bcb7..fca55899 100644 --- a/.gitignore +++ b/.gitignore @@ -12,7 +12,6 @@ client vendor/ v3/ .idea/ -.vscode/ # Test binary, build with `go test -c` *.test diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..86102eb2 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,12 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Run simple server", + "type": "go", + "request": "launch", + "mode": "auto", + "program": "${workspaceFolder}/examples/observe/server" + } + ] +} \ No newline at end of file From ec6f6e186db268d68116fc91fc278d6469c0d9b7 Mon Sep 17 00:00:00 2001 From: Theodor Rauch Date: Mon, 9 Mar 2026 14:47:43 +0100 Subject: [PATCH 06/16] fix: ignore mcast adresses --- udp/client/conn.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/udp/client/conn.go b/udp/client/conn.go index d5ccac45..84bfa94a 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -832,7 +832,9 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con resp.SetToken(req.Token()) cc.interfaceIndex.Store(int64(req.ControlMessage().GetIfIndex())) - cc.localAddr.Store(&req.ControlMessage().Dst) + if !req.ControlMessage().Dst.IsMulticast() { + cc.localAddr.Store(&req.ControlMessage().Dst) + } w := responsewriter.New(resp, cc, req.Options()...) defer func() { From 7a6dfeb58e2456988ca6164fd16b23d0dbd736b9 Mon Sep 17 00:00:00 2001 From: Theodor Rauch Date: Mon, 9 Mar 2026 15:06:37 +0100 Subject: [PATCH 07/16] fix: modify signatures in NewConn in tests --- udp/server_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/udp/server_test.go b/udp/server_test.go index 119adf80..26cd9854 100644 --- a/udp/server_test.go +++ b/udp/server_test.go @@ -467,7 +467,7 @@ func TestServerNewClient(t *testing.T) { time.Sleep(time.Second) - cc, err := s1.NewConn(peer) + cc, err := s1.NewConn(peer, &net.UDPAddr{IP: net.IP{127, 0, 0, 1}}) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*1) @@ -478,7 +478,7 @@ func TestServerNewClient(t *testing.T) { require.NoError(t, err) // repeat ping - new client should be created - cc, err = s1.NewConn(peer) + cc, err = s1.NewConn(peer, &net.UDPAddr{IP: net.IP{127, 0, 0, 1}}) require.NoError(t, err) err = cc.Ping(ctx) require.NoError(t, err) @@ -600,7 +600,7 @@ func TestServerReconnectNewClient(t *testing.T) { time.Sleep(time.Second) - cc, err := s1.NewConn(peer) + cc, err := s1.NewConn(peer, &net.UDPAddr{IP: net.IP{127, 0, 0, 1}}) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*1) @@ -626,7 +626,7 @@ func TestServerReconnectNewClient(t *testing.T) { } // new client - cc, err = s1.NewConn(peer) + cc, err = s1.NewConn(peer, &net.UDPAddr{IP: net.IP{127, 0, 0, 1}}) require.NoError(t, err) ctx, cancel = context.WithTimeout(context.Background(), time.Second*1) defer cancel() From 90efe21401c0d64f00b8c2b1e7549e789f939478 Mon Sep 17 00:00:00 2001 From: Theodor Rauch Date: Tue, 10 Mar 2026 09:36:33 +0100 Subject: [PATCH 08/16] fix: data race condition --- udp/server/server.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/udp/server/server.go b/udp/server/server.go index abb5d11c..a3b17e95 100644 --- a/udp/server/server.go +++ b/udp/server/server.go @@ -162,7 +162,9 @@ func (s *Server) Serve(l *coapNet.UDPConn) error { // UDPConn.LocalAddr() only takes into account the address it is bound to. // In the case of a wildcard address, the actual destination address is in the control message. - laddr := l.LocalAddr().(*net.UDPAddr) + // Copy the UDPAddr before mutating to avoid a data race with concurrent LocalAddr() readers. + laddrVal := *l.LocalAddr().(*net.UDPAddr) + laddr := &laddrVal if cm != nil && cm.Dst != nil { laddr.IP = cm.Dst } From abaa681af0bb2549e59bcfed43116bb5536d7cac Mon Sep 17 00:00:00 2001 From: Theodor Rauch Date: Tue, 10 Mar 2026 09:51:32 +0100 Subject: [PATCH 09/16] fix: don't use dummy adresses in test --- udp/server_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/udp/server_test.go b/udp/server_test.go index 26cd9854..f3e47344 100644 --- a/udp/server_test.go +++ b/udp/server_test.go @@ -467,7 +467,7 @@ func TestServerNewClient(t *testing.T) { time.Sleep(time.Second) - cc, err := s1.NewConn(peer, &net.UDPAddr{IP: net.IP{127, 0, 0, 1}}) + cc, err := s1.NewConn(peer, l1.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*1) @@ -478,7 +478,7 @@ func TestServerNewClient(t *testing.T) { require.NoError(t, err) // repeat ping - new client should be created - cc, err = s1.NewConn(peer, &net.UDPAddr{IP: net.IP{127, 0, 0, 1}}) + cc, err = s1.NewConn(peer, l1.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) err = cc.Ping(ctx) require.NoError(t, err) @@ -600,7 +600,7 @@ func TestServerReconnectNewClient(t *testing.T) { time.Sleep(time.Second) - cc, err := s1.NewConn(peer, &net.UDPAddr{IP: net.IP{127, 0, 0, 1}}) + cc, err := s1.NewConn(peer, l1.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*1) @@ -626,7 +626,7 @@ func TestServerReconnectNewClient(t *testing.T) { } // new client - cc, err = s1.NewConn(peer, &net.UDPAddr{IP: net.IP{127, 0, 0, 1}}) + cc, err = s1.NewConn(peer, l1.LocalAddr().(*net.UDPAddr)) require.NoError(t, err) ctx, cancel = context.WithTimeout(context.Background(), time.Second*1) defer cancel() From aef7b2b4134b53a96b8cf914e3571379ee80aead Mon Sep 17 00:00:00 2001 From: Theodor Rauch Date: Tue, 10 Mar 2026 10:58:00 +0100 Subject: [PATCH 10/16] fix: missed nil check --- udp/client/conn.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/udp/client/conn.go b/udp/client/conn.go index 84bfa94a..e8ce6eca 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -832,8 +832,8 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con resp.SetToken(req.Token()) cc.interfaceIndex.Store(int64(req.ControlMessage().GetIfIndex())) - if !req.ControlMessage().Dst.IsMulticast() { - cc.localAddr.Store(&req.ControlMessage().Dst) + if cm := req.ControlMessage(); cm != nil && !cm.Dst.IsMulticast() { + cc.localAddr.Store(&cm.Dst) } w := responsewriter.New(resp, cc, req.Options()...) From ef50a58922455d6046f8df824a9026cdb2dbb53b Mon Sep 17 00:00:00 2001 From: Theodor Rauch Date: Tue, 10 Mar 2026 11:57:36 +0100 Subject: [PATCH 11/16] fix: don't upsert if nothing to do --- udp/client/conn.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/udp/client/conn.go b/udp/client/conn.go index e8ce6eca..5d7eba85 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -863,17 +863,19 @@ func (cc *Conn) handlePong(w *responsewriter.ResponseWriter[*Conn], r *pool.Mess } func (cc *Conn) upsertControlInformation(msg *pool.Message) { - + ifIndex := int(cc.interfaceIndex.Load()) + localAddr := cc.localAddr.Load() + if ifIndex < 1 && localAddr == nil { + return + } cm := coapNet.ControlMessage{} - if cc.interfaceIndex.Load() >= 0 { - cm.IfIndex = int(cc.interfaceIndex.Load()) + if ifIndex >= 1 { + cm.IfIndex = ifIndex } - - if cc.localAddr.Load() != nil { - cm.Src = *cc.localAddr.Load() + if localAddr != nil { + cm.Src = *localAddr } msg.UpsertControlMessage(&cm) - } func (cc *Conn) handleSpecialMessages(r *pool.Message) bool { From ce20149cb63f54ce6a8b3dec063895d0e8716969 Mon Sep 17 00:00:00 2001 From: Theodor Rauch Date: Tue, 10 Mar 2026 13:04:02 +0100 Subject: [PATCH 12/16] fix: prevent against race conditions and dangling pointers --- udp/client/conn.go | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/udp/client/conn.go b/udp/client/conn.go index 5d7eba85..891887f5 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -833,7 +833,9 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con cc.interfaceIndex.Store(int64(req.ControlMessage().GetIfIndex())) if cm := req.ControlMessage(); cm != nil && !cm.Dst.IsMulticast() { - cc.localAddr.Store(&cm.Dst) + dst := make(net.IP, len(cm.Dst)) + copy(dst, cm.Dst) + cc.localAddr.Store(&dst) } w := responsewriter.New(resp, cc, req.Options()...) @@ -864,18 +866,17 @@ func (cc *Conn) handlePong(w *responsewriter.ResponseWriter[*Conn], r *pool.Mess func (cc *Conn) upsertControlInformation(msg *pool.Message) { ifIndex := int(cc.interfaceIndex.Load()) - localAddr := cc.localAddr.Load() - if ifIndex < 1 && localAddr == nil { + localAddrPtr := cc.localAddr.Load() + if ifIndex < 1 && localAddrPtr == nil { return } - cm := coapNet.ControlMessage{} - if ifIndex >= 1 { - cm.IfIndex = ifIndex - } - if localAddr != nil { - cm.Src = *localAddr + + var localAddr net.IP + if localAddrPtr != nil { + localAddr = *localAddrPtr } - msg.UpsertControlMessage(&cm) + + msg.UpsertControlMessage(&coapNet.ControlMessage{IfIndex: ifIndex, Src: localAddr}) } func (cc *Conn) handleSpecialMessages(r *pool.Message) bool { From 7eb30bacaa851399be13b1ef6826f1f031b57972 Mon Sep 17 00:00:00 2001 From: Theodor Rauch Date: Tue, 10 Mar 2026 18:51:40 +0100 Subject: [PATCH 13/16] fix: comment --- udp/server/server.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/udp/server/server.go b/udp/server/server.go index a3b17e95..cb454a18 100644 --- a/udp/server/server.go +++ b/udp/server/server.go @@ -162,7 +162,7 @@ func (s *Server) Serve(l *coapNet.UDPConn) error { // UDPConn.LocalAddr() only takes into account the address it is bound to. // In the case of a wildcard address, the actual destination address is in the control message. - // Copy the UDPAddr before mutating to avoid a data race with concurrent LocalAddr() readers. + // On server-initiated exchanges, listeners Local Addr can be used as the client has no assumptions of the source. laddrVal := *l.LocalAddr().(*net.UDPAddr) laddr := &laddrVal if cm != nil && cm.Dst != nil { From b175457494d6ae6611cf866285a52a9ba2481d41 Mon Sep 17 00:00:00 2001 From: Jozef Kralik Date: Fri, 20 Mar 2026 10:00:52 +0100 Subject: [PATCH 14/16] fix for golangci-lint --- udp/server/server.go | 792 ++++++++++++++++++++++--------------------- 1 file changed, 398 insertions(+), 394 deletions(-) diff --git a/udp/server/server.go b/udp/server/server.go index cb454a18..79ba46ba 100644 --- a/udp/server/server.go +++ b/udp/server/server.go @@ -1,394 +1,398 @@ -package server - -import ( - "context" - "errors" - "fmt" - "net" - "sync" - "time" - - "github.com/plgd-dev/go-coap/v3/message" - "github.com/plgd-dev/go-coap/v3/message/pool" - coapNet "github.com/plgd-dev/go-coap/v3/net" - "github.com/plgd-dev/go-coap/v3/net/blockwise" - "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" - "github.com/plgd-dev/go-coap/v3/net/responsewriter" - coapSync "github.com/plgd-dev/go-coap/v3/pkg/sync" - "github.com/plgd-dev/go-coap/v3/udp/client" -) - -type Server struct { - doneCtx context.Context - ctx context.Context - multicastRequests *client.RequestsMap - multicastHandler *coapSync.Map[uint64, HandlerFunc] - serverStartedChan chan struct{} - doneCancel context.CancelFunc - cancel context.CancelFunc - - connsMutex sync.Mutex - conns map[string]*client.Conn - - listenMutex sync.Mutex - listen *coapNet.UDPConn - - cfg *Config -} - -// A Option sets options such as credentials, codec and keepalive parameters, etc. -type Option interface { - UDPServerApply(cfg *Config) -} - -func New(opt ...Option) *Server { - cfg := DefaultConfig - for _, o := range opt { - o.UDPServerApply(&cfg) - } - - if cfg.Errors == nil { - cfg.Errors = func(error) { - // default no-op - } - } - - if cfg.GetMID == nil { - cfg.GetMID = message.GetMID - } - - if cfg.GetToken == nil { - cfg.GetToken = message.GetToken - } - - if cfg.CreateInactivityMonitor == nil { - cfg.CreateInactivityMonitor = func() client.InactivityMonitor { - return inactivity.NewNilMonitor[*client.Conn]() - } - } - if cfg.MessagePool == nil { - cfg.MessagePool = pool.New(0, 0) - } - - ctx, cancel := context.WithCancel(cfg.Ctx) - serverStartedChan := make(chan struct{}) - - doneCtx, doneCancel := context.WithCancel(context.Background()) - errorsFunc := cfg.Errors - cfg.Errors = func(err error) { - if coapNet.IsCancelOrCloseError(err) { - // this error was produced by cancellation context or closing connection. - return - } - errorsFunc(fmt.Errorf("udp: %w", err)) - } - return &Server{ - ctx: ctx, - cancel: cancel, - multicastHandler: coapSync.NewMap[uint64, HandlerFunc](), - multicastRequests: coapSync.NewMap[uint64, *pool.Message](), - serverStartedChan: serverStartedChan, - doneCtx: doneCtx, - doneCancel: doneCancel, - conns: make(map[string]*client.Conn), - - cfg: &cfg, - } -} - -func (s *Server) checkAndSetListener(l *coapNet.UDPConn) error { - s.listenMutex.Lock() - defer s.listenMutex.Unlock() - if s.listen != nil { - return fmt.Errorf("server already serve: %v", s.listen.LocalAddr().String()) - } - s.listen = l - close(s.serverStartedChan) - return nil -} - -func (s *Server) closeConnection(cc *client.Conn) { - if err := cc.Close(); err != nil { - s.cfg.Errors(fmt.Errorf("cannot close connection: %w", err)) - } -} - -func (s *Server) Serve(l *coapNet.UDPConn) error { - if s.cfg.BlockwiseSZX > blockwise.SZX1024 { - return errors.New("invalid blockwiseSZX") - } - - err := s.checkAndSetListener(l) - if err != nil { - return err - } - - defer func() { - s.closeSessions() - s.doneCancel() - s.listenMutex.Lock() - defer s.listenMutex.Unlock() - s.listen = nil - s.serverStartedChan = make(chan struct{}, 1) - }() - - m := make([]byte, s.cfg.MaxMessageSize) - var wg sync.WaitGroup - - s.cfg.PeriodicRunner(func(now time.Time) bool { - s.handleInactivityMonitors(now) - return s.ctx.Err() == nil - }) - - for { - buf := m - var raddr *net.UDPAddr - var cm *coapNet.ControlMessage - n, err := l.ReadWithOptions(buf, coapNet.WithContext(s.ctx), coapNet.WithGetControlMessage(&cm), coapNet.WithGetRemoteAddr(&raddr)) - if err != nil { - wg.Wait() - - select { - case <-s.ctx.Done(): - return nil - default: - if coapNet.IsCancelOrCloseError(err) { - return nil - } - return err - } - } - buf = buf[:n] - - // UDPConn.LocalAddr() only takes into account the address it is bound to. - // In the case of a wildcard address, the actual destination address is in the control message. - // On server-initiated exchanges, listeners Local Addr can be used as the client has no assumptions of the source. - laddrVal := *l.LocalAddr().(*net.UDPAddr) - laddr := &laddrVal - if cm != nil && cm.Dst != nil { - laddr.IP = cm.Dst - } - - cc, err := s.getConn(l, raddr, laddr, true) - if err != nil { - s.cfg.Errors(fmt.Errorf("%v: cannot get client connection: %w", raddr, err)) - continue - } - err = cc.Process(cm, buf) - if err != nil { - s.closeConnection(cc) - s.cfg.Errors(fmt.Errorf("%v: cannot process packet: %w", cc.RemoteAddr(), err)) - } - } -} - -func (s *Server) getListener() *coapNet.UDPConn { - s.listenMutex.Lock() - defer s.listenMutex.Unlock() - return s.listen -} - -// Stop stops server without wait of ends Serve function. -func (s *Server) Stop() { - s.cancel() - l := s.getListener() - if l != nil { - if errC := l.Close(); errC != nil { - s.cfg.Errors(fmt.Errorf("cannot close listener: %w", errC)) - } - } - s.closeSessions() -} - -func (s *Server) closeSessions() { - s.connsMutex.Lock() - conns := s.conns - s.conns = make(map[string]*client.Conn) - s.connsMutex.Unlock() - for _, cc := range conns { - s.closeConnection(cc) - if closeFn := getClose(cc); closeFn != nil { - closeFn() - } - } -} - -func (s *Server) conn() *coapNet.UDPConn { - s.listenMutex.Lock() - serverStartedChan := s.serverStartedChan - s.listenMutex.Unlock() - select { - case <-serverStartedChan: - case <-s.ctx.Done(): - } - s.listenMutex.Lock() - defer s.listenMutex.Unlock() - return s.listen -} - -const closeKey = "gocoapCloseConnection" - -func (s *Server) getConns() []*client.Conn { - s.connsMutex.Lock() - defer s.connsMutex.Unlock() - conns := make([]*client.Conn, 0, 32) - for _, c := range s.conns { - conns = append(conns, c) - } - return conns -} - -func (s *Server) handleInactivityMonitors(now time.Time) { - for _, cc := range s.getConns() { - select { - case <-cc.Context().Done(): - if closeFn := getClose(cc); closeFn != nil { - closeFn() - } - continue - default: - cc.CheckExpirations(now) - } - } -} - -func getClose(cc *client.Conn) func() { - v := cc.Context().Value(closeKey) - if v == nil { - return nil - } - closeFn, ok := v.(func()) - if !ok { - panic(fmt.Errorf("invalid type(%T) of context value for key %s", v, closeKey)) - } - return closeFn -} - -func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr, laddr *net.UDPAddr) (cc *client.Conn, created bool) { - s.connsMutex.Lock() - defer s.connsMutex.Unlock() - key := raddr.String() + "-" + laddr.String() - cc = s.conns[key] - - if cc != nil { - return cc, false - } - - createBlockWise := func(*client.Conn) *blockwise.BlockWise[*client.Conn] { - return nil - } - if s.cfg.BlockwiseEnable { - createBlockWise = func(cc *client.Conn) *blockwise.BlockWise[*client.Conn] { - v := cc - return blockwise.New( - v, - s.cfg.BlockwiseTransferTimeout, - s.cfg.Errors, - func(token message.Token) (*pool.Message, bool) { - msg, ok := v.GetObservationRequest(token) - if ok { - return msg, ok - } - return s.multicastRequests.LoadWithFunc(token.Hash(), func(m *pool.Message) *pool.Message { - msg := v.AcquireMessage(m.Context()) - msg.ResetOptionsTo(m.Options()) - msg.SetCode(m.Code()) - msg.SetToken(m.Token()) - msg.SetMessageID(m.MessageID()) - return msg - }) - }) - } - } - session := NewSession( - s.ctx, - s.doneCtx, - udpConn, - raddr, - s.cfg.MaxMessageSize, - s.cfg.MTU, - false, - ) - monitor := s.cfg.CreateInactivityMonitor() - cfg := client.DefaultConfig - cfg.TransmissionNStart = s.cfg.TransmissionNStart - cfg.TransmissionAcknowledgeTimeout = s.cfg.TransmissionAcknowledgeTimeout - cfg.TransmissionMaxRetransmit = s.cfg.TransmissionMaxRetransmit - cfg.Handler = func(w *responsewriter.ResponseWriter[*client.Conn], r *pool.Message) { - h, ok := s.multicastHandler.Load(r.Token().Hash()) - if ok { - h(w, r) - return - } - s.cfg.Handler(w, r) - } - cfg.BlockwiseSZX = s.cfg.BlockwiseSZX - cfg.Errors = s.cfg.Errors - cfg.GetMID = s.cfg.GetMID - cfg.GetToken = s.cfg.GetToken - cfg.MessagePool = s.cfg.MessagePool - cfg.ProcessReceivedMessage = s.cfg.ProcessReceivedMessage - cfg.ReceivedMessageQueueSize = s.cfg.ReceivedMessageQueueSize - - requestMonitor := s.cfg.RequestMonitor - cc = client.NewConnWithOpts( - session, - &cfg, - client.WithInactivityMonitor(monitor), - client.WithRequestMonitor(requestMonitor), - client.WithBlockWise(createBlockWise), - ) - cc.SetContextValue(closeKey, func() { - if err := session.Close(); err != nil { - s.cfg.Errors(fmt.Errorf("cannot close session: %w", err)) - } - session.shutdown() - }) - cc.AddOnClose(func() { - s.connsMutex.Lock() - defer s.connsMutex.Unlock() - if cc == s.conns[key] { - delete(s.conns, key) - } - }) - s.conns[key] = cc - return cc, true -} - -func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, laddr *net.UDPAddr, firstTime bool) (*client.Conn, error) { - cc, created := s.getOrCreateConn(l, raddr, laddr) - if created { - if s.cfg.OnNewConn != nil { - s.cfg.OnNewConn(cc) - } - } else { - // check if client is not expired now + 10ms - if so, close it - // 10ms - The expected maximum time taken by cc.CheckExpirations and cc.InactivityMonitor().Notify() - cc.CheckExpirations(time.Now().Add(10 * time.Millisecond)) - if cc.Context().Err() == nil { - // if client is not closed, extend expiration time - cc.InactivityMonitor().Notify() - } - } - - if cc.Context().Err() != nil { - // connection is closed so we need to create new one - if closeFn := getClose(cc); closeFn != nil { - closeFn() - } - if firstTime { - return s.getConn(l, raddr, laddr, false) - } - return nil, errors.New("connection is closed") - } - return cc, nil -} - -func (s *Server) NewConn(addr *net.UDPAddr, laddr *net.UDPAddr) (*client.Conn, error) { - l := s.getListener() - if l == nil { - // server is not started/stopped - return nil, errors.New("server is not running") - } - return s.getConn(l, addr, laddr, true) -} +package server + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "time" + + "github.com/plgd-dev/go-coap/v3/message" + "github.com/plgd-dev/go-coap/v3/message/pool" + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/plgd-dev/go-coap/v3/net/blockwise" + "github.com/plgd-dev/go-coap/v3/net/monitor/inactivity" + "github.com/plgd-dev/go-coap/v3/net/responsewriter" + coapSync "github.com/plgd-dev/go-coap/v3/pkg/sync" + "github.com/plgd-dev/go-coap/v3/udp/client" +) + +type Server struct { + doneCtx context.Context + ctx context.Context + multicastRequests *client.RequestsMap + multicastHandler *coapSync.Map[uint64, HandlerFunc] + serverStartedChan chan struct{} + doneCancel context.CancelFunc + cancel context.CancelFunc + + connsMutex sync.Mutex + conns map[string]*client.Conn + + listenMutex sync.Mutex + listen *coapNet.UDPConn + + cfg *Config +} + +// A Option sets options such as credentials, codec and keepalive parameters, etc. +type Option interface { + UDPServerApply(cfg *Config) +} + +func New(opt ...Option) *Server { + cfg := DefaultConfig + for _, o := range opt { + o.UDPServerApply(&cfg) + } + + if cfg.Errors == nil { + cfg.Errors = func(error) { + // default no-op + } + } + + if cfg.GetMID == nil { + cfg.GetMID = message.GetMID + } + + if cfg.GetToken == nil { + cfg.GetToken = message.GetToken + } + + if cfg.CreateInactivityMonitor == nil { + cfg.CreateInactivityMonitor = func() client.InactivityMonitor { + return inactivity.NewNilMonitor[*client.Conn]() + } + } + if cfg.MessagePool == nil { + cfg.MessagePool = pool.New(0, 0) + } + + ctx, cancel := context.WithCancel(cfg.Ctx) + serverStartedChan := make(chan struct{}) + + doneCtx, doneCancel := context.WithCancel(context.Background()) + errorsFunc := cfg.Errors + cfg.Errors = func(err error) { + if coapNet.IsCancelOrCloseError(err) { + // this error was produced by cancellation context or closing connection. + return + } + errorsFunc(fmt.Errorf("udp: %w", err)) + } + return &Server{ + ctx: ctx, + cancel: cancel, + multicastHandler: coapSync.NewMap[uint64, HandlerFunc](), + multicastRequests: coapSync.NewMap[uint64, *pool.Message](), + serverStartedChan: serverStartedChan, + doneCtx: doneCtx, + doneCancel: doneCancel, + conns: make(map[string]*client.Conn), + + cfg: &cfg, + } +} + +func (s *Server) checkAndSetListener(l *coapNet.UDPConn) error { + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + if s.listen != nil { + return fmt.Errorf("server already serve: %v", s.listen.LocalAddr().String()) + } + s.listen = l + close(s.serverStartedChan) + return nil +} + +func (s *Server) closeConnection(cc *client.Conn) { + if err := cc.Close(); err != nil { + s.cfg.Errors(fmt.Errorf("cannot close connection: %w", err)) + } +} + +func (s *Server) Serve(l *coapNet.UDPConn) error { + if s.cfg.BlockwiseSZX > blockwise.SZX1024 { + return errors.New("invalid blockwiseSZX") + } + + err := s.checkAndSetListener(l) + if err != nil { + return err + } + + defer func() { + s.closeSessions() + s.doneCancel() + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + s.listen = nil + s.serverStartedChan = make(chan struct{}, 1) + }() + + m := make([]byte, s.cfg.MaxMessageSize) + var wg sync.WaitGroup + + s.cfg.PeriodicRunner(func(now time.Time) bool { + s.handleInactivityMonitors(now) + return s.ctx.Err() == nil + }) + + for { + buf := m + var raddr *net.UDPAddr + var cm *coapNet.ControlMessage + n, err := l.ReadWithOptions(buf, coapNet.WithContext(s.ctx), coapNet.WithGetControlMessage(&cm), coapNet.WithGetRemoteAddr(&raddr)) + if err != nil { + wg.Wait() + + select { + case <-s.ctx.Done(): + return nil + default: + if coapNet.IsCancelOrCloseError(err) { + return nil + } + return err + } + } + buf = buf[:n] + + // UDPConn.LocalAddr() only takes into account the address it is bound to. + // In the case of a wildcard address, the actual destination address is in the control message. + // On server-initiated exchanges, listeners Local Addr can be used as the client has no assumptions of the source. + localAddr, ok := l.LocalAddr().(*net.UDPAddr) + if !ok || localAddr == nil { + return fmt.Errorf("unexpected listener local addr type: %T", l.LocalAddr()) + } + laddrVal := *localAddr + laddr := &laddrVal + if cm != nil && cm.Dst != nil { + laddr.IP = cm.Dst + } + + cc, err := s.getConn(l, raddr, laddr, true) + if err != nil { + s.cfg.Errors(fmt.Errorf("%v: cannot get client connection: %w", raddr, err)) + continue + } + err = cc.Process(cm, buf) + if err != nil { + s.closeConnection(cc) + s.cfg.Errors(fmt.Errorf("%v: cannot process packet: %w", cc.RemoteAddr(), err)) + } + } +} + +func (s *Server) getListener() *coapNet.UDPConn { + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + return s.listen +} + +// Stop stops server without wait of ends Serve function. +func (s *Server) Stop() { + s.cancel() + l := s.getListener() + if l != nil { + if errC := l.Close(); errC != nil { + s.cfg.Errors(fmt.Errorf("cannot close listener: %w", errC)) + } + } + s.closeSessions() +} + +func (s *Server) closeSessions() { + s.connsMutex.Lock() + conns := s.conns + s.conns = make(map[string]*client.Conn) + s.connsMutex.Unlock() + for _, cc := range conns { + s.closeConnection(cc) + if closeFn := getClose(cc); closeFn != nil { + closeFn() + } + } +} + +func (s *Server) conn() *coapNet.UDPConn { + s.listenMutex.Lock() + serverStartedChan := s.serverStartedChan + s.listenMutex.Unlock() + select { + case <-serverStartedChan: + case <-s.ctx.Done(): + } + s.listenMutex.Lock() + defer s.listenMutex.Unlock() + return s.listen +} + +const closeKey = "gocoapCloseConnection" + +func (s *Server) getConns() []*client.Conn { + s.connsMutex.Lock() + defer s.connsMutex.Unlock() + conns := make([]*client.Conn, 0, 32) + for _, c := range s.conns { + conns = append(conns, c) + } + return conns +} + +func (s *Server) handleInactivityMonitors(now time.Time) { + for _, cc := range s.getConns() { + select { + case <-cc.Context().Done(): + if closeFn := getClose(cc); closeFn != nil { + closeFn() + } + continue + default: + cc.CheckExpirations(now) + } + } +} + +func getClose(cc *client.Conn) func() { + v := cc.Context().Value(closeKey) + if v == nil { + return nil + } + closeFn, ok := v.(func()) + if !ok { + panic(fmt.Errorf("invalid type(%T) of context value for key %s", v, closeKey)) + } + return closeFn +} + +func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr, laddr *net.UDPAddr) (cc *client.Conn, created bool) { + s.connsMutex.Lock() + defer s.connsMutex.Unlock() + key := raddr.String() + "-" + laddr.String() + cc = s.conns[key] + + if cc != nil { + return cc, false + } + + createBlockWise := func(*client.Conn) *blockwise.BlockWise[*client.Conn] { + return nil + } + if s.cfg.BlockwiseEnable { + createBlockWise = func(cc *client.Conn) *blockwise.BlockWise[*client.Conn] { + v := cc + return blockwise.New( + v, + s.cfg.BlockwiseTransferTimeout, + s.cfg.Errors, + func(token message.Token) (*pool.Message, bool) { + msg, ok := v.GetObservationRequest(token) + if ok { + return msg, ok + } + return s.multicastRequests.LoadWithFunc(token.Hash(), func(m *pool.Message) *pool.Message { + msg := v.AcquireMessage(m.Context()) + msg.ResetOptionsTo(m.Options()) + msg.SetCode(m.Code()) + msg.SetToken(m.Token()) + msg.SetMessageID(m.MessageID()) + return msg + }) + }) + } + } + session := NewSession( + s.ctx, + s.doneCtx, + udpConn, + raddr, + s.cfg.MaxMessageSize, + s.cfg.MTU, + false, + ) + monitor := s.cfg.CreateInactivityMonitor() + cfg := client.DefaultConfig + cfg.TransmissionNStart = s.cfg.TransmissionNStart + cfg.TransmissionAcknowledgeTimeout = s.cfg.TransmissionAcknowledgeTimeout + cfg.TransmissionMaxRetransmit = s.cfg.TransmissionMaxRetransmit + cfg.Handler = func(w *responsewriter.ResponseWriter[*client.Conn], r *pool.Message) { + h, ok := s.multicastHandler.Load(r.Token().Hash()) + if ok { + h(w, r) + return + } + s.cfg.Handler(w, r) + } + cfg.BlockwiseSZX = s.cfg.BlockwiseSZX + cfg.Errors = s.cfg.Errors + cfg.GetMID = s.cfg.GetMID + cfg.GetToken = s.cfg.GetToken + cfg.MessagePool = s.cfg.MessagePool + cfg.ProcessReceivedMessage = s.cfg.ProcessReceivedMessage + cfg.ReceivedMessageQueueSize = s.cfg.ReceivedMessageQueueSize + + requestMonitor := s.cfg.RequestMonitor + cc = client.NewConnWithOpts( + session, + &cfg, + client.WithInactivityMonitor(monitor), + client.WithRequestMonitor(requestMonitor), + client.WithBlockWise(createBlockWise), + ) + cc.SetContextValue(closeKey, func() { + if err := session.Close(); err != nil { + s.cfg.Errors(fmt.Errorf("cannot close session: %w", err)) + } + session.shutdown() + }) + cc.AddOnClose(func() { + s.connsMutex.Lock() + defer s.connsMutex.Unlock() + if cc == s.conns[key] { + delete(s.conns, key) + } + }) + s.conns[key] = cc + return cc, true +} + +func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, laddr *net.UDPAddr, firstTime bool) (*client.Conn, error) { + cc, created := s.getOrCreateConn(l, raddr, laddr) + if created { + if s.cfg.OnNewConn != nil { + s.cfg.OnNewConn(cc) + } + } else { + // check if client is not expired now + 10ms - if so, close it + // 10ms - The expected maximum time taken by cc.CheckExpirations and cc.InactivityMonitor().Notify() + cc.CheckExpirations(time.Now().Add(10 * time.Millisecond)) + if cc.Context().Err() == nil { + // if client is not closed, extend expiration time + cc.InactivityMonitor().Notify() + } + } + + if cc.Context().Err() != nil { + // connection is closed so we need to create new one + if closeFn := getClose(cc); closeFn != nil { + closeFn() + } + if firstTime { + return s.getConn(l, raddr, laddr, false) + } + return nil, errors.New("connection is closed") + } + return cc, nil +} + +func (s *Server) NewConn(addr *net.UDPAddr, laddr *net.UDPAddr) (*client.Conn, error) { + l := s.getListener() + if l == nil { + // server is not started/stopped + return nil, errors.New("server is not running") + } + return s.getConn(l, addr, laddr, true) +} From bec995cf0cf114760762d62f7e6f66e8e7e5980d Mon Sep 17 00:00:00 2001 From: Jozef Kralik Date: Fri, 20 Mar 2026 10:41:18 +0100 Subject: [PATCH 15/16] fix(udp): avoid multicast conn key leaks and stale control state --- udp/client/conn.go | 26 +++++++++---- udp/client/controlmessage_test.go | 49 +++++++++++++++++++++++++ udp/server/server.go | 61 ++++++++++++++++++++++++++----- udp/server/server_key_test.go | 26 +++++++++++++ udp/server_test.go | 11 ++++-- 5 files changed, 152 insertions(+), 21 deletions(-) create mode 100644 udp/client/controlmessage_test.go create mode 100644 udp/server/server_key_test.go diff --git a/udp/client/conn.go b/udp/client/conn.go index 891887f5..694f8f2c 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -830,13 +830,7 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con }() resp := cc.AcquireMessage(cc.Context()) resp.SetToken(req.Token()) - - cc.interfaceIndex.Store(int64(req.ControlMessage().GetIfIndex())) - if cm := req.ControlMessage(); cm != nil && !cm.Dst.IsMulticast() { - dst := make(net.IP, len(cm.Dst)) - copy(dst, cm.Dst) - cc.localAddr.Store(&dst) - } + cc.setControlInformation(req.ControlMessage()) w := responsewriter.New(resp, cc, req.Options()...) defer func() { @@ -864,6 +858,24 @@ func (cc *Conn) handlePong(w *responsewriter.ResponseWriter[*Conn], r *pool.Mess cc.sendPong(w, r) } +func (cc *Conn) setControlInformation(cm *coapNet.ControlMessage) { + if cm == nil { + cc.interfaceIndex.Store(0) + cc.localAddr.Store(nil) + return + } + + cc.interfaceIndex.Store(int64(cm.GetIfIndex())) + if len(cm.Dst) == 0 || cm.Dst.IsMulticast() { + cc.localAddr.Store(nil) + return + } + + dst := make(net.IP, len(cm.Dst)) + copy(dst, cm.Dst) + cc.localAddr.Store(&dst) +} + func (cc *Conn) upsertControlInformation(msg *pool.Message) { ifIndex := int(cc.interfaceIndex.Load()) localAddrPtr := cc.localAddr.Load() diff --git a/udp/client/controlmessage_test.go b/udp/client/controlmessage_test.go new file mode 100644 index 00000000..6dd4bed0 --- /dev/null +++ b/udp/client/controlmessage_test.go @@ -0,0 +1,49 @@ +package client + +import ( + "net" + "testing" + + coapNet "github.com/plgd-dev/go-coap/v3/net" + "github.com/stretchr/testify/require" +) + +func TestSetControlInformationNil(t *testing.T) { + var cc Conn + addr := net.ParseIP("192.0.2.1").To4() + cc.localAddr.Store(&addr) + cc.interfaceIndex.Store(7) + + cc.setControlInformation(nil) + + require.Equal(t, int64(0), cc.interfaceIndex.Load()) + require.Nil(t, cc.localAddr.Load()) +} + +func TestSetControlInformationUnicastCopiesAddress(t *testing.T) { + var cc Conn + dst := net.ParseIP("2001:db8::1") + cm := &coapNet.ControlMessage{Dst: append(net.IP(nil), dst...), IfIndex: 9} + + cc.setControlInformation(cm) + + require.Equal(t, int64(9), cc.interfaceIndex.Load()) + stored := cc.localAddr.Load() + require.NotNil(t, stored) + expected := append(net.IP(nil), dst...) + require.True(t, (*stored).Equal(expected)) + + cm.Dst[0] ^= 0xff + require.True(t, (*stored).Equal(expected)) +} + +func TestSetControlInformationMulticastClearsAddress(t *testing.T) { + var cc Conn + addr := net.ParseIP("192.0.2.1").To4() + cc.localAddr.Store(&addr) + + cc.setControlInformation(&coapNet.ControlMessage{Dst: net.ParseIP("ff02::1"), IfIndex: 11}) + + require.Equal(t, int64(11), cc.interfaceIndex.Load()) + require.Nil(t, cc.localAddr.Load()) +} diff --git a/udp/server/server.go b/udp/server/server.go index 79ba46ba..f1871cd1 100644 --- a/udp/server/server.go +++ b/udp/server/server.go @@ -162,14 +162,12 @@ func (s *Server) Serve(l *coapNet.UDPConn) error { // UDPConn.LocalAddr() only takes into account the address it is bound to. // In the case of a wildcard address, the actual destination address is in the control message. - // On server-initiated exchanges, listeners Local Addr can be used as the client has no assumptions of the source. - localAddr, ok := l.LocalAddr().(*net.UDPAddr) - if !ok || localAddr == nil { - return fmt.Errorf("unexpected listener local addr type: %T", l.LocalAddr()) + // On server-initiated exchanges, listener's LocalAddr can be used as the client has no assumptions of the source. + laddr, err := s.getListenerLocalAddr(l) + if err != nil { + return err } - laddrVal := *localAddr - laddr := &laddrVal - if cm != nil && cm.Dst != nil { + if cm != nil && len(cm.Dst) > 0 && !cm.Dst.IsMulticast() { laddr.IP = cm.Dst } @@ -192,6 +190,15 @@ func (s *Server) getListener() *coapNet.UDPConn { return s.listen } +func (s *Server) getListenerLocalAddr(l *coapNet.UDPConn) (*net.UDPAddr, error) { + localAddr, ok := l.LocalAddr().(*net.UDPAddr) + if !ok || localAddr == nil { + return nil, fmt.Errorf("unexpected listener local addr type: %T", l.LocalAddr()) + } + laddrVal := *localAddr + return &laddrVal, nil +} + // Stop stops server without wait of ends Serve function. func (s *Server) Stop() { s.cancel() @@ -268,10 +275,21 @@ func getClose(cc *client.Conn) func() { return closeFn } +func getConnKey(raddr *net.UDPAddr, laddr *net.UDPAddr) string { + normalizedLocalAddr := *laddr + if len(normalizedLocalAddr.IP) > 0 && normalizedLocalAddr.IP.IsMulticast() { + // Multicast destination address does not identify a unique server-side source address. + // Normalize it to avoid creating one conn key per multicast group. + normalizedLocalAddr.IP = nil + normalizedLocalAddr.Zone = "" + } + return raddr.String() + "-" + normalizedLocalAddr.String() +} + func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr, laddr *net.UDPAddr) (cc *client.Conn, created bool) { s.connsMutex.Lock() defer s.connsMutex.Unlock() - key := raddr.String() + "-" + laddr.String() + key := getConnKey(raddr, laddr) cc = s.conns[key] if cc != nil { @@ -360,6 +378,17 @@ func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr, l } func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, laddr *net.UDPAddr, firstTime bool) (*client.Conn, error) { + if raddr == nil { + return nil, errors.New("invalid remote address") + } + if laddr == nil { + var err error + laddr, err = s.getListenerLocalAddr(l) + if err != nil { + return nil, err + } + } + cc, created := s.getOrCreateConn(l, raddr, laddr) if created { if s.cfg.OnNewConn != nil { @@ -388,11 +417,23 @@ func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, laddr *net.UDPA return cc, nil } -func (s *Server) NewConn(addr *net.UDPAddr, laddr *net.UDPAddr) (*client.Conn, error) { +// NewConn creates or gets a connection for the provided remote address. +// +// Optional laddr may be used to pin a concrete local address when the listener is bound to a wildcard address. +// If laddr is omitted or nil, listener's local address is used. +func (s *Server) NewConn(addr *net.UDPAddr, laddr ...*net.UDPAddr) (*client.Conn, error) { + if len(laddr) > 1 { + return nil, fmt.Errorf("invalid number of local addresses: %d", len(laddr)) + } + var localAddr *net.UDPAddr + if len(laddr) == 1 { + localAddr = laddr[0] + } + l := s.getListener() if l == nil { // server is not started/stopped return nil, errors.New("server is not running") } - return s.getConn(l, addr, laddr, true) + return s.getConn(l, addr, localAddr, true) } diff --git a/udp/server/server_key_test.go b/udp/server/server_key_test.go new file mode 100644 index 00000000..1565aa93 --- /dev/null +++ b/udp/server/server_key_test.go @@ -0,0 +1,26 @@ +package server + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetConnKeyIgnoresMulticastLocalAddress(t *testing.T) { + raddr := &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 56830} + mcastV6 := &net.UDPAddr{IP: net.ParseIP("ff02::fd"), Port: 5683} + mcastV4 := &net.UDPAddr{IP: net.ParseIP("224.0.1.187"), Port: 5683} + normalized := &net.UDPAddr{Port: 5683} + + require.Equal(t, getConnKey(raddr, mcastV6), getConnKey(raddr, normalized)) + require.Equal(t, getConnKey(raddr, mcastV4), getConnKey(raddr, normalized)) +} + +func TestGetConnKeyKeepsUnicastLocalAddressDistinct(t *testing.T) { + raddr := &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 56830} + laddrA := &net.UDPAddr{IP: net.ParseIP("2001:db8::10"), Port: 5683} + laddrB := &net.UDPAddr{IP: net.ParseIP("2001:db8::11"), Port: 5683} + + require.NotEqual(t, getConnKey(raddr, laddrA), getConnKey(raddr, laddrB)) +} diff --git a/udp/server_test.go b/udp/server_test.go index f3e47344..9c6799a3 100644 --- a/udp/server_test.go +++ b/udp/server_test.go @@ -467,7 +467,10 @@ func TestServerNewClient(t *testing.T) { time.Sleep(time.Second) - cc, err := s1.NewConn(peer, l1.LocalAddr().(*net.UDPAddr)) + _, err = s1.NewConn(nil) + require.ErrorContains(t, err, "invalid remote address") + + cc, err := s1.NewConn(peer) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*1) @@ -478,7 +481,7 @@ func TestServerNewClient(t *testing.T) { require.NoError(t, err) // repeat ping - new client should be created - cc, err = s1.NewConn(peer, l1.LocalAddr().(*net.UDPAddr)) + cc, err = s1.NewConn(peer, nil) require.NoError(t, err) err = cc.Ping(ctx) require.NoError(t, err) @@ -600,7 +603,7 @@ func TestServerReconnectNewClient(t *testing.T) { time.Sleep(time.Second) - cc, err := s1.NewConn(peer, l1.LocalAddr().(*net.UDPAddr)) + cc, err := s1.NewConn(peer) require.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second*1) @@ -626,7 +629,7 @@ func TestServerReconnectNewClient(t *testing.T) { } // new client - cc, err = s1.NewConn(peer, l1.LocalAddr().(*net.UDPAddr)) + cc, err = s1.NewConn(peer, nil) require.NoError(t, err) ctx, cancel = context.WithTimeout(context.Background(), time.Second*1) defer cancel() From 2250c2777b9bc405e79304b07de110a303611b92 Mon Sep 17 00:00:00 2001 From: Jozef Kralik Date: Fri, 20 Mar 2026 11:49:27 +0100 Subject: [PATCH 16/16] fix the two gocritic warnings --- udp/client/controlmessage_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/udp/client/controlmessage_test.go b/udp/client/controlmessage_test.go index 6dd4bed0..3f0d0745 100644 --- a/udp/client/controlmessage_test.go +++ b/udp/client/controlmessage_test.go @@ -31,10 +31,10 @@ func TestSetControlInformationUnicastCopiesAddress(t *testing.T) { stored := cc.localAddr.Load() require.NotNil(t, stored) expected := append(net.IP(nil), dst...) - require.True(t, (*stored).Equal(expected)) + require.True(t, stored.Equal(expected)) cm.Dst[0] ^= 0xff - require.True(t, (*stored).Equal(expected)) + require.True(t, stored.Equal(expected)) } func TestSetControlInformationMulticastClearsAddress(t *testing.T) {