Skip to content
11 changes: 11 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,17 @@ func (c *dedicatedSingleClient) SetPubSubHooks(hooks PubSubHooks) <-chan error {
return c.wire.SetPubSubHooks(hooks)
}

func (c *dedicatedSingleClient) SetOnInvalidations(fn func([]RedisMessage)) <-chan error {
if err := c.check(); err != nil {
ch := make(chan error, 1)
ch <- err
return ch
}
hooks := c.wire.GetPubSubHooks()
hooks.onInvalidations = fn
return c.SetPubSubHooks(hooks)
}
Comment thread
cursor[bot] marked this conversation as resolved.

func (c *dedicatedSingleClient) Close() {
c.wire.Close()
c.release()
Expand Down
4 changes: 4 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ func (m *mockConn) SetPubSubHooks(_ PubSubHooks) <-chan error {
panic("not implemented")
}

func (m *mockConn) GetPubSubHooks() PubSubHooks {
return PubSubHooks{}
}

func (m *mockConn) SetOnCloseHook(func(error)) {

}
Expand Down
13 changes: 13 additions & 0 deletions cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -1611,6 +1611,19 @@ func (c *dedicatedClusterClient) SetPubSubHooks(hooks PubSubHooks) <-chan error
return ch
}

func (c *dedicatedClusterClient) SetOnInvalidations(fn func([]RedisMessage)) <-chan error {
c.mu.Lock()
var hooks PubSubHooks
if c.wire != nil {
hooks = c.wire.GetPubSubHooks()
} else if c.pshks != nil {
hooks = c.pshks.hooks
}
c.mu.Unlock()
hooks.onInvalidations = fn
return c.SetPubSubHooks(hooks)
}
Comment thread
cursor[bot] marked this conversation as resolved.

func (c *dedicatedClusterClient) Close() {
c.mu.Lock()
if p := c.pshks; p != nil {
Expand Down
4 changes: 4 additions & 0 deletions internal/cmds/cmds.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ var (
cs: newCommandSlice([]string{"UNSUBSCRIBE", "+sentinel", "+slave", "-sdown", "+sdown", "+switch-master", "+reboot"}),
cf: unsubTag,
}
// ClientTrackingOffCmd is predefined CLIENT TRACKING OFF
ClientTrackingOffCmd = Completed{
cs: newCommandSlice([]string{"CLIENT", "TRACKING", "OFF"}),
}

// DiscardCmd is predefined DISCARD
DiscardCmd = Completed{
Expand Down
14 changes: 14 additions & 0 deletions mock/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,17 @@ func (mr *DedicatedClientMockRecorder) SetPubSubHooks(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetPubSubHooks", reflect.TypeOf((*DedicatedClient)(nil).SetPubSubHooks), arg0)
}

// SetOnInvalidations mocks base method.
func (m *DedicatedClient) SetOnInvalidations(arg0 func([]rueidis.RedisMessage)) <-chan error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SetOnInvalidations", arg0)
ret0, _ := ret[0].(<-chan error)
return ret0
}

// SetOnInvalidations indicates an expected call of SetOnInvalidations.
func (mr *DedicatedClientMockRecorder) SetOnInvalidations(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetOnInvalidations", reflect.TypeOf((*DedicatedClient)(nil).SetOnInvalidations), arg0)
}
4 changes: 4 additions & 0 deletions mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,8 +394,12 @@ func (m *mux) Acquire(ctx context.Context) wire {
}

func (m *mux) Store(w wire) {
hasOnInvalidations := w.GetPubSubHooks().onInvalidations != nil
w.SetPubSubHooks(PubSubHooks{})
w.CleanSubscriptions()
if hasOnInvalidations {
w.Do(context.Background(), cmds.ClientTrackingOffCmd)
}
Comment thread
cursor[bot] marked this conversation as resolved.
m.dpool.Store(w)
}

Expand Down
91 changes: 91 additions & 0 deletions mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,89 @@ func TestMuxReuseWire(t *testing.T) {
t.Fatalf("CleanSubscriptions not called")
}
})

