diff --git a/pkg/finality-grandpa/environment_test.go b/pkg/finality-grandpa/environment_test.go index 729a8ae987..ff66f942c4 100644 --- a/pkg/finality-grandpa/environment_test.go +++ b/pkg/finality-grandpa/environment_test.go @@ -187,63 +187,108 @@ func (*environment) PrecommitEquivocation( // p2p network data for a round. type BroadcastNetwork[M, N any] struct { - receiver chan M - senders []chan M - history []M - routing bool - wg sync.WaitGroup + receiver chan M + stop chan struct{} + mu sync.Mutex + senders []chan M + history []M + routing bool + stopped bool + routeWG sync.WaitGroup + forwarderWG sync.WaitGroup } func NewBroadcastNetwork[M, N any]() *BroadcastNetwork[M, N] { bn := BroadcastNetwork[M, N]{ receiver: make(chan M, 10000), + stop: make(chan struct{}), } return &bn } func (bm *BroadcastNetwork[M, N]) SendMessage(message M) { - bm.receiver <- message + select { + case bm.receiver <- message: + case <-bm.stop: + } } func (bm *BroadcastNetwork[M, N]) AddNode(f func(N) M, out chan N) (in chan M) { // buffer to 100 messages for now in = make(chan M, 10000) + bm.mu.Lock() // get history to the node. for _, priorMessage := range bm.history { in <- priorMessage } - bm.senders = append(bm.senders, in) - - if !bm.routing { + startRoute := !bm.routing + if startRoute { bm.routing = true - bm.wg.Add(1) + bm.routeWG.Add(1) + } + bm.mu.Unlock() + + if startRoute { go bm.route() } + bm.forwarderWG.Add(1) go func() { - for n := range out { - bm.receiver <- f(n) + defer bm.forwarderWG.Done() + for { + select { + case n, ok := <-out: + if !ok { + return + } + select { + case bm.receiver <- f(n): + case <-bm.stop: + return + } + case <-bm.stop: + return + } } }() return in } func (bm *BroadcastNetwork[M, N]) route() { - defer bm.wg.Done() + defer bm.routeWG.Done() for msg := range bm.receiver { + bm.mu.Lock() bm.history = append(bm.history, msg) - for _, sender := range bm.senders { + senders := append([]chan M(nil), bm.senders...) + bm.mu.Unlock() + for _, sender := range senders { sender <- msg } } } func (bm *BroadcastNetwork[M, N]) Stop() { + bm.mu.Lock() + if bm.stopped { + bm.mu.Unlock() + return + } + bm.stopped = true + close(bm.stop) + bm.mu.Unlock() + + // Order matters: drain forwarders first so they stop sending into receiver, + // then close receiver so route can exit, then close per-node senders. + bm.forwarderWG.Wait() close(bm.receiver) - bm.wg.Wait() - for _, sender := range bm.senders { + bm.routeWG.Wait() + bm.mu.Lock() + senders := bm.senders + bm.senders = nil + bm.mu.Unlock() + for _, sender := range senders { close(sender) } } diff --git a/pkg/finality-grandpa/timer.go b/pkg/finality-grandpa/timer.go index 3a044ef1c4..97c2755f7a 100644 --- a/pkg/finality-grandpa/timer.go +++ b/pkg/finality-grandpa/timer.go @@ -5,13 +5,14 @@ package grandpa import ( "sync" + "sync/atomic" "time" ) type timer struct { wakerChan *wakerChan[error] - mtx sync.Mutex - expired bool + closeOnce sync.Once + expired atomic.Bool } func newTimer(in <-chan time.Time) *timer { @@ -24,14 +25,11 @@ func newTimer(in <-chan time.Time) *timer { func (t *timer) poll(in <-chan time.Time) { <-in - t.mtx.Lock() - defer t.mtx.Unlock() - if t.wakerChan.in != nil { + t.closeOnce.Do(func() { t.wakerChan.in <- nil close(t.wakerChan.in) - t.wakerChan.in = nil - } - t.expired = true + }) + t.expired.Store(true) } func (t *timer) SetWaker(waker *waker) { @@ -39,14 +37,11 @@ func (t *timer) SetWaker(waker *waker) { } func (t *timer) Elapsed() (bool, error) { - return t.expired, nil + return t.expired.Load(), nil } func (t *timer) Close() { - t.mtx.Lock() - defer t.mtx.Unlock() - if t.wakerChan.in != nil { + t.closeOnce.Do(func() { close(t.wakerChan.in) - t.wakerChan.in = nil - } + }) } diff --git a/pkg/finality-grandpa/timer_test.go b/pkg/finality-grandpa/timer_test.go new file mode 100644 index 0000000000..b363339521 --- /dev/null +++ b/pkg/finality-grandpa/timer_test.go @@ -0,0 +1,111 @@ +// Copyright 2025 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package grandpa + +import ( + "sync" + "testing" + "time" +) + +// TestTimer_ElapsedConcurrentWithFiring exercises the read of `expired` from +// one goroutine while the timer's `poll` goroutine is writing it. With the +// pre-fix code (unsynchronized read in Elapsed) this test trips the race +// detector under `go test -race`. +func TestTimer_ElapsedConcurrentWithFiring(t *testing.T) { + t.Parallel() + tick := make(chan time.Time, 1) + timer := newTimer(tick) + + stop := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + for { + select { + case <-stop: + return + default: + _, _ = timer.Elapsed() + } + } + }() + + tick <- time.Now() + close(tick) + + deadline := time.Now().Add(2 * time.Second) + for { + elapsed, _ := timer.Elapsed() + if elapsed { + break + } + if time.Now().After(deadline) { + close(stop) + wg.Wait() + t.Fatal("timer never reported elapsed after the tick was consumed") + } + time.Sleep(time.Millisecond) + } + close(stop) + wg.Wait() +} + +// TestTimer_CloseIsIdempotent ensures Close() can be called more than once +// (and after poll has already drained the channel) without panicking — the +// `closed` flag prevents the double-close. +func TestTimer_CloseIsIdempotent(t *testing.T) { + t.Parallel() + tick := make(chan time.Time) + timer := newTimer(tick) + + timer.Close() + timer.Close() +} + +// TestWakerChan_SetWakerConcurrentWithItems exercises the write of `waker` +// from one goroutine while the `start` goroutine is reading it on every item. +// With the pre-fix code (plain *waker field) this trips the race detector. +func TestWakerChan_SetWakerConcurrentWithItems(t *testing.T) { + t.Parallel() + in := make(chan int, 100) + wc := newWakerChan(in) + + w1 := &waker{wakeCh: make(chan struct{}, 1000)} + w2 := &waker{wakeCh: make(chan struct{}, 1000)} + + // Drain the output channel so start() can keep making progress. + drained := make(chan struct{}) + go func() { + defer close(drained) + for range wc.channel() { + } + }() + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + for i := 0; i < 500; i++ { + in <- i + } + }() + + go func() { + defer wg.Done() + for i := 0; i < 500; i++ { + if i%2 == 0 { + wc.setWaker(w1) + } else { + wc.setWaker(w2) + } + } + }() + + wg.Wait() + close(in) + <-drained +} diff --git a/pkg/finality-grandpa/voter.go b/pkg/finality-grandpa/voter.go index 50d3ac6082..2b0d48afac 100644 --- a/pkg/finality-grandpa/voter.go +++ b/pkg/finality-grandpa/voter.go @@ -6,6 +6,7 @@ package grandpa import ( "fmt" "sync" + "sync/atomic" "time" "github.com/tidwall/btree" @@ -15,14 +16,13 @@ import ( type wakerChan[Item any] struct { in chan Item out chan Item - waker *waker + waker atomic.Pointer[waker] } func newWakerChan[Item any](in chan Item) *wakerChan[Item] { wc := &wakerChan[Item]{ - in: in, - out: make(chan Item), - waker: nil, + in: in, + out: make(chan Item), } go wc.start() return wc @@ -34,15 +34,15 @@ func (wc *wakerChan[Item]) start() { return } for item := range wc.in { - if wc.waker != nil { - wc.waker.wake() + if w := wc.waker.Load(); w != nil { + w.wake() } wc.out <- item } } func (wc *wakerChan[Item]) setWaker(waker *waker) { - wc.waker = waker + wc.waker.Store(waker) } // Chan returns a channel to consume `Item`. Not thread safe, only supports one consumer @@ -612,8 +612,11 @@ func NewVoter[Hash constraints.Ordered, Number constraints.Unsigned, Signature c } func (v *Voter[Hash, Number, Signature, ID]) pruneBackgroundRounds(waker *waker) error { + // Collect finalize notifications under the lock, then invoke + // env.FinalizeBlock outside it. Holding inner.Mutex across user-supplied + // callbacks is a deadlock hazard: a slow environment can block readers + // of the voter state and stall Stop(). v.inner.Lock() - defer v.inner.Unlock() pastRounds: for { @@ -622,6 +625,7 @@ pastRounds: switch ready { case true: if err != nil { + v.inner.Unlock() return err } if nc != nil { @@ -636,31 +640,31 @@ pastRounds: } v.finalizedNotifications.setWaker(waker) + var toFinalize []finalizedNotification[Hash, Number, Signature, ID] finalizedNotifications: for { select { case notif := <-v.finalizedNotifications.channel(): - fHash := notif.Hash fNum := notif.Number - round := notif.Round - commit := notif.Commit - v.inner.pastRounds.UpdateFinalized(fNum) if v.setLastFinalizedNumber(fNum) { - err := v.env.FinalizeBlock(fHash, fNum, round, commit) - if err != nil { - return err - } + toFinalize = append(toFinalize, notif) } - if fNum > v.lastFinalizedInRounds.Number { - v.lastFinalizedInRounds = HashNumber[Hash, Number]{fHash, fNum} + v.lastFinalizedInRounds = HashNumber[Hash, Number]{notif.Hash, fNum} } default: break finalizedNotifications } } + v.inner.Unlock() + + for _, n := range toFinalize { + if err := v.env.FinalizeBlock(n.Hash, n.Number, n.Round, n.Commit); err != nil { + return err + } + } return nil } @@ -818,6 +822,7 @@ func (v *Voter[Hash, Number, Signature, ID]) processBestRound(waker *waker) (boo var shouldStartNext bool completable, err := v.inner.bestRound.poll(waker) if err != nil { + v.inner.Unlock() return true, err } @@ -991,7 +996,6 @@ type sharedVoteState[ E Environment[Hash, Number, Signature, ID], ] struct { inner *innerVoterState[Hash, Number, Signature, ID, E] - mtx sync.Mutex } func (svs *sharedVoteState[Hash, Number, Signature, ID, E]) Get() VoterStateReport[ID] { @@ -1006,8 +1010,8 @@ func (svs *sharedVoteState[Hash, Number, Signature, ID, E]) Get() VoterStateRepo } } - svs.mtx.Lock() - defer svs.mtx.Unlock() + svs.inner.Lock() + defer svs.inner.Unlock() bestRoundNum, bestRound := toRoundState(svs.inner.bestRound) backgroundRounds := svs.inner.pastRounds.votingRounds() diff --git a/pkg/finality-grandpa/voter_test.go b/pkg/finality-grandpa/voter_test.go index 407630e4c4..da0a4cef72 100644 --- a/pkg/finality-grandpa/voter_test.go +++ b/pkg/finality-grandpa/voter_test.go @@ -5,6 +5,7 @@ package grandpa import ( "sync" + "sync/atomic" "testing" "time" @@ -700,12 +701,13 @@ func TestBuffered(t *testing.T) { return nil }) - run := true + var run atomic.Bool + run.Store(true) wg := sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() - for run { + for run.Load() { buffered.Push(999) time.Sleep(1 * time.Millisecond) } @@ -714,7 +716,7 @@ func TestBuffered(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - for run { + for run.Load() { buffered.flush(newWaker()) time.Sleep(1 * time.Millisecond) } @@ -723,6 +725,6 @@ func TestBuffered(t *testing.T) { time.Sleep(100 * time.Millisecond) buffered.Close() - run = false + run.Store(false) wg.Wait() }