diff --git a/stream.go b/stream.go index 293fcc2..7197694 100644 --- a/stream.go +++ b/stream.go @@ -27,9 +27,15 @@ type Stream[T any] interface { // the current value. WaitNext() T + // WaitNextFiltered does the same as WaitNext but only returns when filterFunc evaluates to true. + WaitNextFiltered(filterFunc func(T) bool) T + // WaitNextCtx does the same as WaitNext but returns earlier with an error if the given context is cancelled first. WaitNextCtx(ctx context.Context) (T, error) + // WaitNextCtxFiltered does the same as WaitNextCtx and additionally only returns the value when filterFunc evaluates to true. + WaitNextCtxFiltered(ctx context.Context, filterFunc func(T) bool) (T, error) + // Clone creates a new independent stream from this one but sharing the same // Property. Updates to the property will be reflected in both streams but // they may have different values depending on when they advance the stream @@ -77,6 +83,15 @@ func (s *stream[T]) WaitNext() T { return s.state.value } +func (s *stream[T]) WaitNextFiltered(filterFunc func(T) bool) T { + for { + val := s.WaitNext() + if evaluateFilterFunc[T](val, filterFunc) { + return val + } + } +} + func (s *stream[T]) WaitNextCtx(ctx context.Context) (T, error) { select { case <-s.Changes(): @@ -92,6 +107,22 @@ func (s *stream[T]) WaitNextCtx(ctx context.Context) (T, error) { return zeroVal, ctx.Err() } +func (s *stream[T]) WaitNextCtxFiltered(ctx context.Context, filterFunc func(T) bool) (T, error) { + for { + val, err := s.WaitNextCtx(ctx) + if err != nil { + return val, err + } + if evaluateFilterFunc[T](val, filterFunc) { + return val, nil + } + } +} + func (s *stream[T]) Peek() T { return s.state.next.value } + +func evaluateFilterFunc[T any](val T, filterFunc func(T) bool) bool { + return filterFunc == nil || filterFunc(val) +} diff --git a/stream_test.go b/stream_test.go index 6a0fb11..6d37707 100644 --- a/stream_test.go +++ b/stream_test.go @@ -102,6 +102,35 @@ func TestStreamWaitsNext(t *testing.T) { } } +func TestStreamWaitsNextFiltered(t *testing.T) { + state := newState(0) + stream := &stream[int]{state: state} + + values := []int{1, 11, 2, 3, 33, 4, 5, 6, 7, 77, 8, 9, 10} + filteredValues := []int{2, 4, 6, 8, 10} + + onlyEvenFilter := func(i int) bool { + return i%2 == 0 + } + + for _, i := range values { + state = state.update(i) + } + + i := 0 + for stream.HasNext() { + val := stream.WaitNextFiltered(onlyEvenFilter) + if val != filteredValues[i] { + t.Fatalf("Expecting %#v but got %#v\n", filteredValues[i], val) + } + i++ + } + + if stream.HasNext() { + t.Fatalf("Expecting no changes\n") + } +} + func TestStreamWaitNextBackgroundCtx(t *testing.T) { state := newState(0) stream := &stream[int]{state: state} @@ -158,6 +187,54 @@ func TestStreamWaitNextCanceledCtx(t *testing.T) { } } +func TestStreamWaitNextCtxFiltered(t *testing.T) { + state := newState(0) + stream := &stream[int]{state: state} + + onlyOddFilter := func(i int) bool { + return i%2 != 0 + } + + state = state.update(2) + state = state.update(3) + + lastAwaitedVal, err := stream.WaitNextCtxFiltered(context.Background(), onlyOddFilter) + if err != nil { + t.Fatalf("Expecting no error\n") + } + if lastAwaitedVal != 3 { + t.Fatalf("Expecting 3 but got %#v\n", lastAwaitedVal) + } + + if stream.HasNext() { + t.Fatalf("Expecting no changes\n") + } + + state = state.update(4) + + // ensure the method returns an error when a canceled context is used and doesn't advance the stream + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = stream.WaitNextCtxFiltered(ctx, onlyOddFilter) + if err == nil { + t.Fatalf("Expecting error but got none\n") + } + if stream.Value() != lastAwaitedVal { + t.Fatalf("Expecting stream's current value to be %#v but it is %#v\n", lastAwaitedVal, stream.Value()) + } + + ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + _, err = stream.WaitNextCtxFiltered(ctx, onlyOddFilter) + if err == nil { + t.Fatalf("Expecting error but got none\n") + } + + if stream.HasNext() { + t.Fatalf("Expecting no changes\n") + } +} + func TestStreamClone(t *testing.T) { state := newState(10) stream1 := &stream[int]{state: state}