Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ client
!client/
vendor/
v3/

/.idea/
# Test binary, build with `go test -c`
*.test

Expand Down
2 changes: 1 addition & 1 deletion examples/simple/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

func main() {
co, err := udp.Dial("localhost:5688")
co, err := udp.Dial("localhost:5685")
Comment thread
coderabbitai[bot] marked this conversation as resolved.
if err != nil {
log.Fatalf("Error dialing: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion examples/simple/server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,5 @@ func main() {
r.Handle("/a", mux.HandlerFunc(handleA))
r.Handle("/b", mux.HandlerFunc(handleB))

log.Fatal(coap.ListenAndServe("udp", ":5688", r))
log.Fatal(coap.ListenAndServe("udp", ":5685", r))
}
104 changes: 66 additions & 38 deletions net/blockwise/blockwise.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
"context"
"errors"
"fmt"
"hash/fnv"
"io"
"net"
"time"

"github.com/dsnet/golib/memfile"
Expand Down Expand Up @@ -156,7 +158,7 @@
}

// New provides blockwise.
// getSentRequestFromOutside must returns a copy of request which will be released after use.
// getSentRequestFromOutside must return a copy of request which will be released after use.
func New[C Client](
cc C,
expiration time.Duration,
Expand Down Expand Up @@ -202,8 +204,8 @@
return fmt.Errorf("cannot get size of payload: %w", err)
}

// Do sends an coap message and returns an coap response via blockwise transfer.
func (b *BlockWise[C]) Do(r *pool.Message, maxSzx SZX, maxMessageSize uint32, do func(req *pool.Message) (*pool.Message, error)) (*pool.Message, error) {
// Do sends a coap message and returns a coap response via blockwise transfer.
func (b *BlockWise[C]) Do(r *pool.Message, maxSzx SZX, maxMessageSize uint32, remoteAddr net.Addr, do func(req *pool.Message) (*pool.Message, error)) (*pool.Message, error) {

Check failure on line 208 in net/blockwise/blockwise.go

View workflow job for this annotation

GitHub Actions / lint

cyclomatic complexity 16 of func `(*BlockWise).Do` is high (> 15) (gocyclo)
if maxSzx > SZXBERT {
return nil, errors.New("invalid szx")
}
Expand All @@ -215,11 +217,16 @@
if !ok {
expire = time.Now().Add(b.expiration)
}
_, loaded := b.sendingMessagesCache.LoadOrStore(r.Token().Hash(), cache.NewElement(r, expire, nil))

exchangeKey, err := getExchangeKey(r.Token(), remoteAddr)
if err != nil {
return nil, fmt.Errorf("cannot get exchange key from request Parameters(%v): %w", r, err)
}
_, loaded := b.sendingMessagesCache.LoadOrStore(exchangeKey, cache.NewElement(r, expire, nil))
if loaded {
return nil, errors.New("invalid token")
}
defer b.sendingMessagesCache.Delete(r.Token().Hash())
defer b.sendingMessagesCache.Delete(exchangeKey)
if r.Body() == nil {
return do(r)
}
Expand Down Expand Up @@ -289,7 +296,7 @@
}

w := newWriteRequestResponse(b.cc, request)
err = b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock)
err = b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock, 0)
if err != nil {
return fmt.Errorf("cannot start writing request: %w", err)
}
Expand Down Expand Up @@ -347,44 +354,40 @@
}

