Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 61 additions & 16 deletions pkg/finality-grandpa/environment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,63 +187,108 @@ func (*environment) PrecommitEquivocation(

// p2p network data for a round.
type BroadcastNetwork[M, N any] struct {
receiver chan M
senders []chan M
history []M
routing bool
wg sync.WaitGroup
receiver chan M
stop chan struct{}
mu sync.Mutex
senders []chan M
history []M
routing bool
stopped bool
routeWG sync.WaitGroup
forwarderWG sync.WaitGroup
}

func NewBroadcastNetwork[M, N any]() *BroadcastNetwork[M, N] {
bn := BroadcastNetwork[M, N]{
receiver: make(chan M, 10000),
stop: make(chan struct{}),
}
return &bn
}

func (bm *BroadcastNetwork[M, N]) SendMessage(message M) {
bm.receiver <- message
select {
case bm.receiver <- message:
case <-bm.stop:
}
}

func (bm *BroadcastNetwork[M, N]) AddNode(f func(N) M, out chan N) (in chan M) {
// buffer to 100 messages for now
in = make(chan M, 10000)

bm.mu.Lock()
// get history to the node.
for _, priorMessage := range bm.history {
in <- priorMessage
}

bm.senders = append(bm.senders, in)

if !bm.routing {
startRoute := !bm.routing
if startRoute {
bm.routing = true
bm.wg.Add(1)
bm.routeWG.Add(1)
}
bm.mu.Unlock()

if startRoute {
go bm.route()
}

bm.forwarderWG.Add(1)
go func() {
for n := range out {
bm.receiver <- f(n)
defer bm.forwarderWG.Done()
for {
select {
case n, ok := <-out:
if !ok {
return
}
select {
case bm.receiver <- f(n):
case <-bm.stop:
return
}
case <-bm.stop:
return
}
}
}()
return in
}

func (bm *BroadcastNetwork[M, N]) route() {
defer bm.wg.Done()
defer bm.routeWG.Done()
for msg := range bm.receiver {
bm.mu.Lock()
bm.history = append(bm.history, msg)
for _, sender := range bm.senders {
senders := append([]chan M(nil), bm.senders...)
bm.mu.Unlock()
for _, sender := range senders {
sender <- msg
}
}
}

func (bm *BroadcastNetwork[M, N]) Stop() {
bm.mu.Lock()
if bm.stopped {
bm.mu.Unlock()
return
}
bm.stopped = true
close(bm.stop)
bm.mu.Unlock()

// Order matters: drain forwarders first so they stop sending into receiver,
// then close receiver so route can exit, then close per-node senders.
bm.forwarderWG.Wait()
close(bm.receiver)
bm.wg.Wait()
for _, sender := range bm.senders {
bm.routeWG.Wait()
bm.mu.Lock()
senders := bm.senders
bm.senders = nil
bm.mu.Unlock()
for _, sender := range senders {
close(sender)
}
}
Expand Down
23 changes: 9 additions & 14 deletions pkg/finality-grandpa/timer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ package grandpa

import (
"sync"
"sync/atomic"
"time"
)

type timer struct {
wakerChan *wakerChan[error]
mtx sync.Mutex
expired bool
closeOnce sync.Once
expired atomic.Bool
}

func newTimer(in <-chan time.Time) *timer {
Expand All @@ -24,29 +25,23 @@ func newTimer(in <-chan time.Time) *timer {

func (t *timer) poll(in <-chan time.Time) {
<-in
t.mtx.Lock()
defer t.mtx.Unlock()
if t.wakerChan.in != nil {
t.closeOnce.Do(func() {
t.wakerChan.in <- nil
close(t.wakerChan.in)
t.wakerChan.in = nil
}
t.expired = true
})
t.expired.Store(true)
}

func (t *timer) SetWaker(waker *waker) {
t.wakerChan.setWaker(waker)
}

func (t *timer) Elapsed() (bool, error) {
return t.expired, nil
return t.expired.Load(), nil
}

func (t *timer) Close() {
t.mtx.Lock()
defer t.mtx.Unlock()
if t.wakerChan.in != nil {
t.closeOnce.Do(func() {
close(t.wakerChan.in)
t.wakerChan.in = nil
}
})
}
111 changes: 111 additions & 0 deletions pkg/finality-grandpa/timer_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright 2025 ChainSafe Systems (ON)
// SPDX-License-Identifier: LGPL-3.0-only

package grandpa

import (
"sync"
"testing"
"time"
)

// TestTimer_ElapsedConcurrentWithFiring exercises the read of `expired` from
// one goroutine while the timer's `poll` goroutine is writing it. With the
// pre-fix code (unsynchronized read in Elapsed) this test trips the race
// detector under `go test -race`.
func TestTimer_ElapsedConcurrentWithFiring(t *testing.T) {
t.Parallel()
tick := make(chan time.Time, 1)
timer := newTimer(tick)

stop := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
for {
select {
case <-stop:
return
default:
_, _ = timer.Elapsed()
}
}
}()

tick <- time.Now()
close(tick)

deadline := time.Now().Add(2 * time.Second)
for {
elapsed, _ := timer.Elapsed()
if elapsed {
break
}
if time.Now().After(deadline) {
close(stop)
wg.Wait()
t.Fatal("timer never reported elapsed after the tick was consumed")
}
time.Sleep(time.Millisecond)
}
close(stop)
wg.Wait()
}

// TestTimer_CloseIsIdempotent ensures Close() can be called more than once
// (and after poll has already drained the channel) without panicking — the
// `closed` flag prevents the double-close.
func TestTimer_CloseIsIdempotent(t *testing.T) {
t.Parallel()
tick := make(chan time.Time)
timer := newTimer(tick)

timer.Close()
timer.Close()
}

// TestWakerChan_SetWakerConcurrentWithItems exercises the write of `waker`
// from one goroutine while the `start` goroutine is reading it on every item.
// With the pre-fix code (plain *waker field) this trips the race detector.
func TestWakerChan_SetWakerConcurrentWithItems(t *testing.T) {
t.Parallel()
in := make(chan int, 100)
wc := newWakerChan(in)

w1 := &waker{wakeCh: make(chan struct{}, 1000)}
w2 := &waker{wakeCh: make(chan struct{}, 1000)}

// Drain the output channel so start() can keep making progress.
drained := make(chan struct{})
go func() {
defer close(drained)
for range wc.channel() {
}
}()

var wg sync.WaitGroup
wg.Add(2)

go func() {
defer wg.Done()
for i := 0; i < 500; i++ {
in <- i
}
}()

go func() {
defer wg.Done()
for i := 0; i < 500; i++ {
if i%2 == 0 {
wc.setWaker(w1)
} else {
wc.setWaker(w2)
}
}
}()

wg.Wait()
close(in)
<-drained
}
Loading
Loading