diff --git a/.travis.yml b/.travis.yml index 45d4e02..40078cf 100644 --- a/.travis.yml +++ b/.travis.yml @@ -18,6 +18,9 @@ before_install: - wget https://github.com/nats-io/nats-streaming-server/releases/download/v0.5.0/nats-streaming-server-v0.5.0-linux-amd64.zip - unzip -d gnatsd -j nats-streaming-server-v0.5.0-linux-amd64.zip - ./gnatsd/nats-streaming-server & + # install EMQ + - docker pull emqx/emqx:v3.0.0 + - docker run -d -p 127.0.0.1:1883:1883 --name mqtt -e 'EMQX_ALLOW_ANONYMOUS=true' emqx/emqx:v3.0.0 # give the queues some time to start. - sleep 5 @@ -29,6 +32,7 @@ before_script: services: - redis - rabbitmq + - docker script: - go test -v -timeout 30s -race ./... diff --git a/README.md b/README.md index 2ef9dfa..300ce6c 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # Go channels at horizontal scale [![Build Status](https://travis-ci.org/matryer/vice.svg?branch=master)](https://travis-ci.org/matryer/vice) -* Use Go channels transparently over a [messaging queue technology of your choice](https://github.com/matryer/vice/tree/master/queues) (Currently [NATS](http://nats.io), [Redis](http://redis.io) or [NSQ](http://nsq.io), [Amazon SQS](https://aws.amazon.com/sqs/)) +* Use Go channels transparently over a [messaging queue technology of your choice](https://github.com/matryer/vice/tree/master/queues) (Currently [NATS](http://nats.io), [Redis](http://redis.io), [NSQ](http://nsq.io), [Amazon SQS](https://aws.amazon.com/sqs/)), [RabbitMQ](https://rabbitmq.com), or [MQTT, with shared subscription](https://mqtt.org). * Swap `vice.Transport` to change underlying queueing technologies transparently * Write idiomatic Go code instead of learning queue specific APIs * Develop against in-memory implementation before putting it into the wild diff --git a/queues/mqtt/mqtt.go b/queues/mqtt/mqtt.go new file mode 100644 index 0000000..ce7962e --- /dev/null +++ b/queues/mqtt/mqtt.go @@ -0,0 +1,199 @@ +package mqtt + +import ( + "sync" + "time" + + "github.com/eclipse/paho.mqtt.golang" + "github.com/matryer/vice" +) + +const ( + DefaultMQTTAddress = "tcp://127.0.0.1:1883" + + SharedQueuePrefix = "$share/vice/" +) + +// make sure Transport satisfies vice.Transport interface. +var _ vice.Transport = (*Transport)(nil) + +type Transport struct { + sync.Mutex + wg sync.WaitGroup + + clientOptions *mqtt.ClientOptions + subQoS byte + pubQoS byte + pubRetained bool + subTimeout time.Duration + pubTimeout time.Duration + + subClients map[string]mqtt.Client + + pubChans map[string]chan []byte + subChans map[string]chan []byte + + errChan chan error + stopChan chan struct{} + stopPubChan chan struct{} +} + +func New(opts ...Option) *Transport { + var options Options + for _, o := range opts { + o(&options) + } + + if options.ClientOptions == nil { + options.ClientOptions = mqtt.NewClientOptions() + options.ClientOptions.AddBroker(DefaultMQTTAddress) + } + + if options.PubTimeout == 0 { + options.PubTimeout = time.Second + } + if options.SubTimeout == 0 { + options.SubTimeout = time.Second + } + + return &Transport{ + clientOptions: options.ClientOptions, + + subQoS: options.SubQoS, + subTimeout: options.SubTimeout, + + pubQoS: options.PubQoS, + pubRetained: options.PubRetained, + pubTimeout: options.PubTimeout, + + subClients: make(map[string]mqtt.Client), + + pubChans: make(map[string]chan []byte), + subChans: make(map[string]chan []byte), + + errChan: make(chan error, 10), + stopChan: make(chan struct{}), + stopPubChan: make(chan struct{}), + } +} + +func (t *Transport) Receive(name string) <-chan []byte { + t.Lock() + defer t.Unlock() + + subCh, ok := t.subChans[name] + if ok { + return subCh + } + + subCh, err := t.makeSubscriber(name) + if err != nil { + t.errChan <- &vice.Err{Name: name, Err: err} + return make(chan []byte) + } + + t.subChans[name] = subCh + return subCh +} + +func (t *Transport) makeSubscriber(topic string) (chan []byte, error) { + ch := make(chan []byte, 1024) + + cli := mqtt.NewClient(t.clientOptions) + if token := cli.Connect(); token.Wait() && token.Error() != nil { + return nil, token.Error() + } + if token := cli.Subscribe(SharedQueuePrefix+topic, t.subQoS, func(c mqtt.Client, msg mqtt.Message) { + if !cli.IsConnected() { + if token := cli.Connect(); token.Wait() && token.Error() != nil { + t.errChan <- &vice.Err{Name: topic, Err: token.Error(), Message: msg.Payload()} + } + return + } + ch <- msg.Payload() + }); token.WaitTimeout(t.subTimeout) && token.Error() != nil { + return nil, token.Error() + } + + t.subClients[topic] = cli + + return ch, nil +} + +func (t *Transport) Send(name string) chan<- []byte { + t.Lock() + defer t.Unlock() + + pubCh, ok := t.pubChans[name] + if ok { + return pubCh + } + + pubCh, err := t.makePublisher(name) + if err != nil { + t.errChan <- &vice.Err{Name: name, Err: err} + return make(chan []byte) + } + + t.pubChans[name] = pubCh + return pubCh +} + +func (t *Transport) makePublisher(topic string) (chan []byte, error) { + + ch := make(chan []byte, 1024) + + cli := mqtt.NewClient(t.clientOptions) + if token := cli.Connect(); token.Wait() && token.Error() != nil { + return nil, token.Error() + } + + t.wg.Add(1) + go func() { + defer t.wg.Done() + for { + select { + case <-t.stopPubChan: + cli.Disconnect(100) + return + case msg := <-ch: + if !cli.IsConnected() { + if token := cli.Connect(); token.Wait() && token.Error() != nil { + t.errChan <- &vice.Err{Name: topic, Err: token.Error(), Message: msg} + } + continue + } + if token := cli.Publish(topic, t.pubQoS, t.pubRetained, msg); token.WaitTimeout(t.pubTimeout) && token.Error() != nil { + t.errChan <- &vice.Err{Name: topic, Err: token.Error()} + } + } + } + }() + + return ch, nil +} + +func (t *Transport) ErrChan() <-chan error { + return t.errChan +} + +func (t *Transport) Stop() { + t.Lock() + defer t.Unlock() + + for topic, cli := range t.subClients { + if token := cli.Unsubscribe(SharedQueuePrefix + topic); token.Wait() && token.Error() != nil { + t.errChan <- &vice.Err{Name: topic, Err: token.Error()} + } + cli.Disconnect(100) + } + + close(t.stopPubChan) + t.wg.Wait() + + close(t.stopChan) +} + +func (t *Transport) Done() chan struct{} { + return t.stopChan +} diff --git a/queues/mqtt/mqtt_test.go b/queues/mqtt/mqtt_test.go new file mode 100644 index 0000000..34fc291 --- /dev/null +++ b/queues/mqtt/mqtt_test.go @@ -0,0 +1,105 @@ +package mqtt + +import ( + "sync" + "testing" + "time" + + "github.com/eclipse/paho.mqtt.golang" + "github.com/matryer/is" + "github.com/matryer/vice" + "github.com/matryer/vice/vicetest" +) + +func TestDefaultTransport(t *testing.T) { + new := func() vice.Transport { + return New() + } + + vicetest.Transport(t, new) +} + +func TestReceive(t *testing.T) { + is := is.New(t) + + transport := New() + + var wg sync.WaitGroup + + go func() { + for { + select { + case <-transport.Done(): + return + case err := <-transport.ErrChan(): + is.NoErr(err) + case msg := <-transport.Receive("test_receive"): + is.Equal(msg, []byte("hello vice")) + wg.Done() + case <-time.After(2 * time.Second): + is.Fail() // time out: transport.Receive + } + } + }() + + time.Sleep(time.Millisecond * 100) + + // create new client + opts := mqtt.NewClientOptions().AddBroker("tcp://localhost:1883") + cli := mqtt.NewClient(opts) + if token := cli.Connect(); token.Wait() && token.Error() != nil { + is.NoErr(token.Error()) + } + + wg.Add(1) + + // publish + if token := cli.Publish("test_receive", 0, false, []byte("hello vice")); token.Wait() && token.Error() != nil { + is.NoErr(token.Error()) + } + + wg.Wait() + + transport.Stop() + <-transport.Done() +} + +func TestSend(t *testing.T) { + is := is.New(t) + + transport := New() + + var wg sync.WaitGroup + + var msgs [][]byte + + go func() { + for { + select { + case <-transport.Done(): + return + case err := <-transport.ErrChan(): + is.NoErr(err) + case msg := <-transport.Receive("test_send"): + msgs = append(msgs, msg) + is.Equal(msg, []byte("hello vice")) + wg.Done() + case <-time.After(3 * time.Second): + is.Fail() // time out: transport.Receive + } + } + }() + + time.Sleep(time.Millisecond * 100) + wg.Add(1) + + transport.Send("test_send") <- []byte("hello vice") + + wg.Wait() + + is.Equal(len(msgs), 1) + is.Equal(transport.Send("test_send"), transport.Send("test_send")) + is.True(transport.Send("test_send") != transport.Send("different")) + + transport.Stop() +} diff --git a/queues/mqtt/options.go b/queues/mqtt/options.go new file mode 100644 index 0000000..b0ddf21 --- /dev/null +++ b/queues/mqtt/options.go @@ -0,0 +1,40 @@ +package mqtt + +import ( + "time" + + "github.com/eclipse/paho.mqtt.golang" +) + +type Options struct { + ClientOptions *mqtt.ClientOptions + SubQoS byte + PubQoS byte + PubRetained bool + + SubTimeout time.Duration + PubTimeout time.Duration +} + +type Option func(*Options) + +func WithClientOptions(c *mqtt.ClientOptions) Option { + return func(o *Options) { + o.ClientOptions = c + } +} + +func WithPub(qos byte, retained bool, timeout time.Duration) Option { + return func(o *Options) { + o.PubQoS = qos + o.PubRetained = retained + o.PubTimeout = timeout + } +} + +func WithSub(qos byte, timeout time.Duration) Option { + return func(o *Options) { + o.SubQoS = qos + o.SubTimeout = timeout + } +}