Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ client
!client/
vendor/
v3/

/.idea/
# Test binary, build with `go test -c`
*.test

Expand Down
2 changes: 1 addition & 1 deletion examples/simple/client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

func main() {
co, err := udp.Dial("localhost:5688")
co, err := udp.Dial("localhost:5685")
Comment thread
coderabbitai[bot] marked this conversation as resolved.
if err != nil {
log.Fatalf("Error dialing: %v", err)
}
Expand Down
104 changes: 67 additions & 37 deletions net/blockwise/blockwise.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
"context"
"errors"
"fmt"
"hash/fnv"
"io"
"net"
"time"

"github.com/dsnet/golib/memfile"
Expand Down Expand Up @@ -156,7 +158,7 @@
}

// New provides blockwise.
// getSentRequestFromOutside must returns a copy of request which will be released after use.
// getSentRequestFromOutside must return a copy of request which will be released after use.
func New[C Client](
cc C,
expiration time.Duration,
Expand Down Expand Up @@ -202,8 +204,8 @@
return fmt.Errorf("cannot get size of payload: %w", err)
}

// Do sends an coap message and returns an coap response via blockwise transfer.
func (b *BlockWise[C]) Do(r *pool.Message, maxSzx SZX, maxMessageSize uint32, do func(req *pool.Message) (*pool.Message, error)) (*pool.Message, error) {
// Do sends a coap message and returns a coap response via blockwise transfer.
func (b *BlockWise[C]) Do(r *pool.Message, maxSzx SZX, maxMessageSize uint32, remoteAddr net.Addr, do func(req *pool.Message) (*pool.Message, error)) (*pool.Message, error) {

Check failure on line 208 in net/blockwise/blockwise.go

View workflow job for this annotation

GitHub Actions / lint

unused-parameter: parameter 'remoteAddr' seems to be unused, consider removing or renaming it as _ (revive)
if maxSzx > SZXBERT {
return nil, errors.New("invalid szx")
}
Expand All @@ -215,11 +217,13 @@
if !ok {
expire = time.Now().Add(b.expiration)
}
_, loaded := b.sendingMessagesCache.LoadOrStore(r.Token().Hash(), cache.NewElement(r, expire, nil))

exchangeKey := r.Token().Hash()
_, loaded := b.sendingMessagesCache.LoadOrStore(exchangeKey, cache.NewElement(r, expire, nil))
if loaded {
return nil, errors.New("invalid token")
}
defer b.sendingMessagesCache.Delete(r.Token().Hash())
defer b.sendingMessagesCache.Delete(exchangeKey)
if r.Body() == nil {
return do(r)
}
Expand Down Expand Up @@ -289,7 +293,7 @@
}

w := newWriteRequestResponse(b.cc, request)
err = b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock)
err = b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock, 0)
if err != nil {
return fmt.Errorf("cannot start writing request: %w", err)
}
Expand Down Expand Up @@ -347,44 +351,48 @@
}

// Handle middleware which constructs COAP request from blockwise transfer and send COAP response via blockwise.
func (b *BlockWise[C]) Handle(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) {
func (b *BlockWise[C]) Handle(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, remoteAddr net.Addr, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) {
if maxSZX > SZXBERT {
panic("invalid maxSZX")
}
token := r.Token()

if len(token) == 0 {
err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next)
var exchangeKey uint64
if len(r.Token()) != 0 {
exchangeKey = r.Token().Hash()
} else {
var err error
exchangeKey, err = getExchangeKey(remoteAddr)
if err != nil {
b.sendEntityIncomplete(w, token)
b.errors(fmt.Errorf("handleReceivedMessage(%v): %w", r, err))
b.errors(fmt.Errorf("cannot get exchange key from request Parameters(%v): %w", r, err))
return
}
return

}

Check failure on line 370 in net/blockwise/blockwise.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary trailing newline (whitespace)
Comment thread
jkralik marked this conversation as resolved.
tokenStr := token.Hash()

sendingMessageCode, sendingMessageExist := b.getSendingMessageCode(tokenStr)
token := r.Token()

sendingMessageCode, sendingMessageExist := b.getSendingMessageCode(exchangeKey)
if !sendingMessageExist || wantsToBeReceived(r) {
err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next)
err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, exchangeKey, next)
if err != nil {
b.sendEntityIncomplete(w, token)
b.errors(fmt.Errorf("handleReceivedMessage(%v): %w", r, err))
}
return
}
more, err := b.continueSendingMessage(w, r, maxSZX, maxMessageSize, sendingMessageCode)
more, err := b.continueSendingMessage(w, r, maxSZX, maxMessageSize, sendingMessageCode, exchangeKey)
if err != nil {
b.sendingMessagesCache.Delete(tokenStr)
b.sendingMessagesCache.Delete(exchangeKey)
b.errors(fmt.Errorf("continueSendingMessage(%v): %w", r, err))
return
}
// For codes GET,POST,PUT,DELETE, we want them to wait for pairing response and then delete them when the full response comes in or when timeout occurs.
if !more && sendingMessageCode > codes.DELETE {
b.sendingMessagesCache.Delete(tokenStr)
b.sendingMessagesCache.Delete(exchangeKey)
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

func (b *BlockWise[C]) handleReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) error {
func (b *BlockWise[C]) handleReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, exchangeKey uint64, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) error {
startSendingMessageBlock, err := EncodeBlockOption(maxSZX, 0, true)
if err != nil {
return fmt.Errorf("cannot encode start sending message block option(%v,%v,%v): %w", maxSZX, 0, true, err)
Expand All @@ -402,18 +410,18 @@
}
case codes.POST, codes.PUT:
maxSZX = fitSZX(r, message.Block1, maxSZX)
errP := b.processReceivedMessage(w, r, maxSZX, next, message.Block1, message.Size1)
errP := b.processReceivedMessage(w, r, maxSZX, next, exchangeKey, message.Block1, message.Size1)
if errP != nil {
return errP
}
default:
maxSZX = fitSZX(r, message.Block2, maxSZX)
errP := b.processReceivedMessage(w, r, maxSZX, next, message.Block2, message.Size2)
errP := b.processReceivedMessage(w, r, maxSZX, next, exchangeKey, message.Block2, message.Size2)
if errP != nil {
return errP
}
}
return b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock)
return b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock, exchangeKey)
}

