diff --git a/client/internal/dns/upstream.go b/client/internal/dns/upstream.go index 375f6df1cae..52d2079e606 100644 --- a/client/internal/dns/upstream.go +++ b/client/internal/dns/upstream.go @@ -45,6 +45,19 @@ const ( const testRecord = "com." +// netbirdCGNATRange is the CGNAT range (100.64.0.0/10) used by NetBird for +// overlay peer IP addresses. Addresses in this range are only reachable once +// the WireGuard tunnel to the peer is fully established. +var netbirdCGNATRange = netip.MustParsePrefix("100.64.0.0/10") + +// isNetBirdOverlayAddr returns true if the given address is within the NetBird +// CGNAT range (100.64.0.0/10). Such addresses are only reachable after the +// WireGuard tunnel is established, so probing them at engine startup will +// always produce a false-positive "unable to reach" warning. +func isNetBirdOverlayAddr(addr netip.AddrPort) bool { + return netbirdCGNATRange.Contains(addr.Addr()) +} + type upstreamClient interface { exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) } @@ -65,6 +78,7 @@ type upstreamResolverBase struct { mutex sync.Mutex reactivatePeriod time.Duration upstreamTimeout time.Duration + wg sync.WaitGroup deactivate func(error) reactivate func() @@ -115,6 +129,8 @@ func (u *upstreamResolverBase) MatchSubdomains() bool { func (u *upstreamResolverBase) Stop() { log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers) u.cancel() + + u.wg.Wait() } // ServeDNS handles a DNS request @@ -259,51 +275,76 @@ func formatFailures(failures []upstreamFailure) string { } // ProbeAvailability tests all upstream servers simultaneously and -// disables the resolver if none work +// disables the resolver if none work. +// +// Nameservers configured with NetBird overlay IPs (100.64.0.0/10 CGNAT range) +// are skipped during this probe. Those addresses are only reachable once the +// WireGuard tunnel to the peer is established, which has not yet happened at +// engine startup. Probing them at this point always produces a false-positive +// "unable to reach" warning even though they will be available moments later. +// They continue to be exercised by real DNS queries and the waitUntilResponse +// retry loop if an actual failure occurs. func (u *upstreamResolverBase) ProbeAvailability() { u.mutex.Lock() defer u.mutex.Unlock() - select { - case <-u.ctx.Done(): - return - default: - } - // avoid probe if upstreams could resolve at least one query if u.successCount.Load() > 0 { return } var success bool + var probedAny bool var mu sync.Mutex var wg sync.WaitGroup - var errors *multierror.Error + var errs *multierror.Error for _, upstream := range u.upstreamServers { - upstream := upstream + // Skip probing NetBird overlay addresses (100.64.0.0/10 CGNAT range). + // These peers are only reachable once the WireGuard tunnel is established. + // Probing them at engine startup produces a spurious warning; they will + // be exercised normally once real DNS queries arrive. + if isNetBirdOverlayAddr(upstream) { + log.Debugf("skipping probe for NetBird overlay nameserver %s (tunnel not yet established)", upstream) + continue + } + probedAny = true wg.Add(1) - go func() { + go func(upstream netip.AddrPort) { defer wg.Done() - err := u.testNameserver(upstream, 500*time.Millisecond) + err := u.testNameserver(u.ctx, nil, upstream, 500*time.Millisecond) if err != nil { - errors = multierror.Append(errors, err) + mu.Lock() + errs = multierror.Append(errs, err) + mu.Unlock() log.Warnf("probing upstream nameserver %s: %s", upstream, err) return } mu.Lock() - defer mu.Unlock() success = true - }() + mu.Unlock() + }(upstream) } wg.Wait() + select { + case <-u.ctx.Done(): + return + default: + } + + // All upstreams were overlay addresses — skip startup disable/warning. + // They will be exercised by real DNS queries once the tunnel is established. + if !probedAny { + return + } + // didn't find a working upstream server, let's disable and try later if !success { - u.disable(errors.ErrorOrNil()) + u.disable(errs.ErrorOrNil()) if u.statusRecorder == nil { return @@ -339,7 +380,7 @@ func (u *upstreamResolverBase) waitUntilResponse() { } for _, upstream := range u.upstreamServers { - if err := u.testNameserver(upstream, probeTimeout); err != nil { + if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil { log.Tracef("upstream check for %s: %s", upstream, err) } else { // at least one upstream server is available, stop probing @@ -364,7 +405,9 @@ func (u *upstreamResolverBase) waitUntilResponse() { log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString()) u.successCount.Add(1) u.reactivate() + u.mutex.Lock() u.disabled = false + u.mutex.Unlock() } // isTimeout returns true if the given error is a network timeout error. @@ -387,7 +430,11 @@ func (u *upstreamResolverBase) disable(err error) { u.successCount.Store(0) u.deactivate(err) u.disabled = true - go u.waitUntilResponse() + u.wg.Add(1) + go func() { + defer u.wg.Done() + u.waitUntilResponse() + }() } func (u *upstreamResolverBase) upstreamServersString() string { @@ -398,13 +445,18 @@ func (u *upstreamResolverBase) upstreamServersString() string { return strings.Join(servers, ", ") } -func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error { - ctx, cancel := context.WithTimeout(u.ctx, timeout) +func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error { + mergedCtx, cancel := context.WithTimeout(baseCtx, timeout) defer cancel() + if externalCtx != nil { + stop2 := context.AfterFunc(externalCtx, cancel) + defer stop2() + } + r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA) - _, _, err := u.upstreamClient.exchange(ctx, server.String(), r) + _, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r) return err } diff --git a/client/internal/dns/upstream_test.go b/client/internal/dns/upstream_test.go index 8b06e447530..481e551e0cb 100644 --- a/client/internal/dns/upstream_test.go +++ b/client/internal/dns/upstream_test.go @@ -475,3 +475,255 @@ func TestFormatFailures(t *testing.T) { }) } } + +// TestIsNetBirdOverlayAddr verifies that the CGNAT range check correctly +// identifies NetBird overlay addresses vs regular addresses. +func TestIsNetBirdOverlayAddr(t *testing.T) { + testCases := []struct { + name string + addr string + expected bool + }{ + { + name: "NetBird overlay address at start of range", + addr: "100.64.0.1:53", + expected: true, + }, + { + name: "NetBird overlay address mid-range", + addr: "100.110.145.54:53", + expected: true, + }, + { + name: "NetBird overlay address end of range", + addr: "100.127.255.254:53", + expected: true, + }, + { + name: "LAN address", + addr: "192.168.1.5:53", + expected: false, + }, + { + name: "LAN address secondary DC", + addr: "192.168.1.30:53", + expected: false, + }, + { + name: "public DNS address", + addr: "8.8.8.8:53", + expected: false, + }, + { + name: "address just below CGNAT range", + addr: "100.63.255.255:53", + expected: false, + }, + { + name: "address just above CGNAT range", + addr: "100.128.0.0:53", + expected: false, + }, + { + name: "loopback address", + addr: "127.0.0.1:53", + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + addr := netip.MustParseAddrPort(tc.addr) + result := isNetBirdOverlayAddr(addr) + assert.Equal(t, tc.expected, result, "isNetBirdOverlayAddr(%s)", tc.addr) + }) + } +} + +// TestProbeAvailability_SkipsOverlayAddresses verifies that ProbeAvailability +// does not probe or deactivate when all nameservers are NetBird overlay +// addresses (100.64.0.0/10). These are unreachable at engine startup because +// the WireGuard tunnel has not yet been established. +func TestProbeAvailability_SkipsOverlayAddresses(t *testing.T) { + overlayUpstream := netip.MustParseAddrPort("100.110.145.54:53") + + probed := false + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + overlayUpstream.String(): {err: fmt.Errorf("tunnel not up")}, + }, + rtt: time.Millisecond, + } + + trackingClient := &trackingUpstreamClient{ + inner: mockClient, + probed: &probed, + } + + deactivated := false + resolver := &upstreamResolverBase{ + ctx: context.Background(), + upstreamClient: trackingClient, + upstreamServers: []netip.AddrPort{overlayUpstream}, + upstreamTimeout: UpstreamTimeout, + deactivate: func(error) { deactivated = true }, + reactivate: func() {}, + } + + resolver.ProbeAvailability() + + assert.False(t, probed, "overlay address should not have been probed") + assert.False(t, deactivated, "resolver should not have been deactivated for overlay address") + assert.False(t, resolver.disabled, "resolver should not be disabled") +} + +// TestProbeAvailability_ProbesLANAddresses verifies that ProbeAvailability +// still probes non-overlay (LAN/public) nameservers normally. +func TestProbeAvailability_ProbesLANAddresses(t *testing.T) { + lanUpstream := netip.MustParseAddrPort("192.168.1.5:53") + + probed := false + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + lanUpstream.String(): {err: fmt.Errorf("connection refused")}, + }, + rtt: time.Millisecond, + } + + trackingClient := &trackingUpstreamClient{ + inner: mockClient, + probed: &probed, + } + + deactivated := false + resolver := &upstreamResolverBase{ + ctx: context.Background(), + upstreamClient: trackingClient, + upstreamServers: []netip.AddrPort{lanUpstream}, + upstreamTimeout: UpstreamTimeout, + deactivate: func(error) { deactivated = true }, + reactivate: func() {}, + } + + resolver.ProbeAvailability() + + assert.True(t, probed, "LAN address should have been probed") + assert.True(t, deactivated, "resolver should have been deactivated when LAN probe fails") +} + +// TestProbeAvailability_MixedOverlayAndLAN verifies that when a nameserver +// group contains both overlay and LAN addresses, the overlay is skipped but +// the LAN address is still probed. A successful LAN probe prevents deactivation. +func TestProbeAvailability_MixedOverlayAndLAN(t *testing.T) { + overlayUpstream := netip.MustParseAddrPort("100.110.145.54:53") + lanUpstream := netip.MustParseAddrPort("192.168.1.5:53") + + probedAddrs := make(map[string]bool) + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + overlayUpstream.String(): {err: fmt.Errorf("tunnel not up")}, + lanUpstream.String(): {msg: &dns.Msg{MsgHdr: dns.MsgHdr{Response: true}}}, + }, + rtt: time.Millisecond, + } + + trackingClient := &trackingUpstreamClientPerAddr{ + inner: mockClient, + probedAddrs: probedAddrs, + } + + deactivated := false + resolver := &upstreamResolverBase{ + ctx: context.Background(), + upstreamClient: trackingClient, + upstreamServers: []netip.AddrPort{overlayUpstream, lanUpstream}, + upstreamTimeout: UpstreamTimeout, + deactivate: func(error) { deactivated = true }, + reactivate: func() {}, + } + + resolver.ProbeAvailability() + + assert.False(t, probedAddrs[overlayUpstream.String()], "overlay address should not have been probed") + assert.True(t, probedAddrs[lanUpstream.String()], "LAN address should have been probed") + assert.False(t, deactivated, "resolver should not be deactivated when LAN probe succeeds") +} + +// TestProbeAvailability_OnlyOverlay_NoDeactivation verifies that a nameserver +// group containing only overlay addresses is never deactivated. The probedAny +// flag ensures we return early without triggering the disable path. +func TestProbeAvailability_OnlyOverlay_NoDeactivation(t *testing.T) { + overlay1 := netip.MustParseAddrPort("100.110.145.54:53") + overlay2 := netip.MustParseAddrPort("100.64.0.1:53") + + deactivated := false + resolver := &upstreamResolverBase{ + ctx: context.Background(), + // upstreamClient intentionally nil — it must never be called + upstreamServers: []netip.AddrPort{overlay1, overlay2}, + upstreamTimeout: UpstreamTimeout, + deactivate: func(error) { deactivated = true }, + reactivate: func() {}, + } + + // Should not panic (nil client never called) and should not deactivate + require.NotPanics(t, func() { + resolver.ProbeAvailability() + }) + + assert.False(t, deactivated, "resolver should not be deactivated for overlay-only nameserver group") + assert.False(t, resolver.disabled, "resolver should not be disabled") +} + +// TestProbeAvailability_MixedOverlayAndLAN_LANFails verifies that when a +// nameserver group contains both overlay and LAN addresses, and the LAN probe +// fails, the resolver is deactivated. The overlay skip must not mask a genuine +// LAN failure. +func TestProbeAvailability_MixedOverlayAndLAN_LANFails(t *testing.T) { + overlayUpstream := netip.MustParseAddrPort("100.110.145.54:53") + lanUpstream := netip.MustParseAddrPort("192.168.1.5:53") + + mockClient := &mockUpstreamResolverPerServer{ + responses: map[string]mockUpstreamResponse{ + overlayUpstream.String(): {err: fmt.Errorf("tunnel not up")}, + lanUpstream.String(): {err: fmt.Errorf("connection refused")}, + }, + rtt: time.Millisecond, + } + + deactivated := false + resolver := &upstreamResolverBase{ + ctx: context.Background(), + upstreamClient: mockClient, + upstreamServers: []netip.AddrPort{overlayUpstream, lanUpstream}, + upstreamTimeout: UpstreamTimeout, + deactivate: func(error) { deactivated = true }, + reactivate: func() {}, + } + + resolver.ProbeAvailability() + + assert.True(t, deactivated, "resolver should be deactivated when LAN probe fails even with overlay present") +} + +// trackingUpstreamClient records whether any probe was attempted, regardless of address. +type trackingUpstreamClient struct { + inner *mockUpstreamResolverPerServer + probed *bool +} + +func (t *trackingUpstreamClient) exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) { + *t.probed = true + return t.inner.exchange(ctx, upstream, r) +} + +// trackingUpstreamClientPerAddr records which addresses were probed. +type trackingUpstreamClientPerAddr struct { + inner *mockUpstreamResolverPerServer + probedAddrs map[string]bool +} + +func (t *trackingUpstreamClientPerAddr) exchange(ctx context.Context, upstream string, r *dns.Msg) (*dns.Msg, time.Duration, error) { + t.probedAddrs[upstream] = true + return t.inner.exchange(ctx, upstream, r) +}