Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion gotatun/src/device/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,13 @@ impl<T: DeviceTransports> DeviceWrite<'_, T> {
} = peer_mut;

if let Update::Set(preshared_key) = preshared_key {
existing_peer.preshared_key = preshared_key;
if existing_peer.preshared_key != preshared_key {
// Keep the stored config and live tunnel in sync. `modify_peer` / `update_peer`
// are expected to affect subsequent handshakes, not only the value returned by
// inspection APIs.
existing_peer.preshared_key = preshared_key;
existing_peer.tunnel.set_preshared_key(preshared_key);
}
}

if let Update::Set(keepalive) = keepalive {
Expand Down
96 changes: 95 additions & 1 deletion gotatun/src/device/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@ use std::{future::ready, time::Duration};
use futures::{StreamExt, future::pending};
use mock::MockEavesdropper;
use rand::{SeedableRng, rngs::StdRng};
use tokio::{join, select, time::sleep};
use tokio::{
join, select,
time::{sleep, timeout},
};
use zerocopy::IntoBytes;

use crate::noise::index_table::IndexTable;
use crate::packet::{Ip, Packet};

pub mod mock;

Expand Down Expand Up @@ -179,11 +183,101 @@ async fn test_endpoint_roaming() {
ping_pong("1.2.3.4".parse().unwrap()).await;
}

#[derive(Clone, Copy)]
enum PeerUpdateApi {
Modify,
Update,
}

#[tokio::test]
#[test_log::test]
async fn modify_peer_updates_live_preshared_key() {
assert_peer_psk_update_changes_live_handshake(PeerUpdateApi::Modify).await;
}

#[tokio::test]
#[test_log::test]
async fn update_peer_updates_live_preshared_key() {
assert_peer_psk_update_changes_live_handshake(PeerUpdateApi::Update).await;
}

/// The number of packets we send through the tunnel
fn packet_count() -> usize {
mock::packets_of_every_size().len()
}

async fn assert_peer_psk_update_changes_live_handshake(api: PeerUpdateApi) {
let (mut alice, mut bob, _eve) = mock::device_pair().await;
let packet = mock::packet(b"Hello!");
let preshared_key = [0xA5; 32];

apply_peer_psk_update(&mut alice, api, Some(preshared_key)).await;
send_and_expect_blocked(&alice, &mut bob, &packet).await;

apply_peer_psk_update(&mut bob, api, Some(preshared_key)).await;
// Reset Alice's failed in-flight handshake so the next packet starts a fresh exchange.
apply_peer_psk_update(&mut alice, api, Some([0x5A; 32])).await;
apply_peer_psk_update(&mut alice, api, Some(preshared_key)).await;

send_and_expect_delivery(&alice, &mut bob, &packet).await;
}

async fn apply_peer_psk_update(
device: &mut mock::MockDevice,
api: PeerUpdateApi,
preshared_key: Option<[u8; 32]>,
) {
let mut peers = device.device.peers().await;
assert_eq!(peers.len(), 1, "expected exactly one peer");

let mut peer = peers.pop().expect("missing peer").peer;
let updated = match api {
PeerUpdateApi::Modify => device
.device
.modify_peer(&peer.public_key, |peer_mut| {
peer_mut.set_preshared_key(preshared_key);
})
.await
.expect("modify_peer should succeed"),
PeerUpdateApi::Update => {
peer.preshared_key = preshared_key;
device
.device
.update_peer(peer)
.await
.expect("update_peer should succeed")
}
};

assert!(updated, "peer update should affect an existing peer");
}

async fn send_and_expect_blocked(
sender: &mock::MockDevice,
receiver: &mut mock::MockDevice,
packet: &Packet<Ip>,
) {
sender.app_tx.send(packet.clone()).await;
assert!(
timeout(Duration::from_millis(500), receiver.app_rx.recv())
.await
.is_err(),
"packet should not be delivered while peers disagree on the live PSK"
);
}

async fn send_and_expect_delivery(
sender: &mock::MockDevice,
receiver: &mut mock::MockDevice,
packet: &Packet<Ip>,
) {
sender.app_tx.send(packet.clone()).await;
let received = timeout(Duration::from_secs(1), receiver.app_rx.recv())
.await
.expect("expected packet delivery once both live PSKs match");
assert_eq!(received.as_bytes(), packet.as_bytes());
}

/// Helper method to test that packets can be sent from one [`Device`] to another.
/// Use `eavesdrop` to sniff wireguard packets and assert things about the connection.
async fn test_device_pair(eavesdrop: impl AsyncFnOnce(MockEavesdropper) + Send) {
Expand Down
18 changes: 18 additions & 0 deletions gotatun/src/noise/handshake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,11 @@ impl NoiseParams {

self.static_shared = self.static_private.diffie_hellman(&self.peer_static_public);
}

#[cfg(any(feature = "device", test))]
fn set_preshared_key(&mut self, preshared_key: Option<[u8; 32]>) {
self.preshared_key = preshared_key;
}
}

impl Handshake {
Expand Down Expand Up @@ -499,6 +504,19 @@ impl Handshake {
self.params.set_static_private(private_key, public_key)
}

/// Update the preshared key and invalidate handshake state derived from the previous value.
#[cfg(any(feature = "device", test))]
pub(crate) fn set_preshared_key(&mut self, preshared_key: Option<[u8; 32]>) {
self.params.set_preshared_key(preshared_key);
// Any in-flight handshake transcript was mixed with the previous PSK and cannot continue.
self.previous = HandshakeState::None;
self.state = HandshakeState::None;
// Replay protection for handshake initiations is tied to the previous crypto context.
// Reset it so an immediate retry after a PSK change is not rejected as stale.
self.last_handshake_timestamp = Tai64N::zero();
self.last_rtt = None;
}

pub(super) fn receive_handshake_initialization(
&mut self,
packet: crate::packet::Packet<WgHandshakeInit>,
Expand Down
51 changes: 51 additions & 0 deletions gotatun/src/noise/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,18 @@ impl<R: RngCore + Send> Tunn<R> {
}
}

/// Update the preshared key and discard crypto state derived from the previous one.
#[cfg(any(feature = "device", test))]
pub(crate) fn set_preshared_key(&mut self, preshared_key: Option<[u8; 32]>) {
self.handshake.set_preshared_key(preshared_key);
// Established sessions are keyed from the previous PSK and must not remain usable.
for s in &mut self.sessions {
*s = None;
}
// Reset timer-driven handshake/keepalive state so the next packet starts a fresh exchange.
self.timers.clear();
}

/// Encapsulate a single packet.
///
/// If there's an active session, return the encapsulated packet. Otherwise, if needed, return
Expand Down Expand Up @@ -795,6 +807,45 @@ mod tests {
assert_eq!(sent_packet_buf.as_bytes(), recv_packet_buf.as_bytes());
}

#[test]
fn set_preshared_key_invalidates_existing_sessions() {
let (mut my_tun, _their_tun) = create_two_tuns_and_handshake();

assert!(my_tun.time_since_last_handshake().is_some());

my_tun.set_preshared_key(Some([7; 32]));

assert!(my_tun.time_since_last_handshake().is_none());
assert!(matches!(
my_tun.handle_outgoing_packet(create_ipv4_udp_packet().into_bytes(), None),
Some(WgKind::HandshakeInit(..))
));
}

#[test]
fn set_preshared_key_resets_handshake_replay_state() {
let (mut my_tun, mut their_tun) = create_two_tuns();
let preshared_key = [7; 32];

my_tun.set_preshared_key(Some(preshared_key));

let init = create_handshake_init(&mut my_tun);
let resp = create_handshake_response(&mut their_tun, init);
assert!(matches!(
my_tun.handle_incoming_packet(WgKind::HandshakeResp(resp)),
TunnResult::Err(WireGuardError::InvalidAeadTag)
));

their_tun.set_preshared_key(Some(preshared_key));
my_tun.set_preshared_key(Some([8; 32]));
my_tun.set_preshared_key(Some(preshared_key));

let init = create_handshake_init(&mut my_tun);
let resp = create_handshake_response(&mut their_tun, init);
let keepalive = parse_handshake_resp(&mut my_tun, resp);
parse_keepalive(&mut their_tun, keepalive);
}

/// Test that [`Tunn::update_timers`] does not panic if clock jumps back.
#[test]
#[cfg(feature = "mock_instant")]
Expand Down
Loading