From 4000aaf9e19921174a470825e79183dac5f96fc3 Mon Sep 17 00:00:00 2001 From: Quentin Monnet Date: Wed, 20 May 2026 22:28:30 +0100 Subject: [PATCH 1/5] refactor(nat): Clean a little the error translation code in NAT Rather than the large translate_error() functions, implement From for DoneReason. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Quentin Monnet --- nat/src/icmp_handler/icmp_error_msg.rs | 21 ++++++++++- nat/src/stateful/allocation.rs | 18 ++++++++++ nat/src/stateful/nf.rs | 49 +++++++++----------------- nat/src/stateless/nf.rs | 36 ++++++------------- 4 files changed, 65 insertions(+), 59 deletions(-) diff --git a/nat/src/icmp_handler/icmp_error_msg.rs b/nat/src/icmp_handler/icmp_error_msg.rs index 344fb32fd2..a0b6e7a39e 100644 --- a/nat/src/icmp_handler/icmp_error_msg.rs +++ b/nat/src/icmp_handler/icmp_error_msg.rs @@ -15,7 +15,7 @@ use net::headers::{ use net::icmp_any::TruncatedIcmpAny; use net::icmp_any::{IcmpAnyChecksumErrorPlaceholder, IcmpAnyChecksumPayload}; use net::ipv4::Ipv4; -use net::packet::Packet; +use net::packet::{DoneReason, Packet}; use std::net::IpAddr; use std::num::NonZero; @@ -41,6 +41,25 @@ pub enum IcmpErrorMsgError { NoTranslationPossible, } +impl From<&IcmpErrorMsgError> for DoneReason { + fn from(error: &IcmpErrorMsgError) -> Self { + match error { + IcmpErrorMsgError::NoIpHeader => DoneReason::NotIp, + IcmpErrorMsgError::InvalidPort(_) => DoneReason::Malformed, + IcmpErrorMsgError::NotUnicast(_) => DoneReason::NatFailure, + IcmpErrorMsgError::InvalidIpVersion | IcmpErrorMsgError::NoTranslationPossible => { + DoneReason::InternalFailure + } + IcmpErrorMsgError::BadChecksumIcmp(_) | IcmpErrorMsgError::BadChecksumInnerIpv4(_) => { + DoneReason::InvalidChecksum + } + IcmpErrorMsgError::NoEmbeddedHeaders | IcmpErrorMsgError::NoInnerIpHeader => { + DoneReason::Filtered + } + } + } +} + // # Return // // * `Ok(())` if checksums are valid and we can translate the inner packet diff --git a/nat/src/stateful/allocation.rs b/nat/src/stateful/allocation.rs index 78e605271c..c5426ff7a2 100644 --- a/nat/src/stateful/allocation.rs +++ b/nat/src/stateful/allocation.rs @@ -5,6 +5,7 @@ use crate::port::NatPortError; use net::ip::NextHeader; +use net::packet::DoneReason; use std::fmt::{Debug, Display}; use std::time::Duration; @@ -34,6 +35,23 @@ pub enum AllocatorError { Denied, } +impl From<&AllocatorError> for DoneReason { + fn from(error: &AllocatorError) -> Self { + match error { + AllocatorError::UnsupportedProtocol(_) => DoneReason::NatUnsupportedProto, + AllocatorError::NoFreeIp + | AllocatorError::NoPortBlock + | AllocatorError::NoFreePort(_) => DoneReason::NatOutOfResources, + AllocatorError::PortAllocationFailed(_) + | AllocatorError::PortReservationFailed(_) + | AllocatorError::MissingDiscriminant + | AllocatorError::UnsupportedDiscriminant => DoneReason::NatFailure, + AllocatorError::InternalIssue(_) => DoneReason::InternalFailure, + AllocatorError::Denied => DoneReason::Filtered, + } + } +} + /// `AllocationResult` is a struct to represent the result of an allocation. /// It contains the allocated IP address and port to masquerade a packet, /// and the time for the allocation (flow timeout). diff --git a/nat/src/stateful/nf.rs b/nat/src/stateful/nf.rs index ba5e759de9..1572b1640f 100644 --- a/nat/src/stateful/nf.rs +++ b/nat/src/stateful/nf.rs @@ -484,7 +484,7 @@ impl StatefulNat { // TODO: Check whether the packet is fragmented if let Err(error) = self.masquerade_packet(packet) { - packet.done(translate_error(&error)); + packet.done((&error).into()); debug!("Did not masquerade packet: {error}"); } else { packet.meta_mut().set_checksum_refresh(true); @@ -492,39 +492,22 @@ impl StatefulNat { } } -fn translate_error(error: &StatefulNatError) -> DoneReason { - match error { - StatefulNatError::BadTransportHeader - | StatefulNatError::AllocationFailure(AllocatorError::UnsupportedProtocol(_)) => { - DoneReason::NatUnsupportedProto - } - - StatefulNatError::FlowKeyError | StatefulNatError::InvalidPort(_) => DoneReason::Malformed, - - StatefulNatError::AllocationFailure( - AllocatorError::NoFreeIp | AllocatorError::NoPortBlock | AllocatorError::NoFreePort(_), - ) => DoneReason::NatOutOfResources, - - StatefulNatError::CapacityExceeded => DoneReason::FlowCapacityExceeded, - StatefulNatError::NoAllocator - | StatefulNatError::UnexpectedKeyVariant - | StatefulNatError::IcmpUnsupportedCategory - | StatefulNatError::IcmpError - | StatefulNatError::AllocationFailure( - AllocatorError::PortAllocationFailed(_) - | AllocatorError::PortReservationFailed(_) - | AllocatorError::MissingDiscriminant - | AllocatorError::UnsupportedDiscriminant, - ) - | StatefulNatError::NatError(_) => DoneReason::NatFailure, - - StatefulNatError::AllocationFailure(AllocatorError::InternalIssue(_)) => { - DoneReason::InternalFailure +impl From<&StatefulNatError> for DoneReason { + fn from(error: &StatefulNatError) -> Self { + match error { + StatefulNatError::BadTransportHeader => DoneReason::NatUnsupportedProto, + StatefulNatError::FlowKeyError | StatefulNatError::InvalidPort(_) => { + DoneReason::Malformed + } + StatefulNatError::CapacityExceeded => DoneReason::FlowCapacityExceeded, + StatefulNatError::NoAllocator + | StatefulNatError::UnexpectedKeyVariant + | StatefulNatError::IcmpUnsupportedCategory + | StatefulNatError::IcmpError + | StatefulNatError::NatError(_) => DoneReason::NatFailure, + StatefulNatError::Bug(_) | StatefulNatError::IntendedDrop(_) => DoneReason::Filtered, + StatefulNatError::AllocationFailure(inner) => inner.into(), } - - StatefulNatError::AllocationFailure(AllocatorError::Denied) - | StatefulNatError::Bug(_) - | StatefulNatError::IntendedDrop(_) => DoneReason::Filtered, } } diff --git a/nat/src/stateless/nf.rs b/nat/src/stateless/nf.rs index a4dac63d3b..80c7bdc44e 100644 --- a/nat/src/stateless/nf.rs +++ b/nat/src/stateless/nf.rs @@ -322,7 +322,7 @@ impl StatelessNat { match self.translate(nat_tables, packet, src_vni, dst_vni) { Err(error) => { debug!("{nfi}: Translation failed: {error}"); - packet.done(translate_error(&error)); + packet.done((&error).into()); } Ok(modified) => { if modified { @@ -336,30 +336,16 @@ impl StatelessNat { } } -fn translate_error(error: &StatelessNatError) -> DoneReason { - match error { - StatelessNatError::NoIpHeader - | StatelessNatError::IcmpErrorMsg(IcmpErrorMsgError::NoIpHeader) => DoneReason::NotIp, - - StatelessNatError::MissingTable(_) => DoneReason::Unroutable, - - StatelessNatError::IcmpErrorMsg(IcmpErrorMsgError::InvalidPort(_)) => DoneReason::Malformed, - - StatelessNatError::IcmpErrorMsg(IcmpErrorMsgError::NotUnicast(_)) => DoneReason::NatFailure, - - StatelessNatError::FailedToSetDestIp(_) - | StatelessNatError::FailedToSetSourceIp(_) - | StatelessNatError::IcmpErrorMsg( - IcmpErrorMsgError::InvalidIpVersion | IcmpErrorMsgError::NoTranslationPossible, - ) => DoneReason::InternalFailure, - - StatelessNatError::IcmpErrorMsg( - IcmpErrorMsgError::BadChecksumIcmp(_) | IcmpErrorMsgError::BadChecksumInnerIpv4(_), - ) => DoneReason::InvalidChecksum, - - StatelessNatError::IcmpErrorMsg( - IcmpErrorMsgError::NoEmbeddedHeaders | IcmpErrorMsgError::NoInnerIpHeader, - ) => DoneReason::Filtered, +impl From<&StatelessNatError> for DoneReason { + fn from(error: &StatelessNatError) -> Self { + match error { + StatelessNatError::NoIpHeader => DoneReason::NotIp, + StatelessNatError::MissingTable(_) => DoneReason::Unroutable, + StatelessNatError::FailedToSetSourceIp(_) | StatelessNatError::FailedToSetDestIp(_) => { + DoneReason::InternalFailure + } + StatelessNatError::IcmpErrorMsg(inner) => inner.into(), + } } } From db526053a42cdd96d7fc38891b272e74afb2a075 Mon Sep 17 00:00:00 2001 From: Quentin Monnet Date: Wed, 20 May 2026 22:40:54 +0100 Subject: [PATCH 2/5] refactor(net): Drop FlowKey's Unidirectional enum variant FlowKey is an enum with a single Unidirectional variant wrapping FlowKeyData, the second (bidirectional) variant was removed some time ago. Turn FlowKey into a tuple struct so the wrapping no longer requires pattern-matching, and drop the now-trivial match arms in PartialEq, Hash, and the Display impl. The inner field is left pub for now, a follow-up commit will merge FlowKey and FlowKeyData into a single struct. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Quentin Monnet --- flow-entry/src/flow_table/table.rs | 20 +++++++-------- net/src/flows/display.rs | 4 +-- net/src/flows/flow_key.rs | 41 ++++++------------------------ 3 files changed, 19 insertions(+), 46 deletions(-) diff --git a/flow-entry/src/flow_table/table.rs b/flow-entry/src/flow_table/table.rs index 4a0ed52c71..b71c48dd61 100644 --- a/flow-entry/src/flow_table/table.rs +++ b/flow-entry/src/flow_table/table.rs @@ -439,7 +439,7 @@ mod tests { let five_seconds_from_now = now + five_seconds; let flow_table = FlowTable::default(); - let flow_key = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key = FlowKey(FlowKeyData::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -469,7 +469,7 @@ mod tests { let one_second = Duration::from_secs(1); let flow_table = FlowTable::default(); - let flow_key = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key = FlowKey(FlowKeyData::new( Some(VpcDiscriminant::VNI(Vni::new_checked(42).unwrap())), "10.0.0.1".parse::().unwrap(), "10.0.0.2".parse::().unwrap(), @@ -501,7 +501,7 @@ mod tests { let second_expiry_time = now + Duration::from_secs(10); let flow_table = FlowTable::default(); - let flow_key = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key = FlowKey(FlowKeyData::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -586,7 +586,7 @@ mod tests { let mut flow_keys = vec![]; for src_port in 1..=NUM_FLOWS { - let flow_key = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key = FlowKey(FlowKeyData::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -627,7 +627,7 @@ mod tests { let now = Instant::now(); let deadline = now + Duration::from_secs(2); - let flow_key = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key = FlowKey(FlowKeyData::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -663,7 +663,7 @@ mod tests { for i in 1u16..=2 { let src_port = TcpPort::new_checked(1000 + i).unwrap(); let dst_port = TcpPort::new_checked(80).unwrap(); - let flow_key = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key = FlowKey(FlowKeyData::new( Some(src_vpcd), src_ip, dst_ip, @@ -675,7 +675,7 @@ mod tests { } // One more insert must fail with CapacityExceeded. - let overflow_key = FlowKey::Unidirectional(FlowKeyData::new( + let overflow_key = FlowKey(FlowKeyData::new( Some(src_vpcd), src_ip, dst_ip, @@ -709,7 +709,7 @@ mod tests { let two_seconds = Duration::from_secs(2); let flow_keys: Vec<_> = (0u16..2u16) .map(|i| { - FlowKey::Unidirectional(FlowKeyData::new( + FlowKey(FlowKeyData::new( Some(VpcDiscriminant::VNI( Vni::new_checked(u32::from(i) + 1).unwrap(), )), @@ -825,7 +825,7 @@ mod tests { let flow_table = Arc::new(FlowTable::default()); let five_seconds_from_now = Instant::now() + Duration::from_secs(5); - let flow_key1 = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key1 = FlowKey(FlowKeyData::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -835,7 +835,7 @@ mod tests { }), )); - let flow_key2 = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key2 = FlowKey(FlowKeyData::new( Some(VpcDiscriminant::VNI(Vni::new_checked(10).unwrap())), "10.2.3.4".parse::().unwrap(), "40.5.6.7".parse::().unwrap(), diff --git a/net/src/flows/display.rs b/net/src/flows/display.rs index 63b9695fc3..56d03fec93 100644 --- a/net/src/flows/display.rs +++ b/net/src/flows/display.rs @@ -33,9 +33,7 @@ impl Display for FlowKeyData { impl Display for FlowKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - FlowKey::Unidirectional(data) => write!(f, "{data}"), - } + write!(f, "{}", self.data()) } } diff --git a/net/src/flows/flow_key.rs b/net/src/flows/flow_key.rs index 6b3ad16b79..11016f35d3 100644 --- a/net/src/flows/flow_key.rs +++ b/net/src/flows/flow_key.rs @@ -581,23 +581,17 @@ impl Hash for FlowKeyData { } } -#[derive(Debug, Clone, Copy, Eq, PartialOrd, Ord)] -pub enum FlowKey { - Unidirectional(FlowKeyData), -} +#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)] +pub struct FlowKey(pub FlowKeyData); impl FlowKey { #[must_use] pub fn data(&self) -> &FlowKeyData { - match self { - FlowKey::Unidirectional(data) => data, - } + &self.0 } #[must_use] pub fn data_mut(&mut self) -> &mut FlowKeyData { - match self { - FlowKey::Unidirectional(data) => data, - } + &mut self.0 } /// Create a unidirectional flow key @@ -610,32 +604,13 @@ impl FlowKey { dst_ip: IpAddr, proto_key_info: IpProtoKey, ) -> FlowKey { - FlowKey::Unidirectional(FlowKeyData::new(src_vpcd, src_ip, dst_ip, proto_key_info)) + FlowKey(FlowKeyData::new(src_vpcd, src_ip, dst_ip, proto_key_info)) } // Creates the flow key with src and dst swapped #[must_use] pub fn reverse(&self, src_vpcd: Option) -> FlowKey { - match self { - FlowKey::Unidirectional(data) => FlowKey::Unidirectional(data.reverse(src_vpcd)), - } - } -} - -// The FlowKey Eq is symmetric, src == src or src == dst -impl PartialEq for FlowKey { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (FlowKey::Unidirectional(a), FlowKey::Unidirectional(b)) => a == b, - } - } -} - -impl Hash for FlowKey { - fn hash(&self, state: &mut H) { - match self { - FlowKey::Unidirectional(a) => a.hash(state), - } + FlowKey(self.0.reverse(src_vpcd)) } } @@ -884,7 +859,7 @@ mod contract { impl TypeGenerator for FlowKey { fn generate(driver: &mut D) -> Option { let data = FlowKeyData::generate(driver)?; - Some(FlowKey::Unidirectional(data)) + Some(FlowKey(data)) } } } @@ -1152,7 +1127,7 @@ mod tests { bolero::check!() .with_generator(FlowKeyAndPacket) .for_each(|(flow_key, packet)| match flow_key { - Some(FlowKey::Unidirectional(_)) => { + Some(_) => { let gen_flow_key = FlowKey::try_from(Uni(packet)).unwrap(); assert_eq!( gen_flow_key, From a9ab07dba65ba2b47d4e1bf8c3afe98e26ed9f2a Mon Sep 17 00:00:00 2001 From: Quentin Monnet Date: Wed, 20 May 2026 22:46:08 +0100 Subject: [PATCH 3/5] refactor(net): Merge FlowKey and FlowKeyData FlowKey had become a tuple struct wrapping FlowKeyData, with accessors data() and data_mut() used at every call site. The two types described the same flow tuple, so we can fold FlowKeyData's fields, methods, and Hash implementation directly into FlowKey and drop the indirection. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Quentin Monnet --- flow-entry/src/flow_table/table.rs | 42 ++++++------- nat/src/portfw/flow_state.rs | 10 ++- nat/src/stateful/flows.rs | 6 +- nat/src/stateful/nf.rs | 12 ++-- net/src/flows/display.rs | 10 +-- net/src/flows/flow_key.rs | 97 +++++++----------------------- net/src/lib.rs | 4 +- 7 files changed, 60 insertions(+), 121 deletions(-) diff --git a/flow-entry/src/flow_table/table.rs b/flow-entry/src/flow_table/table.rs index b71c48dd61..8700392150 100644 --- a/flow-entry/src/flow_table/table.rs +++ b/flow-entry/src/flow_table/table.rs @@ -423,7 +423,7 @@ mod tests { use net::tcp::TcpPort; use net::vxlan::Vni; - use net::{FlowKey, FlowKeyData, IpProtoKey, TcpProtoKey}; + use net::{FlowKey, IpProtoKey, TcpProtoKey}; #[concurrency_mode(std)] mod std_tests { @@ -439,7 +439,7 @@ mod tests { let five_seconds_from_now = now + five_seconds; let flow_table = FlowTable::default(); - let flow_key = FlowKey(FlowKeyData::new( + let flow_key = FlowKey::uni( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -447,7 +447,7 @@ mod tests { src_port: TcpPort::new_checked(1025).unwrap(), dst_port: TcpPort::new_checked(2048).unwrap(), }), - )); + ); let flow_info = FlowInfo::new(flow_key, five_seconds_from_now); @@ -469,7 +469,7 @@ mod tests { let one_second = Duration::from_secs(1); let flow_table = FlowTable::default(); - let flow_key = FlowKey(FlowKeyData::new( + let flow_key = FlowKey::uni( Some(VpcDiscriminant::VNI(Vni::new_checked(42).unwrap())), "10.0.0.1".parse::().unwrap(), "10.0.0.2".parse::().unwrap(), @@ -477,7 +477,7 @@ mod tests { src_port: TcpPort::new_checked(1234).unwrap(), dst_port: TcpPort::new_checked(5678).unwrap(), }), - )); + ); let flow_info = FlowInfo::new(flow_key, now + two_seconds); flow_table.insert(flow_info).unwrap(); @@ -501,7 +501,7 @@ mod tests { let second_expiry_time = now + Duration::from_secs(10); let flow_table = FlowTable::default(); - let flow_key = FlowKey(FlowKeyData::new( + let flow_key = FlowKey::uni( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -509,7 +509,7 @@ mod tests { src_port: TcpPort::new_checked(1025).unwrap(), dst_port: TcpPort::new_checked(2048).unwrap(), }), - )); + ); // Insert first entry. let first_arc = Arc::new(FlowInfo::new(flow_key, first_expiry_time)); @@ -586,7 +586,7 @@ mod tests { let mut flow_keys = vec![]; for src_port in 1..=NUM_FLOWS { - let flow_key = FlowKey(FlowKeyData::new( + let flow_key = FlowKey::uni( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -594,7 +594,7 @@ mod tests { src_port: TcpPort::new_checked(src_port).unwrap(), dst_port: TcpPort::new_checked(2048).unwrap(), }), - )); + ); let flow_info = FlowInfo::new(flow_key, deadline); flow_table.insert(flow_info).unwrap(); flow_keys.push(flow_key); @@ -627,7 +627,7 @@ mod tests { let now = Instant::now(); let deadline = now + Duration::from_secs(2); - let flow_key = FlowKey(FlowKeyData::new( + let flow_key = FlowKey::uni( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -635,7 +635,7 @@ mod tests { src_port: TcpPort::new_checked(1).unwrap(), dst_port: TcpPort::new_checked(2048).unwrap(), }), - )); + ); let flow_info = FlowInfo::new(flow_key, deadline); flow_table.insert(flow_info).unwrap(); @@ -663,19 +663,19 @@ mod tests { for i in 1u16..=2 { let src_port = TcpPort::new_checked(1000 + i).unwrap(); let dst_port = TcpPort::new_checked(80).unwrap(); - let flow_key = FlowKey(FlowKeyData::new( + let flow_key = FlowKey::uni( Some(src_vpcd), src_ip, dst_ip, IpProtoKey::Tcp(TcpProtoKey { src_port, dst_port }), - )); + ); flow_table .insert(FlowInfo::new(flow_key, far_future)) .expect("insert under capacity should succeed"); } // One more insert must fail with CapacityExceeded. - let overflow_key = FlowKey(FlowKeyData::new( + let overflow_key = FlowKey::uni( Some(src_vpcd), src_ip, dst_ip, @@ -683,7 +683,7 @@ mod tests { src_port: TcpPort::new_checked(9999).unwrap(), dst_port: TcpPort::new_checked(80).unwrap(), }), - )); + ); assert!(matches!( flow_table.insert(FlowInfo::new(overflow_key, far_future)), Err(FlowTableError::CapacityExceeded) @@ -709,7 +709,7 @@ mod tests { let two_seconds = Duration::from_secs(2); let flow_keys: Vec<_> = (0u16..2u16) .map(|i| { - FlowKey(FlowKeyData::new( + FlowKey::uni( Some(VpcDiscriminant::VNI( Vni::new_checked(u32::from(i) + 1).unwrap(), )), @@ -719,7 +719,7 @@ mod tests { src_port: TcpPort::new_checked(1000 + i).unwrap(), dst_port: TcpPort::new_checked(2000 + i).unwrap(), }), - )) + ) }) .collect(); @@ -825,7 +825,7 @@ mod tests { let flow_table = Arc::new(FlowTable::default()); let five_seconds_from_now = Instant::now() + Duration::from_secs(5); - let flow_key1 = FlowKey(FlowKeyData::new( + let flow_key1 = FlowKey::uni( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -833,9 +833,9 @@ mod tests { src_port: TcpPort::new_checked(1025).unwrap(), dst_port: TcpPort::new_checked(2048).unwrap(), }), - )); + ); - let flow_key2 = FlowKey(FlowKeyData::new( + let flow_key2 = FlowKey::uni( Some(VpcDiscriminant::VNI(Vni::new_checked(10).unwrap())), "10.2.3.4".parse::().unwrap(), "40.5.6.7".parse::().unwrap(), @@ -843,7 +843,7 @@ mod tests { src_port: TcpPort::new_checked(1025).unwrap(), dst_port: TcpPort::new_checked(2048).unwrap(), }), - )); + ); let flow_table_clone1 = flow_table.clone(); let flow_table_clone2 = flow_table.clone(); diff --git a/nat/src/portfw/flow_state.rs b/nat/src/portfw/flow_state.rs index 705c953ca3..0666b57228 100644 --- a/nat/src/portfw/flow_state.rs +++ b/nat/src/portfw/flow_state.rs @@ -120,14 +120,12 @@ pub(crate) fn build_portfw_flow_keys( .unwrap_or(current_flow_key); // Build the key for the reverse path - let proto = current_flow_key.data().proto(); - let src_port = current_flow_key.data().src_port().ok_or(())?; + let proto = current_flow_key.proto(); + let src_port = current_flow_key.src_port().ok_or(())?; let mut key_forward_dnated = current_flow_key; - key_forward_dnated.data_mut().set_dst_ip(new_dst_ip.inner()); - key_forward_dnated - .data_mut() - .set_ip_proto_key(IpProtoKey::from((proto, src_port, new_dst_port))); + key_forward_dnated.set_dst_ip(new_dst_ip.inner()); + key_forward_dnated.set_ip_proto_key(IpProtoKey::from((proto, src_port, new_dst_port))); let key_reverse = key_forward_dnated.reverse(Some(dst_vpcd)); Ok((initial_flow_key, key_reverse)) diff --git a/nat/src/stateful/flows.rs b/nat/src/stateful/flows.rs index a102ddd8de..63332634c5 100644 --- a/nat/src/stateful/flows.rs +++ b/nat/src/stateful/flows.rs @@ -61,9 +61,9 @@ fn re_reserve_ip_and_port( port: NatPort, ) -> Result<(), ()> { let flow_key = flow_info.flowkey(); - let proto = flow_key.data().proto(); + let proto = flow_key.proto(); let dst_vpcd = flow_info.get_dst_vpcd().unwrap_or_else(|| unreachable!()); - let src_ip = *flow_key.data().src_ip(); + let src_ip = *flow_key.src_ip(); let port_u16 = port.as_u16(); debug!("Attempting to reserve {ip} {port_u16} {proto}..."); @@ -104,7 +104,7 @@ pub(crate) fn check_masquerading_flow( return; }; let dst_vpcd = flow_info.get_dst_vpcd().unwrap_or_else(|| unreachable!()); - let src_vpcd = flow_key.data().src_vpcd().unwrap_or_else(|| unreachable!()); + let src_vpcd = flow_key.src_vpcd().unwrap_or_else(|| unreachable!()); debug!("Checking flow {}", flow_info.logfmt()); diff --git a/nat/src/stateful/nf.rs b/nat/src/stateful/nf.rs index 1572b1640f..4de9e83eee 100644 --- a/nat/src/stateful/nf.rs +++ b/nat/src/stateful/nf.rs @@ -219,8 +219,8 @@ impl StatefulNat { } fn get_reverse_mapping(flow_key: &FlowKey) -> Result<(IpAddr, NatPort), StatefulNatError> { - let src_ip = *flow_key.data().src_ip(); - let src_port = match flow_key.data().proto_key_info() { + let src_ip = *flow_key.src_ip(); + let src_port = match flow_key.proto_key_info() { IpProtoKey::Tcp(tcp) => tcp.src_port.into(), IpProtoKey::Udp(udp) => udp.src_port.into(), IpProtoKey::Icmp(icmp) => NatPort::Identifier(Self::get_icmp_query_id(icmp)?), @@ -313,12 +313,12 @@ impl StatefulNat { // So we want: // - tuple r.init = (src: f.nated.dst, dst: f.nated.src) // - mapping r.nated = (src: f.init.dst, dst: f.init.src) - let reverse_src_addr = *flow_key.data().dst_ip(); + let reverse_src_addr = *flow_key.dst_ip(); let reverse_dst_addr = alloc.allocation.ip(); let dst_port = alloc.allocation.port(); // Reverse the forward protocol key and adjust ports to use the allocated values. - let mut reverse_proto_key = flow_key.data().proto_key_info().reverse(); + let mut reverse_proto_key = flow_key.proto_key_info().reverse(); match reverse_proto_key { IpProtoKey::Tcp(_) | IpProtoKey::Udp(_) => { reverse_proto_key @@ -388,9 +388,9 @@ impl StatefulNat { .unwrap_or(current_flow_key); // Create a new session and translate the address - let src_ip = *initial_flow_key.data().src_ip(); + let src_ip = *initial_flow_key.src_ip(); let alloc = allocator - .allocate(dst_vpcd, src_ip, initial_flow_key.data().proto()) + .allocate(dst_vpcd, src_ip, initial_flow_key.proto()) .map_err(StatefulNatError::AllocationFailure)?; // Forbid addresses we won't know how to translate. This is a work around of a larger change diff --git a/net/src/flows/display.rs b/net/src/flows/display.rs index 56d03fec93..4617cbef16 100644 --- a/net/src/flows/display.rs +++ b/net/src/flows/display.rs @@ -4,13 +4,13 @@ //! Flow keys use super::flow_info::{FlowInfo, FlowInfoLocked}; -use super::flow_key::{FlowKey, FlowKeyData}; +use super::flow_key::FlowKey; use concurrency::sync::Weak; use std::fmt::Display; use std::time::Instant; -impl Display for FlowKeyData { +impl Display for FlowKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if let Some(vpcd) = self.src_vpcd() { write!(f, "from {vpcd},")?; @@ -31,12 +31,6 @@ impl Display for FlowKeyData { } } -impl Display for FlowKey { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.data()) - } -} - impl Display for FlowInfoLocked { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if let Some(data) = &self.dst_vpcd { diff --git a/net/src/flows/flow_key.rs b/net/src/flows/flow_key.rs index 11016f35d3..b51788e0ad 100644 --- a/net/src/flows/flow_key.rs +++ b/net/src/flows/flow_key.rs @@ -454,26 +454,29 @@ impl HashDst for IpProtoKey { } #[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)] -pub struct FlowKeyData { +pub struct FlowKey { src_vpcd: Option, src_ip: IpAddr, dst_ip: IpAddr, proto_key_info: IpProtoKey, } -impl FlowKeyData { +impl FlowKey { + /// Create a unidirectional flow key + /// + /// packets with src -> dst will match, but dst -> src will not #[must_use] - pub fn new( + pub fn uni( src_vpcd: Option, src_ip: IpAddr, dst_ip: IpAddr, - ip_proto_key: IpProtoKey, + proto_key_info: IpProtoKey, ) -> Self { Self { src_vpcd, src_ip, dst_ip, - proto_key_info: ip_proto_key, + proto_key_info, } } @@ -571,7 +574,7 @@ impl FlowKeyData { } } -impl Hash for FlowKeyData { +impl Hash for FlowKey { fn hash(&self, state: &mut H) { self.src_vpcd.hash(state); self.src_ip.hash(state); @@ -581,45 +584,12 @@ impl Hash for FlowKeyData { } } -#[derive(Debug, Clone, Copy, Eq, PartialEq, PartialOrd, Ord, Hash)] -pub struct FlowKey(pub FlowKeyData); - -impl FlowKey { - #[must_use] - pub fn data(&self) -> &FlowKeyData { - &self.0 - } - #[must_use] - pub fn data_mut(&mut self) -> &mut FlowKeyData { - &mut self.0 - } - - /// Create a unidirectional flow key - /// - /// packets with src -> dst will match, but dst -> src will not - #[must_use] - pub fn uni( - src_vpcd: Option, - src_ip: IpAddr, - dst_ip: IpAddr, - proto_key_info: IpProtoKey, - ) -> FlowKey { - FlowKey(FlowKeyData::new(src_vpcd, src_ip, dst_ip, proto_key_info)) - } - - // Creates the flow key with src and dst swapped - #[must_use] - pub fn reverse(&self, src_vpcd: Option) -> FlowKey { - FlowKey(self.0.reverse(src_vpcd)) - } -} - /// Wrapper to specify unidirectional `FlowKey` creation #[repr(transparent)] #[derive(Debug)] pub struct Uni(pub T); -fn flow_key_data_from_packet(packet: &Packet) -> Option { +fn flow_key_from_packet(packet: &Packet) -> Option { let ip = packet.headers().try_ip()?; let src_ip = ip.src_addr(); let dst_ip = ip.dst_addr(); @@ -641,21 +611,13 @@ fn flow_key_data_from_packet(packet: &Packet) -> Opti }; let src_vpcd = packet.meta().src_vpcd; - Some(FlowKeyData::new(src_vpcd, src_ip, dst_ip, ip_proto_key)) + Some(FlowKey::uni(src_vpcd, src_ip, dst_ip, ip_proto_key)) } impl TryFrom>> for FlowKey { type Error = FlowKeyError; fn try_from(packet: Uni<&Packet>) -> Result { - let packet = packet.0; - let FlowKeyData { - src_vpcd, - src_ip, - dst_ip, - proto_key_info, - } = flow_key_data_from_packet(packet).ok_or(FlowKeyError::NoFlowKeyData)?; - - Ok(FlowKey::uni(src_vpcd, src_ip, dst_ip, proto_key_info)) + flow_key_from_packet(packet.0).ok_or(FlowKeyError::NoFlowKeyData) } } @@ -720,8 +682,8 @@ pub fn flowkey_embedded_in_icmp_error( #[cfg(any(test, feature = "bolero"))] mod contract { use super::{ - EmbeddedPacketData, FlowKey, FlowKeyData, IcmpProtoKey, InnerIcmpProtoKey, InnerIpProtoKey, - IpProtoKey, TcpProtoKey, UdpProtoKey, + EmbeddedPacketData, FlowKey, IcmpProtoKey, InnerIcmpProtoKey, InnerIpProtoKey, IpProtoKey, + TcpProtoKey, UdpProtoKey, }; use crate::ip::UnicastIpAddr; use crate::ipv4::addr::UnicastIpv4Addr; @@ -830,7 +792,7 @@ mod contract { } } - impl TypeGenerator for FlowKeyData { + impl TypeGenerator for FlowKey { fn generate(driver: &mut D) -> Option { let src_vpcd = driver.produce(); let v6 = driver.produce::()?; @@ -847,19 +809,7 @@ mod contract { ) }; let proto_key_info = super::IpProtoKey::generate(driver)?; - Some(FlowKeyData { - src_vpcd, - src_ip, - dst_ip, - proto_key_info, - }) - } - } - - impl TypeGenerator for FlowKey { - fn generate(driver: &mut D) -> Option { - let data = FlowKeyData::generate(driver)?; - Some(FlowKey(data)) + Some(FlowKey::uni(src_vpcd, src_ip, dst_ip, proto_key_info)) } } } @@ -891,11 +841,11 @@ mod tests { ); let reverse_flow_key = flow_key.reverse(None); - assert_eq!(flow_key.data().src_ip, reverse_flow_key.data().dst_ip); - assert_eq!(flow_key.data().dst_ip, reverse_flow_key.data().src_ip); + assert_eq!(flow_key.src_ip, reverse_flow_key.dst_ip); + assert_eq!(flow_key.dst_ip, reverse_flow_key.src_ip); assert_eq!( - flow_key.data().proto_key_info, - reverse_flow_key.data().proto_key_info.reverse() + flow_key.proto_key_info, + reverse_flow_key.proto_key_info.reverse() ); } @@ -931,12 +881,11 @@ mod tests { /// This function panics if the packet has a different transport protocol than the flow key. /// It also panics if the packet IP address family does not match the flow key. fn set_packet_fields(packet: &mut Packet, flow_key: &FlowKey) { - let flow_key_data = flow_key.data(); packet - .set_ip_source(flow_key_data.src_ip.try_into().unwrap()) + .set_ip_source(flow_key.src_ip.try_into().unwrap()) .unwrap(); - packet.set_ip_destination(flow_key_data.dst_ip).unwrap(); - match flow_key_data.proto_key_info { + packet.set_ip_destination(flow_key.dst_ip).unwrap(); + match flow_key.proto_key_info { IpProtoKey::Tcp(tcp) => { packet.set_tcp_source_port(tcp.src_port).unwrap(); packet.set_tcp_destination_port(tcp.dst_port).unwrap(); diff --git a/net/src/lib.rs b/net/src/lib.rs index ffb50a43d2..6f9e3f938f 100644 --- a/net/src/lib.rs +++ b/net/src/lib.rs @@ -43,6 +43,4 @@ pub mod vxlan; // re-export #[cfg(unix)] -pub use flows::flow_key::{ - self, FlowKey, FlowKeyData, IcmpProtoKey, IpProtoKey, TcpProtoKey, UdpProtoKey, -}; +pub use flows::flow_key::{self, FlowKey, IcmpProtoKey, IpProtoKey, TcpProtoKey, UdpProtoKey}; From 77e380384544e8a196629e3d157a4564d1fc9e66 Mon Sep 17 00:00:00 2001 From: Quentin Monnet Date: Wed, 20 May 2026 22:48:29 +0100 Subject: [PATCH 4/5] refactor(net): Rename FlowKey::uni() to FlowKey::new() FlowKey used to support a bidirectional variant alongside the unidirectional one, and the builder was named uni() to disambiguate. Bidirectional keys were dropped a while ago, and the remaining builder named "uni()" is now confusing. Rename it to "new()" instead. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Quentin Monnet --- flow-entry/src/flow_table/table.rs | 20 ++++++++++---------- nat/src/stateful/nf.rs | 4 ++-- net/src/flows/flow_key.rs | 19 ++++++++++--------- 3 files changed, 22 insertions(+), 21 deletions(-) diff --git a/flow-entry/src/flow_table/table.rs b/flow-entry/src/flow_table/table.rs index 8700392150..e3e191afcb 100644 --- a/flow-entry/src/flow_table/table.rs +++ b/flow-entry/src/flow_table/table.rs @@ -439,7 +439,7 @@ mod tests { let five_seconds_from_now = now + five_seconds; let flow_table = FlowTable::default(); - let flow_key = FlowKey::uni( + let flow_key = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -469,7 +469,7 @@ mod tests { let one_second = Duration::from_secs(1); let flow_table = FlowTable::default(); - let flow_key = FlowKey::uni( + let flow_key = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(42).unwrap())), "10.0.0.1".parse::().unwrap(), "10.0.0.2".parse::().unwrap(), @@ -501,7 +501,7 @@ mod tests { let second_expiry_time = now + Duration::from_secs(10); let flow_table = FlowTable::default(); - let flow_key = FlowKey::uni( + let flow_key = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -586,7 +586,7 @@ mod tests { let mut flow_keys = vec![]; for src_port in 1..=NUM_FLOWS { - let flow_key = FlowKey::uni( + let flow_key = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -627,7 +627,7 @@ mod tests { let now = Instant::now(); let deadline = now + Duration::from_secs(2); - let flow_key = FlowKey::uni( + let flow_key = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -663,7 +663,7 @@ mod tests { for i in 1u16..=2 { let src_port = TcpPort::new_checked(1000 + i).unwrap(); let dst_port = TcpPort::new_checked(80).unwrap(); - let flow_key = FlowKey::uni( + let flow_key = FlowKey::new( Some(src_vpcd), src_ip, dst_ip, @@ -675,7 +675,7 @@ mod tests { } // One more insert must fail with CapacityExceeded. - let overflow_key = FlowKey::uni( + let overflow_key = FlowKey::new( Some(src_vpcd), src_ip, dst_ip, @@ -709,7 +709,7 @@ mod tests { let two_seconds = Duration::from_secs(2); let flow_keys: Vec<_> = (0u16..2u16) .map(|i| { - FlowKey::uni( + FlowKey::new( Some(VpcDiscriminant::VNI( Vni::new_checked(u32::from(i) + 1).unwrap(), )), @@ -825,7 +825,7 @@ mod tests { let flow_table = Arc::new(FlowTable::default()); let five_seconds_from_now = Instant::now() + Duration::from_secs(5); - let flow_key1 = FlowKey::uni( + let flow_key1 = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -835,7 +835,7 @@ mod tests { }), ); - let flow_key2 = FlowKey::uni( + let flow_key2 = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(10).unwrap())), "10.2.3.4".parse::().unwrap(), "40.5.6.7".parse::().unwrap(), diff --git a/nat/src/stateful/nf.rs b/nat/src/stateful/nf.rs index 4de9e83eee..caeb97138d 100644 --- a/nat/src/stateful/nf.rs +++ b/nat/src/stateful/nf.rs @@ -198,7 +198,7 @@ impl StatefulNat { dst_ip: IpAddr, proto_key_info: IpProtoKey, ) -> Option<(NatTranslate, Duration)> { - let flow_key = FlowKey::uni(src_vpcd, src_ip, dst_ip, proto_key_info); + let flow_key = FlowKey::new(src_vpcd, src_ip, dst_ip, proto_key_info); let flow_info = self.flow_table.lookup(&flow_key)?; let value = flow_info.locked.read(); let state = value.nat_state.as_ref()?.extract_ref::()?; @@ -339,7 +339,7 @@ impl StatefulNat { } } - Ok(FlowKey::uni( + Ok(FlowKey::new( Some(dst_vpc_id), reverse_src_addr, reverse_dst_addr, diff --git a/net/src/flows/flow_key.rs b/net/src/flows/flow_key.rs index b51788e0ad..fbf57671ff 100644 --- a/net/src/flows/flow_key.rs +++ b/net/src/flows/flow_key.rs @@ -462,11 +462,12 @@ pub struct FlowKey { } impl FlowKey { - /// Create a unidirectional flow key + /// Create a flow key. /// - /// packets with src -> dst will match, but dst -> src will not + /// Flow keys are unidirectional: packets with src -> dst will match, but + /// dst -> src will not. #[must_use] - pub fn uni( + pub fn new( src_vpcd: Option, src_ip: IpAddr, dst_ip: IpAddr, @@ -611,7 +612,7 @@ fn flow_key_from_packet(packet: &Packet) -> Option TryFrom>> for FlowKey { @@ -676,7 +677,7 @@ pub fn flowkey_embedded_in_icmp_error( return Err(FlowKeyError::EmbeddedMissingIcmpId); } }; - Ok(FlowKey::uni(None, src_ip, dst_ip, proto_key)) + Ok(FlowKey::new(None, src_ip, dst_ip, proto_key)) } #[cfg(any(test, feature = "bolero"))] @@ -809,7 +810,7 @@ mod contract { ) }; let proto_key_info = super::IpProtoKey::generate(driver)?; - Some(FlowKey::uni(src_vpcd, src_ip, dst_ip, proto_key_info)) + Some(FlowKey::new(src_vpcd, src_ip, dst_ip, proto_key_info)) } } } @@ -830,7 +831,7 @@ mod tests { #[test] fn test_flow_key_reverse() { - let flow_key = FlowKey::uni( + let flow_key = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -851,7 +852,7 @@ mod tests { #[test] fn test_flow_key_uni_hash() { - let flow_key = FlowKey::uni( + let flow_key = FlowKey::new( None, "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -1062,7 +1063,7 @@ mod tests { }; if let Some(proto) = proto { let (flow_key, mut packet) = - (FlowKey::uni(src_vpcd, src_ip, dst_ip, proto), packet); + (FlowKey::new(src_vpcd, src_ip, dst_ip, proto), packet); set_packet_fields(&mut packet, &flow_key); Some((Some(flow_key), packet)) } else { From 3199fcf5887bdebf18c83aaf3f185ddd81ccaaf6 Mon Sep 17 00:00:00 2001 From: Quentin Monnet Date: Wed, 20 May 2026 22:51:30 +0100 Subject: [PATCH 5/5] refactor(net): Drop the Uni wrapper around packets Uni existed to tag a packet for unidirectional FlowKey extraction back when bidirectional keys were also supported. Bidirectional keys are gone, so Uni is not longer necessary, it simply adds dead weight that every call site has to thread through, and makes the code less clear. Convert the TryFrom impl to accept &Packet directly, delete the Uni struct, and drop the wrapper at the call sites. Co-Authored-By: Claude Opus 4.7 (1M context) Signed-off-by: Quentin Monnet --- flow-entry/src/flow_table/nf_lookup.rs | 15 +++--- flow-filter/src/lib.rs | 2 +- flow-filter/src/tests.rs | 3 +- nat/src/portfw/flow_state.rs | 3 +- nat/src/stateful/nf.rs | 4 +- nat/src/stateful/test.rs | 3 +- net/src/flows/flow_key.rs | 65 ++++++++++++-------------- net/src/packet/mod.rs | 3 +- 8 files changed, 45 insertions(+), 53 deletions(-) diff --git a/flow-entry/src/flow_table/nf_lookup.rs b/flow-entry/src/flow_table/nf_lookup.rs index a609e48da9..d57827b9c6 100644 --- a/flow-entry/src/flow_table/nf_lookup.rs +++ b/flow-entry/src/flow_table/nf_lookup.rs @@ -12,7 +12,6 @@ use pipeline::NetworkFunction; use crate::flow_table::FlowTable; use net::FlowKey; -use net::flow_key; use tracectl::trace_target; trace_target!("flow-lookup", LevelFilter::INFO, &["pipeline"]); @@ -40,7 +39,7 @@ impl NetworkFunction for FlowLookup { input.filter_map(move |mut packet| { let nfi = &self.name; if !packet.is_done() && packet.meta().is_overlay() && packet.meta().dst_vpcd.is_none() { - if let Ok(flow_key) = FlowKey::try_from(flow_key::Uni(&packet)) { + if let Ok(flow_key) = FlowKey::try_from(&packet) { if let Some(flow_info) = self.flow_table.lookup(&flow_key) { debug!("{nfi}: Tagging packet with flow info for flow key {flow_key}",); packet.meta_mut().flow_info = Some(flow_info); @@ -101,7 +100,7 @@ mod test { packet.meta_mut().set_overlay(true); // Insert matching flow entry - let flow_key = FlowKey::try_from(net::flow_key::Uni(&packet)).unwrap(); + let flow_key = FlowKey::try_from(&packet).unwrap(); let flow_info = FlowInfo::new(flow_key, Instant::now() + Duration::from_secs(10)); flow_table.insert(flow_info).unwrap(); @@ -130,7 +129,7 @@ mod test { input: Input, ) -> impl Iterator> + 'a { input.filter_map(move |packet| { - let flow_key = FlowKey::try_from(net::flow_key::Uni(&packet)).unwrap(); + let flow_key = FlowKey::try_from(&packet).unwrap(); let flow_info = FlowInfo::new(flow_key, Instant::now() + self.timeout); self.flow_table .insert(flow_info) @@ -193,8 +192,8 @@ mod test { packet_2.meta_mut().set_overlay(true); // build keys for the packets - let key_1 = FlowKey::try_from(net::flow_key::Uni(&packet_1)).unwrap(); - let key_2 = FlowKey::try_from(net::flow_key::Uni(&packet_2)).unwrap(); + let key_1 = FlowKey::try_from(&packet_1).unwrap(); + let key_2 = FlowKey::try_from(&packet_2).unwrap(); // create a pair of related flow entries; flow_2 will get a longer timeout let expires_at = tokio::time::Instant::now().into_std() + Duration::from_secs(2); @@ -251,8 +250,8 @@ mod test { let mut packet_2 = build_test_udp_ipv4_packet("192.168.1.1", "20.0.0.1", 500, 80); packet_1.meta_mut().set_overlay(true); packet_2.meta_mut().set_overlay(true); - let key_1 = FlowKey::try_from(net::flow_key::Uni(&packet_1)).unwrap(); - let key_2 = FlowKey::try_from(net::flow_key::Uni(&packet_2)).unwrap(); + let key_1 = FlowKey::try_from(&packet_1).unwrap(); + let key_2 = FlowKey::try_from(&packet_2).unwrap(); let input = vec![packet_1, packet_2]; let out: Vec<_> = pipeline.process(input.into_iter()).collect(); let packet_1 = &out[0]; diff --git a/flow-filter/src/lib.rs b/flow-filter/src/lib.rs index cb8aebb99f..acfc52905e 100644 --- a/flow-filter/src/lib.rs +++ b/flow-filter/src/lib.rs @@ -363,7 +363,7 @@ impl FlowFilter { return; } - let Ok(flow_key) = FlowKey::try_from(net::flow_key::Uni(&*packet)) else { + let Ok(flow_key) = FlowKey::try_from(&*packet) else { return; }; diff --git a/flow-filter/src/tests.rs b/flow-filter/src/tests.rs index 228bed35a5..50e38fca76 100644 --- a/flow-filter/src/tests.rs +++ b/flow-filter/src/tests.rs @@ -12,7 +12,6 @@ use config::external::overlay::vpcpeering::{VpcExpose, VpcManifest, VpcPeering, use lpm::prefix::{L4Protocol, PortRange, Prefix, PrefixWithOptionalPorts}; use net::FlowKey; use net::buffer::{PacketBufferMut, TestBuffer}; -use net::flow_key::Uni; use net::flows::{FlowInfo, FlowStatus}; use net::headers::{Net, TryHeadersMut, TryIpMut}; use net::ip::NextHeader; @@ -191,7 +190,7 @@ fn fake_flow_session( set_port_fw_state: bool, ) { // build flow key - let flow_key = FlowKey::try_from(Uni(&*packet)).unwrap(); + let flow_key = FlowKey::try_from(&*packet).unwrap(); // Create flow_info with dst_vpcd and NAT info and attach it to the packet let flow_info = FlowInfo::new(flow_key, Instant::now() + Duration::from_secs(60)); diff --git a/nat/src/portfw/flow_state.rs b/nat/src/portfw/flow_state.rs index 0666b57228..0b5c8c5677 100644 --- a/nat/src/portfw/flow_state.rs +++ b/nat/src/portfw/flow_state.rs @@ -6,7 +6,6 @@ #![allow(clippy::single_match_else)] use net::buffer::PacketBufferMut; -use net::flow_key::Uni; use net::flows::{ExtractRef, FlowStatus}; use net::ip::UnicastIpAddr; use net::packet::{Packet, VpcDiscriminant}; @@ -107,7 +106,7 @@ pub(crate) fn build_portfw_flow_keys( dst_vpcd: VpcDiscriminant, // destination VPC to forward to ) -> Result<(FlowKey, FlowKey), ()> { // Extract flow key for the current packet - let current_flow_key = FlowKey::try_from(Uni(&*packet)).map_err(|_| ())?; + let current_flow_key = FlowKey::try_from(&*packet).map_err(|_| ())?; // Retrieve initial flow key for the current packet (before any other NAT translation); if // we don't have the information, we didn't populate it because we don't need it and fall diff --git a/nat/src/stateful/nf.rs b/nat/src/stateful/nf.rs index caeb97138d..0f91082d03 100644 --- a/nat/src/stateful/nf.rs +++ b/nat/src/stateful/nf.rs @@ -16,7 +16,7 @@ use crate::stateful::state::MasqueradeState; use concurrency::sync::{Arc, Weak}; use flow_entry::flow_table::table::{FlowTable, FlowTableError}; use net::buffer::PacketBufferMut; -use net::flow_key::{IcmpProtoKey, Uni}; +use net::flow_key::IcmpProtoKey; use net::flows::{ExtractRef, FlowInfo}; use net::headers::{TryIp, TryTcp}; use net::ip::UnicastIpAddr; @@ -375,7 +375,7 @@ impl StatefulNat { // Extract flow key for the current packet let current_flow_key = - FlowKey::try_from(Uni(&*packet)).map_err(|_| StatefulNatError::FlowKeyError)?; + FlowKey::try_from(&*packet).map_err(|_| StatefulNatError::FlowKeyError)?; // Retrieve initial flow key for the current packet (before any other NAT translation); if // we don't have the information, we didn't populate it because we don't need it and fall diff --git a/nat/src/stateful/test.rs b/nat/src/stateful/test.rs index e509e580d3..759c909f36 100644 --- a/nat/src/stateful/test.rs +++ b/nat/src/stateful/test.rs @@ -18,7 +18,6 @@ use flow_entry::flow_table::{FlowLookup, FlowTable}; use flow_filter::{FlowFilter, FlowFilterTable, FlowFilterTableWriter}; use net::buffer::{PacketBufferMut, TestBuffer}; use net::eth::mac::Mac; -use net::flow_key::Uni; use net::flows::FlowStatus; use net::flows::flow_info_item::ExtractRef; use net::headers::TryTcpMut; @@ -366,7 +365,7 @@ fn check_packet( } fn flow_lookup(flow_table: &FlowTable, packet: &mut Packet) { - let flow_key = FlowKey::try_from(Uni(&*packet)).unwrap(); + let flow_key = FlowKey::try_from(&*packet).unwrap(); if let Some(flow_info) = flow_table.lookup(&flow_key) { packet.meta_mut().flow_info = Some(flow_info); } diff --git a/net/src/flows/flow_key.rs b/net/src/flows/flow_key.rs index fbf57671ff..38dd6de575 100644 --- a/net/src/flows/flow_key.rs +++ b/net/src/flows/flow_key.rs @@ -585,40 +585,37 @@ impl Hash for FlowKey { } } -/// Wrapper to specify unidirectional `FlowKey` creation -#[repr(transparent)] -#[derive(Debug)] -pub struct Uni(pub T); - -fn flow_key_from_packet(packet: &Packet) -> Option { - let ip = packet.headers().try_ip()?; - let src_ip = ip.src_addr(); - let dst_ip = ip.dst_addr(); - - let transport = packet.headers().try_transport()?; - let ip_proto_key = match transport { - Transport::Tcp(tcp) => IpProtoKey::Tcp(TcpProtoKey { - src_port: tcp.source(), - dst_port: tcp.destination(), - }), - Transport::Udp(udp) => IpProtoKey::Udp(UdpProtoKey { - src_port: udp.source(), - dst_port: udp.destination(), - }), - Transport::Icmp4(icmp) => IpProtoKey::Icmp(IcmpProtoKey::new_icmp_v4(packet, icmp)), - Transport::Icmp6(icmp) => IpProtoKey::Icmp(IcmpProtoKey::new_icmp_v6(packet, icmp)), - #[allow(unreachable_patterns)] - _ => return None, - }; +impl TryFrom<&Packet> for FlowKey { + type Error = FlowKeyError; + fn try_from(packet: &Packet) -> Result { + let ip = packet + .headers() + .try_ip() + .ok_or(FlowKeyError::NoFlowKeyData)?; + let src_ip = ip.src_addr(); + let dst_ip = ip.dst_addr(); - let src_vpcd = packet.meta().src_vpcd; - Some(FlowKey::new(src_vpcd, src_ip, dst_ip, ip_proto_key)) -} + let transport = packet + .headers() + .try_transport() + .ok_or(FlowKeyError::NoFlowKeyData)?; + let ip_proto_key = match transport { + Transport::Tcp(tcp) => IpProtoKey::Tcp(TcpProtoKey { + src_port: tcp.source(), + dst_port: tcp.destination(), + }), + Transport::Udp(udp) => IpProtoKey::Udp(UdpProtoKey { + src_port: udp.source(), + dst_port: udp.destination(), + }), + Transport::Icmp4(icmp) => IpProtoKey::Icmp(IcmpProtoKey::new_icmp_v4(packet, icmp)), + Transport::Icmp6(icmp) => IpProtoKey::Icmp(IcmpProtoKey::new_icmp_v6(packet, icmp)), + #[allow(unreachable_patterns)] + _ => return Err(FlowKeyError::NoFlowKeyData), + }; -impl TryFrom>> for FlowKey { - type Error = FlowKeyError; - fn try_from(packet: Uni<&Packet>) -> Result { - flow_key_from_packet(packet.0).ok_or(FlowKeyError::NoFlowKeyData) + let src_vpcd = packet.meta().src_vpcd; + Ok(FlowKey::new(src_vpcd, src_ip, dst_ip, ip_proto_key)) } } @@ -1078,7 +1075,7 @@ mod tests { .with_generator(FlowKeyAndPacket) .for_each(|(flow_key, packet)| match flow_key { Some(_) => { - let gen_flow_key = FlowKey::try_from(Uni(packet)).unwrap(); + let gen_flow_key = FlowKey::try_from(packet).unwrap(); assert_eq!( gen_flow_key, flow_key.unwrap(), @@ -1087,7 +1084,7 @@ mod tests { ); } None => { - assert!(FlowKey::try_from(Uni(packet)).is_err()); + assert!(FlowKey::try_from(packet).is_err()); } }); } diff --git a/net/src/packet/mod.rs b/net/src/packet/mod.rs index 7fc8edb07e..c696a887d5 100644 --- a/net/src/packet/mod.rs +++ b/net/src/packet/mod.rs @@ -23,7 +23,6 @@ use crate::FlowKey; use crate::buffer::{Headroom, PacketBufferMut, Prepend, Tailroom, TrimFromStart}; use crate::eth::Eth; use crate::eth::EthError; -use crate::flow_key::Uni; use crate::flows::{FlowInfo, FlowStatus}; use crate::headers::{ EmbeddedHeaders, Headers, Net, Transport, TryEmbeddedHeaders, TryEmbeddedHeadersMut, @@ -97,7 +96,7 @@ impl Packet { /// Update the flow key of this packet based on its current state pub fn update_flow_key(&mut self) { - self.meta.flow_key = FlowKey::try_from(Uni(&*self)).ok().map(Box::new); + self.meta.flow_key = FlowKey::try_from(&*self).ok().map(Box::new); } /// Get the length of the packet's payload