// Handle middleware which constructs COAP request from blockwise transfer and send COAP response via blockwise.
func (b *BlockWise[C]) Handle(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) {
func (b *BlockWise[C]) Handle(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, remoteAddr net.Addr, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) {
if maxSZX > SZXBERT {
panic("invalid maxSZX")
}
token := r.Token()

if len(token) == 0 {
err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next)
if err != nil {
b.sendEntityIncomplete(w, token)
b.errors(fmt.Errorf("handleReceivedMessage(%v): %w", r, err))
}
exchangeKey, err := getExchangeKey(r.Token(), remoteAddr)
if err != nil {
b.errors(fmt.Errorf("cannot get exchange key from request Parameters(%v): %w", r, err))
return
}
Comment thread
jkralik marked this conversation as resolved.
tokenStr := token.Hash()
token := r.Token()

sendingMessageCode, sendingMessageExist := b.getSendingMessageCode(tokenStr)
sendingMessageCode, sendingMessageExist := b.getSendingMessageCode(exchangeKey)
if !sendingMessageExist || wantsToBeReceived(r) {
err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next)
err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, exchangeKey, next)

Check failure on line 371 in net/blockwise/blockwise.go

View workflow job for this annotation

GitHub Actions / lint

shadow: declaration of "err" shadows declaration at line 362 (govet)
if err != nil {
b.sendEntityIncomplete(w, token)
b.errors(fmt.Errorf("handleReceivedMessage(%v): %w", r, err))
}
return
}
more, err := b.continueSendingMessage(w, r, maxSZX, maxMessageSize, sendingMessageCode)
more, err := b.continueSendingMessage(w, r, maxSZX, maxMessageSize, sendingMessageCode, exchangeKey)
if err != nil {
b.sendingMessagesCache.Delete(tokenStr)
b.sendingMessagesCache.Delete(exchangeKey)
b.errors(fmt.Errorf("continueSendingMessage(%v): %w", r, err))
return
}
// For codes GET,POST,PUT,DELETE, we want them to wait for pairing response and then delete them when the full response comes in or when timeout occurs.
if !more && sendingMessageCode > codes.DELETE {
b.sendingMessagesCache.Delete(tokenStr)
b.sendingMessagesCache.Delete(exchangeKey)
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

func (b *BlockWise[C]) handleReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) error {
func (b *BlockWise[C]) handleReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, exchangeKey uint64, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) error {
startSendingMessageBlock, err := EncodeBlockOption(maxSZX, 0, true)
if err != nil {
return fmt.Errorf("cannot encode start sending message block option(%v,%v,%v): %w", maxSZX, 0, true, err)
Expand All @@ -402,18 +405,18 @@
}
case codes.POST, codes.PUT:
maxSZX = fitSZX(r, message.Block1, maxSZX)
errP := b.processReceivedMessage(w, r, maxSZX, next, message.Block1, message.Size1)
errP := b.processReceivedMessage(w, r, maxSZX, next, exchangeKey, message.Block1, message.Size1)
if errP != nil {
return errP
}
default:
maxSZX = fitSZX(r, message.Block2, maxSZX)
errP := b.processReceivedMessage(w, r, maxSZX, next, message.Block2, message.Size2)
errP := b.processReceivedMessage(w, r, maxSZX, next, exchangeKey, message.Block2, message.Size2)
if errP != nil {
return errP
}
}
return b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock)
return b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock, exchangeKey)
}

func (b *BlockWise[C]) createSendingMessage(sendingMessage *pool.Message, maxSZX SZX, maxMessageSize uint32, block uint32) (sendMessage *pool.Message, more bool, err error) {
Expand Down Expand Up @@ -497,7 +500,7 @@
return sendMessage, more, nil
}

