diff --git a/flow-entry/src/flow_table/nf_lookup.rs b/flow-entry/src/flow_table/nf_lookup.rs index a609e48da..d57827b9c 100644 --- a/flow-entry/src/flow_table/nf_lookup.rs +++ b/flow-entry/src/flow_table/nf_lookup.rs @@ -12,7 +12,6 @@ use pipeline::NetworkFunction; use crate::flow_table::FlowTable; use net::FlowKey; -use net::flow_key; use tracectl::trace_target; trace_target!("flow-lookup", LevelFilter::INFO, &["pipeline"]); @@ -40,7 +39,7 @@ impl NetworkFunction for FlowLookup { input.filter_map(move |mut packet| { let nfi = &self.name; if !packet.is_done() && packet.meta().is_overlay() && packet.meta().dst_vpcd.is_none() { - if let Ok(flow_key) = FlowKey::try_from(flow_key::Uni(&packet)) { + if let Ok(flow_key) = FlowKey::try_from(&packet) { if let Some(flow_info) = self.flow_table.lookup(&flow_key) { debug!("{nfi}: Tagging packet with flow info for flow key {flow_key}",); packet.meta_mut().flow_info = Some(flow_info); @@ -101,7 +100,7 @@ mod test { packet.meta_mut().set_overlay(true); // Insert matching flow entry - let flow_key = FlowKey::try_from(net::flow_key::Uni(&packet)).unwrap(); + let flow_key = FlowKey::try_from(&packet).unwrap(); let flow_info = FlowInfo::new(flow_key, Instant::now() + Duration::from_secs(10)); flow_table.insert(flow_info).unwrap(); @@ -130,7 +129,7 @@ mod test { input: Input, ) -> impl Iterator> + 'a { input.filter_map(move |packet| { - let flow_key = FlowKey::try_from(net::flow_key::Uni(&packet)).unwrap(); + let flow_key = FlowKey::try_from(&packet).unwrap(); let flow_info = FlowInfo::new(flow_key, Instant::now() + self.timeout); self.flow_table .insert(flow_info) @@ -193,8 +192,8 @@ mod test { packet_2.meta_mut().set_overlay(true); // build keys for the packets - let key_1 = FlowKey::try_from(net::flow_key::Uni(&packet_1)).unwrap(); - let key_2 = FlowKey::try_from(net::flow_key::Uni(&packet_2)).unwrap(); + let key_1 = FlowKey::try_from(&packet_1).unwrap(); + let key_2 = FlowKey::try_from(&packet_2).unwrap(); // create a pair of related flow entries; flow_2 will get a longer timeout let expires_at = tokio::time::Instant::now().into_std() + Duration::from_secs(2); @@ -251,8 +250,8 @@ mod test { let mut packet_2 = build_test_udp_ipv4_packet("192.168.1.1", "20.0.0.1", 500, 80); packet_1.meta_mut().set_overlay(true); packet_2.meta_mut().set_overlay(true); - let key_1 = FlowKey::try_from(net::flow_key::Uni(&packet_1)).unwrap(); - let key_2 = FlowKey::try_from(net::flow_key::Uni(&packet_2)).unwrap(); + let key_1 = FlowKey::try_from(&packet_1).unwrap(); + let key_2 = FlowKey::try_from(&packet_2).unwrap(); let input = vec![packet_1, packet_2]; let out: Vec<_> = pipeline.process(input.into_iter()).collect(); let packet_1 = &out[0]; diff --git a/flow-entry/src/flow_table/table.rs b/flow-entry/src/flow_table/table.rs index 4a0ed52c7..e3e191afc 100644 --- a/flow-entry/src/flow_table/table.rs +++ b/flow-entry/src/flow_table/table.rs @@ -423,7 +423,7 @@ mod tests { use net::tcp::TcpPort; use net::vxlan::Vni; - use net::{FlowKey, FlowKeyData, IpProtoKey, TcpProtoKey}; + use net::{FlowKey, IpProtoKey, TcpProtoKey}; #[concurrency_mode(std)] mod std_tests { @@ -439,7 +439,7 @@ mod tests { let five_seconds_from_now = now + five_seconds; let flow_table = FlowTable::default(); - let flow_key = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -447,7 +447,7 @@ mod tests { src_port: TcpPort::new_checked(1025).unwrap(), dst_port: TcpPort::new_checked(2048).unwrap(), }), - )); + ); let flow_info = FlowInfo::new(flow_key, five_seconds_from_now); @@ -469,7 +469,7 @@ mod tests { let one_second = Duration::from_secs(1); let flow_table = FlowTable::default(); - let flow_key = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(42).unwrap())), "10.0.0.1".parse::().unwrap(), "10.0.0.2".parse::().unwrap(), @@ -477,7 +477,7 @@ mod tests { src_port: TcpPort::new_checked(1234).unwrap(), dst_port: TcpPort::new_checked(5678).unwrap(), }), - )); + ); let flow_info = FlowInfo::new(flow_key, now + two_seconds); flow_table.insert(flow_info).unwrap(); @@ -501,7 +501,7 @@ mod tests { let second_expiry_time = now + Duration::from_secs(10); let flow_table = FlowTable::default(); - let flow_key = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -509,7 +509,7 @@ mod tests { src_port: TcpPort::new_checked(1025).unwrap(), dst_port: TcpPort::new_checked(2048).unwrap(), }), - )); + ); // Insert first entry. let first_arc = Arc::new(FlowInfo::new(flow_key, first_expiry_time)); @@ -586,7 +586,7 @@ mod tests { let mut flow_keys = vec![]; for src_port in 1..=NUM_FLOWS { - let flow_key = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -594,7 +594,7 @@ mod tests { src_port: TcpPort::new_checked(src_port).unwrap(), dst_port: TcpPort::new_checked(2048).unwrap(), }), - )); + ); let flow_info = FlowInfo::new(flow_key, deadline); flow_table.insert(flow_info).unwrap(); flow_keys.push(flow_key); @@ -627,7 +627,7 @@ mod tests { let now = Instant::now(); let deadline = now + Duration::from_secs(2); - let flow_key = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -635,7 +635,7 @@ mod tests { src_port: TcpPort::new_checked(1).unwrap(), dst_port: TcpPort::new_checked(2048).unwrap(), }), - )); + ); let flow_info = FlowInfo::new(flow_key, deadline); flow_table.insert(flow_info).unwrap(); @@ -663,19 +663,19 @@ mod tests { for i in 1u16..=2 { let src_port = TcpPort::new_checked(1000 + i).unwrap(); let dst_port = TcpPort::new_checked(80).unwrap(); - let flow_key = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key = FlowKey::new( Some(src_vpcd), src_ip, dst_ip, IpProtoKey::Tcp(TcpProtoKey { src_port, dst_port }), - )); + ); flow_table .insert(FlowInfo::new(flow_key, far_future)) .expect("insert under capacity should succeed"); } // One more insert must fail with CapacityExceeded. - let overflow_key = FlowKey::Unidirectional(FlowKeyData::new( + let overflow_key = FlowKey::new( Some(src_vpcd), src_ip, dst_ip, @@ -683,7 +683,7 @@ mod tests { src_port: TcpPort::new_checked(9999).unwrap(), dst_port: TcpPort::new_checked(80).unwrap(), }), - )); + ); assert!(matches!( flow_table.insert(FlowInfo::new(overflow_key, far_future)), Err(FlowTableError::CapacityExceeded) @@ -709,7 +709,7 @@ mod tests { let two_seconds = Duration::from_secs(2); let flow_keys: Vec<_> = (0u16..2u16) .map(|i| { - FlowKey::Unidirectional(FlowKeyData::new( + FlowKey::new( Some(VpcDiscriminant::VNI( Vni::new_checked(u32::from(i) + 1).unwrap(), )), @@ -719,7 +719,7 @@ mod tests { src_port: TcpPort::new_checked(1000 + i).unwrap(), dst_port: TcpPort::new_checked(2000 + i).unwrap(), }), - )) + ) }) .collect(); @@ -825,7 +825,7 @@ mod tests { let flow_table = Arc::new(FlowTable::default()); let five_seconds_from_now = Instant::now() + Duration::from_secs(5); - let flow_key1 = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key1 = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -833,9 +833,9 @@ mod tests { src_port: TcpPort::new_checked(1025).unwrap(), dst_port: TcpPort::new_checked(2048).unwrap(), }), - )); + ); - let flow_key2 = FlowKey::Unidirectional(FlowKeyData::new( + let flow_key2 = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(10).unwrap())), "10.2.3.4".parse::().unwrap(), "40.5.6.7".parse::().unwrap(), @@ -843,7 +843,7 @@ mod tests { src_port: TcpPort::new_checked(1025).unwrap(), dst_port: TcpPort::new_checked(2048).unwrap(), }), - )); + ); let flow_table_clone1 = flow_table.clone(); let flow_table_clone2 = flow_table.clone(); diff --git a/flow-filter/src/lib.rs b/flow-filter/src/lib.rs index cb8aebb99..acfc52905 100644 --- a/flow-filter/src/lib.rs +++ b/flow-filter/src/lib.rs @@ -363,7 +363,7 @@ impl FlowFilter { return; } - let Ok(flow_key) = FlowKey::try_from(net::flow_key::Uni(&*packet)) else { + let Ok(flow_key) = FlowKey::try_from(&*packet) else { return; }; diff --git a/flow-filter/src/tests.rs b/flow-filter/src/tests.rs index 228bed35a..50e38fca7 100644 --- a/flow-filter/src/tests.rs +++ b/flow-filter/src/tests.rs @@ -12,7 +12,6 @@ use config::external::overlay::vpcpeering::{VpcExpose, VpcManifest, VpcPeering, use lpm::prefix::{L4Protocol, PortRange, Prefix, PrefixWithOptionalPorts}; use net::FlowKey; use net::buffer::{PacketBufferMut, TestBuffer}; -use net::flow_key::Uni; use net::flows::{FlowInfo, FlowStatus}; use net::headers::{Net, TryHeadersMut, TryIpMut}; use net::ip::NextHeader; @@ -191,7 +190,7 @@ fn fake_flow_session( set_port_fw_state: bool, ) { // build flow key - let flow_key = FlowKey::try_from(Uni(&*packet)).unwrap(); + let flow_key = FlowKey::try_from(&*packet).unwrap(); // Create flow_info with dst_vpcd and NAT info and attach it to the packet let flow_info = FlowInfo::new(flow_key, Instant::now() + Duration::from_secs(60)); diff --git a/nat/src/icmp_handler/icmp_error_msg.rs b/nat/src/icmp_handler/icmp_error_msg.rs index 344fb32fd..a0b6e7a39 100644 --- a/nat/src/icmp_handler/icmp_error_msg.rs +++ b/nat/src/icmp_handler/icmp_error_msg.rs @@ -15,7 +15,7 @@ use net::headers::{ use net::icmp_any::TruncatedIcmpAny; use net::icmp_any::{IcmpAnyChecksumErrorPlaceholder, IcmpAnyChecksumPayload}; use net::ipv4::Ipv4; -use net::packet::Packet; +use net::packet::{DoneReason, Packet}; use std::net::IpAddr; use std::num::NonZero; @@ -41,6 +41,25 @@ pub enum IcmpErrorMsgError { NoTranslationPossible, } +impl From<&IcmpErrorMsgError> for DoneReason { + fn from(error: &IcmpErrorMsgError) -> Self { + match error { + IcmpErrorMsgError::NoIpHeader => DoneReason::NotIp, + IcmpErrorMsgError::InvalidPort(_) => DoneReason::Malformed, + IcmpErrorMsgError::NotUnicast(_) => DoneReason::NatFailure, + IcmpErrorMsgError::InvalidIpVersion | IcmpErrorMsgError::NoTranslationPossible => { + DoneReason::InternalFailure + } + IcmpErrorMsgError::BadChecksumIcmp(_) | IcmpErrorMsgError::BadChecksumInnerIpv4(_) => { + DoneReason::InvalidChecksum + } + IcmpErrorMsgError::NoEmbeddedHeaders | IcmpErrorMsgError::NoInnerIpHeader => { + DoneReason::Filtered + } + } + } +} + // # Return // // * `Ok(())` if checksums are valid and we can translate the inner packet diff --git a/nat/src/portfw/flow_state.rs b/nat/src/portfw/flow_state.rs index 705c953ca..0b5c8c567 100644 --- a/nat/src/portfw/flow_state.rs +++ b/nat/src/portfw/flow_state.rs @@ -6,7 +6,6 @@ #![allow(clippy::single_match_else)] use net::buffer::PacketBufferMut; -use net::flow_key::Uni; use net::flows::{ExtractRef, FlowStatus}; use net::ip::UnicastIpAddr; use net::packet::{Packet, VpcDiscriminant}; @@ -107,7 +106,7 @@ pub(crate) fn build_portfw_flow_keys( dst_vpcd: VpcDiscriminant, // destination VPC to forward to ) -> Result<(FlowKey, FlowKey), ()> { // Extract flow key for the current packet - let current_flow_key = FlowKey::try_from(Uni(&*packet)).map_err(|_| ())?; + let current_flow_key = FlowKey::try_from(&*packet).map_err(|_| ())?; // Retrieve initial flow key for the current packet (before any other NAT translation); if // we don't have the information, we didn't populate it because we don't need it and fall @@ -120,14 +119,12 @@ pub(crate) fn build_portfw_flow_keys( .unwrap_or(current_flow_key); // Build the key for the reverse path - let proto = current_flow_key.data().proto(); - let src_port = current_flow_key.data().src_port().ok_or(())?; + let proto = current_flow_key.proto(); + let src_port = current_flow_key.src_port().ok_or(())?; let mut key_forward_dnated = current_flow_key; - key_forward_dnated.data_mut().set_dst_ip(new_dst_ip.inner()); - key_forward_dnated - .data_mut() - .set_ip_proto_key(IpProtoKey::from((proto, src_port, new_dst_port))); + key_forward_dnated.set_dst_ip(new_dst_ip.inner()); + key_forward_dnated.set_ip_proto_key(IpProtoKey::from((proto, src_port, new_dst_port))); let key_reverse = key_forward_dnated.reverse(Some(dst_vpcd)); Ok((initial_flow_key, key_reverse)) diff --git a/nat/src/stateful/allocation.rs b/nat/src/stateful/allocation.rs index 78e605271..c5426ff7a 100644 --- a/nat/src/stateful/allocation.rs +++ b/nat/src/stateful/allocation.rs @@ -5,6 +5,7 @@ use crate::port::NatPortError; use net::ip::NextHeader; +use net::packet::DoneReason; use std::fmt::{Debug, Display}; use std::time::Duration; @@ -34,6 +35,23 @@ pub enum AllocatorError { Denied, } +impl From<&AllocatorError> for DoneReason { + fn from(error: &AllocatorError) -> Self { + match error { + AllocatorError::UnsupportedProtocol(_) => DoneReason::NatUnsupportedProto, + AllocatorError::NoFreeIp + | AllocatorError::NoPortBlock + | AllocatorError::NoFreePort(_) => DoneReason::NatOutOfResources, + AllocatorError::PortAllocationFailed(_) + | AllocatorError::PortReservationFailed(_) + | AllocatorError::MissingDiscriminant + | AllocatorError::UnsupportedDiscriminant => DoneReason::NatFailure, + AllocatorError::InternalIssue(_) => DoneReason::InternalFailure, + AllocatorError::Denied => DoneReason::Filtered, + } + } +} + /// `AllocationResult` is a struct to represent the result of an allocation. /// It contains the allocated IP address and port to masquerade a packet, /// and the time for the allocation (flow timeout). diff --git a/nat/src/stateful/flows.rs b/nat/src/stateful/flows.rs index a102ddd8d..63332634c 100644 --- a/nat/src/stateful/flows.rs +++ b/nat/src/stateful/flows.rs @@ -61,9 +61,9 @@ fn re_reserve_ip_and_port( port: NatPort, ) -> Result<(), ()> { let flow_key = flow_info.flowkey(); - let proto = flow_key.data().proto(); + let proto = flow_key.proto(); let dst_vpcd = flow_info.get_dst_vpcd().unwrap_or_else(|| unreachable!()); - let src_ip = *flow_key.data().src_ip(); + let src_ip = *flow_key.src_ip(); let port_u16 = port.as_u16(); debug!("Attempting to reserve {ip} {port_u16} {proto}..."); @@ -104,7 +104,7 @@ pub(crate) fn check_masquerading_flow( return; }; let dst_vpcd = flow_info.get_dst_vpcd().unwrap_or_else(|| unreachable!()); - let src_vpcd = flow_key.data().src_vpcd().unwrap_or_else(|| unreachable!()); + let src_vpcd = flow_key.src_vpcd().unwrap_or_else(|| unreachable!()); debug!("Checking flow {}", flow_info.logfmt()); diff --git a/nat/src/stateful/nf.rs b/nat/src/stateful/nf.rs index ba5e759de..0f91082d0 100644 --- a/nat/src/stateful/nf.rs +++ b/nat/src/stateful/nf.rs @@ -16,7 +16,7 @@ use crate::stateful::state::MasqueradeState; use concurrency::sync::{Arc, Weak}; use flow_entry::flow_table::table::{FlowTable, FlowTableError}; use net::buffer::PacketBufferMut; -use net::flow_key::{IcmpProtoKey, Uni}; +use net::flow_key::IcmpProtoKey; use net::flows::{ExtractRef, FlowInfo}; use net::headers::{TryIp, TryTcp}; use net::ip::UnicastIpAddr; @@ -198,7 +198,7 @@ impl StatefulNat { dst_ip: IpAddr, proto_key_info: IpProtoKey, ) -> Option<(NatTranslate, Duration)> { - let flow_key = FlowKey::uni(src_vpcd, src_ip, dst_ip, proto_key_info); + let flow_key = FlowKey::new(src_vpcd, src_ip, dst_ip, proto_key_info); let flow_info = self.flow_table.lookup(&flow_key)?; let value = flow_info.locked.read(); let state = value.nat_state.as_ref()?.extract_ref::()?; @@ -219,8 +219,8 @@ impl StatefulNat { } fn get_reverse_mapping(flow_key: &FlowKey) -> Result<(IpAddr, NatPort), StatefulNatError> { - let src_ip = *flow_key.data().src_ip(); - let src_port = match flow_key.data().proto_key_info() { + let src_ip = *flow_key.src_ip(); + let src_port = match flow_key.proto_key_info() { IpProtoKey::Tcp(tcp) => tcp.src_port.into(), IpProtoKey::Udp(udp) => udp.src_port.into(), IpProtoKey::Icmp(icmp) => NatPort::Identifier(Self::get_icmp_query_id(icmp)?), @@ -313,12 +313,12 @@ impl StatefulNat { // So we want: // - tuple r.init = (src: f.nated.dst, dst: f.nated.src) // - mapping r.nated = (src: f.init.dst, dst: f.init.src) - let reverse_src_addr = *flow_key.data().dst_ip(); + let reverse_src_addr = *flow_key.dst_ip(); let reverse_dst_addr = alloc.allocation.ip(); let dst_port = alloc.allocation.port(); // Reverse the forward protocol key and adjust ports to use the allocated values. - let mut reverse_proto_key = flow_key.data().proto_key_info().reverse(); + let mut reverse_proto_key = flow_key.proto_key_info().reverse(); match reverse_proto_key { IpProtoKey::Tcp(_) | IpProtoKey::Udp(_) => { reverse_proto_key @@ -339,7 +339,7 @@ impl StatefulNat { } } - Ok(FlowKey::uni( + Ok(FlowKey::new( Some(dst_vpc_id), reverse_src_addr, reverse_dst_addr, @@ -375,7 +375,7 @@ impl StatefulNat { // Extract flow key for the current packet let current_flow_key = - FlowKey::try_from(Uni(&*packet)).map_err(|_| StatefulNatError::FlowKeyError)?; + FlowKey::try_from(&*packet).map_err(|_| StatefulNatError::FlowKeyError)?; // Retrieve initial flow key for the current packet (before any other NAT translation); if // we don't have the information, we didn't populate it because we don't need it and fall @@ -388,9 +388,9 @@ impl StatefulNat { .unwrap_or(current_flow_key); // Create a new session and translate the address - let src_ip = *initial_flow_key.data().src_ip(); + let src_ip = *initial_flow_key.src_ip(); let alloc = allocator - .allocate(dst_vpcd, src_ip, initial_flow_key.data().proto()) + .allocate(dst_vpcd, src_ip, initial_flow_key.proto()) .map_err(StatefulNatError::AllocationFailure)?; // Forbid addresses we won't know how to translate. This is a work around of a larger change @@ -484,7 +484,7 @@ impl StatefulNat { // TODO: Check whether the packet is fragmented if let Err(error) = self.masquerade_packet(packet) { - packet.done(translate_error(&error)); + packet.done((&error).into()); debug!("Did not masquerade packet: {error}"); } else { packet.meta_mut().set_checksum_refresh(true); @@ -492,39 +492,22 @@ impl StatefulNat { } } -fn translate_error(error: &StatefulNatError) -> DoneReason { - match error { - StatefulNatError::BadTransportHeader - | StatefulNatError::AllocationFailure(AllocatorError::UnsupportedProtocol(_)) => { - DoneReason::NatUnsupportedProto - } - - StatefulNatError::FlowKeyError | StatefulNatError::InvalidPort(_) => DoneReason::Malformed, - - StatefulNatError::AllocationFailure( - AllocatorError::NoFreeIp | AllocatorError::NoPortBlock | AllocatorError::NoFreePort(_), - ) => DoneReason::NatOutOfResources, - - StatefulNatError::CapacityExceeded => DoneReason::FlowCapacityExceeded, - StatefulNatError::NoAllocator - | StatefulNatError::UnexpectedKeyVariant - | StatefulNatError::IcmpUnsupportedCategory - | StatefulNatError::IcmpError - | StatefulNatError::AllocationFailure( - AllocatorError::PortAllocationFailed(_) - | AllocatorError::PortReservationFailed(_) - | AllocatorError::MissingDiscriminant - | AllocatorError::UnsupportedDiscriminant, - ) - | StatefulNatError::NatError(_) => DoneReason::NatFailure, - - StatefulNatError::AllocationFailure(AllocatorError::InternalIssue(_)) => { - DoneReason::InternalFailure +impl From<&StatefulNatError> for DoneReason { + fn from(error: &StatefulNatError) -> Self { + match error { + StatefulNatError::BadTransportHeader => DoneReason::NatUnsupportedProto, + StatefulNatError::FlowKeyError | StatefulNatError::InvalidPort(_) => { + DoneReason::Malformed + } + StatefulNatError::CapacityExceeded => DoneReason::FlowCapacityExceeded, + StatefulNatError::NoAllocator + | StatefulNatError::UnexpectedKeyVariant + | StatefulNatError::IcmpUnsupportedCategory + | StatefulNatError::IcmpError + | StatefulNatError::NatError(_) => DoneReason::NatFailure, + StatefulNatError::Bug(_) | StatefulNatError::IntendedDrop(_) => DoneReason::Filtered, + StatefulNatError::AllocationFailure(inner) => inner.into(), } - - StatefulNatError::AllocationFailure(AllocatorError::Denied) - | StatefulNatError::Bug(_) - | StatefulNatError::IntendedDrop(_) => DoneReason::Filtered, } } diff --git a/nat/src/stateful/test.rs b/nat/src/stateful/test.rs index e509e580d..759c909f3 100644 --- a/nat/src/stateful/test.rs +++ b/nat/src/stateful/test.rs @@ -18,7 +18,6 @@ use flow_entry::flow_table::{FlowLookup, FlowTable}; use flow_filter::{FlowFilter, FlowFilterTable, FlowFilterTableWriter}; use net::buffer::{PacketBufferMut, TestBuffer}; use net::eth::mac::Mac; -use net::flow_key::Uni; use net::flows::FlowStatus; use net::flows::flow_info_item::ExtractRef; use net::headers::TryTcpMut; @@ -366,7 +365,7 @@ fn check_packet( } fn flow_lookup(flow_table: &FlowTable, packet: &mut Packet) { - let flow_key = FlowKey::try_from(Uni(&*packet)).unwrap(); + let flow_key = FlowKey::try_from(&*packet).unwrap(); if let Some(flow_info) = flow_table.lookup(&flow_key) { packet.meta_mut().flow_info = Some(flow_info); } diff --git a/nat/src/stateless/nf.rs b/nat/src/stateless/nf.rs index a4dac63d3..80c7bdc44 100644 --- a/nat/src/stateless/nf.rs +++ b/nat/src/stateless/nf.rs @@ -322,7 +322,7 @@ impl StatelessNat { match self.translate(nat_tables, packet, src_vni, dst_vni) { Err(error) => { debug!("{nfi}: Translation failed: {error}"); - packet.done(translate_error(&error)); + packet.done((&error).into()); } Ok(modified) => { if modified { @@ -336,30 +336,16 @@ impl StatelessNat { } } -fn translate_error(error: &StatelessNatError) -> DoneReason { - match error { - StatelessNatError::NoIpHeader - | StatelessNatError::IcmpErrorMsg(IcmpErrorMsgError::NoIpHeader) => DoneReason::NotIp, - - StatelessNatError::MissingTable(_) => DoneReason::Unroutable, - - StatelessNatError::IcmpErrorMsg(IcmpErrorMsgError::InvalidPort(_)) => DoneReason::Malformed, - - StatelessNatError::IcmpErrorMsg(IcmpErrorMsgError::NotUnicast(_)) => DoneReason::NatFailure, - - StatelessNatError::FailedToSetDestIp(_) - | StatelessNatError::FailedToSetSourceIp(_) - | StatelessNatError::IcmpErrorMsg( - IcmpErrorMsgError::InvalidIpVersion | IcmpErrorMsgError::NoTranslationPossible, - ) => DoneReason::InternalFailure, - - StatelessNatError::IcmpErrorMsg( - IcmpErrorMsgError::BadChecksumIcmp(_) | IcmpErrorMsgError::BadChecksumInnerIpv4(_), - ) => DoneReason::InvalidChecksum, - - StatelessNatError::IcmpErrorMsg( - IcmpErrorMsgError::NoEmbeddedHeaders | IcmpErrorMsgError::NoInnerIpHeader, - ) => DoneReason::Filtered, +impl From<&StatelessNatError> for DoneReason { + fn from(error: &StatelessNatError) -> Self { + match error { + StatelessNatError::NoIpHeader => DoneReason::NotIp, + StatelessNatError::MissingTable(_) => DoneReason::Unroutable, + StatelessNatError::FailedToSetSourceIp(_) | StatelessNatError::FailedToSetDestIp(_) => { + DoneReason::InternalFailure + } + StatelessNatError::IcmpErrorMsg(inner) => inner.into(), + } } } diff --git a/net/src/flows/display.rs b/net/src/flows/display.rs index 63b9695fc..4617cbef1 100644 --- a/net/src/flows/display.rs +++ b/net/src/flows/display.rs @@ -4,13 +4,13 @@ //! Flow keys use super::flow_info::{FlowInfo, FlowInfoLocked}; -use super::flow_key::{FlowKey, FlowKeyData}; +use super::flow_key::FlowKey; use concurrency::sync::Weak; use std::fmt::Display; use std::time::Instant; -impl Display for FlowKeyData { +impl Display for FlowKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if let Some(vpcd) = self.src_vpcd() { write!(f, "from {vpcd},")?; @@ -31,14 +31,6 @@ impl Display for FlowKeyData { } } -impl Display for FlowKey { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - FlowKey::Unidirectional(data) => write!(f, "{data}"), - } - } -} - impl Display for FlowInfoLocked { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if let Some(data) = &self.dst_vpcd { diff --git a/net/src/flows/flow_key.rs b/net/src/flows/flow_key.rs index 6b3ad16b7..38dd6de57 100644 --- a/net/src/flows/flow_key.rs +++ b/net/src/flows/flow_key.rs @@ -454,26 +454,30 @@ impl HashDst for IpProtoKey { } #[derive(Debug, Clone, Copy, PartialEq, Eq, Ord, PartialOrd)] -pub struct FlowKeyData { +pub struct FlowKey { src_vpcd: Option, src_ip: IpAddr, dst_ip: IpAddr, proto_key_info: IpProtoKey, } -impl FlowKeyData { +impl FlowKey { + /// Create a flow key. + /// + /// Flow keys are unidirectional: packets with src -> dst will match, but + /// dst -> src will not. #[must_use] pub fn new( src_vpcd: Option, src_ip: IpAddr, dst_ip: IpAddr, - ip_proto_key: IpProtoKey, + proto_key_info: IpProtoKey, ) -> Self { Self { src_vpcd, src_ip, dst_ip, - proto_key_info: ip_proto_key, + proto_key_info, } } @@ -571,7 +575,7 @@ impl FlowKeyData { } } -impl Hash for FlowKeyData { +impl Hash for FlowKey { fn hash(&self, state: &mut H) { self.src_vpcd.hash(state); self.src_ip.hash(state); @@ -581,106 +585,37 @@ impl Hash for FlowKeyData { } } -#[derive(Debug, Clone, Copy, Eq, PartialOrd, Ord)] -pub enum FlowKey { - Unidirectional(FlowKeyData), -} - -impl FlowKey { - #[must_use] - pub fn data(&self) -> &FlowKeyData { - match self { - FlowKey::Unidirectional(data) => data, - } - } - #[must_use] - pub fn data_mut(&mut self) -> &mut FlowKeyData { - match self { - FlowKey::Unidirectional(data) => data, - } - } - - /// Create a unidirectional flow key - /// - /// packets with src -> dst will match, but dst -> src will not - #[must_use] - pub fn uni( - src_vpcd: Option, - src_ip: IpAddr, - dst_ip: IpAddr, - proto_key_info: IpProtoKey, - ) -> FlowKey { - FlowKey::Unidirectional(FlowKeyData::new(src_vpcd, src_ip, dst_ip, proto_key_info)) - } - - // Creates the flow key with src and dst swapped - #[must_use] - pub fn reverse(&self, src_vpcd: Option) -> FlowKey { - match self { - FlowKey::Unidirectional(data) => FlowKey::Unidirectional(data.reverse(src_vpcd)), - } - } -} - -// The FlowKey Eq is symmetric, src == src or src == dst -impl PartialEq for FlowKey { - fn eq(&self, other: &Self) -> bool { - match (self, other) { - (FlowKey::Unidirectional(a), FlowKey::Unidirectional(b)) => a == b, - } - } -} - -impl Hash for FlowKey { - fn hash(&self, state: &mut H) { - match self { - FlowKey::Unidirectional(a) => a.hash(state), - } - } -} - -/// Wrapper to specify unidirectional `FlowKey` creation -#[repr(transparent)] -#[derive(Debug)] -pub struct Uni(pub T); - -fn flow_key_data_from_packet(packet: &Packet) -> Option { - let ip = packet.headers().try_ip()?; - let src_ip = ip.src_addr(); - let dst_ip = ip.dst_addr(); - - let transport = packet.headers().try_transport()?; - let ip_proto_key = match transport { - Transport::Tcp(tcp) => IpProtoKey::Tcp(TcpProtoKey { - src_port: tcp.source(), - dst_port: tcp.destination(), - }), - Transport::Udp(udp) => IpProtoKey::Udp(UdpProtoKey { - src_port: udp.source(), - dst_port: udp.destination(), - }), - Transport::Icmp4(icmp) => IpProtoKey::Icmp(IcmpProtoKey::new_icmp_v4(packet, icmp)), - Transport::Icmp6(icmp) => IpProtoKey::Icmp(IcmpProtoKey::new_icmp_v6(packet, icmp)), - #[allow(unreachable_patterns)] - _ => return None, - }; - - let src_vpcd = packet.meta().src_vpcd; - Some(FlowKeyData::new(src_vpcd, src_ip, dst_ip, ip_proto_key)) -} - -impl TryFrom>> for FlowKey { +impl TryFrom<&Packet> for FlowKey { type Error = FlowKeyError; - fn try_from(packet: Uni<&Packet>) -> Result { - let packet = packet.0; - let FlowKeyData { - src_vpcd, - src_ip, - dst_ip, - proto_key_info, - } = flow_key_data_from_packet(packet).ok_or(FlowKeyError::NoFlowKeyData)?; + fn try_from(packet: &Packet) -> Result { + let ip = packet + .headers() + .try_ip() + .ok_or(FlowKeyError::NoFlowKeyData)?; + let src_ip = ip.src_addr(); + let dst_ip = ip.dst_addr(); - Ok(FlowKey::uni(src_vpcd, src_ip, dst_ip, proto_key_info)) + let transport = packet + .headers() + .try_transport() + .ok_or(FlowKeyError::NoFlowKeyData)?; + let ip_proto_key = match transport { + Transport::Tcp(tcp) => IpProtoKey::Tcp(TcpProtoKey { + src_port: tcp.source(), + dst_port: tcp.destination(), + }), + Transport::Udp(udp) => IpProtoKey::Udp(UdpProtoKey { + src_port: udp.source(), + dst_port: udp.destination(), + }), + Transport::Icmp4(icmp) => IpProtoKey::Icmp(IcmpProtoKey::new_icmp_v4(packet, icmp)), + Transport::Icmp6(icmp) => IpProtoKey::Icmp(IcmpProtoKey::new_icmp_v6(packet, icmp)), + #[allow(unreachable_patterns)] + _ => return Err(FlowKeyError::NoFlowKeyData), + }; + + let src_vpcd = packet.meta().src_vpcd; + Ok(FlowKey::new(src_vpcd, src_ip, dst_ip, ip_proto_key)) } } @@ -739,14 +674,14 @@ pub fn flowkey_embedded_in_icmp_error( return Err(FlowKeyError::EmbeddedMissingIcmpId); } }; - Ok(FlowKey::uni(None, src_ip, dst_ip, proto_key)) + Ok(FlowKey::new(None, src_ip, dst_ip, proto_key)) } #[cfg(any(test, feature = "bolero"))] mod contract { use super::{ - EmbeddedPacketData, FlowKey, FlowKeyData, IcmpProtoKey, InnerIcmpProtoKey, InnerIpProtoKey, - IpProtoKey, TcpProtoKey, UdpProtoKey, + EmbeddedPacketData, FlowKey, IcmpProtoKey, InnerIcmpProtoKey, InnerIpProtoKey, IpProtoKey, + TcpProtoKey, UdpProtoKey, }; use crate::ip::UnicastIpAddr; use crate::ipv4::addr::UnicastIpv4Addr; @@ -855,7 +790,7 @@ mod contract { } } - impl TypeGenerator for FlowKeyData { + impl TypeGenerator for FlowKey { fn generate(driver: &mut D) -> Option { let src_vpcd = driver.produce(); let v6 = driver.produce::()?; @@ -872,19 +807,7 @@ mod contract { ) }; let proto_key_info = super::IpProtoKey::generate(driver)?; - Some(FlowKeyData { - src_vpcd, - src_ip, - dst_ip, - proto_key_info, - }) - } - } - - impl TypeGenerator for FlowKey { - fn generate(driver: &mut D) -> Option { - let data = FlowKeyData::generate(driver)?; - Some(FlowKey::Unidirectional(data)) + Some(FlowKey::new(src_vpcd, src_ip, dst_ip, proto_key_info)) } } } @@ -905,7 +828,7 @@ mod tests { #[test] fn test_flow_key_reverse() { - let flow_key = FlowKey::uni( + let flow_key = FlowKey::new( Some(VpcDiscriminant::VNI(Vni::new_checked(1).unwrap())), "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -916,17 +839,17 @@ mod tests { ); let reverse_flow_key = flow_key.reverse(None); - assert_eq!(flow_key.data().src_ip, reverse_flow_key.data().dst_ip); - assert_eq!(flow_key.data().dst_ip, reverse_flow_key.data().src_ip); + assert_eq!(flow_key.src_ip, reverse_flow_key.dst_ip); + assert_eq!(flow_key.dst_ip, reverse_flow_key.src_ip); assert_eq!( - flow_key.data().proto_key_info, - reverse_flow_key.data().proto_key_info.reverse() + flow_key.proto_key_info, + reverse_flow_key.proto_key_info.reverse() ); } #[test] fn test_flow_key_uni_hash() { - let flow_key = FlowKey::uni( + let flow_key = FlowKey::new( None, "1.2.3.4".parse::().unwrap(), "4.5.6.7".parse::().unwrap(), @@ -956,12 +879,11 @@ mod tests { /// This function panics if the packet has a different transport protocol than the flow key. /// It also panics if the packet IP address family does not match the flow key. fn set_packet_fields(packet: &mut Packet, flow_key: &FlowKey) { - let flow_key_data = flow_key.data(); packet - .set_ip_source(flow_key_data.src_ip.try_into().unwrap()) + .set_ip_source(flow_key.src_ip.try_into().unwrap()) .unwrap(); - packet.set_ip_destination(flow_key_data.dst_ip).unwrap(); - match flow_key_data.proto_key_info { + packet.set_ip_destination(flow_key.dst_ip).unwrap(); + match flow_key.proto_key_info { IpProtoKey::Tcp(tcp) => { packet.set_tcp_source_port(tcp.src_port).unwrap(); packet.set_tcp_destination_port(tcp.dst_port).unwrap(); @@ -1138,7 +1060,7 @@ mod tests { }; if let Some(proto) = proto { let (flow_key, mut packet) = - (FlowKey::uni(src_vpcd, src_ip, dst_ip, proto), packet); + (FlowKey::new(src_vpcd, src_ip, dst_ip, proto), packet); set_packet_fields(&mut packet, &flow_key); Some((Some(flow_key), packet)) } else { @@ -1152,8 +1074,8 @@ mod tests { bolero::check!() .with_generator(FlowKeyAndPacket) .for_each(|(flow_key, packet)| match flow_key { - Some(FlowKey::Unidirectional(_)) => { - let gen_flow_key = FlowKey::try_from(Uni(packet)).unwrap(); + Some(_) => { + let gen_flow_key = FlowKey::try_from(packet).unwrap(); assert_eq!( gen_flow_key, flow_key.unwrap(), @@ -1162,7 +1084,7 @@ mod tests { ); } None => { - assert!(FlowKey::try_from(Uni(packet)).is_err()); + assert!(FlowKey::try_from(packet).is_err()); } }); } diff --git a/net/src/lib.rs b/net/src/lib.rs index ffb50a43d..6f9e3f938 100644 --- a/net/src/lib.rs +++ b/net/src/lib.rs @@ -43,6 +43,4 @@ pub mod vxlan; // re-export #[cfg(unix)] -pub use flows::flow_key::{ - self, FlowKey, FlowKeyData, IcmpProtoKey, IpProtoKey, TcpProtoKey, UdpProtoKey, -}; +pub use flows::flow_key::{self, FlowKey, IcmpProtoKey, IpProtoKey, TcpProtoKey, UdpProtoKey}; diff --git a/net/src/packet/mod.rs b/net/src/packet/mod.rs index 7fc8edb07..c696a887d 100644 --- a/net/src/packet/mod.rs +++ b/net/src/packet/mod.rs @@ -23,7 +23,6 @@ use crate::FlowKey; use crate::buffer::{Headroom, PacketBufferMut, Prepend, Tailroom, TrimFromStart}; use crate::eth::Eth; use crate::eth::EthError; -use crate::flow_key::Uni; use crate::flows::{FlowInfo, FlowStatus}; use crate::headers::{ EmbeddedHeaders, Headers, Net, Transport, TryEmbeddedHeaders, TryEmbeddedHeadersMut, @@ -97,7 +96,7 @@ impl Packet { /// Update the flow key of this packet based on its current state pub fn update_flow_key(&mut self) { - self.meta.flow_key = FlowKey::try_from(Uni(&*self)).ok().map(Box::new); + self.meta.flow_key = FlowKey::try_from(&*self).ok().map(Box::new); } /// Get the length of the packet's payload