t.Run("send CLIENT TRACKING OFF on store when onInvalidations was set", func(t *testing.T) {
cleaned := false
trackingOffCalls := 0

m, checkClean := setupMux([]*mockWire{
{
// leave first wire for pipeline calls
},
{
GetPubSubHooksFn: func() PubSubHooks {
return PubSubHooks{onInvalidations: func([]RedisMessage) {}}
},
CleanSubscriptionsFn: func() {
cleaned = true
},
DoFn: func(cmd Completed) RedisResult {
got := cmd.Commands()
if len(got) == 3 && got[0] == "CLIENT" && got[1] == "TRACKING" && got[2] == "OFF" {
trackingOffCalls++
return newResult(strmsg('+', "OK"), nil)
}
t.Fatalf("unexpected command: %v", got)
return RedisResult{}
},
},
})
defer checkClean(t)
defer m.Close()

if err := m.Dial(); err != nil {
t.Fatalf("unexpected dial error %v", err)
}

wire1 := m.Acquire(context.Background())
m.Store(wire1)

if !cleaned {
t.Fatalf("CleanSubscriptions not called")
}
if trackingOffCalls != 1 {
t.Fatalf("unexpected CLIENT TRACKING OFF calls: %d", trackingOffCalls)
}
})

t.Run("skip CLIENT TRACKING OFF on store when no onInvalidations was set", func(t *testing.T) {
cleaned := false
doCalled := false

m, checkClean := setupMux([]*mockWire{
{
// leave first wire for pipeline calls
},
{
GetPubSubHooksFn: func() PubSubHooks {
return PubSubHooks{OnMessage: func(PubSubMessage) {}}
},
CleanSubscriptionsFn: func() {
cleaned = true
},
DoFn: func(cmd Completed) RedisResult {
doCalled = true
return newResult(strmsg('+', "OK"), nil)
},
},
})
defer checkClean(t)
defer m.Close()

if err := m.Dial(); err != nil {
t.Fatalf("unexpected dial error %v", err)
}

wire1 := m.Acquire(context.Background())
m.Store(wire1)

if !cleaned {
t.Fatalf("CleanSubscriptions not called")
}
if doCalled {
t.Fatalf("CLIENT TRACKING OFF should not be sent when onInvalidations was not set")
}
})
}

//gocyclo:ignore
Expand Down Expand Up @@ -1136,6 +1219,7 @@ type mockWire struct {

CleanSubscriptionsFn func()
SetPubSubHooksFn func(hooks PubSubHooks) <-chan error
GetPubSubHooksFn func() PubSubHooks
SetOnCloseHookFn func(fn func(error))
}

Expand Down Expand Up @@ -1201,6 +1285,13 @@ func (m *mockWire) SetPubSubHooks(hooks PubSubHooks) <-chan error {
return nil
}

func (m *mockWire) GetPubSubHooks() PubSubHooks {
if m.GetPubSubHooksFn != nil {
return m.GetPubSubHooksFn()
}
return PubSubHooks{}
}

