diff --git a/boringtun/src/noise/mod.rs b/boringtun/src/noise/mod.rs index 76e377b63..8661463cf 100644 --- a/boringtun/src/noise/mod.rs +++ b/boringtun/src/noise/mod.rs @@ -14,6 +14,7 @@ use crate::noise::rate_limiter::RateLimiter; use crate::noise::timers::{TimerName, Timers}; use crate::x25519; +use self::session::Session; use std::collections::VecDeque; use std::convert::{TryFrom, TryInto}; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; @@ -236,9 +237,7 @@ impl Tunn { }); self.handshake .set_static_private(static_private, static_public); - for s in &mut self.sessions { - *s = None; - } + self.clear_sessions(); } /// Encapsulate a single packet from the tunnel interface. @@ -248,8 +247,7 @@ impl Tunn { /// Panics if dst buffer is too small. /// Size of dst should be at least src.len() + 32, and no less than 148 bytes. pub fn encapsulate<'a>(&mut self, src: &[u8], dst: &'a mut [u8]) -> TunnResult<'a> { - let current = self.current; - if let Some(ref session) = self.sessions[current % N_SESSIONS] { + if let Some(session) = self.current_session() { // Send the packet using an established session let packet = session.format_packet_data(src, dst); self.timer_tick(TimerName::TimeLastPacketSent); @@ -328,8 +326,7 @@ impl Tunn { let (packet, session) = self.handshake.receive_handshake_initialization(p, dst)?; // Store new session in ring buffer - let index = session.local_index(); - self.sessions[index % N_SESSIONS] = Some(session); + let index = self.set_session(session); self.timer_tick(TimerName::TimeLastPacketReceived); self.timer_tick(TimerName::TimeLastPacketSent); @@ -355,13 +352,11 @@ impl Tunn { let keepalive_packet = session.format_packet_data(&[], dst); // Store new session in ring buffer - let l_idx = session.local_index(); - let index = l_idx % N_SESSIONS; - self.sessions[index] = Some(session); + let index = self.set_session(session); self.timer_tick(TimerName::TimeLastPacketReceived); self.timer_tick_session_established(true, index); // New session established, we are the initiator - self.set_current_session(l_idx); + self.set_current_session(index); tracing::debug!("Sending keepalive"); @@ -393,7 +388,7 @@ impl Tunn { // There is nothing to do, already using this session, this is the common case return; } - if self.sessions[cur_idx % N_SESSIONS].is_none() + if self.session(cur_idx).is_none() || self.timers.session_timers[new_idx % N_SESSIONS] >= self.timers.session_timers[cur_idx % N_SESSIONS] { @@ -413,11 +408,11 @@ impl Tunn { // Get the (probably) right session let decapsulated_packet = { - let session = self.sessions[idx].as_ref(); - let session = session.ok_or_else(|| { + let session = self.session(idx).ok_or_else(|| { tracing::trace!(message = "No current session available", remote_idx = r_idx); WireGuardError::NoCurrentSession })?; + session.receive_packet_data(packet, dst)? }; @@ -548,7 +543,7 @@ impl Tunn { let mut total_weight = 0.0; for i in 0..N_SESSIONS { - if let Some(ref session) = self.sessions[(session_idx.wrapping_sub(i)) % N_SESSIONS] { + if let Some(session) = self.session(session_idx.wrapping_sub(i)) { let (expected, received) = session.current_packet_cnt(); let loss = if expected == 0 { @@ -583,6 +578,27 @@ impl Tunn { (time, tx_bytes, rx_bytes, loss, rtt) } + + fn clear_sessions(&mut self) { + for s in &mut self.sessions { + *s = None; + } + } + + fn current_session(&self) -> Option<&Session> { + self.session(self.current) + } + + fn session(&self, index: usize) -> Option<&Session> { + self.sessions[index % N_SESSIONS].as_ref() + } + + fn set_session(&mut self, session: Session) -> usize { + let index = session.local_index() % N_SESSIONS; + self.sessions[index] = Some(session); + + index + } } #[cfg(test)] diff --git a/boringtun/src/noise/timers.rs b/boringtun/src/noise/timers.rs index 6b91d5767..bb3408daf 100644 --- a/boringtun/src/noise/timers.rs +++ b/boringtun/src/noise/timers.rs @@ -140,9 +140,7 @@ impl Tunn { // We don't really clear the timers, but we set them to the current time to // so the reference time frame is the same fn clear_all(&mut self) { - for session in &mut self.sessions { - *session = None; - } + self.clear_sessions(); self.packet_queue.clear(); @@ -312,15 +310,12 @@ impl Tunn { } pub fn time_since_last_handshake(&self) -> Option { - let current_session = self.current; - if self.sessions[current_session % super::N_SESSIONS].is_some() { - let duration_since_tun_start = Instant::now().duration_since(self.timers.time_started); - let duration_since_session_established = self.timers[TimeSessionEstablished]; + let _current = self.current_session()?; // Guard to ensure we have a session. - Some(duration_since_tun_start - duration_since_session_established) - } else { - None - } + let duration_since_tun_start = Instant::now().duration_since(self.timers.time_started); + let duration_since_session_established = self.timers[TimeSessionEstablished]; + + Some(duration_since_tun_start - duration_since_session_established) } pub fn persistent_keepalive(&self) -> Option {