From 8f660e4b4b2e0bec628182559ef3cad201993fdb Mon Sep 17 00:00:00 2001 From: luren Date: Thu, 16 Apr 2026 13:03:12 +0800 Subject: [PATCH 1/2] Update live tunnel PSK on peer updates --- gotatun/src/device/configure.rs | 8 ++- gotatun/src/device/tests.rs | 96 ++++++++++++++++++++++++++++++++- gotatun/src/noise/handshake.rs | 13 +++++ gotatun/src/noise/mod.rs | 26 +++++++++ 4 files changed, 141 insertions(+), 2 deletions(-) diff --git a/gotatun/src/device/configure.rs b/gotatun/src/device/configure.rs index a49fa84e..c7cde346 100644 --- a/gotatun/src/device/configure.rs +++ b/gotatun/src/device/configure.rs @@ -320,7 +320,13 @@ impl DeviceWrite<'_, T> { } = peer_mut; if let Update::Set(preshared_key) = preshared_key { - existing_peer.preshared_key = preshared_key; + if existing_peer.preshared_key != preshared_key { + // Keep the stored config and live tunnel in sync. `modify_peer` / `update_peer` + // are expected to affect subsequent handshakes, not only the value returned by + // inspection APIs. + existing_peer.preshared_key = preshared_key; + existing_peer.tunnel.set_preshared_key(preshared_key); + } } if let Update::Set(keepalive) = keepalive { diff --git a/gotatun/src/device/tests.rs b/gotatun/src/device/tests.rs index 27b9f1f6..d9ab8a79 100644 --- a/gotatun/src/device/tests.rs +++ b/gotatun/src/device/tests.rs @@ -14,10 +14,14 @@ use std::{future::ready, time::Duration}; use futures::{StreamExt, future::pending}; use mock::MockEavesdropper; use rand::{SeedableRng, rngs::StdRng}; -use tokio::{join, select, time::sleep}; +use tokio::{ + join, select, + time::{sleep, timeout}, +}; use zerocopy::IntoBytes; use crate::noise::index_table::IndexTable; +use crate::packet::{Ip, Packet}; pub mod mock; @@ -179,11 +183,101 @@ async fn test_endpoint_roaming() { ping_pong("1.2.3.4".parse().unwrap()).await; } +#[derive(Clone, Copy)] +enum PeerUpdateApi { + Modify, + Update, +} + +#[tokio::test] +#[test_log::test] +async fn modify_peer_updates_live_preshared_key() { + assert_peer_psk_update_changes_live_handshake(PeerUpdateApi::Modify).await; +} + +#[tokio::test] +#[test_log::test] +async fn update_peer_updates_live_preshared_key() { + assert_peer_psk_update_changes_live_handshake(PeerUpdateApi::Update).await; +} + /// The number of packets we send through the tunnel fn packet_count() -> usize { mock::packets_of_every_size().len() } +async fn assert_peer_psk_update_changes_live_handshake(api: PeerUpdateApi) { + let (mut alice, mut bob, _eve) = mock::device_pair().await; + let packet = mock::packet(b"Hello!"); + let preshared_key = [0xA5; 32]; + + apply_peer_psk_update(&mut alice, api, Some(preshared_key)).await; + send_and_expect_blocked(&alice, &mut bob, &packet).await; + + apply_peer_psk_update(&mut bob, api, Some(preshared_key)).await; + // Reset Alice's failed in-flight handshake so the next packet starts a fresh exchange. + apply_peer_psk_update(&mut alice, api, Some([0x5A; 32])).await; + apply_peer_psk_update(&mut alice, api, Some(preshared_key)).await; + + send_and_expect_delivery(&alice, &mut bob, &packet).await; +} + +async fn apply_peer_psk_update( + device: &mut mock::MockDevice, + api: PeerUpdateApi, + preshared_key: Option<[u8; 32]>, +) { + let mut peers = device.device.peers().await; + assert_eq!(peers.len(), 1, "expected exactly one peer"); + + let mut peer = peers.pop().expect("missing peer").peer; + let updated = match api { + PeerUpdateApi::Modify => device + .device + .modify_peer(&peer.public_key, |peer_mut| { + peer_mut.set_preshared_key(preshared_key); + }) + .await + .expect("modify_peer should succeed"), + PeerUpdateApi::Update => { + peer.preshared_key = preshared_key; + device + .device + .update_peer(peer) + .await + .expect("update_peer should succeed") + } + }; + + assert!(updated, "peer update should affect an existing peer"); +} + +async fn send_and_expect_blocked( + sender: &mock::MockDevice, + receiver: &mut mock::MockDevice, + packet: &Packet, +) { + sender.app_tx.send(packet.clone()).await; + assert!( + timeout(Duration::from_millis(500), receiver.app_rx.recv()) + .await + .is_err(), + "packet should not be delivered while peers disagree on the live PSK" + ); +} + +async fn send_and_expect_delivery( + sender: &mock::MockDevice, + receiver: &mut mock::MockDevice, + packet: &Packet, +) { + sender.app_tx.send(packet.clone()).await; + let received = timeout(Duration::from_secs(1), receiver.app_rx.recv()) + .await + .expect("expected packet delivery once both live PSKs match"); + assert_eq!(received.as_bytes(), packet.as_bytes()); +} + /// Helper method to test that packets can be sent from one [`Device`] to another. /// Use `eavesdrop` to sniff wireguard packets and assert things about the connection. async fn test_device_pair(eavesdrop: impl AsyncFnOnce(MockEavesdropper) + Send) { diff --git a/gotatun/src/noise/handshake.rs b/gotatun/src/noise/handshake.rs index edd5a33a..cfb27de9 100644 --- a/gotatun/src/noise/handshake.rs +++ b/gotatun/src/noise/handshake.rs @@ -434,6 +434,10 @@ impl NoiseParams { self.static_shared = self.static_private.diffie_hellman(&self.peer_static_public); } + + fn set_preshared_key(&mut self, preshared_key: Option<[u8; 32]>) { + self.preshared_key = preshared_key; + } } impl Handshake { @@ -499,6 +503,15 @@ impl Handshake { self.params.set_static_private(private_key, public_key) } + /// Update the preshared key and invalidate handshake state derived from the previous value. + pub(crate) fn set_preshared_key(&mut self, preshared_key: Option<[u8; 32]>) { + self.params.set_preshared_key(preshared_key); + // Any in-flight handshake transcript was mixed with the previous PSK and cannot continue. + self.previous = HandshakeState::None; + self.state = HandshakeState::None; + self.last_rtt = None; + } + pub(super) fn receive_handshake_initialization( &mut self, packet: crate::packet::Packet, diff --git a/gotatun/src/noise/mod.rs b/gotatun/src/noise/mod.rs index 7770fa26..9a845c89 100644 --- a/gotatun/src/noise/mod.rs +++ b/gotatun/src/noise/mod.rs @@ -163,6 +163,17 @@ impl Tunn { } } + /// Update the preshared key and discard crypto state derived from the previous one. + pub(crate) fn set_preshared_key(&mut self, preshared_key: Option<[u8; 32]>) { + self.handshake.set_preshared_key(preshared_key); + // Established sessions are keyed from the previous PSK and must not remain usable. + for s in &mut self.sessions { + *s = None; + } + // Reset timer-driven handshake/keepalive state so the next packet starts a fresh exchange. + self.timers.clear(); + } + /// Encapsulate a single packet. /// /// If there's an active session, return the encapsulated packet. Otherwise, if needed, return @@ -795,6 +806,21 @@ mod tests { assert_eq!(sent_packet_buf.as_bytes(), recv_packet_buf.as_bytes()); } + #[test] + fn set_preshared_key_invalidates_existing_sessions() { + let (mut my_tun, _their_tun) = create_two_tuns_and_handshake(); + + assert!(my_tun.time_since_last_handshake().is_some()); + + my_tun.set_preshared_key(Some([7; 32])); + + assert!(my_tun.time_since_last_handshake().is_none()); + assert!(matches!( + my_tun.handle_outgoing_packet(create_ipv4_udp_packet().into_bytes(), None), + Some(WgKind::HandshakeInit(..)) + )); + } + /// Test that [`Tunn::update_timers`] does not panic if clock jumps back. #[test] #[cfg(feature = "mock_instant")] From 97d07e9374bff3fc79d432844d1edd230614c8a7 Mon Sep 17 00:00:00 2001 From: luren Date: Thu, 16 Apr 2026 14:08:08 +0800 Subject: [PATCH 2/2] Fix PSK update CI regressions --- gotatun/src/noise/handshake.rs | 5 +++++ gotatun/src/noise/mod.rs | 25 +++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/gotatun/src/noise/handshake.rs b/gotatun/src/noise/handshake.rs index cfb27de9..f0cb47d8 100644 --- a/gotatun/src/noise/handshake.rs +++ b/gotatun/src/noise/handshake.rs @@ -435,6 +435,7 @@ impl NoiseParams { self.static_shared = self.static_private.diffie_hellman(&self.peer_static_public); } + #[cfg(any(feature = "device", test))] fn set_preshared_key(&mut self, preshared_key: Option<[u8; 32]>) { self.preshared_key = preshared_key; } @@ -504,11 +505,15 @@ impl Handshake { } /// Update the preshared key and invalidate handshake state derived from the previous value. + #[cfg(any(feature = "device", test))] pub(crate) fn set_preshared_key(&mut self, preshared_key: Option<[u8; 32]>) { self.params.set_preshared_key(preshared_key); // Any in-flight handshake transcript was mixed with the previous PSK and cannot continue. self.previous = HandshakeState::None; self.state = HandshakeState::None; + // Replay protection for handshake initiations is tied to the previous crypto context. + // Reset it so an immediate retry after a PSK change is not rejected as stale. + self.last_handshake_timestamp = Tai64N::zero(); self.last_rtt = None; } diff --git a/gotatun/src/noise/mod.rs b/gotatun/src/noise/mod.rs index 9a845c89..88b09822 100644 --- a/gotatun/src/noise/mod.rs +++ b/gotatun/src/noise/mod.rs @@ -164,6 +164,7 @@ impl Tunn { } /// Update the preshared key and discard crypto state derived from the previous one. + #[cfg(any(feature = "device", test))] pub(crate) fn set_preshared_key(&mut self, preshared_key: Option<[u8; 32]>) { self.handshake.set_preshared_key(preshared_key); // Established sessions are keyed from the previous PSK and must not remain usable. @@ -821,6 +822,30 @@ mod tests { )); } + #[test] + fn set_preshared_key_resets_handshake_replay_state() { + let (mut my_tun, mut their_tun) = create_two_tuns(); + let preshared_key = [7; 32]; + + my_tun.set_preshared_key(Some(preshared_key)); + + let init = create_handshake_init(&mut my_tun); + let resp = create_handshake_response(&mut their_tun, init); + assert!(matches!( + my_tun.handle_incoming_packet(WgKind::HandshakeResp(resp)), + TunnResult::Err(WireGuardError::InvalidAeadTag) + )); + + their_tun.set_preshared_key(Some(preshared_key)); + my_tun.set_preshared_key(Some([8; 32])); + my_tun.set_preshared_key(Some(preshared_key)); + + let init = create_handshake_init(&mut my_tun); + let resp = create_handshake_response(&mut their_tun, init); + let keepalive = parse_handshake_resp(&mut my_tun, resp); + parse_keepalive(&mut their_tun, keepalive); + } + /// Test that [`Tunn::update_timers`] does not panic if clock jumps back. #[test] #[cfg(feature = "mock_instant")]