func (m *mockWire) SetOnCloseHook(fn func(error)) {
if m.SetOnCloseHookFn != nil {
m.SetOnCloseHookFn(fn)
Expand Down
74 changes: 51 additions & 23 deletions pipe.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ type wire interface {

CleanSubscriptions()
SetPubSubHooks(hooks PubSubHooks) <-chan error
GetPubSubHooks() PubSubHooks
SetOnCloseHook(fn func(error))
StopTimer() bool
ResetTimer() bool
Expand Down Expand Up @@ -434,7 +435,8 @@ func (p *pipe) _background() {
p.nsubs.Close()
p.psubs.Close()
p.ssubs.Close()
if old := p.pshks.Swap(emptypshks); old.close != nil {
old := p.pshks.Swap(emptypshks)
if old.close != nil {
old.close <- err
close(old.close)
}
Expand All @@ -451,6 +453,9 @@ func (p *pipe) _background() {
if p.onInvalidations != nil {
p.onInvalidations(nil)
}
if old.hooks.onInvalidations != nil {
old.hooks.onInvalidations(nil)
}

resp := newErrResult(err)
for p.loadWaits() != 0 {
Expand Down Expand Up @@ -743,64 +748,89 @@ func (p *pipe) handlePush(values []RedisMessage) (reply bool, unsubscribe bool)
p.onInvalidations(values[1].values())
}
}
if fn := p.pshks.Load().hooks.onInvalidations; fn != nil {
Comment thread
cursor[bot] marked this conversation as resolved.
if values[1].IsNil() {
fn(nil)
} else {
fn(values[1].values())
}
}
case "message":
if len(values) >= 3 {
m := PubSubMessage{Channel: values[1].string(), Message: values[2].string()}
p.nsubs.Publish(values[1].string(), m)
p.pshks.Load().hooks.OnMessage(m)
if fn := p.pshks.Load().hooks.OnMessage; fn != nil {
fn(m)
}
}
case "pmessage":
if len(values) >= 4 {
m := PubSubMessage{Pattern: values[1].string(), Channel: values[2].string(), Message: values[3].string()}
p.psubs.Publish(values[1].string(), m)
p.pshks.Load().hooks.OnMessage(m)
if fn := p.pshks.Load().hooks.OnMessage; fn != nil {
fn(m)
}
}
case "smessage":
if len(values) >= 3 {
m := PubSubMessage{Channel: values[1].string(), Message: values[2].string()}
p.ssubs.Publish(values[1].string(), m)
p.pshks.Load().hooks.OnMessage(m)
if fn := p.pshks.Load().hooks.OnMessage; fn != nil {
fn(m)
}
}
case "unsubscribe":
if len(values) >= 3 {
s := PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen}
p.nsubs.Unsubscribe(s)
p.pshks.Load().hooks.OnSubscription(s)
if fn := p.pshks.Load().hooks.OnSubscription; fn != nil {
fn(s)
}
}
return true, true
case "punsubscribe":
if len(values) >= 3 {
s := PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen}
p.psubs.Unsubscribe(s)
p.pshks.Load().hooks.OnSubscription(s)
if fn := p.pshks.Load().hooks.OnSubscription; fn != nil {
fn(s)
}
}
return true, true
case "sunsubscribe":
if len(values) >= 3 {
s := PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen}
p.ssubs.Unsubscribe(s)
p.pshks.Load().hooks.OnSubscription(s)
if fn := p.pshks.Load().hooks.OnSubscription; fn != nil {
fn(s)
}
}
return true, true
case "subscribe":
if len(values) >= 3 {
s := PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen}
p.nsubs.Confirm(s)
p.pshks.Load().hooks.OnSubscription(s)
if fn := p.pshks.Load().hooks.OnSubscription; fn != nil {
fn(s)
}
}
return true, false
case "psubscribe":
if len(values) >= 3 {
s := PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen}
p.psubs.Confirm(s)
p.pshks.Load().hooks.OnSubscription(s)
if fn := p.pshks.Load().hooks.OnSubscription; fn != nil {
fn(s)
}
}
return true, false
case "ssubscribe":
if len(values) >= 3 {
s := PubSubSubscription{Kind: values[0].string(), Channel: values[1].string(), Count: values[2].intlen}
p.ssubs.Confirm(s)
p.pshks.Load().hooks.OnSubscription(s)
if fn := p.pshks.Load().hooks.OnSubscription; fn != nil {
fn(s)
}
}
return true, false
}
Expand Down Expand Up @@ -892,6 +922,16 @@ func (p *pipe) CleanSubscriptions() {
}
}

func (p *pipe) GetPubSubHooks() PubSubHooks {
if p.r2p != nil {
return p.r2p.pipe(context.Background()).GetPubSubHooks()
}
if pshks := p.pshks.Load(); pshks != emptypshks {
return pshks.hooks
}
return PubSubHooks{}
}
Comment thread
cursor[bot] marked this conversation as resolved.

func (p *pipe) SetPubSubHooks(hooks PubSubHooks) <-chan error {
if p.r2p != nil {
return p.r2p.pipe(context.Background()).SetPubSubHooks(hooks)
Expand All @@ -902,12 +942,6 @@ func (p *pipe) SetPubSubHooks(hooks PubSubHooks) <-chan error {
}
return nil
}
if hooks.OnMessage == nil {
hooks.OnMessage = func(m PubSubMessage) {}
}
if hooks.OnSubscription == nil {
hooks.OnSubscription = func(s PubSubSubscription) {}
}
ch := make(chan error, 1)
if old := p.pshks.Swap(&pshks{hooks: hooks, close: ch}); old.close != nil {
close(old.close)
Expand Down Expand Up @@ -1784,13 +1818,7 @@ type pshks struct {
close chan error
}

var emptypshks = &pshks{
hooks: PubSubHooks{
OnMessage: func(m PubSubMessage) {},
OnSubscription: func(s PubSubSubscription) {},
},
close: nil,
}
var emptypshks = &pshks{}

var emptyclhks = func(error) {}

Expand Down
Loading
Loading