func (b *BlockWise[C]) createSendingMessage(sendingMessage *pool.Message, maxSZX SZX, maxMessageSize uint32, block uint32) (sendMessage *pool.Message, more bool, err error) {
Expand Down Expand Up @@ -497,7 +505,7 @@
return sendMessage, more, nil
}

func (b *BlockWise[C]) continueSendingMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, sendingMessageCode codes.Code /* msg *pool.Message*/) (bool, error) {
func (b *BlockWise[C]) continueSendingMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, sendingMessageCode codes.Code /* msg *pool.Message*/, exchangeKey uint64) (bool, error) {
blockType := message.Block2
switch sendingMessageCode {
case codes.POST, codes.PUT:
Expand All @@ -510,7 +518,7 @@
}
var sendMessage *pool.Message
var more bool
b.sendingMessagesCache.LoadWithFunc(r.Token().Hash(), func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] {
b.sendingMessagesCache.LoadWithFunc(exchangeKey, func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] {
sendMessage, more, err = b.createSendingMessage(value.Data(), maxSZX, maxMessageSize, block)
if err != nil {
err = fmt.Errorf("cannot create sending message: %w", err)
Expand All @@ -535,7 +543,7 @@
return msg.Code() >= codes.Created
}

func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C], maxSZX SZX, maxMessageSize uint32, block uint32) error {
func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C], maxSZX SZX, maxMessageSize uint32, block uint32, exchangeKey uint64) error {
payloadSize, err := w.Message().BodySize()
if err != nil {
return payloadSizeError(err)
Expand All @@ -558,16 +566,22 @@
if !ok {
expire = time.Now().Add(b.expiration)
}
el, loaded := b.sendingMessagesCache.LoadOrStore(sendingMessage.Token().Hash(), cache.NewElement(originalSendingMessage, expire, nil))

if exchangeKey == 0 {
exchangeKey = sendingMessage.Token().Hash()
}

el, loaded := b.sendingMessagesCache.LoadOrStore(exchangeKey, cache.NewElement(originalSendingMessage, expire, nil))
if loaded {
defer b.cc.ReleaseMessage(originalSendingMessage)
return fmt.Errorf("cannot add message (%v) to sending message cache: message(%v) with token(%v) already exist", originalSendingMessage, el.Data(), sendingMessage.Token())
}
return nil
}

func (b *BlockWise[C]) getSentRequest(token message.Token) *pool.Message {
data, ok := b.sendingMessagesCache.LoadWithFunc(token.Hash(), func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] {
func (b *BlockWise[C]) getSentRequest(exchangeKey uint64, token message.Token) *pool.Message {

Check failure on line 582 in net/blockwise/blockwise.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary leading newline (whitespace)

Check failure on line 583 in net/blockwise/blockwise.go

View workflow job for this annotation

GitHub Actions / lint

File is not properly formatted (gofumpt)
data, ok := b.sendingMessagesCache.LoadWithFunc(exchangeKey, func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] {
if value == nil {
return nil
}
Expand Down Expand Up @@ -739,7 +753,7 @@
}

//nolint:gocyclo,gocognit
func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSzx SZX, next func(w *responsewriter.ResponseWriter[C], r *pool.Message), blockType message.OptionID, sizeType message.OptionID) error {
func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSzx SZX, next func(w *responsewriter.ResponseWriter[C], r *pool.Message), exchangeKey uint64, blockType message.OptionID, sizeType message.OptionID) error {
token := r.Token()
if len(token) == 0 {
next(w, r)
Expand All @@ -761,7 +775,8 @@
if err != nil {
return fmt.Errorf("cannot decode block option: %w", err)
}
sentRequest := b.getSentRequest(token)

sentRequest := b.getSentRequest(exchangeKey, token)
if sentRequest != nil {
defer b.cc.ReleaseMessage(sentRequest)
}
Expand All @@ -774,11 +789,12 @@
if err != nil {
return fmt.Errorf("cannot process message: %w", err)
}
// we need to create new exchangeKey because token was changed
exchangeKey = token.Hash()
}

tokenStr := token.Hash()
var cachedReceivedMessageGuard *messageGuard
if e := b.receivingMessagesCache.Load(tokenStr); e != nil {
if e := b.receivingMessagesCache.Load(exchangeKey); e != nil {
cachedReceivedMessageGuard = e.Data()
}
if cachedReceivedMessageGuard == nil {
Expand All @@ -789,15 +805,15 @@
return nil
}
}
cachedReceivedMessage, closeCachedReceivedMessage, err := b.getCachedReceivedMessage(cachedReceivedMessageGuard, r, tokenStr, validUntil)
cachedReceivedMessage, closeCachedReceivedMessage, err := b.getCachedReceivedMessage(cachedReceivedMessageGuard, r, exchangeKey, validUntil)
if err != nil {
return err
}
defer closeCachedReceivedMessage()

defer func(err *error) {
if *err != nil {
b.receivingMessagesCache.Delete(tokenStr)
b.receivingMessagesCache.Delete(exchangeKey)
}
}(&err)
payloadFile, payloadSize, err := b.getPayloadFromCachedReceivedMessage(r, cachedReceivedMessage)
Expand All @@ -811,12 +827,12 @@
return fmt.Errorf("cannot copy data to payload: %w", err)
}
if !more {
b.receivingMessagesCache.Delete(tokenStr)
b.receivingMessagesCache.Delete(exchangeKey)
cachedReceivedMessage.Remove(blockType)
cachedReceivedMessage.Remove(sizeType)
cachedReceivedMessage.SetType(r.Type())
if !bytes.Equal(cachedReceivedMessage.Token(), token) {
b.sendingMessagesCache.Delete(tokenStr)
b.sendingMessagesCache.Delete(exchangeKey)
}
_, errS := cachedReceivedMessage.Body().Seek(0, io.SeekStart)
if errS != nil {
Expand Down Expand Up @@ -849,3 +865,17 @@
w.SetMessage(sendMessage)
return nil
}

// getExchangeKey returns a key for the blockwise exchange cache.
// According to RFC 7252: "[...] An empty token value is appropriate e.g., when no other tokens are in use to a destination"
// so we need to generate a key from the remote address of the corresponding endpoint.
func getExchangeKey(addr net.Addr) (uint64, error) {

Check failure on line 872 in net/blockwise/blockwise.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary leading newline (whitespace)

if addr == nil {
return 0, errors.New("cannot get exchange key: addr is nil")
}

h := fnv.New64a()
h.Write([]byte(addr.String()))
return h.Sum64(), nil
}
10 changes: 5 additions & 5 deletions net/blockwise/blockwise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ func makeDo[C Client](t *testing.T, sender, receiver *BlockWise[C], senderMaxSZX
for {
var resp *pool.Message
receiverResp := responsewriter.New(receiver.cc.AcquireMessage(roReq.Context()), receiver.cc)
receiver.Handle(receiverResp, roReq, senderMaxSZX, senderMaxMessageSize, next)
receiver.Handle(receiverResp, roReq, senderMaxSZX, senderMaxMessageSize, nil, next)
t.Logf("receiver.Handle - receiverResp %v senderResp.Message: %v\n", receiverResp.Message(), roReq)
senderResp := responsewriter.New(sender.cc.AcquireMessage(roReq.Context()), sender.cc)
sender.Handle(senderResp, receiverResp.Message(), receiverMaxSZX, receiverMaxMessageSize, func(_ *responsewriter.ResponseWriter[C], r *pool.Message) {
sender.Handle(senderResp, receiverResp.Message(), receiverMaxSZX, receiverMaxMessageSize, nil, func(_ *responsewriter.ResponseWriter[C], r *pool.Message) {
resp = r
})
t.Logf("sender.Handle - senderResp %v receiverResp.Message: %v\n", senderResp.Message(), receiverResp.Message())
Expand Down Expand Up @@ -299,7 +299,7 @@ func TestBlockWiseDo(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := sender.Do(toPoolMessage(tt.args.r), tt.args.szx, uint32(tt.args.maxMessageSize), tt.args.do)
got, err := sender.Do(toPoolMessage(tt.args.r), tt.args.szx, uint32(tt.args.maxMessageSize), nil, tt.args.do)
if tt.wantErr {
require.Error(t, err)
return
Expand Down Expand Up @@ -496,13 +496,13 @@ func makeWriteReq[C Client](sender, receiver *BlockWise[C], senderMaxSZX SZX, se
roReq = req
for {
receiverResp := responsewriter.New(receiver.cc.AcquireMessage(roReq.Context()), receiver.cc)
receiver.Handle(receiverResp, roReq, senderMaxSZX, senderMaxMessageSize, func(w *responsewriter.ResponseWriter[C], r *pool.Message) {
receiver.Handle(receiverResp, roReq, senderMaxSZX, senderMaxMessageSize, nil, func(w *responsewriter.ResponseWriter[C], r *pool.Message) {
defer close(c)
next(w, r)
})
senderResp := responsewriter.New(sender.cc.AcquireMessage(roReq.Context()), sender.cc)
orig := senderResp.Message()
sender.Handle(senderResp, receiverResp.Message(), receiverMaxSZX, receiverMaxMessageSize, func(*responsewriter.ResponseWriter[C], *pool.Message) {
sender.Handle(senderResp, receiverResp.Message(), receiverMaxSZX, receiverMaxMessageSize, nil, func(*responsewriter.ResponseWriter[C], *pool.Message) {
})
if orig == senderResp.Message() {
select {
Expand Down
4 changes: 2 additions & 2 deletions tcp/client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ func (cc *Conn) do(req *pool.Message) (*pool.Message, error) {
if !cc.peerBlockWiseTranferEnabled.Load() || cc.blockWise == nil {
return cc.doInternal(req)
}
resp, err := cc.blockWise.Do(req, cc.blockwiseSZX, cc.session.maxMessageSize, cc.doInternal)
resp, err := cc.blockWise.Do(req, cc.blockwiseSZX, cc.session.maxMessageSize, cc.RemoteAddr(), cc.doInternal)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -352,7 +352,7 @@ func (cc *Conn) blockwiseHandle(w *responsewriter.ResponseWriter[*Conn], r *pool

func (cc *Conn) handle(w *responsewriter.ResponseWriter[*Conn], r *pool.Message) {
if cc.blockWise != nil && cc.peerBlockWiseTranferEnabled.Load() {
cc.blockWise.Handle(w, r, cc.blockwiseSZX, cc.Session().maxMessageSize, cc.blockwiseHandle)
cc.blockWise.Handle(w, r, cc.blockwiseSZX, cc.Session().maxMessageSize, cc.RemoteAddr(), cc.blockwiseHandle)
return
}
if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok {
Expand Down
4 changes: 2 additions & 2 deletions udp/client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ func (cc *Conn) do(req *pool.Message) (*pool.Message, error) {
if cc.blockWise == nil {
return cc.doInternal(req)
}
resp, err := cc.blockWise.Do(req, cc.blockwiseSZX, cc.session.MaxMessageSize(), func(bwReq *pool.Message) (*pool.Message, error) {
resp, err := cc.blockWise.Do(req, cc.blockwiseSZX, cc.session.MaxMessageSize(), cc.RemoteAddr(), func(bwReq *pool.Message) (*pool.Message, error) {
if bwReq.Options().HasOption(message.Block1) || bwReq.Options().HasOption(message.Block2) {
bwReq.SetMessageID(cc.GetMessageID())
}
Expand Down Expand Up @@ -663,7 +663,7 @@ func (cc *Conn) handle(w *responsewriter.ResponseWriter[*Conn], m *pool.Message)
return
}
if cc.blockWise != nil {
cc.blockWise.Handle(w, m, cc.blockwiseSZX, cc.session.MaxMessageSize(), func(rw *responsewriter.ResponseWriter[*Conn], rm *pool.Message) {
cc.blockWise.Handle(w, m, cc.blockwiseSZX, cc.session.MaxMessageSize(), cc.RemoteAddr(), func(rw *responsewriter.ResponseWriter[*Conn], rm *pool.Message) {
if h, ok := cc.tokenHandlerContainer.LoadAndDelete(rm.Token().Hash()); ok {
h(rw, rm)
return
Expand Down
Loading