Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ client
!client/
vendor/
v3/

.idea/
# Test binary, build with `go test -c`
*.test

Expand Down
12 changes: 12 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Run simple server",
"type": "go",
"request": "launch",
"mode": "auto",
"program": "${workspaceFolder}/examples/observe/server"
}
]
}
48 changes: 39 additions & 9 deletions udp/client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,9 @@ type Conn struct {
msgID atomic.Uint32
blockwiseSZX blockwise.SZX

localAddr atomic.Pointer[net.IP]
interfaceIndex atomic.Int64

/*
An outstanding interaction is either a CON for which an ACK has not
yet been received but is still expected (message layer) or a request
Expand Down Expand Up @@ -550,6 +553,8 @@ func (cc *Conn) writeMessageAsync(req *pool.Message) error {
func (cc *Conn) writeMessage(req *pool.Message) error {
req.UpsertType(message.Confirmable)
req.UpsertMessageID(cc.GetMessageID())
cc.upsertControlInformation(req)

if req.Type() != message.Confirmable {
return cc.writeMessageAsync(req)
}
Expand Down Expand Up @@ -755,6 +760,7 @@ func (cc *Conn) processResponse(reqType message.Type, reqMessageID int32, w *res
w.Message().SetType(message.Acknowledgement)
w.Message().SetMessageID(reqMessageID)
w.Message().SetToken(nil)

err := cc.addResponseToCache(w.Message())
if err != nil {
return fmt.Errorf("cannot cache response: %w", err)
Expand Down Expand Up @@ -824,7 +830,8 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con
}()
resp := cc.AcquireMessage(cc.Context())
resp.SetToken(req.Token())
ifIndex := req.ControlMessage().GetIfIndex()
cc.setControlInformation(req.ControlMessage())

w := responsewriter.New(resp, cc, req.Options()...)
defer func() {
cc.ReleaseMessage(w.Message())
Expand All @@ -839,7 +846,7 @@ func (cc *Conn) ProcessReceivedMessageWithHandler(req *pool.Message, handler con
// nothing to send
return
}
upsertInterfaceToMessage(w.Message(), ifIndex)
cc.upsertControlInformation(w.Message())
errW := cc.writeMessageAsync(w.Message())
if errW != nil {
cc.closeConnection()
Expand All @@ -851,13 +858,37 @@ func (cc *Conn) handlePong(w *responsewriter.ResponseWriter[*Conn], r *pool.Mess
cc.sendPong(w, r)
}

func upsertInterfaceToMessage(m *pool.Message, ifIndex int) {
if ifIndex >= 1 {
cm := coapNet.ControlMessage{
IfIndex: ifIndex,
}
m.UpsertControlMessage(&cm)
func (cc *Conn) setControlInformation(cm *coapNet.ControlMessage) {
if cm == nil {
cc.interfaceIndex.Store(0)
cc.localAddr.Store(nil)
return
}

cc.interfaceIndex.Store(int64(cm.GetIfIndex()))
if len(cm.Dst) == 0 || cm.Dst.IsMulticast() {
cc.localAddr.Store(nil)
return
}

dst := make(net.IP, len(cm.Dst))
copy(dst, cm.Dst)
cc.localAddr.Store(&dst)
}

func (cc *Conn) upsertControlInformation(msg *pool.Message) {
ifIndex := int(cc.interfaceIndex.Load())
localAddrPtr := cc.localAddr.Load()
if ifIndex < 1 && localAddrPtr == nil {
return
}

var localAddr net.IP
if localAddrPtr != nil {
localAddr = *localAddrPtr
}

msg.UpsertControlMessage(&coapNet.ControlMessage{IfIndex: ifIndex, Src: localAddr})
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

func (cc *Conn) handleSpecialMessages(r *pool.Message) bool {
Expand All @@ -872,7 +903,6 @@ func (cc *Conn) handleSpecialMessages(r *pool.Message) bool {
elem.ReleaseMessage(cc)
resp := cc.AcquireMessage(cc.Context())
resp.SetToken(r.Token())
upsertInterfaceToMessage(resp, r.ControlMessage().GetIfIndex())
w := responsewriter.New(resp, cc, r.Options()...)
defer func() {
cc.ReleaseMessage(w.Message())
Expand Down
49 changes: 49 additions & 0 deletions udp/client/controlmessage_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package client

import (
"net"
"testing"

coapNet "github.com/plgd-dev/go-coap/v3/net"
"github.com/stretchr/testify/require"
)

func TestSetControlInformationNil(t *testing.T) {
var cc Conn
addr := net.ParseIP("192.0.2.1").To4()
cc.localAddr.Store(&addr)
cc.interfaceIndex.Store(7)

cc.setControlInformation(nil)

require.Equal(t, int64(0), cc.interfaceIndex.Load())
require.Nil(t, cc.localAddr.Load())
}

func TestSetControlInformationUnicastCopiesAddress(t *testing.T) {
var cc Conn
dst := net.ParseIP("2001:db8::1")
cm := &coapNet.ControlMessage{Dst: append(net.IP(nil), dst...), IfIndex: 9}

cc.setControlInformation(cm)

require.Equal(t, int64(9), cc.interfaceIndex.Load())
stored := cc.localAddr.Load()
require.NotNil(t, stored)
expected := append(net.IP(nil), dst...)
require.True(t, stored.Equal(expected))

cm.Dst[0] ^= 0xff
require.True(t, stored.Equal(expected))
}

func TestSetControlInformationMulticastClearsAddress(t *testing.T) {
var cc Conn
addr := net.ParseIP("192.0.2.1").To4()
cc.localAddr.Store(&addr)

cc.setControlInformation(&coapNet.ControlMessage{Dst: net.ParseIP("ff02::1"), IfIndex: 11})

require.Equal(t, int64(11), cc.interfaceIndex.Load())
require.Nil(t, cc.localAddr.Load())
}
71 changes: 63 additions & 8 deletions udp/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,19 @@ func (s *Server) Serve(l *coapNet.UDPConn) error {
}
}
buf = buf[:n]
cc, err := s.getConn(l, raddr, true)

// UDPConn.LocalAddr() only takes into account the address it is bound to.
// In the case of a wildcard address, the actual destination address is in the control message.
// On server-initiated exchanges, listener's LocalAddr can be used as the client has no assumptions of the source.
laddr, err := s.getListenerLocalAddr(l)
if err != nil {
return err
}
if cm != nil && len(cm.Dst) > 0 && !cm.Dst.IsMulticast() {
laddr.IP = cm.Dst
}

cc, err := s.getConn(l, raddr, laddr, true)
if err != nil {
s.cfg.Errors(fmt.Errorf("%v: cannot get client connection: %w", raddr, err))
continue
Expand All @@ -178,6 +190,15 @@ func (s *Server) getListener() *coapNet.UDPConn {
return s.listen
}

func (s *Server) getListenerLocalAddr(l *coapNet.UDPConn) (*net.UDPAddr, error) {
localAddr, ok := l.LocalAddr().(*net.UDPAddr)
if !ok || localAddr == nil {
return nil, fmt.Errorf("unexpected listener local addr type: %T", l.LocalAddr())
}
laddrVal := *localAddr
return &laddrVal, nil
}

// Stop stops server without wait of ends Serve function.
func (s *Server) Stop() {
s.cancel()
Expand Down Expand Up @@ -254,10 +275,21 @@ func getClose(cc *client.Conn) func() {
return closeFn
}

func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) (cc *client.Conn, created bool) {
func getConnKey(raddr *net.UDPAddr, laddr *net.UDPAddr) string {
normalizedLocalAddr := *laddr
if len(normalizedLocalAddr.IP) > 0 && normalizedLocalAddr.IP.IsMulticast() {
// Multicast destination address does not identify a unique server-side source address.
// Normalize it to avoid creating one conn key per multicast group.
normalizedLocalAddr.IP = nil
normalizedLocalAddr.Zone = ""
}
return raddr.String() + "-" + normalizedLocalAddr.String()
}

func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr, laddr *net.UDPAddr) (cc *client.Conn, created bool) {
s.connsMutex.Lock()
defer s.connsMutex.Unlock()
key := raddr.String()
key := getConnKey(raddr, laddr)
cc = s.conns[key]

if cc != nil {
Expand Down Expand Up @@ -345,8 +377,19 @@ func (s *Server) getOrCreateConn(udpConn *coapNet.UDPConn, raddr *net.UDPAddr) (
return cc, true
}

func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, firstTime bool) (*client.Conn, error) {
cc, created := s.getOrCreateConn(l, raddr)
func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, laddr *net.UDPAddr, firstTime bool) (*client.Conn, error) {
if raddr == nil {
return nil, errors.New("invalid remote address")
}
if laddr == nil {
var err error
laddr, err = s.getListenerLocalAddr(l)
if err != nil {
return nil, err
}
}

cc, created := s.getOrCreateConn(l, raddr, laddr)
if created {
if s.cfg.OnNewConn != nil {
s.cfg.OnNewConn(cc)
Expand All @@ -367,18 +410,30 @@ func (s *Server) getConn(l *coapNet.UDPConn, raddr *net.UDPAddr, firstTime bool)
closeFn()
}
if firstTime {
return s.getConn(l, raddr, false)
return s.getConn(l, raddr, laddr, false)
}
return nil, errors.New("connection is closed")
}
return cc, nil
}

func (s *Server) NewConn(addr *net.UDPAddr) (*client.Conn, error) {
// NewConn creates or gets a connection for the provided remote address.
//
// Optional laddr may be used to pin a concrete local address when the listener is bound to a wildcard address.
// If laddr is omitted or nil, listener's local address is used.
func (s *Server) NewConn(addr *net.UDPAddr, laddr ...*net.UDPAddr) (*client.Conn, error) {
if len(laddr) > 1 {
return nil, fmt.Errorf("invalid number of local addresses: %d", len(laddr))
}
var localAddr *net.UDPAddr
if len(laddr) == 1 {
localAddr = laddr[0]
}

l := s.getListener()
if l == nil {
// server is not started/stopped
return nil, errors.New("server is not running")
}
return s.getConn(l, addr, true)
return s.getConn(l, addr, localAddr, true)
}
26 changes: 26 additions & 0 deletions udp/server/server_key_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package server

import (
"net"
"testing"

"github.com/stretchr/testify/require"
)

func TestGetConnKeyIgnoresMulticastLocalAddress(t *testing.T) {
raddr := &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 56830}
mcastV6 := &net.UDPAddr{IP: net.ParseIP("ff02::fd"), Port: 5683}
mcastV4 := &net.UDPAddr{IP: net.ParseIP("224.0.1.187"), Port: 5683}
normalized := &net.UDPAddr{Port: 5683}

require.Equal(t, getConnKey(raddr, mcastV6), getConnKey(raddr, normalized))
require.Equal(t, getConnKey(raddr, mcastV4), getConnKey(raddr, normalized))
}

func TestGetConnKeyKeepsUnicastLocalAddressDistinct(t *testing.T) {
raddr := &net.UDPAddr{IP: net.ParseIP("2001:db8::1"), Port: 56830}
laddrA := &net.UDPAddr{IP: net.ParseIP("2001:db8::10"), Port: 5683}
laddrB := &net.UDPAddr{IP: net.ParseIP("2001:db8::11"), Port: 5683}

require.NotEqual(t, getConnKey(raddr, laddrA), getConnKey(raddr, laddrB))
}
7 changes: 5 additions & 2 deletions udp/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,9 @@ func TestServerNewClient(t *testing.T) {

time.Sleep(time.Second)

_, err = s1.NewConn(nil)
require.ErrorContains(t, err, "invalid remote address")

cc, err := s1.NewConn(peer)
require.NoError(t, err)

Expand All @@ -478,7 +481,7 @@ func TestServerNewClient(t *testing.T) {
require.NoError(t, err)

// repeat ping - new client should be created
cc, err = s1.NewConn(peer)
cc, err = s1.NewConn(peer, nil)
require.NoError(t, err)
err = cc.Ping(ctx)
require.NoError(t, err)
Expand Down Expand Up @@ -626,7 +629,7 @@ func TestServerReconnectNewClient(t *testing.T) {
}

// new client
cc, err = s1.NewConn(peer)
cc, err = s1.NewConn(peer, nil)
require.NoError(t, err)
ctx, cancel = context.WithTimeout(context.Background(), time.Second*1)
defer cancel()
Expand Down
Loading