diff --git a/.gitignore b/.gitignore index 7c18b313..16611204 100644 --- a/.gitignore +++ b/.gitignore @@ -11,7 +11,7 @@ client !client/ vendor/ v3/ - +/.idea/ # Test binary, build with `go test -c` *.test diff --git a/examples/simple/client/main.go b/examples/simple/client/main.go index 35640800..8e6f45d4 100644 --- a/examples/simple/client/main.go +++ b/examples/simple/client/main.go @@ -10,7 +10,7 @@ import ( ) func main() { - co, err := udp.Dial("localhost:5688") + co, err := udp.Dial("localhost:5685") if err != nil { log.Fatalf("Error dialing: %v", err) } diff --git a/net/blockwise/blockwise.go b/net/blockwise/blockwise.go index 33bc53b5..223ddf96 100644 --- a/net/blockwise/blockwise.go +++ b/net/blockwise/blockwise.go @@ -5,7 +5,9 @@ import ( "context" "errors" "fmt" + "hash/fnv" "io" + "net" "time" "github.com/dsnet/golib/memfile" @@ -156,7 +158,7 @@ func newRequestGuard(request *pool.Message) *messageGuard { } // New provides blockwise. -// getSentRequestFromOutside must returns a copy of request which will be released after use. +// getSentRequestFromOutside must return a copy of request which will be released after use. func New[C Client]( cc C, expiration time.Duration, @@ -202,8 +204,8 @@ func payloadSizeError(err error) error { return fmt.Errorf("cannot get size of payload: %w", err) } -// Do sends an coap message and returns an coap response via blockwise transfer. -func (b *BlockWise[C]) Do(r *pool.Message, maxSzx SZX, maxMessageSize uint32, do func(req *pool.Message) (*pool.Message, error)) (*pool.Message, error) { +// Do sends a coap message and returns a coap response via blockwise transfer. +func (b *BlockWise[C]) Do(r *pool.Message, maxSzx SZX, maxMessageSize uint32, remoteAddr net.Addr, do func(req *pool.Message) (*pool.Message, error)) (*pool.Message, error) { if maxSzx > SZXBERT { return nil, errors.New("invalid szx") } @@ -215,11 +217,13 @@ func (b *BlockWise[C]) Do(r *pool.Message, maxSzx SZX, maxMessageSize uint32, do if !ok { expire = time.Now().Add(b.expiration) } - _, loaded := b.sendingMessagesCache.LoadOrStore(r.Token().Hash(), cache.NewElement(r, expire, nil)) + + exchangeKey := r.Token().Hash() + _, loaded := b.sendingMessagesCache.LoadOrStore(exchangeKey, cache.NewElement(r, expire, nil)) if loaded { return nil, errors.New("invalid token") } - defer b.sendingMessagesCache.Delete(r.Token().Hash()) + defer b.sendingMessagesCache.Delete(exchangeKey) if r.Body() == nil { return do(r) } @@ -289,7 +293,7 @@ func (b *BlockWise[C]) WriteMessage(request *pool.Message, maxSZX SZX, maxMessag } w := newWriteRequestResponse(b.cc, request) - err = b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock) + err = b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock, 0) if err != nil { return fmt.Errorf("cannot start writing request: %w", err) } @@ -347,44 +351,48 @@ func (b *BlockWise[C]) getSendingMessageCode(token uint64) (codes.Code, bool) { } // Handle middleware which constructs COAP request from blockwise transfer and send COAP response via blockwise. -func (b *BlockWise[C]) Handle(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) { +func (b *BlockWise[C]) Handle(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, remoteAddr net.Addr, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) { if maxSZX > SZXBERT { panic("invalid maxSZX") } - token := r.Token() - if len(token) == 0 { - err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next) + var exchangeKey uint64 + if len(r.Token()) != 0 { + exchangeKey = r.Token().Hash() + } else { + var err error + exchangeKey, err = getExchangeKey(remoteAddr) if err != nil { - b.sendEntityIncomplete(w, token) - b.errors(fmt.Errorf("handleReceivedMessage(%v): %w", r, err)) + b.errors(fmt.Errorf("cannot get exchange key from request Parameters(%v): %w", r, err)) + return } - return + } - tokenStr := token.Hash() - sendingMessageCode, sendingMessageExist := b.getSendingMessageCode(tokenStr) + token := r.Token() + + sendingMessageCode, sendingMessageExist := b.getSendingMessageCode(exchangeKey) if !sendingMessageExist || wantsToBeReceived(r) { - err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, next) + err := b.handleReceivedMessage(w, r, maxSZX, maxMessageSize, exchangeKey, next) if err != nil { b.sendEntityIncomplete(w, token) b.errors(fmt.Errorf("handleReceivedMessage(%v): %w", r, err)) } return } - more, err := b.continueSendingMessage(w, r, maxSZX, maxMessageSize, sendingMessageCode) + more, err := b.continueSendingMessage(w, r, maxSZX, maxMessageSize, sendingMessageCode, exchangeKey) if err != nil { - b.sendingMessagesCache.Delete(tokenStr) + b.sendingMessagesCache.Delete(exchangeKey) b.errors(fmt.Errorf("continueSendingMessage(%v): %w", r, err)) return } // For codes GET,POST,PUT,DELETE, we want them to wait for pairing response and then delete them when the full response comes in or when timeout occurs. if !more && sendingMessageCode > codes.DELETE { - b.sendingMessagesCache.Delete(tokenStr) + b.sendingMessagesCache.Delete(exchangeKey) } } -func (b *BlockWise[C]) handleReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) error { +func (b *BlockWise[C]) handleReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, exchangeKey uint64, next func(w *responsewriter.ResponseWriter[C], r *pool.Message)) error { startSendingMessageBlock, err := EncodeBlockOption(maxSZX, 0, true) if err != nil { return fmt.Errorf("cannot encode start sending message block option(%v,%v,%v): %w", maxSZX, 0, true, err) @@ -402,18 +410,18 @@ func (b *BlockWise[C]) handleReceivedMessage(w *responsewriter.ResponseWriter[C] } case codes.POST, codes.PUT: maxSZX = fitSZX(r, message.Block1, maxSZX) - errP := b.processReceivedMessage(w, r, maxSZX, next, message.Block1, message.Size1) + errP := b.processReceivedMessage(w, r, maxSZX, next, exchangeKey, message.Block1, message.Size1) if errP != nil { return errP } default: maxSZX = fitSZX(r, message.Block2, maxSZX) - errP := b.processReceivedMessage(w, r, maxSZX, next, message.Block2, message.Size2) + errP := b.processReceivedMessage(w, r, maxSZX, next, exchangeKey, message.Block2, message.Size2) if errP != nil { return errP } } - return b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock) + return b.startSendingMessage(w, maxSZX, maxMessageSize, startSendingMessageBlock, exchangeKey) } func (b *BlockWise[C]) createSendingMessage(sendingMessage *pool.Message, maxSZX SZX, maxMessageSize uint32, block uint32) (sendMessage *pool.Message, more bool, err error) { @@ -497,7 +505,7 @@ func (b *BlockWise[C]) createSendingMessage(sendingMessage *pool.Message, maxSZX return sendMessage, more, nil } -func (b *BlockWise[C]) continueSendingMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, sendingMessageCode codes.Code /* msg *pool.Message*/) (bool, error) { +func (b *BlockWise[C]) continueSendingMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSZX SZX, maxMessageSize uint32, sendingMessageCode codes.Code /* msg *pool.Message*/, exchangeKey uint64) (bool, error) { blockType := message.Block2 switch sendingMessageCode { case codes.POST, codes.PUT: @@ -510,7 +518,7 @@ func (b *BlockWise[C]) continueSendingMessage(w *responsewriter.ResponseWriter[C } var sendMessage *pool.Message var more bool - b.sendingMessagesCache.LoadWithFunc(r.Token().Hash(), func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] { + b.sendingMessagesCache.LoadWithFunc(exchangeKey, func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] { sendMessage, more, err = b.createSendingMessage(value.Data(), maxSZX, maxMessageSize, block) if err != nil { err = fmt.Errorf("cannot create sending message: %w", err) @@ -535,7 +543,7 @@ func isObserveResponse(msg *pool.Message) bool { return msg.Code() >= codes.Created } -func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C], maxSZX SZX, maxMessageSize uint32, block uint32) error { +func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C], maxSZX SZX, maxMessageSize uint32, block uint32, exchangeKey uint64) error { payloadSize, err := w.Message().BodySize() if err != nil { return payloadSizeError(err) @@ -558,7 +566,12 @@ func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C], if !ok { expire = time.Now().Add(b.expiration) } - el, loaded := b.sendingMessagesCache.LoadOrStore(sendingMessage.Token().Hash(), cache.NewElement(originalSendingMessage, expire, nil)) + + if exchangeKey == 0 { + exchangeKey = sendingMessage.Token().Hash() + } + + el, loaded := b.sendingMessagesCache.LoadOrStore(exchangeKey, cache.NewElement(originalSendingMessage, expire, nil)) if loaded { defer b.cc.ReleaseMessage(originalSendingMessage) return fmt.Errorf("cannot add message (%v) to sending message cache: message(%v) with token(%v) already exist", originalSendingMessage, el.Data(), sendingMessage.Token()) @@ -566,8 +579,9 @@ func (b *BlockWise[C]) startSendingMessage(w *responsewriter.ResponseWriter[C], return nil } -func (b *BlockWise[C]) getSentRequest(token message.Token) *pool.Message { - data, ok := b.sendingMessagesCache.LoadWithFunc(token.Hash(), func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] { +func (b *BlockWise[C]) getSentRequest(exchangeKey uint64, token message.Token) *pool.Message { + + data, ok := b.sendingMessagesCache.LoadWithFunc(exchangeKey, func(value *cache.Element[*pool.Message]) *cache.Element[*pool.Message] { if value == nil { return nil } @@ -589,7 +603,7 @@ func (b *BlockWise[C]) getSentRequest(token message.Token) *pool.Message { return nil } -func (b *BlockWise[C]) handleObserveResponse(sentRequest *pool.Message) (message.Token, time.Time, error) { +func (b *BlockWise[C]) handleObserveResponse(sentRequest *pool.Message, exchangeKey uint64) (message.Token, time.Time, error) { // https://tools.ietf.org/html/rfc7959#section-2.6 - performs GET with new token. if sentRequest == nil { return nil, time.Time{}, errors.New("observation is not registered") @@ -601,7 +615,7 @@ func (b *BlockWise[C]) handleObserveResponse(sentRequest *pool.Message) (message validUntil := time.Now().Add(b.expiration) // context of observation can be expired. bwSentRequest := b.cloneMessage(sentRequest) bwSentRequest.SetToken(token) - _, loaded := b.sendingMessagesCache.LoadOrStore(token.Hash(), cache.NewElement(bwSentRequest, validUntil, nil)) + _, loaded := b.sendingMessagesCache.LoadOrStore(exchangeKey, cache.NewElement(bwSentRequest, validUntil, nil)) if loaded { return nil, time.Time{}, errors.New("cannot process message: message with token already exist") } @@ -739,7 +753,7 @@ func (b *BlockWise[C]) getCachedReceivedMessage(mg *messageGuard, r *pool.Messag } //nolint:gocyclo,gocognit -func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSzx SZX, next func(w *responsewriter.ResponseWriter[C], r *pool.Message), blockType message.OptionID, sizeType message.OptionID) error { +func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C], r *pool.Message, maxSzx SZX, next func(w *responsewriter.ResponseWriter[C], r *pool.Message), exchangeKey uint64, blockType message.OptionID, sizeType message.OptionID) error { token := r.Token() if len(token) == 0 { next(w, r) @@ -761,7 +775,8 @@ func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C if err != nil { return fmt.Errorf("cannot decode block option: %w", err) } - sentRequest := b.getSentRequest(token) + + sentRequest := b.getSentRequest(exchangeKey, token) if sentRequest != nil { defer b.cc.ReleaseMessage(sentRequest) } @@ -770,15 +785,14 @@ func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C return errors.New("cannot request body without paired request") } if isObserveResponse(r) { - token, validUntil, err = b.handleObserveResponse(sentRequest) + token, validUntil, err = b.handleObserveResponse(sentRequest, exchangeKey) if err != nil { return fmt.Errorf("cannot process message: %w", err) } } - tokenStr := token.Hash() var cachedReceivedMessageGuard *messageGuard - if e := b.receivingMessagesCache.Load(tokenStr); e != nil { + if e := b.receivingMessagesCache.Load(exchangeKey); e != nil { cachedReceivedMessageGuard = e.Data() } if cachedReceivedMessageGuard == nil { @@ -789,7 +803,7 @@ func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C return nil } } - cachedReceivedMessage, closeCachedReceivedMessage, err := b.getCachedReceivedMessage(cachedReceivedMessageGuard, r, tokenStr, validUntil) + cachedReceivedMessage, closeCachedReceivedMessage, err := b.getCachedReceivedMessage(cachedReceivedMessageGuard, r, exchangeKey, validUntil) if err != nil { return err } @@ -797,7 +811,7 @@ func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C defer func(err *error) { if *err != nil { - b.receivingMessagesCache.Delete(tokenStr) + b.receivingMessagesCache.Delete(exchangeKey) } }(&err) payloadFile, payloadSize, err := b.getPayloadFromCachedReceivedMessage(r, cachedReceivedMessage) @@ -811,12 +825,12 @@ func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C return fmt.Errorf("cannot copy data to payload: %w", err) } if !more { - b.receivingMessagesCache.Delete(tokenStr) + b.receivingMessagesCache.Delete(exchangeKey) cachedReceivedMessage.Remove(blockType) cachedReceivedMessage.Remove(sizeType) cachedReceivedMessage.SetType(r.Type()) if !bytes.Equal(cachedReceivedMessage.Token(), token) { - b.sendingMessagesCache.Delete(tokenStr) + b.sendingMessagesCache.Delete(exchangeKey) } _, errS := cachedReceivedMessage.Body().Seek(0, io.SeekStart) if errS != nil { @@ -849,3 +863,17 @@ func (b *BlockWise[C]) processReceivedMessage(w *responsewriter.ResponseWriter[C w.SetMessage(sendMessage) return nil } + +// getExchangeKey returns a key for the blockwise exchange cache. +// According to RFC 7252: "[...] An empty token value is appropriate e.g., when no other tokens are in use to a destination" +// so we need to generate a key from the remote address of the corresponding endpoint. +func getExchangeKey(addr net.Addr) (uint64, error) { + + if addr == nil { + return 0, errors.New("cannot get exchange key: addr is nil") + } + + h := fnv.New64a() + h.Write([]byte(addr.String())) + return h.Sum64(), nil +} diff --git a/tcp/client/conn.go b/tcp/client/conn.go index 018ead4d..563f48d0 100644 --- a/tcp/client/conn.go +++ b/tcp/client/conn.go @@ -207,7 +207,7 @@ func (cc *Conn) do(req *pool.Message) (*pool.Message, error) { if !cc.peerBlockWiseTranferEnabled.Load() || cc.blockWise == nil { return cc.doInternal(req) } - resp, err := cc.blockWise.Do(req, cc.blockwiseSZX, cc.session.maxMessageSize, cc.doInternal) + resp, err := cc.blockWise.Do(req, cc.blockwiseSZX, cc.session.maxMessageSize, cc.RemoteAddr(), cc.doInternal) if err != nil { return nil, err } @@ -352,7 +352,7 @@ func (cc *Conn) blockwiseHandle(w *responsewriter.ResponseWriter[*Conn], r *pool func (cc *Conn) handle(w *responsewriter.ResponseWriter[*Conn], r *pool.Message) { if cc.blockWise != nil && cc.peerBlockWiseTranferEnabled.Load() { - cc.blockWise.Handle(w, r, cc.blockwiseSZX, cc.Session().maxMessageSize, cc.blockwiseHandle) + cc.blockWise.Handle(w, r, cc.blockwiseSZX, cc.Session().maxMessageSize, cc.RemoteAddr(), cc.blockwiseHandle) return } if h, ok := cc.tokenHandlerContainer.LoadAndDelete(r.Token().Hash()); ok { diff --git a/udp/client/conn.go b/udp/client/conn.go index 6e1703c4..4d1d0dab 100644 --- a/udp/client/conn.go +++ b/udp/client/conn.go @@ -440,7 +440,7 @@ func (cc *Conn) do(req *pool.Message) (*pool.Message, error) { if cc.blockWise == nil { return cc.doInternal(req) } - resp, err := cc.blockWise.Do(req, cc.blockwiseSZX, cc.session.MaxMessageSize(), func(bwReq *pool.Message) (*pool.Message, error) { + resp, err := cc.blockWise.Do(req, cc.blockwiseSZX, cc.session.MaxMessageSize(), cc.RemoteAddr(), func(bwReq *pool.Message) (*pool.Message, error) { if bwReq.Options().HasOption(message.Block1) || bwReq.Options().HasOption(message.Block2) { bwReq.SetMessageID(cc.GetMessageID()) } @@ -663,7 +663,7 @@ func (cc *Conn) handle(w *responsewriter.ResponseWriter[*Conn], m *pool.Message) return } if cc.blockWise != nil { - cc.blockWise.Handle(w, m, cc.blockwiseSZX, cc.session.MaxMessageSize(), func(rw *responsewriter.ResponseWriter[*Conn], rm *pool.Message) { + cc.blockWise.Handle(w, m, cc.blockwiseSZX, cc.session.MaxMessageSize(), cc.RemoteAddr(), func(rw *responsewriter.ResponseWriter[*Conn], rm *pool.Message) { if h, ok := cc.tokenHandlerContainer.LoadAndDelete(rm.Token().Hash()); ok { h(rw, rm) return