func (b *BlockWise[C]) continueSendingMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, sendingMessageCode codes.Code /* msg *pool.Message*/) (bool, error) {
func (b *BlockWise[C]) continueSendingMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, sendingMessageCode codes.Code /* msg *pool.Message*/, exchangeKey uint64) (bool, error) {
blockType := message.Block2
switch sendingMessageCode {
case codes.POST, codes.PUT:
Expand All @@ -510,7 +513,7 @@
}
var sendMessage *pool.Message
var more bool
b.sendingMessagesCache.LoadWithFunc(r.Token().Hash(), func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] {
b.sendingMessagesCache.LoadWithFunc(exchangeKey, func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] {
sendMessage, more, err = b.createSendingMessage(value.Data(), maxSZX, maxMessageSize, block)
if err != nil {
err = fmt.Errorf("cannot create sending message: %w", err)
Expand All @@ -535,7 +538,7 @@
return msg.Code() >= codes.Created
}

func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C], maxSZX SZX, maxMessageSize uint32, block uint32) error {
func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C], maxSZX SZX, maxMessageSize uint32, block uint32, exchangeKey uint64) error {
payloadSize, err := w.Message().BodySize()
if err != nil {
return payloadSizeError(err)
Expand All @@ -558,16 +561,22 @@
if !ok {
expire = time.Now().Add(b.expiration)
}
el, loaded := b.sendingMessagesCache.LoadOrStore(sendingMessage.Token().Hash(), cache.NewElement(originalSendingMessage, expire, nil))

if exchangeKey == 0 {
exchangeKey = sendingMessage.Token().Hash()
}

el, loaded := b.sendingMessagesCache.LoadOrStore(exchangeKey, cache.NewElement(originalSendingMessage, expire, nil))
if loaded {
defer b.cc.ReleaseMessage(originalSendingMessage)
return fmt.Errorf("cannot add message (%v) to sending message cache: message(%v) with token(%v) already exist", originalSendingMessage, el.Data(), sendingMessage.Token())
}
return nil
}

func (b *BlockWise[C]) getSentRequest(token message.Token) *pool.Message {
data, ok := b.sendingMessagesCache.LoadWithFunc(token.Hash(), func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] {
func (b *BlockWise[C]) getSentRequest(exchangeKey uint64, token message.Token) *pool.Message {

Check failure on line 577 in net/blockwise/blockwise.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary leading newline (whitespace)

Check failure on line 578 in net/blockwise/blockwise.go

View workflow job for this annotation

GitHub Actions / lint

File is not properly formatted (gofumpt)
data, ok := b.sendingMessagesCache.LoadWithFunc(exchangeKey, func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] {
if value == nil {
return nil
}
Expand Down Expand Up @@ -739,7 +748,7 @@
}

//nolint:gocyclo,gocognit
func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSzx SZX, next func(w *responsewriter.ResponseWriter[C], r *pool.Message), blockType message.OptionID, sizeType message.OptionID) error {
func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSzx SZX, next func(w *responsewriter.ResponseWriter[C], r *pool.Message), exchangeKey uint64, blockType message.OptionID, sizeType message.OptionID) error {
token := r.Token()
if len(token) == 0 {
next(w, r)
Expand All @@ -761,7 +770,8 @@
if err != nil {
return fmt.Errorf("cannot decode block option: %w", err)
}
sentRequest := b.getSentRequest(token)

sentRequest := b.getSentRequest(exchangeKey, token)
if sentRequest != nil {
defer b.cc.ReleaseMessage(sentRequest)
}
Expand All @@ -774,11 +784,12 @@
if err != nil {
return fmt.Errorf("cannot process message: %w", err)
}
// we need to create new exchangeKey because token was changed
exchangeKey = token.Hash()
}

tokenStr := token.Hash()
var cachedReceivedMessageGuard *messageGuard
if e := b.receivingMessagesCache.Load(tokenStr); e != nil {
if e := b.receivingMessagesCache.Load(exchangeKey); e != nil {
cachedReceivedMessageGuard = e.Data()
}
if cachedReceivedMessageGuard == nil {
Expand All @@ -789,15 +800,15 @@
return nil
}
}
cachedReceivedMessage, closeCachedReceivedMessage, err := b.getCachedReceivedMessage(cachedReceivedMessageGuard, r, tokenStr, validUntil)
cachedReceivedMessage, closeCachedReceivedMessage, err := b.getCachedReceivedMessage(cachedReceivedMessageGuard, r, exchangeKey, validUntil)
if err != nil {
return err
}
defer closeCachedReceivedMessage()

defer func(err *error) {
if *err != nil {
b.receivingMessagesCache.Delete(tokenStr)
b.receivingMessagesCache.Delete(exchangeKey)
}
}(&err)
payloadFile, payloadSize, err := b.getPayloadFromCachedReceivedMessage(r, cachedReceivedMessage)
Expand All @@ -811,12 +822,12 @@
return fmt.Errorf("cannot copy data to payload: %w", err)
}
if !more {
b.receivingMessagesCache.Delete(tokenStr)
b.receivingMessagesCache.Delete(exchangeKey)
cachedReceivedMessage.Remove(blockType)
cachedReceivedMessage.Remove(sizeType)
cachedReceivedMessage.SetType(r.Type())
if !bytes.Equal(cachedReceivedMessage.Token(), token) {
b.sendingMessagesCache.Delete(tokenStr)
b.sendingMessagesCache.Delete(exchangeKey)
}
_, errS := cachedReceivedMessage.Body().Seek(0, io.SeekStart)
if errS != nil {
Expand Down Expand Up @@ -849,3 +860,20 @@
w.SetMessage(sendMessage)
return nil
}

// getExchangeKey returns a key for the blockwise exchange cache.
// According to RFC 7252: "[...] An empty token value is appropriate e.g., when no other tokens are in use to a destination"
// so we need to generate a key from the remote address of the corresponding endpoint.
func getExchangeKey(token message.Token, addr net.Addr) (uint64, error) {
if len(token) != 0 {
return token.Hash(), nil
}

if addr == nil {
return 0, errors.New("cannot get exchange key: addr is nil")
}

h := fnv.New64a()
h.Write([]byte(addr.String()))
return h.Sum64(), nil
}
10 changes: 5 additions & 5 deletions net/blockwise/blockwise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ func makeDo[C Client](t *testing.T, sender, receiver *BlockWise[C], senderMaxSZX
for {
var resp *pool.Message
receiverResp := responsewriter.New(receiver.cc.AcquireMessage(roReq.Context()), receiver.cc)
receiver.Handle(receiverResp, roReq, senderMaxSZX, senderMaxMessageSize, next)
receiver.Handle(receiverResp, roReq, senderMaxSZX, senderMaxMessageSize, nil, next)
t.Logf("receiver.Handle - receiverResp %v senderResp.Message: %v\n", receiverResp.Message(), roReq)
senderResp := responsewriter.New(sender.cc.AcquireMessage(roReq.Context()), sender.cc)
sender.Handle(senderResp, receiverResp.Message(), receiverMaxSZX, receiverMaxMessageSize, func(_ *responsewriter.ResponseWriter[C], r *pool.Message) {
sender.Handle(senderResp, receiverResp.Message(), receiverMaxSZX, receiverMaxMessageSize, nil, func(_ *responsewriter.ResponseWriter[C], r *pool.Message) {
resp = r
})
t.Logf("sender.Handle - senderResp %v receiverResp.Message: %v\n", senderResp.Message(), receiverResp.Message())
Expand Down Expand Up @@ -299,7 +299,7 @@ func TestBlockWiseDo(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := sender.Do(toPoolMessage(tt.args.r), tt.args.szx, uint32(tt.args.maxMessageSize), tt.args.do)
got, err := sender.Do(toPoolMessage(tt.args.r), tt.args.szx, uint32(tt.args.maxMessageSize), nil, tt.args.do)
if tt.wantErr {
require.Error(t, err)
return
Expand Down Expand Up @@ -496,13 +496,13 @@ func makeWriteReq[C Client](sender, receiver *BlockWise[C], senderMaxSZX SZX, se
roReq = req
for {
receiverResp := responsewriter.New(receiver.cc.AcquireMessage(roReq.Context()), receiver.cc)
receiver.Handle(receiverResp, roReq, senderMaxSZX, senderMaxMessageSize, func(w *responsewriter.ResponseWriter[C], r *pool.Message) {
receiver.Handle(receiverResp, roReq, senderMaxSZX, senderMaxMessageSize, nil, func(w *responsewriter.ResponseWriter[C], r *pool.Message) {
defer close(c)
next(w, r)
})
senderResp := responsewriter.New(sender.cc.AcquireMessage(roReq.Context()), sender.cc)
orig := senderResp.Message()
sender.Handle(senderResp, receiverResp.Message(), receiverMaxSZX, receiverMaxMessageSize, func(*responsewriter.ResponseWriter[C], *pool.Message) {
sender.Handle(senderResp, receiverResp.Message(), receiverMaxSZX, receiverMaxMessageSize, nil, func(*responsewriter.ResponseWriter[C], *pool.Message) {
})
if orig == senderResp.Message() {
select {
Expand Down
4 changes: 2 additions & 2 deletions tcp/client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func (cc *Conn) do(req *pool.Message) (*pool.Message, error) {
if !cc.peerBlockWiseTranferEnabled.Load() || cc.blockWise == nil {
return cc.doInternal(req)
}
resp, err := cc.blockWise.Do(req, cc.blockwiseSZX, cc.session.maxMessageSize, cc.doInternal)
resp, err := cc.blockWise.Do(req, cc.blockwiseSZX, cc.session.maxMessageSize, cc.RemoteAddr(), cc.doInternal)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -352,7 +352,7 @@ func (cc *Conn) blockwiseHandle(w *responsewriter.ResponseWriter[*Conn], r *pool

func (cc *Conn) handle(w *responsewriter.ResponseWriter[*Conn], r *pool.Message) {
if cc.blockWise != nil && cc.peerBlockWiseTranferEnabled.Load() {
cc.blockWise.Handle(w, r, cc.blockwiseSZX, cc.Session().maxMessageSize, cc.blockwiseHandle)
cc.blockWise.Handle(w, r, cc.blockwiseSZX, cc.Session().maxMessageSize, cc.RemoteAddr(), cc.blockwiseHandle)
return
}
if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok {
Expand Down
4 changes: 2 additions & 2 deletions udp/client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ func (cc *Conn) do(req *pool.Message) (*pool.Message, error) {
if cc.blockWise == nil {
return cc.doInternal(req)
}
resp, err := cc.blockWise.Do(req, cc.blockwiseSZX, cc.session.MaxMessageSize(), func(bwReq *pool.Message) (*pool.Message, error) {
resp, err := cc.blockWise.Do(req, cc.blockwiseSZX, cc.session.MaxMessageSize(), cc.RemoteAddr(), func(bwReq *pool.Message) (*pool.Message, error) {
if bwReq.Options().HasOption(message.Block1) || bwReq.Options().HasOption(message.Block2) {
bwReq.SetMessageID(cc.GetMessageID())
}
Expand Down Expand Up @@ -663,7 +663,7 @@ func (cc *Conn) handle(w *responsewriter.ResponseWriter[*Conn], m *pool.Message)
return
}
if cc.blockWise != nil {
cc.blockWise.Handle(w, m, cc.blockwiseSZX, cc.session.MaxMessageSize(), func(rw *responsewriter.ResponseWriter[*Conn], rm *pool.Message) {
cc.blockWise.Handle(w, m, cc.blockwiseSZX, cc.session.MaxMessageSize(), cc.RemoteAddr(), func(rw *responsewriter.ResponseWriter[*Conn], rm *pool.Message) {
if h, ok := cc.tokenHandlerContainer.LoadAndDelete(rm.Token().Hash()); ok {
h(rw, rm)
return
Expand Down
Loading