diff --git a/Cargo.lock b/Cargo.lock index d672317338..2c73a87f90 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1008,17 +1008,6 @@ dependencies = [ "hybrid-array", ] -[[package]] -name = "ctrlc" -version = "3.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0b1fab2ae45819af2d0731d60f2afe17227ebb1a1538a236da84c93e9a60162" -dependencies = [ - "dispatch2", - "nix 0.31.3", - "windows-sys 0.61.2", -] - [[package]] name = "darling" version = "0.20.11" @@ -1116,13 +1105,13 @@ dependencies = [ "arrayvec", "axum", "axum-server", - "ctrlc", "dataplane-args", "dataplane-concurrency", "dataplane-config", "dataplane-flow-entry", "dataplane-flow-filter", "dataplane-id", + "dataplane-lifecycle", "dataplane-mgmt", "dataplane-nat", "dataplane-net", @@ -1138,7 +1127,6 @@ dependencies = [ "linkme", "metrics", "metrics-exporter-prometheus", - "mio", "n-vm", "netdev", "nix 0.31.3", @@ -1154,6 +1142,18 @@ dependencies = [ "tracing-subscriber", ] +[[package]] +name = "dataplane-acl" +version = "0.21.0" +dependencies = [ + "arrayvec", + "dataplane-concurrency", + "dataplane-lookup", + "dataplane-match-action", + "dataplane-net", + "thiserror", +] + [[package]] name = "dataplane-args" version = "0.21.0" @@ -1255,8 +1255,10 @@ version = "0.21.0" dependencies = [ "bolero", "dataplane-concurrency", + "dataplane-dpdk", "dataplane-dpdk-sys", "dataplane-dpdk-sysroot-helper", + "dataplane-dpdk-test-macros", "dataplane-errno", "dataplane-id", "dataplane-net", @@ -1279,6 +1281,16 @@ dependencies = [ name = "dataplane-dpdk-sysroot-helper" version = "0.21.0" +[[package]] +name = "dataplane-dpdk-test-macros" +version = "0.21.0" +dependencies = [ + "proc-macro-crate 3.5.0", + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "dataplane-errno" version = "0.21.0" @@ -1286,6 +1298,10 @@ dependencies = [ "thiserror", ] +[[package]] +name = "dataplane-fixed-size" +version = "0.21.0" + [[package]] name = "dataplane-flow-entry" version = "0.21.0" @@ -1447,6 +1463,20 @@ dependencies = [ "thiserror", ] +[[package]] +name = "dataplane-lifecycle" +version = "0.21.0" +dependencies = [ + "dataplane-concurrency", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "dataplane-lookup" +version = "0.21.0" + [[package]] name = "dataplane-lpm" version = "0.21.0" @@ -1462,6 +1492,26 @@ dependencies = [ "tracing", ] +[[package]] +name = "dataplane-match-action" +version = "0.21.0" +dependencies = [ + "arrayvec", + "bolero", + "dataplane-fixed-size", + "dataplane-match-action-derive", +] + +[[package]] +name = "dataplane-match-action-derive" +version = "0.21.0" +dependencies = [ + "proc-macro-crate 3.5.0", + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "dataplane-mgmt" version = "0.21.0" @@ -1479,6 +1529,7 @@ dependencies = [ "dataplane-interface-manager", "dataplane-k8s-intf", "dataplane-k8s-less", + "dataplane-lifecycle", "dataplane-lpm", "dataplane-nat", "dataplane-net", @@ -1552,6 +1603,7 @@ dependencies = [ "bytecheck", "dataplane-common", "dataplane-concurrency", + "dataplane-fixed-size", "dataplane-id", "derive_builder", "downcast-rs", @@ -1605,6 +1657,7 @@ dependencies = [ "dataplane-concurrency", "dataplane-config", "dataplane-left-right-tlcache", + "dataplane-lifecycle", "dataplane-lpm", "dataplane-net", "dataplane-tracectl", @@ -5562,6 +5615,7 @@ dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", "pin-project-lite", "slab", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 04a96edf26..0928e8f3d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ + "acl", "args", "cli", "common", @@ -11,7 +12,9 @@ members = [ "dpdk", "dpdk-sys", "dpdk-sysroot-helper", + "dpdk-test-macros", "errno", + "fixed-size", "flow-entry", "flow-filter", "hardware", @@ -21,7 +24,11 @@ members = [ "k8s-intf", "k8s-less", "left-right-tlcache", + "lifecycle", + "lookup", "lpm", + "match-action", + "match-action-derive", "mgmt", "nat", "net", @@ -58,6 +65,7 @@ repository = "https://github.com/githedgehog/dataplane/" # justifying why it is workspace-wide. # Internal +acl = { path = "./acl", package = "dataplane-acl", features = [] } args = { path = "./args", package = "dataplane-args", features = [] } cli = { path = "./cli", package = "dataplane-cli", features = [] } common = { path = "./common", package = "dataplane-common", features = [] } @@ -67,8 +75,10 @@ config = { path = "./config", package = "dataplane-config", features = [] } dpdk = { path = "./dpdk", package = "dataplane-dpdk", features = [] } dpdk-sys = { path = "./dpdk-sys", package = "dataplane-dpdk-sys", features = [] } dpdk-sysroot-helper = { path = "./dpdk-sysroot-helper", package = "dataplane-dpdk-sysroot-helper", features = [] } +dpdk-test-macros = { path = "./dpdk-test-macros", package = "dataplane-dpdk-test-macros", features = [] } dplane-rpc = { git = "https://github.com/githedgehog/dplane-rpc.git", branch = "pr/daniel-noland/bumps", features = [] } errno = { path = "./errno", package = "dataplane-errno", features = [] } +fixed-size = { path = "./fixed-size", package = "dataplane-fixed-size", features = [] } flow-entry = { path = "./flow-entry", package = "dataplane-flow-entry", features = [] } flow-filter = { path = "./flow-filter", package = "dataplane-flow-filter", features = [] } hardware = { path = "./hardware", package = "dataplane-hardware", features = [] } @@ -78,7 +88,11 @@ interface-manager = { path = "./interface-manager", package = "dataplane-interfa k8s-intf = { path = "./k8s-intf", package = "dataplane-k8s-intf", default-features = false, features = [] } k8s-less = { path = "./k8s-less", package = "dataplane-k8s-less", features = [] } left-right-tlcache = { path = "./left-right-tlcache", package = "dataplane-left-right-tlcache", features = [] } +lifecycle = { path = "./lifecycle", package = "dataplane-lifecycle", features = [] } +lookup = { path = "./lookup", package = "dataplane-lookup", features = [] } lpm = { path = "./lpm", package = "dataplane-lpm", features = [] } +match-action = { path = "./match-action", package = "dataplane-match-action", features = [] } +match-action-derive = { path = "./match-action-derive", package = "dataplane-match-action-derive", features = [] } mgmt = { path = "./mgmt", package = "dataplane-mgmt", features = [] } nat = { path = "./nat", package = "dataplane-nat", features = [] } net = { path = "./net", package = "dataplane-net", features = [] } @@ -114,7 +128,6 @@ clap = { version = "4.6.1", default-features = true, features = [] } color-eyre = { version = "0.6.5", default-features = false, features = [] } colored = { version = "3.1.1", default-features = false, features = [] } crossbeam-utils = { version = "0.8.21", default-features = false, features = [] } -ctrlc = { version = "3.5.2", default-features = false, features = [] } dashmap = { version = "6.2.1", default-features = false, features = [] } derive_builder = { version = "0.20.2", default-features = false, features = [] } dotenvy = { version = "0.15.7", default-features = false, features = [] } @@ -249,6 +262,15 @@ overflow-checks = true # modified to use conditional compilation to work in wasm/miri. # miss: packages that are not logically hopeless, or pointless in wasm/miri, but which currently just happen to contain # logic which can and should eventually be factored out or abstracted into something suitable for wasm/miri. +[workspace.metadata.package.acl] +package = "dataplane-acl" +# Default features enable the DPDK `rte_acl` backend, which pulls in +# `dpdk-sys` (bindgen against the system DPDK headers). miri can't +# build that path on the cross target, and the reference backend's +# unit tests run fine outside the miri profile. +miri = false # hopeless + pointless +wasm = false # hopeless + pointless + [workspace.metadata.package.args] package = "dataplane-args" miri = true @@ -280,6 +302,11 @@ package = "dataplane-dpdk-sys" miri = false # hopeless + pointless wasm = false # hopeless + pointless +[workspace.metadata.package.fixed-size] +package = "dataplane-fixed-size" +miri = true +wasm = true + [workspace.metadata.package.flow-entry] package = "dataplane-flow-entry" miri = true @@ -315,6 +342,23 @@ package = "dataplane-k8s-less" miri = true wasm = false # split +[workspace.metadata.package.lookup] +package = "dataplane-lookup" +miri = true +wasm = false # split (std collections) + +[workspace.metadata.package.match-action] +package = "dataplane-match-action" +miri = true +wasm = true + +[workspace.metadata.package.match-action-derive] +package = "dataplane-match-action-derive" +# Proc-macro crate: runs at the host toolchain, not the miri target. +# Excluded from miri to keep the target dep graph clean. +miri = false # hopeless + pointless +wasm = false # hopeless + pointless + [workspace.metadata.package.mgmt] package = "dataplane-mgmt" miri = false diff --git a/acl/Cargo.toml b/acl/Cargo.toml new file mode 100644 index 0000000000..72976b9255 --- /dev/null +++ b/acl/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "dataplane-acl" +edition.workspace = true +license.workspace = true +publish.workspace = true +version.workspace = true + +[dependencies] +arrayvec = { workspace = true, default-features = true } +concurrency = { workspace = true, features = [] } +lookup = { workspace = true, features = [] } +match-action = { workspace = true, features = ["derive"] } +net = { workspace = true, features = [] } +thiserror = { workspace = true } diff --git a/acl/src/lib.rs b/acl/src/lib.rs new file mode 100644 index 0000000000..26995e27e0 --- /dev/null +++ b/acl/src/lib.rs @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +#![deny( + unsafe_code, + clippy::all, + clippy::pedantic, + clippy::unwrap_used, + clippy::expect_used, + clippy::panic +)] +#![allow(missing_docs)] // shape settling; doc once stable + +//! Match-action classifier backends for [`match_action::MatchKey`] +//! tables, behind the [`lookup::Lookup`] interface. +//! +//! - [`reference`](mod@reference): linear-scan software classifier; +//! differential oracle and a mutable cascade front. Always built. +//! +//! The production `rte_acl` backend lands behind a follow-up `dpdk` +//! feature gate. +//! +//! [`lookup::Lookup`]: lookup::Lookup +//! [`match_action::MatchKey`]: match_action::MatchKey + +pub mod reference; diff --git a/acl/src/reference/dyn_table.rs b/acl/src/reference/dyn_table.rs new file mode 100644 index 0000000000..b5384c4def --- /dev/null +++ b/acl/src/reference/dyn_table.rs @@ -0,0 +1,340 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +use match_action::FieldSpec; + +use super::table::RefRule; +#[derive(Debug, PartialEq, Eq, thiserror::Error)] +pub enum DynShapeError { + #[error("specs are empty")] + EmptySpecs, + #[error( + "spec {idx} offset {offset} disagrees with cumulative size {expected_offset} \ + of fields 0..{idx}" + )] + OffsetMismatch { + idx: usize, + offset: usize, + expected_offset: usize, + }, + #[error("spec {idx} has zero size")] + ZeroSize { idx: usize }, + #[error("rule {rule} has {actual} predicates, specs has {expected}")] + FieldCountMismatch { + rule: usize, + expected: usize, + actual: usize, + }, + #[error("rule {rule} field {field}: predicate width {actual} != spec size {expected}")] + PredicateWidthMismatch { + rule: usize, + field: usize, + expected: usize, + actual: usize, + }, +} +#[derive(Clone, Debug)] +pub struct DynReferenceTable { + specs: Vec, + key_size: usize, + rules: Vec>, +} + +impl DynReferenceTable { + pub fn new(specs: Vec, rules: Vec>) -> Result { + let key_size = validate_specs(&specs)?; + for (rule_idx, rule) in rules.iter().enumerate() { + if rule.fields().len() != specs.len() { + return Err(DynShapeError::FieldCountMismatch { + rule: rule_idx, + expected: specs.len(), + actual: rule.fields().len(), + }); + } + for (field_idx, (pred, spec)) in rule.fields().iter().zip(&specs).enumerate() { + if pred.width() != spec.size { + return Err(DynShapeError::PredicateWidthMismatch { + rule: rule_idx, + field: field_idx, + expected: spec.size, + actual: pred.width(), + }); + } + } + } + Ok(Self { + specs, + key_size, + rules, + }) + } + #[must_use] + pub fn key_size(&self) -> usize { + self.key_size + } + #[must_use] + pub fn specs(&self) -> &[FieldSpec] { + &self.specs + } + #[must_use] + pub fn len(&self) -> usize { + self.rules.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.rules.is_empty() + } + #[must_use] + pub fn lookup_bytes(&self, key: &[u8]) -> Option<&A> { + assert_eq!(key.len(), self.key_size, "key length must equal key_size"); + self.rules + .iter() + .find(|rule| rule.matches_packed(&self.specs, key)) + .map(RefRule::action) + } + #[must_use] + pub fn matches_bytes(&self, key: &[u8]) -> Vec<&RefRule> { + assert_eq!(key.len(), self.key_size, "key length must equal key_size"); + self.rules + .iter() + .filter(|rule| rule.matches_packed(&self.specs, key)) + .collect() + } +} +fn validate_specs(specs: &[FieldSpec]) -> Result { + if specs.is_empty() { + return Err(DynShapeError::EmptySpecs); + } + let mut cursor = 0usize; + for (idx, spec) in specs.iter().enumerate() { + if spec.size == 0 { + return Err(DynShapeError::ZeroSize { idx }); + } + if spec.offset != cursor { + return Err(DynShapeError::OffsetMismatch { + idx, + offset: spec.offset, + expected_offset: cursor, + }); + } + cursor += spec.size; + } + Ok(cursor) +} + +#[cfg(test)] +mod tests { + use super::*; + use match_action::{FieldKind, FieldPredicate}; + + fn spec(name: &'static str, kind: FieldKind, size: usize, offset: usize) -> FieldSpec { + FieldSpec { + name, + kind, + size, + offset, + } + } + + fn make_rule_5tuple( + proto: u8, + src: [u8; 4], + src_len: u8, + dport_lo: u16, + dport_hi: u16, + action: u32, + ) -> RefRule { + use match_action::predicate::{Exact, FieldBytes, Prefix, Range}; + let proto_b: FieldBytes = [proto].iter().copied().collect(); + let src_b: FieldBytes = src.iter().copied().collect(); + let dlo: FieldBytes = dport_lo.to_be_bytes().iter().copied().collect(); + let dhi: FieldBytes = dport_hi.to_be_bytes().iter().copied().collect(); + RefRule::new( + vec![ + FieldPredicate::Exact(Exact::new(proto_b)), + FieldPredicate::Prefix(Prefix::new(src_b, src_len)), + FieldPredicate::Range(Range::new(dlo, dhi)), + ], + action, + ) + } + + fn five_tuple_specs() -> Vec { + vec![ + spec("proto", FieldKind::Exact, 1, 0), + spec("src", FieldKind::Prefix, 4, 1), + spec("dport", FieldKind::Range, 2, 5), + ] + } + + #[test] + fn lookup_bytes_hits_and_misses() { + let table = DynReferenceTable::new( + five_tuple_specs(), + vec![make_rule_5tuple(6, [10, 0, 0, 0], 8, 22, 22, 0xAA)], + ) + .expect("valid shape"); + assert_eq!(table.key_size(), 1 + 4 + 2); + + let mut key = vec![6u8]; + key.extend_from_slice(&[10, 1, 2, 3]); + key.extend_from_slice(&22u16.to_be_bytes()); + assert_eq!(table.lookup_bytes(&key), Some(&0xAA)); + let mut key = vec![6u8]; + key.extend_from_slice(&[11, 0, 0, 0]); + key.extend_from_slice(&22u16.to_be_bytes()); + assert_eq!(table.lookup_bytes(&key), None); + let mut key = vec![6u8]; + key.extend_from_slice(&[10, 1, 2, 3]); + key.extend_from_slice(&80u16.to_be_bytes()); + assert_eq!(table.lookup_bytes(&key), None); + } + + #[test] + fn matches_bytes_is_nonlossy() { + use match_action::predicate::{Exact, FieldBytes}; + let broad = RefRule::new( + vec![ + FieldPredicate::Exact(Exact::new([6u8].iter().copied().collect::())), + FieldPredicate::Prefix(match_action::predicate::Prefix::new( + [0u8; 4].iter().copied().collect::(), + 0, + )), + FieldPredicate::Range(match_action::predicate::Range::new( + 0u16.to_be_bytes().iter().copied().collect::(), + u16::MAX + .to_be_bytes() + .iter() + .copied() + .collect::(), + )), + ], + 0xBB, + ); + let narrow = make_rule_5tuple(6, [10, 0, 0, 0], 8, 22, 22, 0xCC); + let table = + DynReferenceTable::new(five_tuple_specs(), vec![broad, narrow]).expect("valid shape"); + + let mut key = vec![6u8]; + key.extend_from_slice(&[10, 1, 2, 3]); + key.extend_from_slice(&22u16.to_be_bytes()); + let m = table.matches_bytes(&key); + assert_eq!(m.len(), 2); + assert_eq!(m[0].action(), &0xBB); + assert_eq!(m[1].action(), &0xCC); + } + + #[test] + fn rejects_offset_mismatch() { + let bad = vec![ + spec("proto", FieldKind::Exact, 1, 0), + spec("src", FieldKind::Prefix, 4, 0), + ]; + let err = DynReferenceTable::<()>::new(bad, vec![]).unwrap_err(); + assert!(matches!( + err, + DynShapeError::OffsetMismatch { + idx: 1, + offset: 0, + expected_offset: 1 + } + )); + } + + #[test] + fn rejects_predicate_width_mismatch() { + use match_action::predicate::{Exact, FieldBytes}; + let bad_proto: FieldBytes = [0u8, 0].iter().copied().collect(); + let rule = RefRule::new(vec![FieldPredicate::Exact(Exact::new(bad_proto))], 0u32); + let specs = vec![spec("proto", FieldKind::Exact, 1, 0)]; + let err = DynReferenceTable::new(specs, vec![rule]).unwrap_err(); + assert!(matches!( + err, + DynShapeError::PredicateWidthMismatch { + rule: 0, + field: 0, + expected: 1, + actual: 2 + } + )); + } + + #[test] + fn rejects_empty_specs() { + let err = DynReferenceTable::::new(vec![], vec![]).unwrap_err(); + assert_eq!(err, DynShapeError::EmptySpecs); + } + + #[test] + fn rejects_zero_size_spec() { + let bad = vec![spec("x", FieldKind::Exact, 0, 0)]; + let err = DynReferenceTable::::new(bad, vec![]).unwrap_err(); + assert_eq!(err, DynShapeError::ZeroSize { idx: 0 }); + } + #[test] + fn dyn_table_agrees_with_typed_table() { + use crate::reference::ReferenceTable; + use core::net::Ipv4Addr; + use lookup::Lookup; + use match_action::{ExactSpec, MatchKey, PrefixSpec, RangeSpec}; + + #[derive(MatchKey)] + struct K { + #[exact] + proto: u8, + #[prefix] + src: Ipv4Addr, + #[range] + dport: u16, + } + + let rule_fields = KRule { + proto: ExactSpec::new(6), + src: PrefixSpec::new(Ipv4Addr::new(10, 0, 0, 0), 8), + dport: RangeSpec::exact(22), + } + .into_backend_fields::(); + + let typed = ReferenceTable::::new(vec![RefRule::new(rule_fields.clone(), 0xAA)]); + let dynamic = DynReferenceTable::new( + K::field_specs().to_vec(), + vec![RefRule::new(rule_fields, 0xAA)], + ) + .expect("valid shape"); + + for (key, label) in &[ + ( + K { + proto: 6, + src: "10.1.2.3".parse().unwrap(), + dport: 22, + }, + "hit", + ), + ( + K { + proto: 6, + src: "11.0.0.0".parse().unwrap(), + dport: 22, + }, + "src miss", + ), + ( + K { + proto: 17, + src: "10.1.2.3".parse().unwrap(), + dport: 22, + }, + "proto miss", + ), + ] { + let bytes = key.as_key(); + assert_eq!( + typed.lookup(key).copied(), + dynamic.lookup_bytes(&bytes).copied(), + "typed vs dynamic disagree on {label}", + ); + } + } +} diff --git a/acl/src/reference/mod.rs b/acl/src/reference/mod.rs new file mode 100644 index 0000000000..0fd27fea59 --- /dev/null +++ b/acl/src/reference/mod.rs @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +pub mod dyn_table; +pub mod table; + +pub use dyn_table::{DynReferenceTable, DynShapeError}; +pub use match_action::{Erased, FieldPredicate}; +pub use table::{RefRule, ReferenceTable}; diff --git a/acl/src/reference/table.rs b/acl/src/reference/table.rs new file mode 100644 index 0000000000..80463f3480 --- /dev/null +++ b/acl/src/reference/table.rs @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +use core::marker::PhantomData; + +use lookup::Lookup; +use match_action::{FieldPredicate, FieldSpec, MatchKey}; +const MAX_KEY_BYTES: usize = 256; +#[derive(Clone, Debug)] +pub struct RefRule { + fields: Vec, + action: A, +} + +impl RefRule { + #[must_use] + pub fn new(fields: Vec, action: A) -> Self { + Self { fields, action } + } + + pub fn action(&self) -> &A { + &self.action + } + #[must_use] + pub fn fields(&self) -> &[FieldPredicate] { + &self.fields + } + + pub(crate) fn matches_packed(&self, specs: &[FieldSpec], buf: &[u8]) -> bool { + debug_assert_eq!(self.fields.len(), specs.len()); + self.fields + .iter() + .zip(specs) + .all(|(pred, spec)| pred.matches(&buf[spec.offset..spec.offset + spec.size])) + } +} +#[derive(Clone, Debug)] +pub struct ReferenceTable { + rules: Vec>, + _key: PhantomData K>, +} + +impl ReferenceTable { + #[must_use] + pub fn new(rules: Vec>) -> Self { + Self { + rules, + _key: PhantomData, + } + } + + #[must_use] + pub fn empty() -> Self { + Self::new(Vec::new()) + } + + #[must_use] + pub fn len(&self) -> usize { + self.rules.len() + } + + #[must_use] + pub fn is_empty(&self) -> bool { + self.rules.is_empty() + } + fn pack(key: &K) -> Option<[u8; MAX_KEY_BYTES]> { + if K::KEY_SIZE > MAX_KEY_BYTES { + return None; + } + let mut buf = [0u8; MAX_KEY_BYTES]; + key.as_key_into(&mut buf[..K::KEY_SIZE]); + Some(buf) + } + #[must_use] + pub fn matches(&self, key: &K) -> Vec<&RefRule> { + let Some(buf) = Self::pack(key) else { + return Vec::new(); + }; + let specs = K::field_specs(); + self.rules + .iter() + .filter(|rule| rule.matches_packed(specs, &buf)) + .collect() + } +} + +impl Lookup for ReferenceTable { + fn lookup(&self, key: &K) -> Option<&A> { + let buf = Self::pack(key)?; + let specs = K::field_specs(); + self.rules + .iter() + .find(|rule| rule.matches_packed(specs, &buf)) + .map(RefRule::action) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use core::net::Ipv4Addr; + use match_action::{Erased, ExactSpec, MatchKey, PrefixSpec, RangeSpec}; + + #[derive(Copy, Clone, Debug, PartialEq, Eq)] + enum Verdict { + Allow, + Drop, + } + + #[derive(MatchKey)] + struct FiveTuple { + #[exact] + proto: u8, + #[prefix] + src_ip: Ipv4Addr, + #[prefix] + dst_ip: Ipv4Addr, + #[range] + src_port: u16, + #[range] + dst_port: u16, + } + + fn drop_10_8_to_22() -> RefRule { + RefRule::new( + FiveTupleRule { + proto: ExactSpec::new(6), + src_ip: PrefixSpec::new(Ipv4Addr::new(10, 0, 0, 0), 8), + dst_ip: PrefixSpec::new(Ipv4Addr::UNSPECIFIED, 0), + src_port: RangeSpec::new(0, u16::MAX), + dst_port: RangeSpec::exact(22), + } + .into_backend_fields::(), + Verdict::Drop, + ) + } + + #[test] + fn single_rule_hit_and_miss() { + let table = ReferenceTable::new(vec![drop_10_8_to_22()]); + + assert_eq!( + table.lookup(&FiveTuple { + proto: 6, + src_ip: "10.1.2.3".parse().unwrap(), + dst_ip: "192.168.1.1".parse().unwrap(), + src_port: 54321, + dst_port: 22, + }), + Some(&Verdict::Drop), + ); + assert_eq!( + table.lookup(&FiveTuple { + proto: 6, + src_ip: "11.0.0.1".parse().unwrap(), + dst_ip: "192.168.1.1".parse().unwrap(), + src_port: 54321, + dst_port: 22, + }), + None, + ); + assert_eq!( + table.lookup(&FiveTuple { + proto: 6, + src_ip: "10.1.2.3".parse().unwrap(), + dst_ip: "192.168.1.1".parse().unwrap(), + src_port: 54321, + dst_port: 80, + }), + None, + ); + } + + #[test] + fn empty_table_always_misses() { + let table: ReferenceTable = ReferenceTable::empty(); + assert!(table.is_empty()); + assert_eq!( + table.lookup(&FiveTuple { + proto: 6, + src_ip: Ipv4Addr::UNSPECIFIED, + dst_ip: Ipv4Addr::UNSPECIFIED, + src_port: 0, + dst_port: 0, + }), + None, + ); + } + fn allow_all_tcp() -> RefRule { + RefRule::new( + FiveTupleRule { + proto: ExactSpec::new(6), + src_ip: PrefixSpec::new(Ipv4Addr::UNSPECIFIED, 0), + dst_ip: PrefixSpec::new(Ipv4Addr::UNSPECIFIED, 0), + src_port: RangeSpec::new(0, u16::MAX), + dst_port: RangeSpec::new(0, u16::MAX), + } + .into_backend_fields::(), + Verdict::Allow, + ) + } + + fn overlapping_packet() -> FiveTuple { + FiveTuple { + proto: 6, + src_ip: "10.1.2.3".parse().unwrap(), + dst_ip: "192.168.1.1".parse().unwrap(), + src_port: 54321, + dst_port: 22, + } + } + + #[test] + fn positional_precedence_first_match_wins() { + let table = ReferenceTable::new(vec![allow_all_tcp(), drop_10_8_to_22()]); + assert_eq!(table.lookup(&overlapping_packet()), Some(&Verdict::Allow)); + } + + #[test] + fn matches_is_nonlossy_and_retains_shadowed_losers() { + let table = ReferenceTable::new(vec![allow_all_tcp(), drop_10_8_to_22()]); + let matched = table.matches(&overlapping_packet()); + + assert_eq!(matched.len(), 2); + assert_eq!(matched[0].action(), &Verdict::Allow); + assert_eq!(matched[1].action(), &Verdict::Drop); + } +} diff --git a/dataplane/Cargo.toml b/dataplane/Cargo.toml index 94fefe5fc5..83d86fcba4 100644 --- a/dataplane/Cargo.toml +++ b/dataplane/Cargo.toml @@ -19,7 +19,6 @@ axum = { workspace = true, features = ["http1", "tokio"] } axum-server = { workspace = true } concurrency = { workspace = true } config = { workspace = true } -ctrlc = { workspace = true, features = ["termination"] } dyn-iter = { workspace = true } flow-entry = { workspace = true } flow-filter = { workspace = true } @@ -27,11 +26,11 @@ futures = { workspace = true } hyper = { workspace = true } hyper-util = { workspace = true } id = { workspace = true } +lifecycle = { workspace = true } linkme = { workspace = true } metrics = { workspace = true } metrics-exporter-prometheus = { workspace = true } mgmt = { workspace = true } -mio = { workspace = true, features = ["os-ext", "net"] } nat = { workspace = true } net = { workspace = true, features = ["test_buffer"] } nix = { workspace = true, features = ["socket", "hostname"] } @@ -46,7 +45,7 @@ rtnetlink = { workspace = true, features = ["default", "tokio"] } serde = { workspace = true, features = ["derive"] } stats = { workspace = true } thiserror = { workspace = true } -tokio = { workspace = true } +tokio = { workspace = true, features = ["rt", "rt-multi-thread"] } tracectl = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true, default-features = true } diff --git a/dataplane/src/drivers/kernel/mod.rs b/dataplane/src/drivers/kernel/mod.rs index 88a5594358..f5914de3ad 100644 --- a/dataplane/src/drivers/kernel/mod.rs +++ b/dataplane/src/drivers/kernel/mod.rs @@ -18,6 +18,9 @@ mod worker; use concurrency::sync::Arc; use concurrency::thread; +#[allow(unused_imports)] // used under loom/shuttle backends +use concurrency::thread::BuilderExt; +use lifecycle::Subsystem; use net::buffer::test_buffer::TestBuffer; use pipeline::DynPipeline; use tracectl::trace_target; @@ -25,88 +28,87 @@ use tracectl::trace_target; use tracing::{debug, error, info, trace, warn}; use super::DriverError; -use super::tokio_util::run_in_local_tokio_runtime; use kif::{Kif, bring_kifs_up}; use worker::Worker; trace_target!("kernel-driver", LevelFilter::INFO, &["driver"]); -/// Main structure representing the kernel driver. -/// This driver: -/// * receives raw frames via `AF_PACKET`, parses to `Packet` -/// * selects a worker by symmetric flow hash -/// * workers run independent pipelines and send processed packets back -/// * dispatcher serializes & transmits on the chosen outgoing interface +/// AF_PACKET-based kernel driver. Spawns N workers with symmetric-hash +/// fanout and per-worker pipelines. pub struct DriverKernel; #[allow(clippy::cast_possible_truncation)] impl DriverKernel { - /// Spawn `workers` processing threads, each with its own pipeline instance. - /// - /// Returns: - /// - `Arc>>>` one sender per worker (dispatcher -> worker) - /// - `Receiver>` a single queue for processed packets (worker -> dispatcher) - fn spawn_workers( + /// Spawn `num_workers` worker threads into `scope`, each with its own + /// pipeline. Bails on the first spawn failure; workers that did spawn + /// drain via the scope join. + fn spawn_workers_scoped<'scope>( + scope: &'scope thread::Scope<'scope, '_>, + workers_subsystem: &Subsystem, num_workers: usize, setup_pipeline: &Arc DynPipeline>, interfaces: &[Kif], - ) -> Vec>> { + ) -> Result>>, std::io::Error> + { info!("Spawning {num_workers} workers"); - let mut workers = Vec::new(); - for wid in 0..num_workers { - let builder = thread::Builder::new().name(format!("dp-worker-{wid}")); - let mut worker = Worker::new(wid, num_workers, setup_pipeline); - match worker.start(builder, interfaces) { - Ok(handle) => workers.push(handle), - Err(e) => { - error!("Failed to start worker {wid}: {e}"); - } - } - } - workers + (0..num_workers) + .map(|wid| { + let builder = thread::Builder::new().name(format!("dp-worker-{wid}")); + Worker::new(wid, num_workers, setup_pipeline, workers_subsystem.clone()) + .start(scope, builder, interfaces) + }) + .collect() } - /// Starts the kernel driver, spawns worker threads, and runs the dispatcher loop. - /// - /// - `args`: kernel driver CLI parameters (e.g., `--interface` list) - /// - `workers`: number of worker threads / pipelines - /// - `setup_pipeline`: factory returning a **fresh** `DynPipeline` per worker + /// Spawn worker threads + supervisor into `scope`. The scope joins + /// all driver threads on closure return. /// /// # Errors - /// Returns [`DriverError`] in case the driver fails to start successfully. - pub fn start( - stop_tx: std::sync::mpsc::Sender, + /// Returns [`DriverError`] on interface setup or thread spawn failure. + pub fn start<'scope>( + scope: &'scope thread::Scope<'scope, '_>, + workers_subsystem: &Subsystem, args: impl IntoIterator + Clone>, num_workers: usize, setup_pipeline: &Arc DynPipeline>, ) -> Result<(), DriverError> { + // A current_thread runtime built inside another tokio runtime + // panics; catch nesting in debug. + debug_assert!( + tokio::runtime::Handle::try_current().is_err(), + "DriverKernel::start must not be invoked from within a tokio runtime context" + ); + info!("Collecting interfaces from config"); let interfaces = kif::get_interfaces(args)?; - // ensure that the kernel interfaces for rx/tx are up - run_in_local_tokio_runtime(async || bring_kifs_up(interfaces.as_slice()).await)?; + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()? + .block_on(bring_kifs_up(interfaces.as_slice()))?; - // Spawn workers - let worker_handles = - Self::spawn_workers(num_workers, setup_pipeline, interfaces.as_slice()); + let worker_handles = Self::spawn_workers_scoped( + scope, + workers_subsystem, + num_workers, + setup_pipeline, + interfaces.as_slice(), + )?; - let control_builder = thread::Builder::new().name("kernel-driver-controller".to_string()); - control_builder.spawn(move || { + // The supervisor just joins-and-logs; worker fatal reporting is + // handled by the `ExitGuard` inside each worker thread. + let supervisor_builder = + thread::Builder::new().name("kernel-driver-supervisor".to_string()); + supervisor_builder.spawn_scoped(scope, move || { for (id, handle) in worker_handles.into_iter().enumerate() { - info!("Waiting for workers to finish"); + info!("Waiting for worker {id} to finish"); match handle.join() { - Ok(result) => match result { - Ok(()) => info!("Worker {id} exited successfully"), - Err(e) => error!("Worker {id} exited with error: {e}"), - }, - Err(e) => error!("Unable to spawn worker {id} error: {e:?}"), + Ok(Ok(())) => info!("Worker {id} exited successfully"), + Ok(Err(e)) => error!("Worker {id} exited with error: {e}"), + Err(panic_payload) => error!("Worker {id} panicked: {panic_payload:?}"), } } - - // Exiting with error as it's not expected for all workers to finish - error!("All workers finished unexpectedly"); - #[allow(clippy::expect_used)] - stop_tx.send(1).expect("Failed to send stop signal"); + info!("All workers joined"); })?; Ok(()) diff --git a/dataplane/src/drivers/kernel/worker.rs b/dataplane/src/drivers/kernel/worker.rs index 48a80935eb..235f3b325d 100644 --- a/dataplane/src/drivers/kernel/worker.rs +++ b/dataplane/src/drivers/kernel/worker.rs @@ -15,6 +15,9 @@ use tokio::sync::Mutex; use concurrency::sync::Arc; use concurrency::thread; +#[allow(unused_imports)] // used under loom/shuttle backends +use concurrency::thread::BuilderExt; +use lifecycle::Subsystem; use net::buffer::test_buffer::TestBuffer; use net::interface::InterfaceIndex; use net::packet::{DoneReason, Packet}; @@ -22,7 +25,6 @@ use pipeline::{DynPipeline, NetworkFunction}; use crate::drivers::kernel::fanout::{PacketFanoutType, set_packet_fanout}; use crate::drivers::kernel::kif::Kif; -use crate::drivers::tokio_util::run_in_local_tokio_runtime; use tracing::{debug, error, info, trace, warn}; @@ -126,6 +128,7 @@ pub struct Worker { id: WorkerId, total_workers: usize, setup_pipeline: Arc DynPipeline>, + subsystem: Subsystem, } impl Worker { @@ -133,28 +136,69 @@ impl Worker { id: WorkerId, total_workers: usize, setup_pipeline: &Arc DynPipeline>, + subsystem: Subsystem, ) -> Self { Worker { id, total_workers, setup_pipeline: setup_pipeline.clone(), + subsystem, } } - pub fn start( + #[allow(clippy::too_many_lines)] + pub fn start<'scope>( &mut self, + scope: &'scope thread::Scope<'scope, '_>, thread_builder: thread::Builder, interfaces: &[Kif], - ) -> Result>, io::Error> { + ) -> Result>, io::Error> { let id = self.id; let total_workers = self.total_workers; let setup = self.setup_pipeline.clone(); - let interfaces = interfaces.iter().map(Kif::clone).collect::>(); + let subsystem = self.subsystem.clone(); + let cancel = subsystem.cancel_token(); + let interfaces = interfaces.to_vec(); + + let handle_res = thread_builder.spawn_scoped(scope, move || { + // Drop-guard so panic-unwind, early-`?`, and unexpected normal + // return all reach report_fatal. Disarmed on the graceful path. + struct ExitGuard { + subsystem: Subsystem, + id: WorkerId, + armed: bool, + } + impl ExitGuard { + fn disarm(&mut self) { + self.armed = false; + } + } + impl Drop for ExitGuard { + fn drop(&mut self) { + if !self.armed || self.subsystem.is_cancelled() { + return; + } + let reason = if std::thread::panicking() { + format!("worker {} panicked", self.id) + } else { + format!("worker {} exited unexpectedly", self.id) + }; + self.subsystem.report_fatal(&reason); + } + } - let handle_res = thread_builder.spawn(move || { info!(worker = id, "Worker started"); + let mut guard = ExitGuard { + subsystem: subsystem.clone(), + id, + armed: true, + }; + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build_local(tokio::runtime::LocalOptions::default())?; - run_in_local_tokio_runtime(async || { + let result = rt.block_on(async { let (readers, if_table) = match build_interface_table(id, total_workers, interfaces.as_slice()) { Ok(table) => table, @@ -166,27 +210,39 @@ impl Worker { let setup = setup.clone(); let if_table = if_table.clone(); + let cancel = cancel.clone(); let mut reader_handles = tokio::task::JoinSet::new(); for intf in readers { let setup = setup.clone(); let if_table = if_table.clone(); + let cancel = cancel.clone(); reader_handles.spawn_local(async move { let intf = intf; let mut pipeline = setup(); loop { - tracing::debug!(worker = id, "awaiting packets"); + debug!(worker = id, "awaiting packets"); - let packets_vec = match read_packets_from_interface(id, &intf).await { - Ok(packets) => packets, - Err(e) => { - error!( + let packets_vec = tokio::select! { + () = cancel.cancelled() => { + info!( worker = id, rx_intf_name = intf.if_name, - "Error reading packets from interface: {e}" + "cancellation observed; exiting reader" ); - vec![] + break; + } + result = read_packets_from_interface(id, &intf) => match result { + Ok(packets) => packets, + Err(e) => { + error!( + worker = id, + rx_intf_name = intf.if_name, + "Error reading packets from interface: {e}" + ); + vec![] + } } }; @@ -198,7 +254,6 @@ impl Worker { intf.if_name ); - // Try to receive everything else that is in the buffer let packets = packets_vec.into_iter(); let mut count = 0; @@ -238,8 +293,13 @@ impl Worker { } Ok::<(), io::Error>(()) - })?; - info!(worker = id, "Worker exited"); + }); + + if subsystem.is_cancelled() { + guard.disarm(); + } + info!(worker = id, "worker exited"); + result?; Ok::<(), io::Error>(()) })?; Ok(handle_res) diff --git a/dataplane/src/drivers/mod.rs b/dataplane/src/drivers/mod.rs index 80731da0e5..74990160a7 100644 --- a/dataplane/src/drivers/mod.rs +++ b/dataplane/src/drivers/mod.rs @@ -4,7 +4,6 @@ use thiserror::Error; pub mod kernel; -mod tokio_util; #[derive(Error, Debug)] pub enum DriverError { diff --git a/dataplane/src/drivers/tokio_util.rs b/dataplane/src/drivers/tokio_util.rs deleted file mode 100644 index 0ad91dbf51..0000000000 --- a/dataplane/src/drivers/tokio_util.rs +++ /dev/null @@ -1,50 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 -// Copyright Open Network Fabric Authors - -use tokio::runtime::{Builder, LocalOptions}; - -/// Executes a function inside a local (non-Send) tokio runtime. -/// The runtime will be torn down when the function returns. -/// -/// # Panics -/// If it fails to create a current thread runtime. -pub fn run_in_local_tokio_runtime(f: F) -> R -where - F: FnOnce() -> Fut, - Fut: std::future::Future, -{ - let current_runtime = tokio::runtime::Handle::try_current(); - assert!( - current_runtime.is_err(), - "Expected no active tokio runtime, but found: {:?}", - current_runtime.unwrap_err() - ); - - let rt = Builder::new_current_thread() - .enable_all() - .build_local(LocalOptions::default()) - .expect("Failed to create current thread runtime"); - - rt.block_on(f()) -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio::time::{Duration, sleep}; - - #[test] - fn test_run_in_tokio_runtime_pure() { - let result = run_in_local_tokio_runtime(|| async { 42 }); - assert_eq!(result, 42); - } - - #[test] - fn test_run_in_tokio_runtime_async() { - let result = run_in_local_tokio_runtime(|| async { - sleep(Duration::from_millis(100)).await; - 42 - }); - assert_eq!(result, 42); - } -} diff --git a/dataplane/src/packet_processor/mod.rs b/dataplane/src/packet_processor/mod.rs index 3142d99d9b..233e0fcaa5 100644 --- a/dataplane/src/packet_processor/mod.rs +++ b/dataplane/src/packet_processor/mod.rs @@ -49,6 +49,9 @@ where /// Start a router and provide the associated pipeline pub(crate) fn start_router( + mgmt: &lifecycle::Subsystem, + mgmt_handle: &tokio::runtime::Handle, + router: &lifecycle::Subsystem, params: RouterParams, ) -> Result, RouterError> { let vpcmapw = VpcMapWriter::::new(); @@ -83,7 +86,7 @@ pub(crate) fn start_router( }; // create router - let router = Router::new(params, Some(cli_sources))?; + let router = Router::new(mgmt, mgmt_handle, router, params, Some(cli_sources))?; let iftr_factory = router.get_iftabler_factory(); let fibtr_factory = router.get_fibtr_factory(); let atabler_factory = router.get_atabler_factory(); diff --git a/dataplane/src/runtime.rs b/dataplane/src/runtime.rs index fccabf8c84..94b4814169 100644 --- a/dataplane/src/runtime.rs +++ b/dataplane/src/runtime.rs @@ -2,11 +2,12 @@ // Copyright Open Network Fabric Authors use crate::packet_processor::start_router; -use crate::statistics::MetricsServer; +use crate::statistics::spawn_metrics; use args::{CmdArgs, Parser}; use crate::drivers::kernel::DriverKernel; -use mgmt::{ConfigProcessorParams, MgmtParams, start_mgmt}; +use lifecycle::{Shutdown, default_deadlines, spawn_shutdown_watchdog}; +use mgmt::{ConfigProcessorParams, LaunchError, MgmtParams, run_mgmt}; use nix::unistd::gethostname; use pyroscope::backend::{BackendConfig, PprofConfig, pprof_backend}; @@ -178,14 +179,20 @@ pub fn main() { process_tracing_cmds(&args); - let (stop_tx, stop_rx) = std::sync::mpsc::channel(); - let ctrlc_stop_tx = stop_tx.clone(); - ctrlc::set_handler(move || { - ctrlc_stop_tx - .send(0) - .expect("Error sending shutdown signal"); - }) - .expect("failed to set SIGINT handler"); + let shutdown = Shutdown::new(); + + let mgmt_runtime = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .thread_name("mgmt-rt") + .build() + .expect("Failed to build mgmt runtime"); + let mgmt_handle = mgmt_runtime.handle().clone(); + + lifecycle::spawn_signal_handler(&mgmt_handle, shutdown.root.clone()) + .expect("failed to install signal handler"); + + spawn_shutdown_watchdog(shutdown.root.clone(), default_deadlines::TOTAL, 124) + .expect("failed to spawn shutdown watchdog"); /* router parameters */ let mut binding = RouterParamsBuilder::default(); @@ -195,7 +202,6 @@ pub fn main() { .frr_agent_path(args.frr_agent_path()) .dp_status(dp_status.clone()); - // Only set BMP when it's enabled (strip_option setter expects the inner type) if let Some(server) = bmp_server_params { rp_builder = rp_builder.bmp(server); } @@ -205,65 +211,101 @@ pub fn main() { panic!("Bad router configuration"); }; - // start the router; returns control-plane handles and a pipeline factory - let setup = start_router(router_params).expect("failed to start router"); - - MetricsServer::new(args.metrics_address(), setup.stats); + let mut setup = start_router( + &shutdown.mgmt, + &mgmt_handle, + &shutdown.router, + router_params, + ) + .expect("failed to start router"); + + spawn_metrics( + &shutdown.metrics, + &mgmt_handle, + args.metrics_address(), + setup.stats, + ); - // pipeline builder let pipeline_factory = setup.pipeline; - /* start management: main thread will be blocked until ready or failure */ - if let Err(e) = start_mgmt(MgmtParams { - config_dir: args.config_dir().cloned(), - hostname: gwname.clone(), - processor_params: ConfigProcessorParams { - router_ctl: setup.router.get_ctl_tx(), - pipeline_data: pipeline_factory().get_data(), - flow_table: setup.flow_table, - vpcmapw: setup.vpcmapw, - nattablesw: setup.nattablesw, - natallocatorw: setup.natallocatorw, - flowfilterw: setup.flowfiltertablesw, - portfw_w: setup.portfw_w, - vpc_stats_store: setup.vpc_stats_store, - dp_status_r: dp_status.clone(), - bmp_options: bmp_client_opts, - }, - }) { - error!("Failed to start mgmt: {e}. Stopping dataplane..."); - std::process::exit(-1); - } - info!("Management is running now"); - - /* start driver with the provided pipeline builder */ - let e = match args.driver_name() { - "dpdk" => { - info!("Using driver DPDK..."); - todo!(); - } - "kernel" => { - info!("Using driver kernel..."); - DriverKernel::start( - stop_tx.clone(), - args.kernel_interfaces(), - args.kernel_num_workers(), - &pipeline_factory, - ) - } - other => { - error!("Unknown driver '{other}'. Aborting..."); - panic!("Packet processing pipeline failed to start. Aborting..."); + concurrency::thread::scope(|scope| { + let mgmt_result = run_mgmt( + &mgmt_handle, + &shutdown.mgmt, + MgmtParams { + config_dir: args.config_dir().cloned(), + hostname: gwname.clone(), + processor_params: ConfigProcessorParams { + router_ctl: setup.router.get_ctl_tx(), + pipeline_data: pipeline_factory().get_data(), + flow_table: setup.flow_table, + vpcmapw: setup.vpcmapw, + nattablesw: setup.nattablesw, + natallocatorw: setup.natallocatorw, + flowfilterw: setup.flowfiltertablesw, + portfw_w: setup.portfw_w, + vpc_stats_store: setup.vpc_stats_store, + dp_status_r: dp_status.clone(), + bmp_options: bmp_client_opts, + }, + }, + ); + + match mgmt_result { + Ok(()) => { + info!("Management is running now"); + + let driver_result = match args.driver_name() { + "dpdk" => { + info!("Using driver DPDK..."); + todo!(); + } + "kernel" => { + info!("Using driver kernel..."); + Some(DriverKernel::start( + scope, + &shutdown.workers, + args.kernel_interfaces(), + args.kernel_num_workers(), + &pipeline_factory, + )) + } + other => { + error!("Unknown driver '{other}'. Stopping dataplane..."); + shutdown.fail(); + None + } + }; + + if let Some(Err(e)) = driver_result { + error!("Failed to start driver: {e}"); + shutdown.fail(); + } + } + Err(LaunchError::Cancelled) => { + // Don't call shutdown.fail() — that flips the fatal flag + // and turns a graceful SIGINT into a non-zero exit, which + // systemd would restart-loop. + info!("Mgmt init cancelled; proceeding to shutdown"); + } + Err(e) => { + error!("Failed to start mgmt: {e}. Stopping dataplane..."); + shutdown.fail(); + } } - }; - if let Err(e) = e { - error!("Failed to start driver: {e}"); - std::process::exit(-1); - } + mgmt_handle.block_on(shutdown.root.cancelled()); + info!("Shutting down dataplane"); + mgmt_handle.block_on(shutdown.drain_in_order()); + }); + + let exit_code = i32::from(shutdown.is_fatal()); + + // Router::stop()'s BMP abort needs mgmt_runtime live, so stop router + // before shutting the runtime down. + setup.router.stop(); + mgmt_runtime.shutdown_timeout(Duration::from_secs(2)); - let exit_code = stop_rx.recv().expect("failed to receive stop signal"); - info!("Shutting down dataplane"); if let Some(running) = agent_running { match running.stop() { Ok(ready) => ready.shutdown(), @@ -272,28 +314,3 @@ pub fn main() { } std::process::exit(exit_code); } - -#[cfg(false)] // disabled until dpdk-sys refactor is complete -#[cfg(test)] -mod test { - use n_vm::in_vm; - - #[test] - #[in_vm] - fn root_filesystem_in_vm_is_read_only() { - let error = std::fs::File::create_new("/some.file").unwrap_err(); - assert_eq!(error.kind(), std::io::ErrorKind::ReadOnlyFilesystem); - } - - #[test] - #[in_vm] - fn run_filesystem_in_vm_is_read_write() { - std::fs::File::create_new("/run/some.file").unwrap(); - } - - #[test] - #[in_vm] - fn tmp_filesystem_in_vm_is_read_write() { - std::fs::File::create_new("/tmp/some.file").unwrap(); - } -} diff --git a/dataplane/src/statistics/mod.rs b/dataplane/src/statistics/mod.rs index 453367697d..0aad75a28a 100644 --- a/dataplane/src/statistics/mod.rs +++ b/dataplane/src/statistics/mod.rs @@ -2,9 +2,9 @@ // Copyright Open Network Fabric Authors use axum::{Router, response::Response, routing::get}; +use lifecycle::Subsystem; use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle}; use stats::StatsCollector; -use std::thread::JoinHandle; use std::time::Duration; use tracing::{error, info}; @@ -45,60 +45,68 @@ async fn metrics_handler( .unwrap() } -#[derive(Debug)] -pub struct MetricsServer { - #[allow(unused)] // temporary - handle: JoinHandle<()>, -} - -impl MetricsServer { - // TODO: convert to scoped thread - #[tracing::instrument(level = "info", skip(stats))] - pub fn new(addr: std::net::SocketAddr, stats: StatsCollector) -> Self { - MetricsServer { - handle: std::thread::Builder::new() - .name("metrics-server".to_string()) - .spawn(move || { - info!("Starting metrics server thread"); - - // create tokio runtime - let rt = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .expect("runtime creation failed for metrics server"); - - // block thread to run metrics HTTP server - rt.block_on(Self::run(addr, stats)); - }) - .unwrap(), - } - } - - #[tracing::instrument(level = "info", skip(stats))] - async fn run(addr: std::net::SocketAddr, stats: StatsCollector) { - let PrometheusHandler { handle } = PrometheusHandler::new(); +/// Spawn the `/metrics` endpoint on `addr`, a 30s upkeep ticker, and the +/// stats collector onto `handle`, tracked under `metrics`. Uses +/// [`Subsystem::spawn_on`] — a dead metrics endpoint should not take down +/// the dataplane. +pub fn spawn_metrics( + metrics: &Subsystem, + handle: &tokio::runtime::Handle, + addr: std::net::SocketAddr, + stats: StatsCollector, +) { + let PrometheusHandler { + handle: prom_handle, + } = PrometheusHandler::new(); - let upkeep_handle = handle.clone(); - tokio::spawn(async move { - // avgerage prometheus scraper is between 15 and 60 secs, - // so run upkeep every 30 secs is a reasonable default + let upkeep_handle = prom_handle.clone(); + let upkeep_cancel = metrics.cancel_token(); + metrics.spawn_on( + async move { let mut ticker = tokio::time::interval(Duration::from_secs(30)); loop { - ticker.tick().await; - // run_upkeep is synchronous; call it periodically. - upkeep_handle.run_upkeep(); + tokio::select! { + () = upkeep_cancel.cancelled() => break, + _ = ticker.tick() => { + upkeep_handle.run_upkeep(); + } + } } - }); - tokio::spawn(stats.run()); - let app = Router::new() - .route("/metrics", get(metrics_handler)) - .with_state(handle); + }, + handle, + ); - info!("metrics server listening on {}", addr); + let stats_cancel = metrics.cancel_token(); + metrics.spawn_on( + async move { + tokio::select! { + () = stats_cancel.cancelled() => {} + () = stats.run() => {} + } + }, + handle, + ); - if let Err(e) = axum_server::bind(addr).serve(app.into_make_service()).await { - error!("metrics server error: {}", e); - } - } + let server_cancel = metrics.cancel_token(); + metrics.spawn_on( + async move { + let app = Router::new() + .route("/metrics", get(metrics_handler)) + .with_state(prom_handle); + + info!("metrics server listening on {}", addr); + + tokio::select! { + () = server_cancel.cancelled() => { + info!("metrics server shutdown requested"); + } + res = axum_server::bind(addr).serve(app.into_make_service()) => { + if let Err(e) = res { + error!("metrics server error: {}", e); + } + } + } + }, + handle, + ); } diff --git a/dpdk-test-macros/Cargo.toml b/dpdk-test-macros/Cargo.toml new file mode 100644 index 0000000000..6f245a49df --- /dev/null +++ b/dpdk-test-macros/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "dataplane-dpdk-test-macros" +edition.workspace = true +license.workspace = true +publish.workspace = true +version.workspace = true + +[lib] +proc-macro = true + +[dependencies] +proc-macro-crate = { workspace = true, default-features = true } +proc-macro2 = { workspace = true, default-features = true } +quote = { workspace = true, default-features = true } +syn = { workspace = true, default-features = true, features = ["full"] } diff --git a/dpdk-test-macros/src/lib.rs b/dpdk-test-macros/src/lib.rs new file mode 100644 index 0000000000..c86f286fff --- /dev/null +++ b/dpdk-test-macros/src/lib.rs @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +use proc_macro::TokenStream; +use proc_macro_crate::{FoundCrate, crate_name}; +use proc_macro2::{Span, TokenStream as TokenStream2}; +use quote::quote; +use syn::{Ident, ItemFn, parse_macro_input, parse_quote}; +fn dpdk_crate_path() -> TokenStream2 { + match crate_name("dataplane-dpdk") { + Ok(FoundCrate::Itself) => quote! { crate }, + Ok(FoundCrate::Name(name)) => { + let ident = Ident::new(&name, Span::call_site()); + quote! { ::#ident } + } + Err(_) => { + let ident = Ident::new("dataplane_dpdk", Span::call_site()); + quote! { ::#ident } + } + } +} + +#[proc_macro_attribute] +pub fn with_eal(args: TokenStream, input: TokenStream) -> TokenStream { + if !args.is_empty() { + let err: TokenStream2 = + syn::Error::new(Span::call_site(), "#[with_eal] takes no arguments").to_compile_error(); + return err.into(); + } + + let mut input_fn = parse_macro_input!(input as ItemFn); + let dpdk = dpdk_crate_path(); + let init_stmt: syn::Stmt = parse_quote! { + let _eal = #dpdk::test_support::start_eal(); + }; + input_fn.block.stmts.insert(0, init_stmt); + + quote! { #input_fn }.into() +} diff --git a/dpdk/Cargo.toml b/dpdk/Cargo.toml index fdee013955..8ef27dd586 100644 --- a/dpdk/Cargo.toml +++ b/dpdk/Cargo.toml @@ -8,13 +8,17 @@ version.workspace = true [features] default = ["serde"] serde = ["dep:serde"] +test = ["dep:id", "dep:nix", "dep:dpdk-test-macros"] [dependencies] concurrency = { workspace = true } dpdk-sys = { workspace = true } +dpdk-test-macros = { workspace = true, optional = true } errno = { workspace = true } +id = { workspace = true, optional = true } net = { workspace = true } +nix = { workspace = true, optional = true, features = ["sched"] } serde = { workspace = true, optional = true, features = ["std"] } thiserror = { workspace = true } @@ -24,7 +28,5 @@ tracing = { workspace = true, features = ["attributes"] } dpdk-sysroot-helper = { workspace = true } [dev-dependencies] -id = { workspace = true } - bolero = { workspace = true, default-features = false, features = ["std"] } -nix = { workspace = true, features = ["sched"] } +dataplane-dpdk = { path = ".", features = ["test"] } diff --git a/dpdk/src/acl/config.rs b/dpdk/src/acl/config.rs index ddbc307016..94a51b10d8 100644 --- a/dpdk/src/acl/config.rs +++ b/dpdk/src/acl/config.rs @@ -46,7 +46,7 @@ use super::rule::Rule; /// while DPDK strides through rules at `rule_size = size_of::>()` over /// `Rule<3>`-sized slots -- the exact OOB read the const generic is meant to /// rule out. Keeping `N` on the type closes that gap statically and is -/// consistent with how [`AclBuildConfig`] is parameterised. +/// consistent with how [`AclBuildConfig`] is parameterized. /// /// # Construction /// @@ -905,34 +905,30 @@ impl AclBuildConfig { }) } - /// Compute the buffer-size requirement at construction time. + /// Compute DPDK's minimum input buffer size. /// - /// See [`min_input_size`][AclBuildConfig::min_input_size] for the - /// formula and rationale. Factored out so that `new` can call it - /// once and cache the result; the public accessor returns the cached - /// value. - /// - /// Precondition: all fields' `offset + 4` fit in `u32`. This is - /// guaranteed by the `FieldExtentOverflow` check in - /// [`new`][AclBuildConfig::new], so the plain `+` below cannot - /// overflow. - fn compute_min_input_size(field_defs: &[FieldDef; N]) -> usize { + /// `offset + 4` must fit in `u32`; [`AclBuildConfig::new`] validates this. + #[must_use] + pub const fn compute_min_input_size(field_defs: &[FieldDef; N]) -> usize { let mut max_load_end: u32 = 0; - for def in field_defs { + let mut i = 0; + while i < N { + let def = &field_defs[i]; let ii = def.input_index(); let mut group_offset = def.offset(); - for other in field_defs { + let mut j = 0; + while j < N { + let other = &field_defs[j]; if other.input_index() == ii && other.offset() < group_offset { group_offset = other.offset(); } + j += 1; } - // No saturation: `new`'s FieldExtentOverflow check has - // already verified `def.offset() + 4 <= u32::MAX` for every - // def, and `group_offset <= def.offset()`. let load_end = group_offset + 4; if load_end > max_load_end { max_load_end = load_end; } + i += 1; } max_load_end as usize } @@ -1425,6 +1421,19 @@ mod tests { ); } + #[test] + fn compute_min_input_size_works_in_const_context() { + const DEFS: [FieldDef; 2] = [ + FieldDef::new(FieldType::Bitmask, FieldSize::One, 0, 0, 0), + FieldDef::new(FieldType::Mask, FieldSize::Four, 1, 9, 100), + ]; + const MIN_INPUT_SIZE: usize = AclBuildConfig::compute_min_input_size(&DEFS); + assert_eq!(MIN_INPUT_SIZE, 104); + + let cfg = AclBuildConfig::new(1, DEFS, 0).expect("config should validate"); + assert_eq!(cfg.min_input_size(), MIN_INPUT_SIZE); + } + /// Property: `AclCreateParams::new` accepts a name iff it is non-empty /// ASCII without interior NUL bytes and of length `<= MAX_ACL_NAME_LEN`. /// Verifies the four error variants are mutually exclusive and that the diff --git a/dpdk/src/acl/mod.rs b/dpdk/src/acl/mod.rs index 85b58f59e5..2f3a558dc6 100644 --- a/dpdk/src/acl/mod.rs +++ b/dpdk/src/acl/mod.rs @@ -85,7 +85,7 @@ //! NonZero::new(1024).unwrap(), //! )?; //! let build_cfg = AclBuildConfig::new(1, field_defs, 0)?; -//! let mut ctx = AclContext::::new(params, build_cfg)?; +//! let mut ctx = AclContext::new(params, build_cfg)?; //! //! // 2. Add rules -- Rule<5> is enforced by the type system. //! let rule = Rule::new( @@ -169,147 +169,23 @@ pub use error::{ // Module-level utilities pub use context::dump_all_contexts; -/// End-to-end integration tests for the ACL wrapper, exercising real -/// `rte_acl_*` calls against a live EAL. -/// -/// # EAL configuration (shared by every test in this module) -/// -/// All tests initialize EAL via [`start_eal`][self::tests::start_eal], which -/// passes a fixed set of flags plus two dynamic values: -/// -/// - `--no-huge --in-memory` -- back EAL with anonymous memory instead of -/// hugetlbfs. Keeps the tests runnable on any host without manual hugepage -/// configuration. -/// - `--lcores 0@({allowed_cpus})` -- a single logical lcore (the main), -/// floated across whatever physical CPUs `sched_getaffinity` reports as -/// available to the process. No workers means -/// `rte_eal_mp_remote_launch` has no per-worker readiness flag to read, so -/// we sidestep a benign-but-flagged data race that ThreadSanitizer reports -/// against DPDK's lcore startup, and we also avoid spawning unused worker -/// threads. Floating (instead of pinning to physical CPU 0) keeps the -/// tests honest about cgroups, taskset, and container CPU restrictions. -/// - `--file-prefix ` -- a per-init unique identifier so that -/// concurrent forked test processes do not fight over the EAL runtime -/// configuration namespace. Necessary alongside `--in-memory` because EAL -/// still creates per-process control state in the runtime dir. -/// - `--no-pci --no-telemetry --no-shconf --no-hpet` -- disable everything we -/// do not need so the tests start quickly and have no shared-config files -/// to clean up. -/// -/// # Running once per process -/// -/// `eal::init` may only be called once per process. Every test in this -/// module funnels through the [`EAL`][self::tests::EAL] `OnceLock`, so -/// the init happens exactly once regardless of how the harness schedules -/// tests: nextest's per-test process fork (the workspace default) runs -/// the lazy init once per fork; a single-process runner (`cargo test -/// --test-threads=1` or an in-process parallel harness) runs it once for -/// the lifetime of the process. -/// -/// # Running locally -/// -/// ```text -/// just setup-roots # rebuild DPDK + wrapper -/// # re-enter `nix-shell` so DATAPLANE_SYSROOT picks up the new sysroot -/// cargo nextest run -p dataplane-dpdk acl::tests -/// ``` #[cfg(test)] mod tests { use core::num::NonZero; - use concurrency::sync::OnceLock; - use crate::acl::*; - use crate::eal::Eal; use crate::socket::SocketId; + use crate::with_eal; - /// Number of fields used by all lifecycle tests in this module. const NUM_FIELDS: usize = 2; - /// Process-wide EAL initialized on first use, shared by every test. - /// - /// `eal::init` may only be called once per process. Nextest's default - /// per-test process forking makes a per-test `init` trivially safe - /// (each forked process re-initializes EAL exactly once), but a - /// single-process test runner -- `cargo test --test-threads=1`, an - /// in-process parallel harness, or any future configuration that drops - /// the fork -- would call init twice and fail. Funneling every test - /// through this lazy [`OnceLock`] makes the tests correct under both - /// modes: per-process forking initializes once per fork (cheap), - /// in-process initializes once for the lifetime of the process. - /// - /// The `Eal` value is intentionally leaked into the static for the - /// lifetime of the process; DPDK has no clean teardown path, and the - /// `Eal` Drop would (per [`crate::eal::init`]) be unable to free DPDK - /// allocations through the system allocator after the allocator swap. - static EAL: OnceLock = OnceLock::new(); - - /// Lazily initialize EAL on first call. - /// - /// Each test calls this in place of `eal::init`; subsequent calls - /// return the shared `&'static Eal` without re-initializing DPDK. - fn start_eal() -> &'static Eal { - // DPDK pins lcores, but that is generally not what we actually want in a test environment. - // Instead, we need to allocate just lcore 0 (main) and pin it to "everything we legally have access to." - fn allowed_cpus() -> String { - use nix::sched::{CpuSet, sched_getaffinity}; - use nix::unistd::Pid; - let set = sched_getaffinity(Pid::from_raw(0)).expect("sched_getaffinity"); - (0..CpuSet::count()) - .filter(|&i| set.is_set(i).unwrap_or(false)) - .map(|x| x.to_string()) - .collect::>() - .join(",") - } - // concurrent executions of DPDK EAL can fight over allocations and file resources. - // You can prevent that with a unique prefix on the hugepage files it allocates (if any). - let eal_id = format!("{}", id::Id::::new()); - let core_pinning = format!("0@({})", allowed_cpus()); - // EAL arguments used the first time EAL is initialized in this process. - let args: &[&str] = &[ - "--no-huge", - "--no-pci", - "--in-memory", - "--no-telemetry", - "--no-shconf", - "--no-hpet", - "--iova-mode=va", - "--file-prefix", - &eal_id, - // Restrict EAL to a single lcore (the main). Without workers, - // rte_eal_mp_remote_launch has no readiness flags to read and there is - // no DPDK-internal init race for ThreadSanitizer to flag. Also avoids - // spawning unused worker threads. - // - // The `0@()` form means "logical lcore 0, floated across - // the listed physical CPUs": DPDK schedules lcore 0 onto any of - // them rather than pinning to a single CPU. Floating instead of - // pinning keeps the tests honest about cgroups, taskset, and - // container affinity restrictions. - "--lcores", - &core_pinning, - ]; - - EAL.get_or_init(|| super::super::eal::init(args.iter().copied())) - } - - /// Standard field layout used by the lifecycle tests. - /// - /// DPDK ACL requires the first field in the rule definition to be one byte - /// long (it is consumed during trie setup). All subsequent fields must be - /// grouped into sets of 4 consecutive bytes via `input_index`. fn standard_field_defs() -> [FieldDef; NUM_FIELDS] { [ - // Field 0: 1-byte entry at offset 0 (required by DPDK to be 1 byte). FieldDef::new(FieldType::Bitmask, FieldSize::One, 0, 0, 0), - // Field 1: 4-byte Mask field at offset 4, input_index 1. FieldDef::new(FieldType::Mask, FieldSize::Four, 1, 1, 4), ] } - /// Build a rule that exact-matches the given 32-bit value in field 1. - /// - /// `userdata` becomes the classify result for matching inputs. fn exact_match_rule(value: u32, userdata: u32) -> Rule { Rule::new( RuleData { @@ -317,50 +193,30 @@ mod tests { priority: Priority::new(1).unwrap(), userdata: NonZero::new(userdata).expect("userdata must be non-zero"), }, - [ - // Wildcard entry byte: field 0 is FieldType::Bitmask - // (per standard_field_defs). mask = 0 makes the - // predicate `(input & 0) == 0`, which is trivially true - // for any input -- so this field matches any byte at - // offset 0. - AclField::from_u8(0, 0), - // Field 1 is FieldType::Mask; mask_range is interpreted - // as a prefix length, so 32 means "compare all 32 bits". - AclField::from_u32(value, 32), - ], + [AclField::from_u8(0, 0), AclField::from_u32(value, 32)], ) } - /// Build an 8-byte input buffer carrying `value` at offset 4 in network byte - /// order, suitable for the field layout returned by [`standard_field_defs`]. fn input_buffer(value: u32) -> [u8; 8] { let mut buf = [0u8; 8]; buf[4..8].copy_from_slice(&value.to_be_bytes()); buf } - /// Build the default `AclBuildConfig` used across the lifecycle tests - /// (`num_categories = 1`, the standard 2-field layout, no max_size). fn standard_build_config() -> AclBuildConfig { AclBuildConfig::new(1, standard_field_defs(), 0).expect("build config") } - /// End-to-end classify smoke test: build a tiny ACL context, run a real - /// `rte_acl_classify` call, and verify the match / no-match outcomes. - /// See the [module-level docs](self) for the EAL setup that applies to - /// every test here. + #[with_eal] #[test] fn classify_smoke() { - let _eal = start_eal(); - let params = AclCreateParams::::new( "test_acl", SocketId::ANY, NonZero::new(16).unwrap(), ) .expect("create params"); - let mut ctx = - AclContext::::new(params, standard_build_config()).expect("new context"); + let mut ctx = AclContext::new(params, standard_build_config()).expect("new context"); ctx.add_rules(&[exact_match_rule(0xDEAD_BEEF, 1)]) .expect("add rules"); @@ -381,14 +237,9 @@ mod tests { assert_eq!(results[1], 0, "expected no match for 0x00000000"); } - /// Reset round-trip: build, classify, reset back to Configuring, swap - /// in a new rule, rebuild (no config supplied -- it lives on the - /// context), and verify the new rule's userdata wins. Also asserts - /// that the build config survives the reset. + #[with_eal] #[test] fn reset_round_trip() { - let _eal = start_eal(); - let original_cfg = standard_build_config(); let params = AclCreateParams::::new( "reset_round_trip", @@ -396,10 +247,8 @@ mod tests { NonZero::new(16).unwrap(), ) .expect("create params"); - let mut ctx = - AclContext::::new(params, original_cfg.clone()).expect("new context"); + let mut ctx = AclContext::new(params, original_cfg.clone()).expect("new context"); - // First build cycle: match 0xAAAAAAAA -> userdata 1. ctx.add_rules(&[exact_match_rule(0xAAAA_AAAA, 1)]) .expect("add rules (first)"); let ctx = ctx.build().map_err(|f| f.error).expect("build (first)"); @@ -416,8 +265,6 @@ mod tests { unsafe { ctx.classify(&data_ptrs, &mut results, 1) }.expect("classify (first)"); assert_eq!(results[0], 1, "first build should match 0xAAAAAAAA"); - // Reset back to Configuring (config carries through) and load a - // different rule. let mut ctx = ctx.reset(); assert_eq!( ctx.build_config(), @@ -441,25 +288,17 @@ mod tests { assert_eq!(results[1], 0, "second build must not retain the first rule"); } - /// `add_rules` rejects a rule whose [`FieldType::Mask`] field carries a - /// prefix length larger than the field's bit width. Without this - /// wrapper-side check, DPDK's `RTE_ACL_MASKLEN_TO_BITMASK` would - /// perform a C shift by an out-of-range amount (UB). + #[with_eal] #[test] fn add_rules_rejects_out_of_range_prefix_length() { - let _eal = start_eal(); - let params = AclCreateParams::::new( "prefix_len_validate", SocketId::ANY, NonZero::new(16).unwrap(), ) .expect("create params"); - let mut ctx = - AclContext::::new(params, standard_build_config()).expect("new context"); + let mut ctx = AclContext::new(params, standard_build_config()).expect("new context"); - // Field 1 in standard_field_defs is a 4-byte Mask field, so the - // maximum legal prefix length is 32. 33 is out of range. let bad_rule: Rule = Rule::new( RuleData { category_mask: CategoryMask::new(1).unwrap(), @@ -490,25 +329,20 @@ mod tests { ); } - /// `set_default_algorithm` happy path: build, switch to a specific - /// algorithm, and classify. Uses `Default` which is always supported. + #[with_eal] #[test] fn set_default_algorithm_then_classify() { - let _eal = start_eal(); - let params = AclCreateParams::::new( "set_algo", SocketId::ANY, NonZero::new(16).unwrap(), ) .expect("create params"); - let mut ctx = - AclContext::::new(params, standard_build_config()).expect("new context"); + let mut ctx = AclContext::new(params, standard_build_config()).expect("new context"); ctx.add_rules(&[exact_match_rule(0xCAFE_BABE, 7)]) .expect("add rules"); let mut ctx = ctx.build().map_err(|f| f.error).expect("build"); - // `Default` is always available on any CPU DPDK runs on. ctx.set_default_algorithm(ClassifyAlgorithm::Default) .expect("set_default_algorithm"); @@ -520,22 +354,16 @@ mod tests { assert_eq!(results[0], 7); } - /// `classify` must reject `categories` values that would overflow DPDK's - /// per-thread runtime arrays sized to `RTE_ACL_MAX_CATEGORIES`, even when - /// the user's `results` slice is generous enough to satisfy the - /// per-element length check. + #[with_eal] #[test] fn classify_categories_validated_before_ffi() { - let _eal = start_eal(); - let params = AclCreateParams::::new( "cat_validation", SocketId::ANY, NonZero::new(16).unwrap(), ) .expect("create params"); - let mut ctx = - AclContext::::new(params, standard_build_config()).expect("new context"); + let mut ctx = AclContext::new(params, standard_build_config()).expect("new context"); ctx.add_rules(&[exact_match_rule(0xAAAA_AAAA, 1)]) .expect("add rules"); let ctx = ctx.build().map_err(|f| f.error).expect("build"); @@ -543,42 +371,31 @@ mod tests { let buf = input_buffer(0xAAAA_AAAA); let data_ptrs: Vec<*const u8> = vec![buf.as_ptr()]; - // results slice large enough to pass the length check, but categories - // out of range -- must still be rejected. let mut results = vec![0u32; 64]; - // categories = 0 // SAFETY: see classify_smoke. let r = unsafe { ctx.classify(&data_ptrs, &mut results, 0) }; assert!(matches!(r, Err(AclClassifyError::InvalidArgs))); - // categories > MAX_CATEGORIES (= 16) // SAFETY: see classify_smoke. let r = unsafe { ctx.classify(&data_ptrs, &mut results, MAX_CATEGORIES + 1) }; assert!(matches!(r, Err(AclClassifyError::InvalidArgs))); - // categories > 1 but not a multiple of RESULTS_MULTIPLIER (= 4) // SAFETY: see classify_smoke. let r = unsafe { ctx.classify(&data_ptrs, &mut results, 3) }; assert!(matches!(r, Err(AclClassifyError::InvalidArgs))); } - /// Creating a second [`AclContext`] with a name already registered in - /// DPDK's global ACL list must fail with [`AclCreateError::AlreadyExists`] - /// rather than silently aliasing the first context (which would - /// double-free on drop). + #[with_eal] #[test] fn duplicate_name_rejected() { - let _eal = start_eal(); - let params_a = AclCreateParams::::new( "dup_name", SocketId::ANY, NonZero::new(16).unwrap(), ) .expect("create params"); - let _ctx_a = - AclContext::::new(params_a, standard_build_config()).expect("first new"); + let _ctx_a = AclContext::new(params_a, standard_build_config()).expect("first new"); let params_b = AclCreateParams::::new( "dup_name", @@ -586,7 +403,7 @@ mod tests { NonZero::new(16).unwrap(), ) .expect("create params (dup)"); - let err = AclContext::::new(params_b, standard_build_config()) + let err = AclContext::new(params_b, standard_build_config()) .expect_err("second new with same name must fail"); assert!( matches!(err, AclCreateError::AlreadyExists { ref name } if name == "dup_name"), @@ -594,33 +411,20 @@ mod tests { ); } - /// Recovery after `add_rules` overflows `max_rule_num`: the context must - /// remain usable. We submit one rule successfully, then submit more rules - /// than the remaining capacity allows, expect the error, and finally build - /// and classify against the first rule. + #[with_eal] #[test] fn add_rules_after_overflow_failure() { - let _eal = start_eal(); - - // `max_rule_num` of 1: a second add_rules call with any rule will - // overflow. let params = AclCreateParams::::new( "overflow_recover", SocketId::ANY, NonZero::new(1).unwrap(), ) .expect("create params"); - let mut ctx = - AclContext::::new(params, standard_build_config()).expect("new context"); + let mut ctx = AclContext::new(params, standard_build_config()).expect("new context"); ctx.add_rules(&[exact_match_rule(0x1111_1111, 1)]) .expect("first add_rules should succeed"); - // Attempting to add another rule must fail: capacity is exhausted. - // DPDK signals "no room left in the rule list" with -ENOMEM, which - // the wrapper maps to AclAddRulesError::OutOfMemory. Pin the variant - // so a future change in mapping or DPDK's behaviour surfaces as a - // test failure rather than silently passing through. let extra = exact_match_rule(0x2222_2222, 2); let err = ctx .add_rules(&[extra]) @@ -630,7 +434,6 @@ mod tests { "expected OutOfMemory from capacity exhaustion, got {err:?}", ); - // Context must still be usable: build + classify against the first rule. let ctx = ctx .build() .map_err(|f| f.error) @@ -644,25 +447,17 @@ mod tests { assert_eq!(results[0], 1); } - /// Build failure recovery: when `build()` fails, the wrapper returns - /// the original `Configuring` context inside `AclBuildFailure`. The - /// caller must be able to keep using it (add rules, retry). We force - /// the failure by calling `build()` with no rules added (DPDK rejects - /// `num_rules == 0` with `-EINVAL`). + #[with_eal] #[test] fn build_failure_returns_usable_context() { - let _eal = start_eal(); - let params = AclCreateParams::::new( "build_failure_recovery", SocketId::ANY, NonZero::new(16).unwrap(), ) .expect("create params"); - let ctx = - AclContext::::new(params, standard_build_config()).expect("new context"); + let ctx = AclContext::new(params, standard_build_config()).expect("new context"); - // First build with zero rules must fail. let failure = ctx.build().expect_err("build() with no rules must fail"); assert!( matches!(failure.error, AclBuildError::InvalidConfig), @@ -670,7 +465,6 @@ mod tests { failure.error, ); - // Recover the context, add a rule, build again -- must succeed. let mut ctx = failure.context; ctx.add_rules(&[exact_match_rule(0xDEAD_BEEF, 1)]) .expect("add rules after recovery"); @@ -687,24 +481,16 @@ mod tests { assert_eq!(results[0], 1); } - /// `add_rules` rejects a rule whose `category_mask` has bits set at - /// positions `>= config.num_categories()`. DPDK would silently mask - /// off those bits at build time, narrowing the rule's intended - /// category set; we surface this at `add_rules` time instead. + #[with_eal] #[test] fn add_rules_rejects_category_mask_beyond_num_categories() { - let _eal = start_eal(); - let params = AclCreateParams::::new( "cat_mask_validate", SocketId::ANY, NonZero::new(16).unwrap(), ) .expect("create params"); - // standard_build_config uses num_categories = 1, so only bit 0 is - // legal. Build a rule with bit 1 also set. - let mut ctx = - AclContext::::new(params, standard_build_config()).expect("new context"); + let mut ctx = AclContext::new(params, standard_build_config()).expect("new context"); let bad_rule: Rule = Rule::new( RuleData { @@ -733,21 +519,12 @@ mod tests { ); } - /// Concurrent classify under `Arc>>`: spawns - /// several worker threads, each calling - /// [`AclContext::classify`][crate::acl::AclContext::classify] in a - /// tight loop, and verifies every thread sees the correct match. - /// Exercises the per-state `Sync` impl on [`Built`] and ensures - /// the wrapper's "share across classification threads" claim isn't - /// vacuous. Test runs with N=4 workers and M=1000 iterations each - /// to give the OS scheduler a chance to interleave. + #[with_eal] #[test] fn classify_concurrent_arc_shared() { use concurrency::sync::Arc; use concurrency::thread; - let _eal = start_eal(); - const WORKERS: usize = 4; const ITERS_PER_WORKER: usize = 1000; @@ -757,8 +534,7 @@ mod tests { NonZero::new(16).unwrap(), ) .expect("create params"); - let mut ctx = - AclContext::::new(params, standard_build_config()).expect("new context"); + let mut ctx = AclContext::new(params, standard_build_config()).expect("new context"); ctx.add_rules(&[exact_match_rule(0xDEAD_BEEF, 1)]) .expect("add rules"); let ctx: Arc>> = @@ -768,8 +544,6 @@ mod tests { .map(|worker| { let ctx = Arc::clone(&ctx); thread::spawn(move || { - // Each worker owns its own buffers; classify is the - // only place we share state across threads. let matching = input_buffer(0xDEAD_BEEF); let non_matching = input_buffer(0); for _ in 0..ITERS_PER_WORKER { @@ -793,22 +567,16 @@ mod tests { } } - /// `classify_with_algorithm` with a non-`Default` algorithm: locks in - /// the special-casing in [`AclContext::classify_with_algorithm`] by - /// dispatching through the `Scalar` variant (always available on every - /// CPU DPDK runs on) and verifying classification still works. + #[with_eal] #[test] fn classify_with_algorithm_scalar() { - let _eal = start_eal(); - let params = AclCreateParams::::new( "classify_alg_scalar", SocketId::ANY, NonZero::new(16).unwrap(), ) .expect("create params"); - let mut ctx = - AclContext::::new(params, standard_build_config()).expect("new context"); + let mut ctx = AclContext::new(params, standard_build_config()).expect("new context"); ctx.add_rules(&[exact_match_rule(0xFEED_FACE, 9)]) .expect("add rules"); let ctx = ctx.build().map_err(|f| f.error).expect("build"); diff --git a/dpdk/src/lib.rs b/dpdk/src/lib.rs index 26ea73a048..6216f4aa3a 100644 --- a/dpdk/src/lib.rs +++ b/dpdk/src/lib.rs @@ -42,3 +42,9 @@ pub mod mem; pub mod queue; pub mod ring; pub mod socket; + +#[cfg(any(test, feature = "test"))] +pub mod test_support; + +#[cfg(feature = "test")] +pub use dpdk_test_macros::with_eal; diff --git a/dpdk/src/test_support.rs b/dpdk/src/test_support.rs new file mode 100644 index 0000000000..83bef312ee --- /dev/null +++ b/dpdk/src/test_support.rs @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +use concurrency::sync::OnceLock; + +use crate::eal::Eal; + +static EAL: OnceLock = OnceLock::new(); +#[must_use] +pub fn start_eal() -> &'static Eal { + EAL.get_or_init(|| { + let cpus = allowed_cpus(); + let eal_id = format!("{}", id::Id::::new()); + let core_pinning = format!("0@({cpus})"); + let args: &[&str] = &[ + "--no-huge", + "--no-pci", + "--in-memory", + "--no-telemetry", + "--no-shconf", + "--no-hpet", + "--iova-mode=va", + "--file-prefix", + &eal_id, + "--lcores", + &core_pinning, + ]; + crate::eal::init(args.iter().copied()) + }) +} +#[allow(clippy::expect_used)] +fn allowed_cpus() -> String { + use nix::sched::{CpuSet, sched_getaffinity}; + use nix::unistd::Pid; + let set = sched_getaffinity(Pid::from_raw(0)).expect("sched_getaffinity"); + (0..CpuSet::count()) + .filter(|&i| set.is_set(i).unwrap_or(false)) + .map(|x| x.to_string()) + .collect::>() + .join(",") +} diff --git a/fixed-size/Cargo.toml b/fixed-size/Cargo.toml new file mode 100644 index 0000000000..3e0596e851 --- /dev/null +++ b/fixed-size/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "dataplane-fixed-size" +edition.workspace = true +license.workspace = true +publish.workspace = true +version.workspace = true + +[dependencies] diff --git a/fixed-size/src/lib.rs b/fixed-size/src/lib.rs new file mode 100644 index 0000000000..1a0617612b --- /dev/null +++ b/fixed-size/src/lib.rs @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +#![no_std] +#![deny( + unsafe_code, + clippy::all, + clippy::pedantic, + clippy::unwrap_used, + clippy::expect_used, + clippy::panic +)] + +use core::net::{Ipv4Addr, Ipv6Addr}; +pub trait FixedSize: Copy { + const SIZE: usize; + fn write_be(&self, out: &mut [u8]); +} + +impl FixedSize for u8 { + const SIZE: usize = 1; + fn write_be(&self, out: &mut [u8]) { + out[0] = *self; + } +} + +impl FixedSize for u16 { + const SIZE: usize = 2; + fn write_be(&self, out: &mut [u8]) { + out[..Self::SIZE].copy_from_slice(&self.to_be_bytes()); + } +} + +impl FixedSize for u32 { + const SIZE: usize = 4; + fn write_be(&self, out: &mut [u8]) { + out[..Self::SIZE].copy_from_slice(&self.to_be_bytes()); + } +} + +impl FixedSize for u64 { + const SIZE: usize = 8; + fn write_be(&self, out: &mut [u8]) { + out[..Self::SIZE].copy_from_slice(&self.to_be_bytes()); + } +} + +impl FixedSize for u128 { + const SIZE: usize = 16; + fn write_be(&self, out: &mut [u8]) { + out[..Self::SIZE].copy_from_slice(&self.to_be_bytes()); + } +} + +impl FixedSize for Ipv4Addr { + const SIZE: usize = 4; + fn write_be(&self, out: &mut [u8]) { + out[..Self::SIZE].copy_from_slice(&self.octets()); + } +} + +impl FixedSize for Ipv6Addr { + const SIZE: usize = 16; + fn write_be(&self, out: &mut [u8]) { + out[..Self::SIZE].copy_from_slice(&self.octets()); + } +} diff --git a/lifecycle/Cargo.toml b/lifecycle/Cargo.toml new file mode 100644 index 0000000000..36e5e9d272 --- /dev/null +++ b/lifecycle/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "dataplane-lifecycle" +version.workspace = true +edition.workspace = true +license.workspace = true +publish.workspace = true +repository.workspace = true + +[dependencies] +concurrency = { workspace = true } +# Base tokio features for cross-platform builds (incl. wasm32-wasip1, which +# rejects features outside the supported wasm set with a compile_error!). +# `rt` is required for the runtime/Handle/JoinHandle APIs we use directly +# (`spawn_signal_handler`, `Subsystem::spawn_on`/`spawn_fatal_on_exit`, the +# watchdog). Don't rely on transitive unification via tokio-util. +tokio = { workspace = true, features = ["macros", "rt", "time"] } +tokio-util = { workspace = true, features = ["rt"] } +tracing = { workspace = true } + +# spawn_signal_handler is cfg(unix); the "signal" feature is only enabled +# on unix targets to keep wasm builds of the lifecycle library clean. +[target.'cfg(unix)'.dependencies] +tokio = { workspace = true, features = ["signal"] } + +[dev-dependencies] +tokio = { workspace = true, features = ["rt", "macros", "time"] } diff --git a/lifecycle/src/lib.rs b/lifecycle/src/lib.rs new file mode 100644 index 0000000000..3cd587e760 --- /dev/null +++ b/lifecycle/src/lib.rs @@ -0,0 +1,534 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +//! Process-lifecycle primitives for the dataplane binary. +//! +//! [`Shutdown`] bundles a root [`CancellationToken`] and one [`Subsystem`] +//! per long-lived component. Each subsystem owns a cancel token and a +//! [`TaskTracker`]; [`Shutdown::drain_in_order`] drains them in topological +//! order with per-subsystem deadlines. + +#![deny( + unsafe_code, + missing_docs, + clippy::all, + clippy::pedantic, + clippy::unwrap_used, + clippy::expect_used, + clippy::panic +)] + +use concurrency::sync::Arc; +use concurrency::sync::atomic::{AtomicBool, Ordering}; +use std::future::Future; +use std::time::Duration; + +use tokio::task::JoinHandle; +use tokio::time::error::Elapsed; +use tracing::{error, info, warn}; + +pub use tokio_util::sync::CancellationToken; +pub use tokio_util::task::TaskTracker; + +/// A named, drainable subsystem. Cheap to clone. +#[derive(Clone, Debug)] +pub struct Subsystem { + /// Stable name used in shutdown logs. + pub name: &'static str, + cancel: CancellationToken, + tasks: TaskTracker, + root: CancellationToken, + fatal: Arc, +} + +impl Subsystem { + /// Tests/ad-hoc only. Production code: use [`Shutdown::new`] so all + /// subsystems share one fatal flag. + #[doc(hidden)] + #[must_use] + pub fn new(name: &'static str, root: CancellationToken) -> Self { + Self::with_fatal(name, root, Arc::new(AtomicBool::new(false))) + } + + /// Construct a subsystem with an explicit shared fatal flag. + #[must_use] + pub fn with_fatal(name: &'static str, root: CancellationToken, fatal: Arc) -> Self { + Self { + name, + cancel: CancellationToken::new(), + tasks: TaskTracker::new(), + root, + fatal, + } + } + + /// Clone of this subsystem's cancellation token. + #[must_use] + pub fn cancel_token(&self) -> CancellationToken { + self.cancel.clone() + } + + /// True if this subsystem's cancellation token has been tripped. + #[must_use] + pub fn is_cancelled(&self) -> bool { + self.cancel.is_cancelled() + } + + /// Clone of the process-wide root cancellation token. Use for startup + /// work — the per-subsystem cancel is tripped after startup returns. + #[must_use] + pub fn root_token(&self) -> CancellationToken { + self.root.clone() + } + + /// Spawn an async task on `handle`, tracked under this subsystem. + pub fn spawn_on(&self, future: F, handle: &tokio::runtime::Handle) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.tasks.spawn_on(future, handle) + } + + /// Spawn `future`; if it exits (normally or by panic) before any + /// shutdown is requested, call [`Self::report_fatal`]. Use for tasks + /// whose unexpected exit means the subsystem is broken; for tasks + /// where silent exit is fine, use [`Self::spawn_on`]. + pub fn spawn_fatal_on_exit(&self, reason: &str, future: F, handle: &tokio::runtime::Handle) + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let cancel = self.cancel.clone(); + let root = self.root.clone(); + let subsystem = self.clone(); + let reason = reason.to_owned(); + // Spawn `inner` detached on the runtime so panics surface via its + // JoinHandle; only the wrapper is tracked. + let mut inner = handle.spawn(future); + self.tasks.spawn_on( + async move { + tokio::select! { + () = cancel.cancelled() => { + inner.abort(); + let _ = (&mut inner).await; + } + result = &mut inner => { + // Root counts as graceful: during SIGINT, root + // trips before this subsystem's cancel. + if root.is_cancelled() || cancel.is_cancelled() { + return; + } + match result { + Ok(_) => subsystem + .report_fatal(&format!("{reason} exited without cancellation")), + Err(e) if e.is_panic() => subsystem + .report_fatal(&format!("{reason} panicked: {e}")), + Err(_) => {} + } + } + } + }, + handle, + ); + } + + /// Set the fatal flag, trip this subsystem's cancel, trip the root. + /// Idempotent. Logs at error. + pub fn report_fatal(&self, reason: &str) { + error!(subsystem = self.name, reason, "fatal; tripping shutdown"); + self.fatal.store(true, Ordering::Relaxed); + self.cancel.cancel(); + self.root.cancel(); + } + + /// Cancel this subsystem and wait for tracked tokio tasks. Idempotent. + /// + /// Thread-based subsystems (workers, RIO) are not tracked here; their + /// joins happen at scope-close. The watchdog is their hard bound. + /// + /// # Errors + /// Returns [`Elapsed`] if any tracked task is still running after + /// `deadline`. Cancel is tripped and tracker closed either way. + pub async fn drain(&self, deadline: Duration) -> Result<(), Elapsed> { + self.cancel.cancel(); + self.tasks.close(); + tokio::time::timeout(deadline, self.tasks.wait()).await + } +} + +/// Default drain deadlines. Per-subsystem deadlines bound only the +/// tokio tasks tracked by each [`Subsystem`]; [`TOTAL`] is the absolute +/// process-level ceiling enforced by [`spawn_shutdown_watchdog`]. +pub mod default_deadlines { + use std::time::Duration; + /// Drain workers' tokio tasks. + pub const WORKERS: Duration = Duration::from_secs(5); + /// Drain RIO's tokio tasks. + pub const ROUTER: Duration = Duration::from_secs(5); + /// Drain mgmt's tasks (config processor, status updater, watcher). + pub const MGMT: Duration = Duration::from_secs(5); + /// Drain metrics; short — a stuck scrape is fine to abandon. + pub const METRICS: Duration = Duration::from_secs(2); + /// Hard process-wide ceiling. Independent of the sum above. + pub const TOTAL: Duration = Duration::from_secs(15); +} + +/// Root lifecycle bundle owned by `main`. +#[derive(Debug)] +pub struct Shutdown { + /// Tripped by `SIGINT`/`SIGTERM` or any subsystem's + /// [`Subsystem::report_fatal`]. + pub root: CancellationToken, + fatal: Arc, + /// Data-plane workers. + pub workers: Subsystem, + /// Routing/control I/O. + pub router: Subsystem, + /// Management plane. + pub mgmt: Subsystem, + /// Prometheus endpoint and stats collection. + pub metrics: Subsystem, +} + +impl Shutdown { + /// Build a `Shutdown` with subsystems pre-wired to one root and one + /// fatal flag. + #[must_use] + pub fn new() -> Self { + let root = CancellationToken::new(); + let fatal = Arc::new(AtomicBool::new(false)); + Self { + workers: Subsystem::with_fatal("workers", root.clone(), fatal.clone()), + router: Subsystem::with_fatal("router", root.clone(), fatal.clone()), + mgmt: Subsystem::with_fatal("mgmt", root.clone(), fatal.clone()), + metrics: Subsystem::with_fatal("metrics", root.clone(), fatal.clone()), + root, + fatal, + } + } + + /// Set the fatal flag and trip the root. Idempotent. + pub fn fail(&self) { + self.fatal.store(true, Ordering::Relaxed); + self.root.cancel(); + } + + /// True if any subsystem reported fatal or `main` called + /// [`Shutdown::fail`]. Read after drain to choose the exit code. + #[must_use] + pub fn is_fatal(&self) -> bool { + self.fatal.load(Ordering::Relaxed) + } + + /// Drain in order: workers, router, metrics, mgmt. Workers stop + /// touching packets before the control plane goes away. Subsystems + /// that miss their deadline are logged and abandoned. + pub async fn drain_in_order(&self) { + Self::drain_one(&self.workers, default_deadlines::WORKERS).await; + Self::drain_one(&self.router, default_deadlines::ROUTER).await; + Self::drain_one(&self.metrics, default_deadlines::METRICS).await; + Self::drain_one(&self.mgmt, default_deadlines::MGMT).await; + } + + async fn drain_one(sub: &Subsystem, deadline: Duration) { + if sub.drain(deadline).await.is_ok() { + info!(subsystem = sub.name, "drained cleanly"); + } else { + warn!( + subsystem = sub.name, + deadline_ms = u64::try_from(deadline.as_millis()).unwrap_or(u64::MAX), + "drain timed out; abandoning" + ); + } + } +} + +impl Default for Shutdown { + fn default() -> Self { + Self::new() + } +} + +/// Spawn a task on `handle` that trips `root` on `SIGINT`/`SIGTERM`, and +/// also exits if `root` was tripped through another path. +/// +/// # Errors +/// Returns [`std::io::Error`] if either signal handler install fails. +#[cfg(unix)] +pub fn spawn_signal_handler( + handle: &tokio::runtime::Handle, + root: CancellationToken, +) -> std::io::Result<()> { + use tokio::signal::unix::{SignalKind, signal}; + + // Install inside the runtime context so the handlers register with + // its signal driver, not just the EnterGuard. + let (mut sigint, mut sigterm) = { + let _guard = handle.enter(); + ( + signal(SignalKind::interrupt())?, + signal(SignalKind::terminate())?, + ) + }; + + handle.spawn(async move { + tokio::select! { + _ = sigint.recv() => info!("SIGINT received; tripping shutdown"), + _ = sigterm.recv() => info!("SIGTERM received; tripping shutdown"), + () = root.cancelled() => {} + } + root.cancel(); + }); + + Ok(()) +} + +/// Spawn a detached OS thread that calls [`std::process::exit`] `deadline` +/// after `root` is cancelled. Independent of the mgmt runtime so it still +/// fires if the runtime wedges. This is the only bound on a worker thread +/// blocked inside an I/O call that doesn't observe cancellation. +/// +/// # Errors +/// Returns [`std::io::Error`] if spawning fails. A runtime-build failure +/// inside the thread is logged and disarms the watchdog (the process then +/// has no hard shutdown ceiling); treat disarm logs as a startup warning. +pub fn spawn_shutdown_watchdog( + root: CancellationToken, + deadline: Duration, + exit_code: i32, +) -> std::io::Result<()> { + use std::io::Write; + std::thread::Builder::new() + .name("shutdown-watchdog".to_string()) + .spawn(move || { + let rt = match tokio::runtime::Builder::new_current_thread() + .enable_time() + .build() + { + Ok(rt) => rt, + Err(e) => { + error!(error = %e, "shutdown watchdog runtime failed to start; disarmed"); + return; + } + }; + rt.block_on(root.cancelled()); + drop(rt); + std::thread::sleep(deadline); + error!( + deadline_ms = u64::try_from(deadline.as_millis()).unwrap_or(u64::MAX), + exit_code, "shutdown exceeded total deadline; aborting" + ); + // process::exit skips destructors, so flush stderr explicitly. + let _ = std::io::stderr().flush(); + std::process::exit(exit_code); + }) + .map(|_| ()) +} + +#[cfg(test)] +mod tests { + use super::*; + use concurrency::sync::Arc; + use concurrency::sync::atomic::{AtomicBool, Ordering}; + + #[tokio::test] + async fn drain_completes_when_tasks_observe_cancel() { + let shutdown = Shutdown::new(); + let mgmt = shutdown.mgmt.clone(); + let cancel = mgmt.cancel_token(); + let observed = Arc::new(AtomicBool::new(false)); + let observed_in_task = observed.clone(); + + let handle = tokio::runtime::Handle::current(); + mgmt.spawn_on( + async move { + cancel.cancelled().await; + observed_in_task.store(true, Ordering::SeqCst); + }, + &handle, + ); + + let result = mgmt.drain(Duration::from_millis(500)).await; + assert!(result.is_ok()); + assert!(observed.load(Ordering::SeqCst)); + } + + #[tokio::test] + async fn drain_times_out_when_task_ignores_cancel() { + let shutdown = Shutdown::new(); + let mgmt = shutdown.mgmt.clone(); + + let handle = tokio::runtime::Handle::current(); + mgmt.spawn_on( + async move { + tokio::time::sleep(Duration::from_mins(1)).await; + }, + &handle, + ); + + let result = mgmt.drain(Duration::from_millis(50)).await; + assert!(result.is_err()); + assert!(mgmt.is_cancelled()); + assert!(mgmt.tasks.is_closed()); + } + + #[tokio::test] + async fn report_fatal_trips_root_self_cancel_and_fatal_flag() { + let shutdown = Shutdown::new(); + assert!(!shutdown.is_fatal()); + shutdown.workers.report_fatal("synthetic test failure"); + + assert!(shutdown.root.is_cancelled()); + assert!(shutdown.is_fatal()); + assert!(shutdown.workers.is_cancelled()); + assert!(!shutdown.mgmt.is_cancelled()); + assert!(!shutdown.router.is_cancelled()); + assert!(!shutdown.metrics.is_cancelled()); + } + + #[tokio::test] + async fn shutdown_fail_sets_fatal_and_trips_root() { + let shutdown = Shutdown::new(); + assert!(!shutdown.is_fatal()); + assert!(!shutdown.root.is_cancelled()); + + shutdown.fail(); + + assert!(shutdown.is_fatal()); + assert!(shutdown.root.is_cancelled()); + } + + #[tokio::test] + async fn standalone_subsystem_has_its_own_fatal_flag() { + let root = CancellationToken::new(); + let a = Subsystem::new("a", root.clone()); + let b = Subsystem::new("b", root); + a.report_fatal("isolated"); + assert!(a.fatal.load(Ordering::Relaxed)); + assert!(!b.fatal.load(Ordering::Relaxed)); + } + + #[tokio::test] + async fn subsystem_cancels_are_independent_of_root() { + let shutdown = Shutdown::new(); + shutdown.root.cancel(); + + assert!(shutdown.root.is_cancelled()); + assert!(!shutdown.workers.is_cancelled()); + assert!(!shutdown.mgmt.is_cancelled()); + } + + #[tokio::test] + async fn subsystem_root_token_observes_signal_handler_cancel() { + let shutdown = Shutdown::new(); + let mgmt_root = shutdown.mgmt.root_token(); + assert!(!mgmt_root.is_cancelled()); + + shutdown.fail(); + assert!(mgmt_root.is_cancelled()); + } + + #[tokio::test] + async fn drain_is_idempotent() { + let shutdown = Shutdown::new(); + let mgmt = shutdown.mgmt.clone(); + + let first = mgmt.drain(Duration::from_millis(50)).await; + let second = mgmt.drain(Duration::from_millis(50)).await; + assert!(first.is_ok()); + assert!(second.is_ok()); + } + + #[tokio::test] + async fn spawn_fatal_on_exit_trips_root_on_normal_return() { + let shutdown = Shutdown::new(); + let handle = tokio::runtime::Handle::current(); + shutdown + .mgmt + .spawn_fatal_on_exit("test task", async {}, &handle); + + tokio::time::timeout(Duration::from_millis(500), shutdown.root.cancelled()) + .await + .expect("root should trip on task exit"); + assert!(shutdown.is_fatal()); + } + + #[tokio::test] + async fn spawn_fatal_on_exit_trips_root_on_panic() { + let shutdown = Shutdown::new(); + let handle = tokio::runtime::Handle::current(); + shutdown.mgmt.spawn_fatal_on_exit( + "test task", + async { + panic!("synthetic panic"); + }, + &handle, + ); + + tokio::time::timeout(Duration::from_millis(500), shutdown.root.cancelled()) + .await + .expect("root should trip on task panic"); + assert!(shutdown.is_fatal()); + } + + #[tokio::test] + async fn spawn_fatal_on_exit_does_not_trip_when_root_cancelled_first() { + // Simulates: SIGINT trips root before drain_in_order reaches mgmt. + // A supervised mgmt task exiting in that window must not flip fatal. + let shutdown = Shutdown::new(); + let handle = tokio::runtime::Handle::current(); + shutdown.root.cancel(); + shutdown + .mgmt + .spawn_fatal_on_exit("test task", async {}, &handle); + + shutdown + .mgmt + .drain(Duration::from_millis(500)) + .await + .unwrap(); + assert!(!shutdown.is_fatal()); + } + + #[tokio::test] + async fn spawn_fatal_on_exit_does_not_trip_when_cancelled_first() { + let shutdown = Shutdown::new(); + let handle = tokio::runtime::Handle::current(); + let cancel = shutdown.mgmt.cancel_token(); + shutdown.mgmt.spawn_fatal_on_exit( + "test task", + async move { + cancel.cancelled().await; + }, + &handle, + ); + + shutdown + .mgmt + .drain(Duration::from_millis(500)) + .await + .unwrap(); + assert!(!shutdown.is_fatal()); + } + + #[tokio::test] + async fn drain_in_order_completes_when_all_subsystems_observe_cancel() { + let shutdown = Shutdown::new(); + let handle = tokio::runtime::Handle::current(); + for sub in [ + &shutdown.workers, + &shutdown.router, + &shutdown.mgmt, + &shutdown.metrics, + ] { + let cancel = sub.cancel_token(); + sub.spawn_on(async move { cancel.cancelled().await }, &handle); + } + shutdown.drain_in_order().await; + assert!(shutdown.workers.is_cancelled()); + assert!(shutdown.router.is_cancelled()); + assert!(shutdown.mgmt.is_cancelled()); + assert!(shutdown.metrics.is_cancelled()); + } +} diff --git a/lookup/Cargo.toml b/lookup/Cargo.toml new file mode 100644 index 0000000000..0d5a4f2662 --- /dev/null +++ b/lookup/Cargo.toml @@ -0,0 +1,8 @@ +[package] +name = "dataplane-lookup" +edition.workspace = true +license.workspace = true +publish.workspace = true +version.workspace = true + +[dependencies] diff --git a/lookup/src/lib.rs b/lookup/src/lib.rs new file mode 100644 index 0000000000..4bf4358a4c --- /dev/null +++ b/lookup/src/lib.rs @@ -0,0 +1,190 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +#![deny( + unsafe_code, + clippy::all, + clippy::pedantic, + clippy::unwrap_used, + clippy::expect_used, + clippy::panic +)] +#![allow(missing_docs)] + +use std::collections::{BTreeMap, HashMap}; +use std::hash::Hash; +pub trait Projection { + fn project(self) -> T; +} +impl Projection> for Option { + fn project(self) -> Option { + self + } +} +pub trait Lookup { + fn lookup(&self, key: &K) -> Option<&A>; + fn classify(&self, source: S) -> Option<&A> + where + S: Projection, + { + self.lookup(&source.project()) + } + fn classify_opt(&self, source: S) -> Option<&A> + where + S: Projection>, + { + source.project().and_then(|key| self.lookup(&key)) + } +} + +impl Lookup for BTreeMap { + fn lookup(&self, key: &K) -> Option<&V> { + BTreeMap::get(self, key) + } +} + +impl Lookup for HashMap { + fn lookup(&self, key: &K) -> Option<&V> { + HashMap::get(self, key) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct Pkt { + src: u32, + dst: u32, + sport: u16, + dport: u16, + } + + impl Projection<(u32, u32)> for &Pkt { + fn project(self) -> (u32, u32) { + (self.src, self.dst) + } + } + + impl Projection<(u32, u32, u16, u16)> for &Pkt { + fn project(self) -> (u32, u32, u16, u16) { + (self.src, self.dst, self.sport, self.dport) + } + } + + impl<'a> Projection<(&'a u32, &'a u32)> for &'a Pkt { + fn project(self) -> (&'a u32, &'a u32) { + (&self.src, &self.dst) + } + } + impl Projection> for &Pkt { + fn project(self) -> Option<(u32, u32)> { + (self.src != 0).then_some((self.src, self.dst)) + } + } + + #[derive(Debug, PartialEq, Eq)] + enum Action { + Allow, + Drop, + } + + #[test] + fn classify_picks_the_two_tuple_projection_from_the_table_type() { + let mut table: BTreeMap<(u32, u32), Action> = BTreeMap::new(); + table.insert((10, 20), Action::Drop); + let pkt = Pkt { + src: 10, + dst: 20, + sport: 22, + dport: 80, + }; + assert_eq!(table.classify(&pkt), Some(&Action::Drop)); + } + + #[test] + fn classify_picks_the_four_tuple_projection_from_the_table_type() { + let mut table: BTreeMap<(u32, u32, u16, u16), Action> = BTreeMap::new(); + table.insert((10, 20, 22, 80), Action::Allow); + let pkt = Pkt { + src: 10, + dst: 20, + sport: 22, + dport: 80, + }; + assert_eq!(table.classify(&pkt), Some(&Action::Allow)); + } + + #[test] + fn borrowed_tuple_projection_threads_lifetime() { + let pkt = Pkt { + src: 10, + dst: 20, + sport: 0, + dport: 0, + }; + let (src, dst): (&u32, &u32) = (&pkt).project(); + assert_eq!(*src, 10); + assert_eq!(*dst, 20); + } + + #[test] + fn miss_returns_none() { + let table: BTreeMap<(u32, u32), Action> = BTreeMap::new(); + let pkt = Pkt { + src: 1, + dst: 2, + sport: 3, + dport: 4, + }; + assert_eq!(table.classify(&pkt), None); + } + + #[test] + fn classify_opt_looks_up_when_projection_yields_some() { + let mut table: BTreeMap<(u32, u32), Action> = BTreeMap::new(); + table.insert((10, 20), Action::Drop); + let pkt = Pkt { + src: 10, + dst: 20, + sport: 0, + dport: 0, + }; + assert_eq!(table.classify_opt(&pkt), Some(&Action::Drop)); + } + + #[test] + fn classify_opt_short_circuits_when_projection_yields_none() { + let mut table: BTreeMap<(u32, u32), Action> = BTreeMap::new(); + table.insert((0, 20), Action::Drop); + let pkt = Pkt { + src: 0, + dst: 20, + sport: 0, + dport: 0, + }; + assert_eq!(table.classify_opt(&pkt), None); + } + + #[test] + fn classify_opt_accepts_a_computed_option_via_identity() { + let mut table: BTreeMap<(u32, u32), Action> = BTreeMap::new(); + table.insert((10, 20), Action::Drop); + let built: Option<(u32, u32)> = Some((10, 20)); + assert_eq!(table.classify_opt(built), Some(&Action::Drop)); + assert_eq!(table.classify_opt(None::<(u32, u32)>), None); + } + + #[test] + fn hashmap_backend_works_the_same_way() { + let mut table: HashMap<(u32, u32), Action> = HashMap::new(); + table.insert((10, 20), Action::Drop); + let pkt = Pkt { + src: 10, + dst: 20, + sport: 0, + dport: 0, + }; + assert_eq!(table.classify(&pkt), Some(&Action::Drop)); + } +} diff --git a/match-action-derive/Cargo.toml b/match-action-derive/Cargo.toml new file mode 100644 index 0000000000..127352aee4 --- /dev/null +++ b/match-action-derive/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "dataplane-match-action-derive" +edition.workspace = true +license.workspace = true +publish.workspace = true +version.workspace = true + +[lib] +proc-macro = true + +[dependencies] +proc-macro-crate = { workspace = true, default-features = true } +proc-macro2 = { workspace = true, default-features = true } +quote = { workspace = true, default-features = true } +syn = { workspace = true, default-features = true, features = ["full"] } + +[features] +bolero = [] diff --git a/match-action-derive/src/lib.rs b/match-action-derive/src/lib.rs new file mode 100644 index 0000000000..aab781d6da --- /dev/null +++ b/match-action-derive/src/lib.rs @@ -0,0 +1,324 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +use proc_macro::TokenStream; +use proc_macro_crate::{FoundCrate, crate_name}; +use proc_macro2::{Span, TokenStream as TokenStream2}; +use quote::quote; +use syn::{ + Attribute, Data, DeriveInput, Field, Fields, GenericParam, Ident, TypeParamBound, + parse_macro_input, parse_quote, spanned::Spanned, +}; +fn match_action_crate_path() -> TokenStream2 { + match crate_name("dataplane-match-action") { + Ok(FoundCrate::Itself) => { + let ident = Ident::new("dataplane_match_action", Span::call_site()); + quote! { ::#ident } + } + Ok(FoundCrate::Name(name)) => { + let ident = Ident::new(&name, Span::call_site()); + quote! { ::#ident } + } + Err(_) => { + let ident = Ident::new("dataplane_match_action", Span::call_site()); + quote! { ::#ident } + } + } +} +#[derive(Debug, Copy, Clone)] +enum Kind { + Prefix, + Mask, + Range, + Exact, +} + +impl Kind { + fn variant_ident(self) -> Ident { + let name = match self { + Self::Prefix => "Prefix", + Self::Mask => "Mask", + Self::Range => "Range", + Self::Exact => "Exact", + }; + Ident::new(name, Span::call_site()) + } + fn spec_ident(self) -> Ident { + let name = match self { + Self::Prefix => "PrefixSpec", + Self::Mask => "MaskSpec", + Self::Range => "RangeSpec", + Self::Exact => "ExactSpec", + }; + Ident::new(name, Span::call_site()) + } +} + +#[proc_macro_derive(MatchKey, attributes(prefix, mask, range, exact))] +pub fn derive_match_key(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + match expand(&input) { + Ok(tokens) => tokens.into(), + Err(e) => e.to_compile_error().into(), + } +} + +fn expand(input: &DeriveInput) -> syn::Result { + let crate_path = match_action_crate_path(); + let key_ident = &input.ident; + let key_vis = &input.vis; + let mut generics = input.generics.clone(); + let fixed_size_bound: TypeParamBound = parse_quote!(#crate_path::FixedSize); + for param in &mut generics.params { + if let GenericParam::Type(tp) = param { + tp.bounds.push(fixed_size_bound.clone()); + } + } + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let is_generic = input + .generics + .params + .iter() + .any(|p| matches!(p, GenericParam::Type(_) | GenericParam::Const(_))); + + let fields = match &input.data { + Data::Struct(s) => match &s.fields { + Fields::Named(named) => &named.named, + Fields::Unnamed(_) => { + return Err(syn::Error::new( + input.span(), + "MatchKey derive requires named fields", + )); + } + Fields::Unit => { + return Err(syn::Error::new( + input.span(), + "MatchKey derive requires at least one field", + )); + } + }, + _ => { + return Err(syn::Error::new( + input.span(), + "MatchKey derive only supports structs", + )); + } + }; + + if fields.is_empty() { + return Err(syn::Error::new( + input.span(), + "MatchKey derive requires at least one field", + )); + } + let kinds: Vec = fields + .iter() + .map(parse_field_kind) + .collect::>()?; + + let n = fields.len(); + let n_literal = syn::Index::from(n); + let size_exprs: Vec = fields + .iter() + .map(|f| { + let ty = &f.ty; + quote! { <#ty as #crate_path::FixedSize>::SIZE } + }) + .collect(); + let mut boundaries: Vec = Vec::with_capacity(n + 1); + boundaries.push(quote! { 0usize }); + let mut acc: Vec = Vec::new(); + for size in &size_exprs { + acc.push(size.clone()); + boundaries.push(quote! { #(#acc)+* }); + } + let key_size_expr = &boundaries[n]; + let mut spec_entries: Vec = Vec::with_capacity(n); + for (i, field) in fields.iter().enumerate() { + let name = field + .ident + .as_ref() + .ok_or_else(|| syn::Error::new(field.span(), "unnamed field"))?; + let name_str = name.to_string(); + let off = &boundaries[i]; + let size = &size_exprs[i]; + let kind_variant = kinds[i].variant_ident(); + spec_entries.push(quote! { + #crate_path::FieldSpec { + name: #name_str, + kind: #crate_path::FieldKind::#kind_variant, + size: #size, + offset: #off, + } + }); + } + let mut writers: Vec = Vec::with_capacity(n); + for (i, field) in fields.iter().enumerate() { + let name = field + .ident + .as_ref() + .ok_or_else(|| syn::Error::new(field.span(), "unnamed field"))?; + let ty = &field.ty; + let start = &boundaries[i]; + let end = &boundaries[i + 1]; + writers.push(quote! { + <#ty as #crate_path::FixedSize>::write_be( + &self.#name, + &mut out[#start..#end], + ); + }); + } + let rule_ident = Ident::new(&format!("{key_ident}Rule"), key_ident.span()); + let mut rule_fields: Vec = Vec::with_capacity(n); + let mut rule_field_bounds: Vec = Vec::with_capacity(n); + let mut rule_field_converts: Vec = Vec::with_capacity(n); + let mut rule_field_accepts: Vec = Vec::with_capacity(n); + let mut rule_field_universal: Vec = Vec::with_capacity(n); + let mut rule_field_accept_bounds: Vec = Vec::with_capacity(n); + let mut rule_field_universal_bounds: Vec = Vec::with_capacity(n); + for (i, field) in fields.iter().enumerate() { + let name = field + .ident + .as_ref() + .ok_or_else(|| syn::Error::new(field.span(), "unnamed field"))?; + let ty = &field.ty; + let spec = kinds[i].spec_ident(); + rule_fields.push(quote! { + pub #name: #crate_path::#spec<#ty> + }); + rule_field_bounds.push(quote! { + #crate_path::#spec<#ty>: #crate_path::IntoBackendField<__MaB> + }); + rule_field_converts.push(quote! { + <#crate_path::#spec<#ty> as #crate_path::IntoBackendField<__MaB>>::into_backend_field(self.#name) + }); + rule_field_accepts.push(quote! { + <#crate_path::#spec<#ty> as #crate_path::Accepts<#ty>>::accepts(&self.#name, &key.#name) + }); + rule_field_universal.push(quote! { + <#crate_path::#spec<#ty> as #crate_path::IsUniversal>::is_universal(&self.#name) + }); + rule_field_accept_bounds.push(quote! { + #crate_path::#spec<#ty>: #crate_path::Accepts<#ty> + }); + rule_field_universal_bounds.push(quote! { + #crate_path::#spec<#ty>: #crate_path::IsUniversal + }); + } + let as_key_impl = if is_generic { + quote! {} + } else { + quote! { + impl #impl_generics #key_ident #ty_generics #where_clause { + #[must_use] + pub fn as_key(&self) -> [u8; ::KEY_SIZE] { + let mut buf = [0u8; ::KEY_SIZE]; + ::as_key_into(self, &mut buf); + buf + } + } + } + }; + let existing_predicates: Vec<_> = where_clause + .map(|wc| wc.predicates.iter().collect()) + .unwrap_or_default(); + let merged_where_accepts = quote! { + where + #(#existing_predicates,)* + #(#rule_field_accept_bounds,)* + }; + let merged_where_universal = quote! { + where + #(#existing_predicates,)* + #(#rule_field_universal_bounds,)* + }; + + let expanded = quote! { + const _: () = { + impl #impl_generics #key_ident #ty_generics #where_clause { + pub const FIELD_SPECS: &'static [#crate_path::FieldSpec] = &[ + #(#spec_entries),* + ]; + } + + impl #impl_generics #crate_path::MatchKey for #key_ident #ty_generics #where_clause { + const N: usize = #n_literal; + const KEY_SIZE: usize = #key_size_expr; + + fn field_specs() -> &'static [#crate_path::FieldSpec] { + Self::FIELD_SPECS + } + + fn as_key_into(&self, out: &mut [u8]) { + assert!( + out.len() >= Self::KEY_SIZE, + "as_key_into: output buffer shorter than KEY_SIZE", + ); + #(#writers)* + } + } + + #as_key_impl + }; + #[derive(::core::marker::Copy, ::core::clone::Clone, ::core::fmt::Debug)] + #key_vis struct #rule_ident #generics { + #(#rule_fields),* + } + impl #impl_generics #rule_ident #ty_generics #where_clause { + pub fn into_backend_fields<__MaB>(self) -> ::std::vec::Vec<<__MaB as #crate_path::Backend>::Field> + where + __MaB: #crate_path::Backend, + #(#rule_field_bounds),* + { + ::std::vec![ + #(#rule_field_converts),* + ] + } + } + impl #impl_generics #rule_ident #ty_generics + #merged_where_accepts + { + #[must_use] + pub fn accepts(&self, key: &#key_ident #ty_generics) -> bool { + #(#rule_field_accepts) && * + } + } + impl #impl_generics #rule_ident #ty_generics + #merged_where_universal + { + #[must_use] + pub fn is_universal(&self) -> bool { + #(#rule_field_universal) && * + } + } + }; + + Ok(expanded) +} +fn parse_field_kind(field: &Field) -> syn::Result { + let mut found: Option<(Kind, &Attribute)> = None; + for attr in &field.attrs { + let kind = if attr.path().is_ident("prefix") { + Some(Kind::Prefix) + } else if attr.path().is_ident("mask") { + Some(Kind::Mask) + } else if attr.path().is_ident("range") { + Some(Kind::Range) + } else if attr.path().is_ident("exact") { + Some(Kind::Exact) + } else { + None + }; + if let Some(k) = kind { + if found.is_some() { + return Err(syn::Error::new( + attr.span(), + "multiple match-flavor attributes on a single field; \ + expected at most one of #[prefix], #[mask], #[range], #[exact]", + )); + } + found = Some((k, attr)); + } + } + Ok(found.map_or(Kind::Exact, |(k, _)| k)) +} diff --git a/match-action/Cargo.toml b/match-action/Cargo.toml new file mode 100644 index 0000000000..8cb35fb73f --- /dev/null +++ b/match-action/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "dataplane-match-action" +edition.workspace = true +license.workspace = true +publish.workspace = true +version.workspace = true + +[dependencies] +arrayvec = { workspace = true, default-features = true } +bolero = { workspace = true, optional = true } +fixed-size = { workspace = true, features = [] } +match-action-derive = { workspace = true, optional = true } + +[features] +default = ["derive"] +derive = ["dep:match-action-derive"] +bolero = ["dep:bolero", "match-action-derive?/bolero"] diff --git a/match-action/src/field.rs b/match-action/src/field.rs new file mode 100644 index 0000000000..2e31b16331 --- /dev/null +++ b/match-action/src/field.rs @@ -0,0 +1,4 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +pub use fixed_size::FixedSize; diff --git a/match-action/src/generator.rs b/match-action/src/generator.rs new file mode 100644 index 0000000000..10cc1d67b6 --- /dev/null +++ b/match-action/src/generator.rs @@ -0,0 +1,508 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +use core::net::{Ipv4Addr, Ipv6Addr}; + +use bolero::{Driver, ValueGenerator, generator::constant}; + +use crate::IsUniversal; +use crate::rule::{ExactSpec, MaskSpec, PrefixSpec, RangeSpec}; +pub struct GuardedMisses { + inner: G, + universal: bool, +} + +impl ValueGenerator for GuardedMisses { + type Output = G::Output; + fn generate(&self, d: &mut D) -> Option { + if self.universal { + return None; + } + self.inner.generate(d) + } +} +pub trait FieldHit { + fn hits(&self) -> impl ValueGenerator; +} +pub trait FieldMiss { + fn misses(&self) -> impl ValueGenerator; +} +macro_rules! high_mask_fn { + ($name:ident, $i:ty, $bits:expr) => { + #[inline] + fn $name(len: u32) -> $i { + if len == 0 { + 0 + } else if len >= $bits { + <$i>::MAX + } else { + !((1 as $i) << ($bits - len)).wrapping_sub(1) + } + } + }; +} +high_mask_fn!(high_mask_u8, u8, 8); +high_mask_fn!(high_mask_u16, u16, 16); +high_mask_fn!(high_mask_u32, u32, 32); +high_mask_fn!(high_mask_u64, u64, 64); +high_mask_fn!(high_mask_u128, u128, 128); +macro_rules! impl_specs_for { + ($t:ty, $i:ty, $high_mask:ident) => { + impl FieldHit<$t> for ExactSpec<$t> { + fn hits(&self) -> impl ValueGenerator { + constant(self.value) + } + } + impl FieldMiss<$t> for ExactSpec<$t> { + fn misses(&self) -> impl ValueGenerator { + let target: $i = self.value.into(); + GuardedMisses { + inner: bolero::produce::<$i>() + .filter_gen(move |x| *x != target) + .map_gen(<$t>::from), + universal: IsUniversal::is_universal(self), + } + } + } + + impl FieldHit<$t> for PrefixSpec<$t> { + fn hits(&self) -> impl ValueGenerator { + let value: $i = self.value.into(); + let len = u32::from(self.len); + bolero::produce::<$i>().map_gen(move |rand| { + let high = $high_mask(len); + <$t>::from((value & high) | (rand & !high)) + }) + } + } + impl FieldMiss<$t> for PrefixSpec<$t> { + fn misses(&self) -> impl ValueGenerator { + let value: $i = self.value.into(); + let len = u32::from(self.len); + GuardedMisses { + inner: bolero::produce::<$i>().filter_map_gen(move |rand| { + let high = $high_mask(len); + ((rand & high) != (value & high)).then(|| <$t>::from(rand)) + }), + universal: IsUniversal::is_universal(self), + } + } + } + + impl FieldHit<$t> for MaskSpec<$t> { + fn hits(&self) -> impl ValueGenerator { + let v: $i = self.value.into(); + let m: $i = self.mask.into(); + bolero::produce::<$i>().map_gen(move |rand| <$t>::from((v & m) | (rand & !m))) + } + } + impl FieldMiss<$t> for MaskSpec<$t> { + fn misses(&self) -> impl ValueGenerator { + let v: $i = self.value.into(); + let m: $i = self.mask.into(); + GuardedMisses { + inner: bolero::produce::<$i>().filter_map_gen(move |rand| { + ((rand & m) != (v & m)).then(|| <$t>::from(rand)) + }), + universal: IsUniversal::is_universal(self), + } + } + } + + impl FieldHit<$t> for RangeSpec<$t> { + fn hits(&self) -> impl ValueGenerator { + let min: $i = self.min.into(); + let max: $i = self.max.into(); + (min..=max).map_gen(<$t>::from) + } + } + impl FieldMiss<$t> for RangeSpec<$t> { + fn misses(&self) -> impl ValueGenerator { + let lo: $i = self.min.into(); + let hi: $i = self.max.into(); + GuardedMisses { + inner: bolero::produce::<$i>().filter_map_gen(move |rand| { + (rand < lo || rand > hi).then(|| <$t>::from(rand)) + }), + universal: IsUniversal::is_universal(self), + } + } + } + }; +} + +impl_specs_for!(u8, u8, high_mask_u8); +impl_specs_for!(u16, u16, high_mask_u16); +impl_specs_for!(u32, u32, high_mask_u32); +impl_specs_for!(u64, u64, high_mask_u64); +impl_specs_for!(Ipv4Addr, u32, high_mask_u32); +impl_specs_for!(Ipv6Addr, u128, high_mask_u128); + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Accepts, IsUniversal}; + + #[test] + fn exact_spec_hits_accepts() { + let spec = ExactSpec::new(42u16); + bolero::check!() + .with_generator(spec.hits()) + .for_each(|v| assert!(spec.accepts(v))); + } + + #[test] + fn exact_spec_misses_rejected() { + let spec = ExactSpec::new(42u16); + bolero::check!() + .with_generator(spec.misses()) + .for_each(|v| assert!(!spec.accepts(v))); + } + + #[test] + fn prefix_spec_u32_hits() { + let spec = PrefixSpec::new(0x0A00_0000u32, 8); + bolero::check!() + .with_generator(spec.hits()) + .for_each(|v| assert_eq!(*v & 0xFF00_0000, 0x0A00_0000, "got {v:08x}")); + } + + #[test] + fn prefix_spec_u32_misses() { + let spec = PrefixSpec::new(0x0A00_0000u32, 8); + bolero::check!() + .with_generator(spec.misses()) + .for_each(|v| assert_ne!(*v & 0xFF00_0000, 0x0A00_0000, "got {v:08x}")); + } + + #[test] + fn prefix_spec_zero_len_is_universal() { + assert!(PrefixSpec::new(0xDEAD_BEEF_u32, 0).is_universal()); + } + + #[test] + fn mask_spec_u16_hits_match_under_mask() { + let spec = MaskSpec::new(0xABCDu16, 0xFF00u16); + bolero::check!() + .with_generator(spec.hits()) + .for_each(|v| assert_eq!(*v & 0xFF00, 0xAB00, "got {v:04x}")); + } + + #[test] + fn mask_spec_u16_misses_disagree_under_mask() { + let spec = MaskSpec::new(0xABCDu16, 0xFF00u16); + bolero::check!() + .with_generator(spec.misses()) + .for_each(|v| assert_ne!(*v & 0xFF00, 0xAB00, "got {v:04x}")); + } + + #[test] + fn mask_spec_zero_mask_is_universal() { + assert!(MaskSpec::new(0xDEADu16, 0u16).is_universal()); + } + + #[test] + fn range_spec_u16_hits_in_range() { + let spec = RangeSpec::new(100u16, 200u16); + bolero::check!() + .with_generator(spec.hits()) + .for_each(|v| assert!((100..=200).contains(v))); + } + + #[test] + fn range_spec_u16_misses_outside_range() { + let spec = RangeSpec::new(100u16, 200u16); + bolero::check!() + .with_generator(spec.misses()) + .for_each(|v| assert!(!(100..=200).contains(v))); + } + + #[test] + fn range_spec_full_domain_is_universal() { + assert!(RangeSpec::new(0u16, u16::MAX).is_universal()); + } + + #[test] + fn ipv4_prefix_hits() { + let spec = PrefixSpec::new(Ipv4Addr::new(10, 0, 0, 0), 8); + bolero::check!() + .with_generator(spec.hits()) + .for_each(|v| assert_eq!(v.octets()[0], 10, "got {v}")); + } + + #[test] + fn ipv4_prefix_misses() { + let spec = PrefixSpec::new(Ipv4Addr::new(10, 0, 0, 0), 8); + bolero::check!() + .with_generator(spec.misses()) + .for_each(|v| assert_ne!(v.octets()[0], 10)); + } + + #[test] + fn ipv6_prefix_hits_on_high_chunk() { + let spec = PrefixSpec::new("2001:db8::".parse::().unwrap(), 32); + bolero::check!() + .with_generator(spec.hits()) + .for_each(|v| assert_eq!(&v.octets()[0..4], &[0x20, 0x01, 0x0d, 0xb8], "got {v}")); + } +} + +use crate::FieldPredicate; +use crate::predicate::mask_matches; +use core::ops::Bound; +#[must_use] +pub fn predicate_is_universal(pred: &FieldPredicate) -> bool { + if let Some((_, len)) = pred.as_prefix() { + len == 0 + } else if let Some((_, mask)) = pred.as_mask() { + mask.iter().all(|&b| b == 0) + } else if let Some((min, max)) = pred.as_range() { + min.iter().all(|&b| b == 0) && max.iter().all(|&b| b == u8::MAX) + } else { + false + } +} +#[must_use] +pub fn predicate_hits_bytes(pred: FieldPredicate) -> PredicateHitsBytes { + PredicateHitsBytes { pred } +} +#[must_use] +pub fn predicate_misses_bytes(pred: FieldPredicate) -> PredicateMissesBytes { + let universal = predicate_is_universal(&pred); + PredicateMissesBytes { pred, universal } +} + +pub struct PredicateHitsBytes { + pred: FieldPredicate, +} + +impl ValueGenerator for PredicateHitsBytes { + type Output = Vec; + fn generate(&self, d: &mut D) -> Option> { + if let Some(value) = self.pred.as_exact() { + Some(value.to_vec()) + } else if let Some((value, prefix_len)) = self.pred.as_prefix() { + let mut buf = draw_bytes(d, value.len())?; + splat_prefix(&mut buf, value, prefix_len); + Some(buf) + } else if let Some((value, mask)) = self.pred.as_mask() { + let mut buf = draw_bytes(d, value.len())?; + splat_under_mask(&mut buf, value, mask); + Some(buf) + } else if let Some((min, max)) = self.pred.as_range() { + let lo = be_to_u32(min); + let hi = be_to_u32(max); + let v = (lo..=hi).generate(d)?; + Some(u32_to_be(v, min.len())) + } else { + None + } + } +} + +pub struct PredicateMissesBytes { + pred: FieldPredicate, + universal: bool, +} + +impl ValueGenerator for PredicateMissesBytes { + type Output = Vec; + fn generate(&self, d: &mut D) -> Option> { + if self.universal { + return None; + } + if let Some(value) = self.pred.as_exact() { + let buf = draw_bytes(d, value.len())?; + (buf.as_slice() != value).then_some(buf) + } else if let Some((value, prefix_len)) = self.pred.as_prefix() { + let buf = draw_bytes(d, value.len())?; + (!prefix_matches(&buf, value, prefix_len)).then_some(buf) + } else if let Some((value, mask)) = self.pred.as_mask() { + let buf = draw_bytes(d, value.len())?; + (!mask_matches(&buf, value, mask)).then_some(buf) + } else if let Some((min, max)) = self.pred.as_range() { + let buf = draw_bytes(d, min.len())?; + (buf.as_slice() < min || buf.as_slice() > max).then_some(buf) + } else { + None + } + } +} + +fn draw_bytes(d: &mut D, width: usize) -> Option> { + let mut buf = vec![0u8; width]; + for byte in &mut buf { + *byte = d.gen_u8(Bound::Unbounded, Bound::Unbounded)?; + } + Some(buf) +} +fn splat_prefix(buf: &mut [u8], value: &[u8], len: u8) { + debug_assert_eq!(buf.len(), value.len()); + let full_bytes = usize::from(len / 8); + let trailing_bits = u32::from(len % 8); + buf[..full_bytes].copy_from_slice(&value[..full_bytes]); + if trailing_bits > 0 && full_bytes < buf.len() { + let mask: u8 = !((1u8 << (8 - trailing_bits)) - 1); + buf[full_bytes] = (value[full_bytes] & mask) | (buf[full_bytes] & !mask); + } +} +fn splat_under_mask(buf: &mut [u8], value: &[u8], mask: &[u8]) { + debug_assert_eq!(buf.len(), value.len()); + debug_assert_eq!(buf.len(), mask.len()); + for ((b, &v), &m) in buf.iter_mut().zip(value).zip(mask) { + *b = (*b & !m) | (v & m); + } +} +fn prefix_matches(field: &[u8], value: &[u8], len: u8) -> bool { + let full_bytes = usize::from(len / 8); + if field[..full_bytes] != value[..full_bytes] { + return false; + } + let trailing_bits = u32::from(len % 8); + if trailing_bits == 0 { + return true; + } + let mask: u8 = !((1u8 << (8 - trailing_bits)) - 1); + (field[full_bytes] & mask) == (value[full_bytes] & mask) +} +fn be_to_u32(bytes: &[u8]) -> u32 { + assert!( + bytes.len() <= 4, + "be_to_u32: width {} exceeds 4 bytes", + bytes.len(), + ); + let mut buf = [0u8; 4]; + let off = 4 - bytes.len(); + buf[off..].copy_from_slice(bytes); + u32::from_be_bytes(buf) +} +fn u32_to_be(value: u32, width: usize) -> Vec { + assert!(width <= 4, "u32_to_be: width {width} exceeds 4 bytes"); + let buf = value.to_be_bytes(); + buf[4 - width..].to_vec() +} + +#[cfg(test)] +mod byte_tests { + use super::*; + use crate::predicate::{Exact, FieldBytes, Mask, Prefix, Range}; + + fn fb(bytes: &[u8]) -> FieldBytes { + bytes.iter().copied().collect() + } + + #[test] + fn is_universal_classifies_each_kind() { + assert!(!predicate_is_universal(&FieldPredicate::Exact(Exact::new( + fb(&[0]) + )))); + assert!(predicate_is_universal(&FieldPredicate::Prefix( + Prefix::new(fb(&[0xAB, 0xCD]), 0) + ))); + assert!(!predicate_is_universal(&FieldPredicate::Prefix( + Prefix::new(fb(&[0xAB, 0xCD]), 4) + ))); + assert!(predicate_is_universal(&FieldPredicate::Mask(Mask::new( + fb(&[0xAB, 0xCD]), + fb(&[0, 0]) + )))); + assert!(!predicate_is_universal(&FieldPredicate::Mask(Mask::new( + fb(&[0xAB, 0xCD]), + fb(&[0xFF, 0]) + )))); + assert!(predicate_is_universal(&FieldPredicate::Range(Range::new( + fb(&[0, 0]), + fb(&[0xFF, 0xFF]) + )))); + assert!(!predicate_is_universal(&FieldPredicate::Range(Range::new( + fb(&[0, 1]), + fb(&[0xFF, 0xFF]) + )))); + } + + #[test] + fn exact_hits_returns_value_bytes() { + let pred = FieldPredicate::Exact(Exact::new(fb(&[1, 2, 3, 4]))); + bolero::check!() + .with_generator(predicate_hits_bytes(pred.clone())) + .for_each(|v| assert_eq!(v.as_slice(), &[1, 2, 3, 4])); + } + + #[test] + fn exact_misses_avoid_value() { + let pred = FieldPredicate::Exact(Exact::new(fb(&[1, 2]))); + bolero::check!() + .with_generator(predicate_misses_bytes(pred.clone())) + .for_each(|v| assert_ne!(v.as_slice(), &[1, 2])); + } + + #[test] + fn prefix_hits_preserve_top_bits() { + let pred = FieldPredicate::Prefix(Prefix::new(fb(&[0xAB, 0xCD]), 12)); + bolero::check!() + .with_generator(predicate_hits_bytes(pred.clone())) + .for_each(|v| { + assert_eq!(v[0], 0xAB); + assert_eq!(v[1] & 0xF0, 0xC0); + }); + } + + #[test] + fn prefix_misses_differ_in_top_bits() { + let pred = FieldPredicate::Prefix(Prefix::new(fb(&[0xAB, 0xCD]), 12)); + bolero::check!() + .with_generator(predicate_misses_bytes(pred.clone())) + .for_each(|v| { + let top_matches = v[0] == 0xAB && (v[1] & 0xF0) == 0xC0; + assert!(!top_matches); + }); + } + + #[test] + fn mask_hits_match_under_mask() { + let pred = FieldPredicate::Mask(Mask::new(fb(&[0xAB, 0xCD]), fb(&[0xFF, 0x00]))); + bolero::check!() + .with_generator(predicate_hits_bytes(pred.clone())) + .for_each(|v| { + assert_eq!(v[0], 0xAB); + }); + } + + #[test] + fn mask_misses_disagree_under_mask() { + let pred = FieldPredicate::Mask(Mask::new(fb(&[0xAB, 0xCD]), fb(&[0xFF, 0x00]))); + bolero::check!() + .with_generator(predicate_misses_bytes(pred.clone())) + .for_each(|v| { + assert_ne!(v[0], 0xAB); + }); + } + + #[test] + fn range_hits_in_range() { + let pred = FieldPredicate::Range(Range::new( + fb(&100u16.to_be_bytes()), + fb(&200u16.to_be_bytes()), + )); + bolero::check!() + .with_generator(predicate_hits_bytes(pred.clone())) + .for_each(|v| { + let x = u16::from_be_bytes([v[0], v[1]]); + assert!((100..=200).contains(&x)); + }); + } + + #[test] + fn range_misses_outside_range() { + let pred = FieldPredicate::Range(Range::new( + fb(&100u16.to_be_bytes()), + fb(&200u16.to_be_bytes()), + )); + bolero::check!() + .with_generator(predicate_misses_bytes(pred.clone())) + .for_each(|v| { + let x = u16::from_be_bytes([v[0], v[1]]); + assert!(!(100..=200).contains(&x)); + }); + } +} diff --git a/match-action/src/lib.rs b/match-action/src/lib.rs new file mode 100644 index 0000000000..25f18f766c --- /dev/null +++ b/match-action/src/lib.rs @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +#![deny( + unsafe_code, + clippy::all, + clippy::pedantic, + clippy::unwrap_used, + clippy::expect_used, + clippy::panic +)] +#![allow(missing_docs)] +#![allow(clippy::missing_errors_doc, clippy::missing_panics_doc)] + +pub mod field; +pub mod predicate; +pub mod rule; + +#[cfg(feature = "bolero")] +pub mod generator; + +pub use field::FixedSize; +pub use predicate::{Erased, FieldBytes, FieldPredicate, MAX_FIELD_BYTES}; +pub use rule::{ + Accepts, Backend, ExactSpec, IntoBackendField, IsUniversal, MaskSpec, PrefixSpec, RangeSpec, + RuleField, +}; + +#[cfg(feature = "bolero")] +pub use generator::{FieldHit, FieldMiss}; + +#[cfg(feature = "derive")] +pub use match_action_derive::MatchKey; +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum FieldKind { + Prefix, + Mask, + Range, + Exact, +} +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct FieldSpec { + pub name: &'static str, + pub kind: FieldKind, + pub size: usize, + pub offset: usize, +} +pub trait MatchKey: Sized { + const N: usize; + const KEY_SIZE: usize; + fn field_specs() -> &'static [FieldSpec]; + fn as_key_into(&self, out: &mut [u8]); +} diff --git a/match-action/src/predicate.rs b/match-action/src/predicate.rs new file mode 100644 index 0000000000..bd1716264f --- /dev/null +++ b/match-action/src/predicate.rs @@ -0,0 +1,354 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +use arrayvec::ArrayVec; + +use crate::field::FixedSize; +use crate::rule::{ + Accepts, Backend, ExactSpec, IntoBackendField, IsUniversal, MaskSpec, PrefixSpec, RangeSpec, +}; +pub const MAX_FIELD_BYTES: usize = 16; +pub type FieldBytes = ArrayVec; +#[derive(Copy, Clone, Debug, Default)] +pub struct Erased; + +impl Backend for Erased { + type Field = FieldPredicate; +} +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Exact { + value: FieldBytes, +} + +impl Exact { + #[must_use] + pub fn new(value: FieldBytes) -> Self { + Self { value } + } + + fn matches(&self, field: &[u8]) -> bool { + assert_eq!(field.len(), self.value.len(), "field width mismatch"); + field == self.value.as_slice() + } +} +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Prefix { + value: FieldBytes, + len: u8, +} + +impl Prefix { + #[must_use] + pub fn new(value: FieldBytes, len: u8) -> Self { + assert!( + usize::from(len) <= value.len() * 8, + "prefix length {len} exceeds field width of {} bits", + value.len() * 8, + ); + Self { value, len } + } + fn matches(&self, field: &[u8]) -> bool { + assert_eq!(field.len(), self.value.len(), "field width mismatch"); + mask_matches(field, &self.value, &prefix_mask(field.len(), self.len)) + } +} +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Mask { + value: FieldBytes, + mask: FieldBytes, +} + +impl Mask { + #[must_use] + pub fn new(value: FieldBytes, mask: FieldBytes) -> Self { + assert_eq!(value.len(), mask.len(), "mask width must equal value width"); + Self { value, mask } + } + fn matches(&self, field: &[u8]) -> bool { + assert_eq!(field.len(), self.value.len(), "field width mismatch"); + mask_matches(field, &self.value, &self.mask) + } +} +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Range { + min: FieldBytes, + max: FieldBytes, +} + +impl Range { + #[must_use] + pub fn new(min: FieldBytes, max: FieldBytes) -> Self { + assert_eq!(min.len(), max.len(), "range bounds must be equal width"); + Self { min, max } + } + fn matches(&self, field: &[u8]) -> bool { + assert_eq!(field.len(), self.min.len(), "field width mismatch"); + field >= self.min.as_slice() && field <= self.max.as_slice() + } +} +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum FieldPredicate { + Exact(Exact), + Prefix(Prefix), + Mask(Mask), + Range(Range), +} + +impl From for FieldPredicate { + fn from(p: Exact) -> Self { + Self::Exact(p) + } +} + +impl From for FieldPredicate { + fn from(p: Prefix) -> Self { + Self::Prefix(p) + } +} + +impl From for FieldPredicate { + fn from(p: Mask) -> Self { + Self::Mask(p) + } +} + +impl From for FieldPredicate { + fn from(p: Range) -> Self { + Self::Range(p) + } +} + +impl FieldPredicate { + #[must_use] + pub fn matches(&self, field: &[u8]) -> bool { + match self { + FieldPredicate::Exact(p) => p.matches(field), + FieldPredicate::Prefix(p) => p.matches(field), + FieldPredicate::Mask(p) => p.matches(field), + FieldPredicate::Range(p) => p.matches(field), + } + } + #[must_use] + pub fn width(&self) -> usize { + match self { + FieldPredicate::Exact(p) => p.value.len(), + FieldPredicate::Prefix(p) => p.value.len(), + FieldPredicate::Mask(p) => p.value.len(), + FieldPredicate::Range(p) => p.min.len(), + } + } + + #[must_use] + pub fn as_exact(&self) -> Option<&[u8]> { + match self { + FieldPredicate::Exact(p) => Some(&p.value), + _ => None, + } + } + #[must_use] + pub fn as_prefix(&self) -> Option<(&[u8], u8)> { + match self { + FieldPredicate::Prefix(p) => Some((&p.value, p.len)), + _ => None, + } + } + #[must_use] + pub fn as_mask(&self) -> Option<(&[u8], &[u8])> { + match self { + FieldPredicate::Mask(p) => Some((&p.value, &p.mask)), + _ => None, + } + } + #[must_use] + pub fn as_range(&self) -> Option<(&[u8], &[u8])> { + match self { + FieldPredicate::Range(p) => Some((&p.min, &p.max)), + _ => None, + } + } +} +fn be_bytes(value: &T) -> FieldBytes { + let mut buf = [0u8; MAX_FIELD_BYTES]; + value.write_be(&mut buf); + buf[..T::SIZE].iter().copied().collect() +} + +impl IntoBackendField for ExactSpec { + fn into_backend_field(self) -> FieldPredicate { + FieldPredicate::Exact(Exact::new(be_bytes(&self.value))) + } +} + +impl IntoBackendField for PrefixSpec { + fn into_backend_field(self) -> FieldPredicate { + FieldPredicate::Prefix(Prefix::new(be_bytes(&self.value), self.len)) + } +} + +impl IntoBackendField for MaskSpec { + fn into_backend_field(self) -> FieldPredicate { + FieldPredicate::Mask(Mask::new(be_bytes(&self.value), be_bytes(&self.mask))) + } +} + +impl IntoBackendField for RangeSpec { + fn into_backend_field(self) -> FieldPredicate { + FieldPredicate::Range(Range::new(be_bytes(&self.min), be_bytes(&self.max))) + } +} +impl Accepts for ExactSpec { + fn accepts(&self, value: &T) -> bool { + Exact::new(be_bytes(&self.value)).matches(&be_bytes(value)) + } +} + +impl Accepts for PrefixSpec { + fn accepts(&self, value: &T) -> bool { + Prefix::new(be_bytes(&self.value), self.len).matches(&be_bytes(value)) + } +} + +impl Accepts for MaskSpec { + fn accepts(&self, value: &T) -> bool { + Mask::new(be_bytes(&self.value), be_bytes(&self.mask)).matches(&be_bytes(value)) + } +} + +impl Accepts for RangeSpec { + fn accepts(&self, value: &T) -> bool { + Range::new(be_bytes(&self.min), be_bytes(&self.max)).matches(&be_bytes(value)) + } +} +impl IsUniversal for MaskSpec { + fn is_universal(&self) -> bool { + be_bytes(&self.mask).iter().all(|b| *b == 0) + } +} + +impl IsUniversal for RangeSpec { + fn is_universal(&self) -> bool { + let lo = be_bytes(&self.min); + let hi = be_bytes(&self.max); + lo.iter().all(|b| *b == 0) && hi.iter().all(|b| *b == u8::MAX) + } +} +#[inline] +pub(crate) fn mask_matches(field: &[u8], value: &[u8], mask: &[u8]) -> bool { + assert_eq!(field.len(), value.len()); + assert_eq!(field.len(), mask.len()); + field + .iter() + .zip(value) + .zip(mask) + .all(|((f, v), m)| (f & m) == (v & m)) +} +#[inline] +fn prefix_mask(nbytes: usize, len: u8) -> FieldBytes { + assert!( + nbytes <= MAX_FIELD_BYTES, + "field width {nbytes} exceeds MAX_FIELD_BYTES {MAX_FIELD_BYTES}", + ); + assert!( + usize::from(len) <= nbytes * 8, + "prefix length {len} exceeds {nbytes}-byte field", + ); + let mut out = FieldBytes::new(); + let mut remaining = usize::from(len); + for _ in 0..nbytes { + let bits = remaining.min(8); + let byte = if bits == 0 { 0 } else { 0xFFu8 << (8 - bits) }; + out.push(byte); + remaining -= bits; + } + out +} + +#[cfg(test)] +mod tests { + use super::*; + use core::net::Ipv4Addr; + + fn bytes(slice: &[u8]) -> FieldBytes { + slice.iter().copied().collect() + } + + #[test] + fn exact_matches_only_equal_bytes() { + let f = ExactSpec::new(6u8).into_backend_field(); + assert!(f.matches(&[6])); + assert!(!f.matches(&[7])); + } + + #[test] + fn prefix_mask_sets_top_bits() { + assert_eq!(prefix_mask(4, 24).as_slice(), &[0xFF, 0xFF, 0xFF, 0x00]); + assert_eq!(prefix_mask(4, 20).as_slice(), &[0xFF, 0xFF, 0xF0, 0x00]); + assert_eq!(prefix_mask(4, 0).as_slice(), &[0x00, 0x00, 0x00, 0x00]); + assert_eq!(prefix_mask(4, 32).as_slice(), &[0xFF, 0xFF, 0xFF, 0xFF]); + } + + #[test] + #[should_panic(expected = "prefix length")] + fn prefix_mask_panics_on_over_long_len() { + let _ = prefix_mask(4, 33); + } + + #[test] + #[should_panic(expected = "MAX_FIELD_BYTES")] + fn prefix_mask_panics_on_oversized_field() { + let _ = prefix_mask(MAX_FIELD_BYTES + 1, 8); + } + + #[test] + fn prefix_matches_on_high_bits_only() { + let f = PrefixSpec::new(Ipv4Addr::new(10, 0, 0, 0), 8).into_backend_field(); + assert!(f.matches(&Ipv4Addr::new(10, 1, 2, 3).octets())); + assert!(f.matches(&Ipv4Addr::new(10, 255, 255, 255).octets())); + assert!(!f.matches(&Ipv4Addr::new(11, 0, 0, 0).octets())); + } + + #[test] + fn prefix_len_zero_is_wildcard() { + let f = PrefixSpec::new(Ipv4Addr::new(10, 0, 0, 0), 0).into_backend_field(); + assert!(f.matches(&Ipv4Addr::new(1, 2, 3, 4).octets())); + assert!(f.matches(&Ipv4Addr::UNSPECIFIED.octets())); + } + + #[test] + #[should_panic(expected = "prefix length")] + fn over_long_prefix_len_panics() { + let _ = Prefix::new(bytes(&[10, 0, 0, 1]), 200); + } + + #[test] + fn mask_matches_required_bits() { + let f = MaskSpec::new(0xABu8, 0xF0u8).into_backend_field(); + assert!(f.matches(&[0xA0])); + assert!(f.matches(&[0xAF])); + assert!(!f.matches(&[0xB0])); + } + + #[test] + fn range_is_inclusive_both_ends() { + let f = RangeSpec::new(80u16, 8080u16).into_backend_field(); + assert!(f.matches(&80u16.to_be_bytes())); + assert!(f.matches(&8080u16.to_be_bytes())); + assert!(f.matches(&443u16.to_be_bytes())); + assert!(!f.matches(&79u16.to_be_bytes())); + assert!(!f.matches(&8081u16.to_be_bytes())); + } + + #[test] + #[should_panic(expected = "field width")] + fn exact_field_width_mismatch_panics() { + let f = FieldPredicate::Exact(Exact::new(bytes(&[1, 2, 3, 4]))); + let _ = f.matches(&[1, 2]); + } + + #[test] + #[should_panic(expected = "field width")] + fn range_field_width_mismatch_panics() { + let r = FieldPredicate::Range(Range::new(bytes(&[0, 0]), bytes(&[255, 255]))); + let _ = r.matches(&[0, 0, 0]); + } +} diff --git a/match-action/src/rule.rs b/match-action/src/rule.rs new file mode 100644 index 0000000000..41cd5299af --- /dev/null +++ b/match-action/src/rule.rs @@ -0,0 +1,148 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +use crate::{FieldKind, FixedSize}; +pub trait RuleField { + const KIND: FieldKind; + type Value: FixedSize; +} +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct ExactSpec { + pub value: T, +} + +impl ExactSpec { + #[must_use] + pub const fn new(value: T) -> Self { + Self { value } + } +} + +impl RuleField for ExactSpec { + const KIND: FieldKind = FieldKind::Exact; + type Value = T; +} +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct PrefixSpec { + pub value: T, + pub len: u8, +} + +impl PrefixSpec { + #[must_use] + pub fn new(value: T, len: u8) -> Self { + let bits = T::SIZE + .checked_mul(8) + .and_then(|b| u8::try_from(b).ok()) + .unwrap_or(u8::MAX); + assert!( + len <= bits, + "prefix length {len} exceeds field width of {bits} bits", + ); + Self { value, len } + } +} + +impl RuleField for PrefixSpec { + const KIND: FieldKind = FieldKind::Prefix; + type Value = T; +} +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct MaskSpec { + pub value: T, + pub mask: T, +} + +impl MaskSpec { + #[must_use] + pub const fn new(value: T, mask: T) -> Self { + Self { value, mask } + } +} + +impl RuleField for MaskSpec { + const KIND: FieldKind = FieldKind::Mask; + type Value = T; +} +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub struct RangeSpec { + pub min: T, + pub max: T, +} + +impl RangeSpec { + #[must_use] + pub const fn new(min: T, max: T) -> Self { + Self { min, max } + } + #[must_use] + pub const fn exact(value: T) -> Self { + Self { + min: value, + max: value, + } + } +} + +impl RuleField for RangeSpec { + const KIND: FieldKind = FieldKind::Range; + type Value = T; +} +impl From> for RangeSpec { + fn from(range: core::ops::RangeInclusive) -> Self { + let (min, max) = range.into_inner(); + Self { min, max } + } +} +pub trait Backend { + type Field; +} +pub trait IntoBackendField { + fn into_backend_field(self) -> B::Field; +} +pub trait Accepts { + fn accepts(&self, value: &T) -> bool; +} +pub trait IsUniversal { + fn is_universal(&self) -> bool; +} + +impl IsUniversal for ExactSpec { + fn is_universal(&self) -> bool { + false + } +} + +impl IsUniversal for PrefixSpec { + fn is_universal(&self) -> bool { + self.len == 0 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use core::net::{Ipv4Addr, Ipv6Addr}; + + #[test] + fn prefix_spec_accepts_max_v4_length() { + let _ = PrefixSpec::new(Ipv4Addr::UNSPECIFIED, 32); + } + + #[test] + #[should_panic(expected = "prefix length 33 exceeds field width of 32 bits")] + fn prefix_spec_rejects_v4_over_32() { + let _ = PrefixSpec::new(Ipv4Addr::UNSPECIFIED, 33); + } + + #[test] + fn prefix_spec_accepts_max_v6_length() { + let _ = PrefixSpec::new(Ipv6Addr::UNSPECIFIED, 128); + } + + #[test] + #[should_panic(expected = "prefix length 129 exceeds field width of 128 bits")] + fn prefix_spec_rejects_v6_over_128() { + let _ = PrefixSpec::new(Ipv6Addr::UNSPECIFIED, 129); + } +} diff --git a/match-action/tests/derive_roundtrip.rs b/match-action/tests/derive_roundtrip.rs new file mode 100644 index 0000000000..fa826004ac --- /dev/null +++ b/match-action/tests/derive_roundtrip.rs @@ -0,0 +1,245 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +use core::net::Ipv4Addr; + +use dataplane_match_action::{ + ExactSpec, FieldKind, FixedSize, MatchKey, PrefixSpec, RangeSpec, RuleField, +}; +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +struct IpProto(u8); + +impl FixedSize for IpProto { + const SIZE: usize = 1; + fn write_be(&self, out: &mut [u8]) { + out[0] = self.0; + } +} + +#[derive(MatchKey)] +struct FiveTuple { + #[exact] + proto: IpProto, + #[prefix] + src_ip: Ipv4Addr, + #[prefix] + dst_ip: Ipv4Addr, + #[range] + src_port: u16, + #[range] + dst_port: u16, +} + +#[test] +fn n_and_key_size_match_field_layout() { + assert_eq!(FiveTuple::N, 5); + assert_eq!(FiveTuple::KEY_SIZE, 13); +} + +#[test] +fn field_specs_match_declaration_order() { + let specs = FiveTuple::field_specs(); + assert_eq!(specs.len(), 5); + + assert_eq!(specs[0].name, "proto"); + assert_eq!(specs[0].kind, FieldKind::Exact); + assert_eq!(specs[0].size, 1); + assert_eq!(specs[0].offset, 0); + + assert_eq!(specs[1].name, "src_ip"); + assert_eq!(specs[1].kind, FieldKind::Prefix); + assert_eq!(specs[1].size, 4); + assert_eq!(specs[1].offset, 1); + + assert_eq!(specs[2].name, "dst_ip"); + assert_eq!(specs[2].kind, FieldKind::Prefix); + assert_eq!(specs[2].size, 4); + assert_eq!(specs[2].offset, 5); + + assert_eq!(specs[3].name, "src_port"); + assert_eq!(specs[3].kind, FieldKind::Range); + assert_eq!(specs[3].size, 2); + assert_eq!(specs[3].offset, 9); + + assert_eq!(specs[4].name, "dst_port"); + assert_eq!(specs[4].kind, FieldKind::Range); + assert_eq!(specs[4].size, 2); + assert_eq!(specs[4].offset, 11); +} + +#[test] +fn key_packs_big_endian_at_field_offsets() { + let key = FiveTuple { + proto: IpProto(6), + src_ip: Ipv4Addr::new(10, 0, 1, 2), + dst_ip: Ipv4Addr::new(192, 168, 5, 7), + src_port: 54321, + dst_port: 22, + }; + let arr: [u8; FiveTuple::KEY_SIZE] = key.as_key(); + let mut buf = [0u8; FiveTuple::KEY_SIZE]; + key.as_key_into(&mut buf); + assert_eq!(arr, buf); + + assert_eq!(arr[0], 6); + assert_eq!(&arr[1..5], &[10, 0, 1, 2]); + assert_eq!(&arr[5..9], &[192, 168, 5, 7]); + assert_eq!(&arr[9..11], &54321u16.to_be_bytes()); + assert_eq!(&arr[11..13], &22u16.to_be_bytes()); +} + +#[test] +fn as_key_into_does_not_touch_bytes_past_the_key() { + let key = FiveTuple { + proto: IpProto(17), + src_ip: Ipv4Addr::UNSPECIFIED, + dst_ip: Ipv4Addr::UNSPECIFIED, + src_port: 0, + dst_port: 0, + }; + let mut buf = [0xFFu8; 64]; + key.as_key_into(&mut buf); + assert_eq!(buf[0], 17); + assert_eq!(buf[FiveTuple::KEY_SIZE], 0xFF); + assert_eq!(buf[63], 0xFF); +} + +#[test] +fn derive_emits_parallel_rule_struct() { + let rule = FiveTupleRule { + proto: ExactSpec::new(IpProto(6)), + src_ip: PrefixSpec::new(Ipv4Addr::new(10, 0, 0, 0), 24), + dst_ip: PrefixSpec::new(Ipv4Addr::UNSPECIFIED, 0), + src_port: RangeSpec::new(0, u16::MAX), + dst_port: RangeSpec::exact(80), + }; + + assert_eq!(rule.proto.value, IpProto(6)); + assert_eq!(rule.src_ip.len, 24); + assert_eq!(rule.dst_ip.len, 0); + assert_eq!(rule.src_port.min, 0); + assert_eq!(rule.src_port.max, u16::MAX); + assert_eq!(rule.dst_port.min, 80); + assert_eq!(rule.dst_port.max, 80); +} + +#[test] +fn rule_field_kinds_match_match_key_attrs() { + assert_eq!( as RuleField>::KIND, FieldKind::Exact); + assert_eq!( as RuleField>::KIND, FieldKind::Prefix); + assert_eq!( as RuleField>::KIND, FieldKind::Range); +} + +#[test] +fn single_field_key_works() { + #[derive(MatchKey)] + #[allow(dead_code)] + struct Mono { + #[exact] + only: u32, + } + + assert_eq!(Mono::N, 1); + assert_eq!(Mono::KEY_SIZE, 4); + let specs = Mono::field_specs(); + assert_eq!(specs[0].name, "only"); + assert_eq!(specs[0].offset, 0); + assert_eq!(specs[0].kind, FieldKind::Exact); + + let m = Mono { only: 0xDEAD_BEEF }; + let arr = m.as_key(); + assert_eq!(arr, 0xDEAD_BEEFu32.to_be_bytes()); +} + +#[test] +fn range_spec_from_inclusive_range() { + let r: RangeSpec = (80..=8080).into(); + assert_eq!(r.min, 80); + assert_eq!(r.max, 8080); + + let single: RangeSpec = (22..=22).into(); + assert_eq!(single, RangeSpec::exact(22)); +} + +#[test] +fn fields_without_attribute_default_to_exact() { + #[derive(MatchKey)] + #[allow(dead_code)] + struct AllExact { + a: u8, + b: u32, + } + + let specs = AllExact::field_specs(); + assert_eq!(specs.len(), 2); + assert_eq!(specs[0].kind, FieldKind::Exact); + assert_eq!(specs[1].kind, FieldKind::Exact); + assert_eq!(AllExact::KEY_SIZE, 5); + let _rule = AllExactRule { + a: ExactSpec::new(6u8), + b: ExactSpec::new(0x0A00_0001u32), + }; + + let key = AllExact { + a: 6, + b: 0x0A00_0001, + }; + let bytes = key.as_key(); + assert_eq!(bytes[0], 6); + assert_eq!(&bytes[1..5], &0x0A00_0001u32.to_be_bytes()); +} + +#[test] +fn generic_match_key_instantiates_for_v4_and_v6() { + use core::net::Ipv6Addr; + #[derive(MatchKey)] + #[allow(dead_code)] + struct TwoTuple { + #[prefix] + src: Addr, + #[prefix] + dst: Addr, + } + assert_eq!(>::N, 2); + assert_eq!(>::KEY_SIZE, 8); + let v4_specs = >::field_specs(); + assert_eq!(v4_specs[0].size, 4); + assert_eq!(v4_specs[1].offset, 4); + assert_eq!(v4_specs[0].kind, FieldKind::Prefix); + assert_eq!(>::N, 2); + assert_eq!(>::KEY_SIZE, 32); + let v6_specs = >::field_specs(); + assert_eq!(v6_specs[0].size, 16); + assert_eq!(v6_specs[1].offset, 16); + let _v4_rule = TwoTupleRule:: { + src: PrefixSpec::new(Ipv4Addr::new(10, 0, 0, 0), 8), + dst: PrefixSpec::new(Ipv4Addr::UNSPECIFIED, 0), + }; + let v4 = TwoTuple:: { + src: Ipv4Addr::new(10, 0, 1, 2), + dst: Ipv4Addr::new(192, 168, 5, 7), + }; + let mut buf = [0u8; 8]; + v4.as_key_into(&mut buf); + assert_eq!(&buf[0..4], &[10, 0, 1, 2]); + assert_eq!(&buf[4..8], &[192, 168, 5, 7]); +} +#[test] +fn derive_accepts_explicit_where_clause() { + #[derive(MatchKey)] + #[allow(dead_code)] + struct WithWhere + where + Addr: FixedSize, + { + #[prefix] + src: Addr, + } + + assert_eq!(>::N, 1); + assert_eq!(>::KEY_SIZE, 4); + + let _rule = WithWhereRule:: { + src: PrefixSpec::new(Ipv4Addr::UNSPECIFIED, 0), + }; +} diff --git a/mgmt/Cargo.toml b/mgmt/Cargo.toml index b474f55db1..b8e9d9fa41 100644 --- a/mgmt/Cargo.toml +++ b/mgmt/Cargo.toml @@ -26,6 +26,7 @@ id = { workspace = true } interface-manager = { workspace = true } k8s-intf = { workspace = true, features = ["client"] } k8s-less = { workspace = true } +lifecycle = { workspace = true } lpm = { workspace = true } nat = { workspace = true } net = { workspace = true } diff --git a/mgmt/src/lib.rs b/mgmt/src/lib.rs index 9847426ab7..245990a966 100644 --- a/mgmt/src/lib.rs +++ b/mgmt/src/lib.rs @@ -7,7 +7,7 @@ mod processor; mod tests; pub mod vpc_manager; -pub use processor::launch::{MgmtParams, start_mgmt}; +pub use processor::launch::{LaunchError, MgmtParams, run_mgmt}; pub use processor::proc::ConfigProcessorParams; use tracectl::trace_target; diff --git a/mgmt/src/processor/launch.rs b/mgmt/src/processor/launch.rs index 95fcf1623c..149c8d6b18 100644 --- a/mgmt/src/processor/launch.rs +++ b/mgmt/src/processor/launch.rs @@ -9,6 +9,7 @@ use crate::processor::proc::ConfigProcessor; use crate::processor::proc::ConfigProcessorParams; use concurrency::sync::Arc; +use lifecycle::{CancellationToken, Subsystem}; use tracing::{debug, error, info, warn}; #[derive(Debug, thiserror::Error)] @@ -16,21 +17,11 @@ pub enum LaunchError { #[error("IO error: {0}")] IoError(std::io::Error), #[error("Error in K8s client task: {0}")] - K8sClientError(K8sClientError), - #[error("Error starting/waiting for K8s client task: {0}")] - K8sClientJoinError(tokio::task::JoinError), - #[error("K8s client exited prematurely")] - PrematureK8sClientExit, - #[error("Config processor exited prematurely")] - PrematureProcessorExit, - - #[error("Error in Config Processor task: {0}")] - ProcessorError(std::io::Error), - #[error("Error starting/waiting for Config Processor task: {0}")] - ProcessorJoinError(tokio::task::JoinError), - + K8sClientError(#[from] K8sClientError), #[error("Error in k8s-less mode: {0}")] K8LessError(#[from] K8sLessError), + #[error("Mgmt init cancelled before completion")] + Cancelled, } pub struct MgmtParams { @@ -42,96 +33,262 @@ pub struct MgmtParams { use std::time::Duration; const K8S_STATUS_UPD: Duration = Duration::from_secs(15); const K8S_INIT_RETRY_TIME: Duration = Duration::from_secs(5); -const K8S_INIT_MAX_ATTEMPTS: u8 = 10; +const K8S_INIT_MAX_RETRIES: u8 = 10; -async fn k8s_mgmt_init(k8s_client: &K8sClient) -> Result<(), K8sClientError> { - let mut retries = K8S_INIT_MAX_ATTEMPTS; +/// Run `init` under `cancel`. Returns [`LaunchError::Cancelled`] on cancel. +async fn init_cancellable(init: F, cancel: &CancellationToken) -> Result<(), LaunchError> +where + F: std::future::Future>, + LaunchError: From, +{ + tokio::select! { + r = init => r.map_err(LaunchError::from), + () = cancel.cancelled() => { + info!("Mgmt init cancelled"); + Err(LaunchError::Cancelled) + } + } +} +/// Retry k8s init up to `K8S_INIT_MAX_RETRIES` times with +/// `K8S_INIT_RETRY_TIME` backoff. Attempt and backoff both observe `cancel`. +async fn k8s_mgmt_init( + k8s_client: &K8sClient, + cancel: &CancellationToken, +) -> Result<(), LaunchError> { + let mut retries = K8S_INIT_MAX_RETRIES; debug!("Initializing k8s client..."); - while let Err(e) = k8s_client.init().await { - warn!("Could not initialize k8s state. Will retry {retries} more times"); - tokio::time::sleep(K8S_INIT_RETRY_TIME).await; - if retries == 0 { - error!("Maximum k8s initialization attempts reached. Giving up..."); - return Err(e); + loop { + match init_cancellable(k8s_client.init(), cancel).await { + Ok(()) => break, + Err(LaunchError::Cancelled) => return Err(LaunchError::Cancelled), + Err(e) if retries == 0 => { + error!("Maximum k8s initialization attempts reached. Giving up..."); + return Err(e); + } + Err(_) => { + warn!("Could not initialize k8s state. Will retry {retries} more times"); + retries -= 1; + tokio::select! { + () = tokio::time::sleep(K8S_INIT_RETRY_TIME) => {} + () = cancel.cancelled() => { + info!("K8s init cancelled during retry backoff"); + return Err(LaunchError::Cancelled); + } + } + } } - retries -= 1; } info!("K8s initialization succeeded"); Ok(()) } -/// Start the mgmt service. If the k8s interface is not ready, this may take up to -/// K8S_INIT_RETRY_TIME * K8S_INIT_MAX_ATTEMPTS seconds to complete. -pub fn start_mgmt(params: MgmtParams) -> Result, LaunchError> { - let (tx, rx) = tokio::sync::oneshot::channel(); - - let handle = std::thread::Builder::new() - .name("mgmt".to_string()) - .spawn(move || { - debug!("Starting dataplane management thread..."); - - /* create tokio runtime */ - let rt = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .expect("Tokio runtime creation failed"); - - if let Some(config_dir) = ¶ms.config_dir { - warn!("Running in k8s-less mode...."); - rt.block_on(async { - let (processor, client) = ConfigProcessor::new(params.processor_params); - let k8sless = - Arc::new(K8sLess::new(params.hostname.as_str(), config_dir, client)); - let k8sless1 = k8sless.clone(); - - let init_result = k8sless.init().await.map_err(LaunchError::K8LessError); - let init_failed = init_result.is_err(); - tx.send(init_result).expect("Main thread gone"); - if init_failed { - return; - } +/// Init mgmt synchronously on `handle`, then spawn the long-lived tasks +/// (config processor, status updater, config watcher) tracked under +/// `mgmt`. Init observes `mgmt.root_token()` so SIGINT during init returns +/// [`LaunchError::Cancelled`] within cancel latency. +/// +/// # Errors +/// Returns [`LaunchError`] on init failure. [`LaunchError::Cancelled`] is +/// a clean-shutdown signal — callers must not flip the fatal flag for it. +pub fn run_mgmt( + handle: &tokio::runtime::Handle, + mgmt: &Subsystem, + params: MgmtParams, +) -> Result<(), LaunchError> { + if let Some(config_dir) = ¶ms.config_dir { + warn!("Running in k8s-less mode...."); + handle.block_on(run_k8s_less( + handle, + mgmt, + params.hostname.as_str(), + config_dir, + params.processor_params, + )) + } else { + debug!("Will start watching k8s for configuration changes"); + handle.block_on(run_k8s( + handle, + mgmt, + params.hostname.as_str(), + params.processor_params, + )) + } +} - tokio::spawn(async { processor.run().await }); - tokio::spawn(async move { k8sless.start_status_update(&K8S_STATUS_UPD).await }); - let _ = K8sLess::start_config_watch(k8sless1).await; - }) - } else { - debug!("Will start watching k8s for configuration changes"); - rt.block_on(async { - let (processor, client) = ConfigProcessor::new(params.processor_params); - let k8s_client = Arc::new(K8sClient::new(params.hostname.as_str(), client)); - let k8s_client1 = k8s_client.clone(); - - let init_result = k8s_mgmt_init(&k8s_client) - .await - .map_err(LaunchError::K8sClientError); - - let init_failed = init_result.is_err(); - tx.send(init_result).expect("Main thread gone"); - if init_failed { - return; - } +async fn run_k8s_less( + handle: &tokio::runtime::Handle, + mgmt: &Subsystem, + hostname: &str, + config_dir: &str, + processor_params: ConfigProcessorParams, +) -> Result<(), LaunchError> { + let (processor, client) = ConfigProcessor::new(processor_params); + let k8sless = Arc::new(K8sLess::new(hostname, config_dir, client)); + let k8sless_for_watch = k8sless.clone(); + + init_cancellable(k8sless.init(), &mgmt.root_token()).await?; - tokio::spawn(async { processor.run().await }); - tokio::spawn(async move { - k8s_client1.k8s_start_status_update(&K8S_STATUS_UPD).await - }); - let _ = - tokio::spawn(async { K8sClient::k8s_start_config_watch(k8s_client).await }) - .await; - }) + mgmt.spawn_fatal_on_exit("k8s-less config processor", processor.run(), handle); + let k8sless_for_status = k8sless.clone(); + mgmt.spawn_fatal_on_exit( + "k8s-less status updater", + async move { + k8sless_for_status + .start_status_update(&K8S_STATUS_UPD) + .await + }, + handle, + ); + mgmt.spawn_fatal_on_exit( + "k8s-less config watcher", + async move { + match K8sLess::start_config_watch(k8sless_for_watch).await { + Ok(()) => warn!("k8s-less config watcher returned Ok unexpectedly"), + Err(e) => error!("k8s-less config watcher failed: {e}"), } - unreachable!() - }) - .map_err(LaunchError::IoError)?; - - match rx - .blocking_recv() - .map_err(|_| LaunchError::PrematureProcessorExit)? - { - Ok(()) => Ok(handle), - Err(e) => Err(e), + }, + handle, + ); + + Ok(()) +} + +async fn run_k8s( + handle: &tokio::runtime::Handle, + mgmt: &Subsystem, + hostname: &str, + processor_params: ConfigProcessorParams, +) -> Result<(), LaunchError> { + let (processor, client) = ConfigProcessor::new(processor_params); + let k8s_client = Arc::new(K8sClient::new(hostname, client)); + let k8s_client_for_status = k8s_client.clone(); + + k8s_mgmt_init(&k8s_client, &mgmt.root_token()).await?; + + mgmt.spawn_fatal_on_exit("k8s config processor", processor.run(), handle); + mgmt.spawn_fatal_on_exit( + "k8s status updater", + async move { + k8s_client_for_status + .k8s_start_status_update(&K8S_STATUS_UPD) + .await + }, + handle, + ); + mgmt.spawn_fatal_on_exit( + "k8s config watcher", + async move { + K8sClient::k8s_start_config_watch(k8s_client).await; + }, + handle, + ); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::processor::k8s_less_client::K8sLessError; + use lifecycle::Shutdown; + use std::time::Duration; + + #[tokio::test] + async fn init_cancellable_returns_cancelled_on_pre_tripped_token() { + let cancel = CancellationToken::new(); + cancel.cancel(); + + let result: Result<(), LaunchError> = init_cancellable( + async { + // Long sleep so a missing cancel arm surfaces as a test timeout. + tokio::time::sleep(Duration::from_secs(60)).await; + Ok::<(), K8sLessError>(()) + }, + &cancel, + ) + .await; + + assert!(matches!(result, Err(LaunchError::Cancelled))); + } + + #[tokio::test] + async fn init_cancellable_returns_cancelled_when_tripped_mid_init() { + let cancel = CancellationToken::new(); + let cancel_for_task = cancel.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(20)).await; + cancel_for_task.cancel(); + }); + + let result: Result<(), LaunchError> = init_cancellable( + async { + tokio::time::sleep(Duration::from_secs(60)).await; + Ok::<(), K8sLessError>(()) + }, + &cancel, + ) + .await; + + assert!(matches!(result, Err(LaunchError::Cancelled))); + } + + #[tokio::test] + async fn init_cancellable_returns_ok_when_init_completes_first() { + let cancel = CancellationToken::new(); + let result: Result<(), LaunchError> = + init_cancellable(async { Ok::<(), K8sLessError>(()) }, &cancel).await; + assert!(result.is_ok()); + assert!(!cancel.is_cancelled()); + } + + #[tokio::test] + async fn init_cancellable_propagates_init_error() { + let cancel = CancellationToken::new(); + let result: Result<(), LaunchError> = init_cancellable( + async { Err::<(), K8sLessError>(K8sLessError::Internal("synthetic".into())) }, + &cancel, + ) + .await; + assert!(matches!(result, Err(LaunchError::K8LessError(_)))); + } + + /// Locks in the main.rs contract: SIGTERM during k8s init must yield + /// exit 0 (else systemd restart-loops the unit). Mirrors the match + /// arms in runtime.rs. + #[tokio::test] + async fn cancelled_launch_error_yields_zero_exit_code_at_call_site() { + let shutdown = Shutdown::new(); + shutdown.root.cancel(); + let mgmt_result: Result<(), LaunchError> = Err(LaunchError::Cancelled); + + match mgmt_result { + Ok(()) => {} + Err(LaunchError::Cancelled) => {} + Err(_) => { + shutdown.fail(); + } + } + + assert!(!shutdown.is_fatal()); + assert_eq!(i32::from(shutdown.is_fatal()), 0); + } + + #[tokio::test] + async fn non_cancelled_launch_error_yields_nonzero_exit_code_at_call_site() { + let shutdown = Shutdown::new(); + let mgmt_result: Result<(), LaunchError> = + Err(LaunchError::IoError(std::io::Error::other("synthetic"))); + + match mgmt_result { + Ok(()) => {} + Err(LaunchError::Cancelled) => {} + Err(_) => { + shutdown.fail(); + } + } + + assert!(shutdown.is_fatal()); + assert_eq!(i32::from(shutdown.is_fatal()), 1); } } diff --git a/mgmt/src/tests/mgmt.rs b/mgmt/src/tests/mgmt.rs index 508f6099cb..0eb6c9477a 100644 --- a/mgmt/src/tests/mgmt.rs +++ b/mgmt/src/tests/mgmt.rs @@ -435,7 +435,10 @@ pub mod test { .expect("Should succeed due to defaults"); /* start router */ - let router = Router::new(router_params, None); + let test_mgmt = lifecycle::Subsystem::new("mgmt", lifecycle::CancellationToken::new()); + let test_router = lifecycle::Subsystem::new("router", lifecycle::CancellationToken::new()); + let handle = tokio::runtime::Handle::current(); + let router = Router::new(&test_mgmt, &handle, &test_router, router_params, None); if let Err(e) = &router { error!("New router failed: {e}"); panic!(); diff --git a/net/Cargo.toml b/net/Cargo.toml index fac84a9665..bbda3162aa 100644 --- a/net/Cargo.toml +++ b/net/Cargo.toml @@ -27,6 +27,7 @@ concurrency = { workspace = true } derive_builder = { workspace = true, features = ["alloc"] } downcast-rs = { workspace = true, features = ["sync"] } etherparse = { workspace = true, features = ["std"] } +fixed-size = { workspace = true, features = [] } id = { workspace = true } multi_index_map = { workspace = true, default-features = false, features = ["serde"] } rapidhash = { workspace = true } diff --git a/net/src/fixed_size.rs b/net/src/fixed_size.rs new file mode 100644 index 0000000000..f14effa170 --- /dev/null +++ b/net/src/fixed_size.rs @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright Open Network Fabric Authors + +use fixed_size::FixedSize; + +use crate::ipv4::UnicastIpv4Addr; +use crate::tcp::TcpPort; +use crate::udp::UdpPort; +use crate::vxlan::Vni; + +impl FixedSize for TcpPort { + const SIZE: usize = 2; + fn write_be(&self, out: &mut [u8]) { + self.as_u16().write_be(out); + } +} + +impl FixedSize for UdpPort { + const SIZE: usize = 2; + fn write_be(&self, out: &mut [u8]) { + self.as_u16().write_be(out); + } +} + +impl FixedSize for UnicastIpv4Addr { + const SIZE: usize = 4; + fn write_be(&self, out: &mut [u8]) { + self.inner().write_be(out); + } +} + +impl FixedSize for Vni { + const SIZE: usize = 4; + fn write_be(&self, out: &mut [u8]) { + self.as_u32().write_be(out); + } +} + +#[cfg(test)] +mod tests { + use core::net::Ipv4Addr; + + use super::*; + + #[test] + fn ports_write_two_big_endian_bytes() { + assert_eq!(::SIZE, 2); + assert_eq!(::SIZE, 2); + let mut buf = [0u8; 2]; + TcpPort::new_checked(443).unwrap().write_be(&mut buf); + assert_eq!(buf, 443u16.to_be_bytes()); + UdpPort::new_checked(4789).unwrap().write_be(&mut buf); + assert_eq!(buf, 4789u16.to_be_bytes()); + } + + #[test] + fn unicast_v4_writes_four_octets() { + assert_eq!(::SIZE, 4); + let mut buf = [0u8; 4]; + UnicastIpv4Addr::new(Ipv4Addr::new(10, 0, 1, 2)) + .unwrap() + .write_be(&mut buf); + assert_eq!(buf, [10, 0, 1, 2]); + } + + #[test] + fn vni_writes_four_bytes_with_zero_high_byte() { + assert_eq!(::SIZE, 4); + let mut buf = [0u8; 4]; + Vni::new_checked(0x00AB_CDEF).unwrap().write_be(&mut buf); + assert_eq!(buf, [0x00, 0xAB, 0xCD, 0xEF]); + } +} diff --git a/net/src/lib.rs b/net/src/lib.rs index e1683d4295..e76f84ebe3 100644 --- a/net/src/lib.rs +++ b/net/src/lib.rs @@ -19,6 +19,8 @@ pub mod addr_parse_error; pub mod buffer; pub mod checksum; pub mod eth; +/// `FixedSize` impls bridging `net` field types into match-action keys. +mod fixed_size; #[cfg(unix)] pub mod flows; pub mod headers; diff --git a/routing/Cargo.toml b/routing/Cargo.toml index 74e4459638..d3c830eb40 100644 --- a/routing/Cargo.toml +++ b/routing/Cargo.toml @@ -20,6 +20,7 @@ config = { workspace = true } concurrency = { workspace = true } dplane-rpc = { workspace = true } left-right-tlcache = { workspace = true } +lifecycle = { workspace = true } lpm = { workspace = true } net = { workspace = true } tracectl = { workspace = true } diff --git a/routing/src/bmp/mod.rs b/routing/src/bmp/mod.rs index 44b5b99146..6e9b5ce9c4 100644 --- a/routing/src/bmp/mod.rs +++ b/routing/src/bmp/mod.rs @@ -9,6 +9,7 @@ pub use server::{BmpServer, BmpServerConfig}; use concurrency::sync::Arc; use config::internal::status::DataplaneStatus; +use lifecycle::Subsystem; use tokio::sync::RwLock; use tokio::task::JoinHandle; use tracing::{error, info}; @@ -16,12 +17,16 @@ use tracing::{error, info}; use tracectl::trace_target; trace_target!("bmp", LevelFilter::INFO, &[]); -/// Spawn BMP server in background +/// Spawn the BMP server on `handle`, tracked under `mgmt` so it drains +/// with the rest of mgmt's tasks. +#[must_use] pub fn spawn_background( + mgmt: &Subsystem, + handle: &tokio::runtime::Handle, bind: std::net::SocketAddr, dp_status: Arc>, ) -> JoinHandle<()> { - // The future we want to run + let cancel = mgmt.cancel_token(); let fut = async move { info!("starting BMP server on {}", bind); let cfg = BmpServerConfig { @@ -29,19 +34,16 @@ pub fn spawn_background( ..Default::default() }; let srv = BmpServer::new(cfg, handler::StatusHandler::new(dp_status)); - if let Err(e) = srv.run().await { - error!("bmp server terminated: {e:#}"); + tokio::select! { + () = cancel.cancelled() => { + info!("BMP server shutdown requested"); + } + res = srv.run() => { + if let Err(e) = res { + error!("bmp server terminated: {e:#}"); + } + } } }; - - if let Ok(handle) = tokio::runtime::Handle::try_current() { - handle.spawn(fut) - } else { - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .expect("failed to build Tokio runtime for BMP"); - let rt_static: &'static tokio::runtime::Runtime = Box::leak(Box::new(rt)); - rt_static.spawn(fut) - } + mgmt.spawn_on(fut, handle) } diff --git a/routing/src/frr/test.rs b/routing/src/frr/test.rs index d83322e6fb..b53ce9847c 100644 --- a/routing/src/frr/test.rs +++ b/routing/src/frr/test.rs @@ -149,7 +149,12 @@ pub mod tests { .expect("Should succeed due to defaults"); /* start router */ - let mut router = Router::new(router_params, None).unwrap(); + let mgmt = lifecycle::Subsystem::new("mgmt", lifecycle::CancellationToken::new()); + let router_subsystem = + lifecycle::Subsystem::new("router", lifecycle::CancellationToken::new()); + let handle = tokio::runtime::Handle::current(); + let mut router = + Router::new(&mgmt, &handle, &router_subsystem, router_params, None).unwrap(); let mut ctl = router.get_ctl_tx(); /* start fake frr agent */ diff --git a/routing/src/router/ctl.rs b/routing/src/router/ctl.rs index 3fda28e73e..e504a7f33d 100644 --- a/routing/src/router/ctl.rs +++ b/routing/src/router/ctl.rs @@ -45,7 +45,6 @@ impl Drop for LockGuard { } pub(crate) enum RouterCtlMsg { - Finish, Lock(RouterCtlReplyTx), Unlock(RouterCtlReplyTx), GuardedUnlock, @@ -241,10 +240,6 @@ fn handle_config_history(rio: &mut Rio, history: Arc>) { /// Handle a request from the control channel pub(crate) fn handle_ctl_msg(rio: &mut Rio, db: &mut RoutingDb) { match rio.ctl_rx.try_recv() { - Ok(RouterCtlMsg::Finish) => { - info!("Got request to shutdown. Au revoir ..."); - rio.run = false; - } Ok(RouterCtlMsg::Lock(reply_to)) => handle_lock(rio, true, Some(reply_to)), Ok(RouterCtlMsg::Unlock(reply_to)) => handle_lock(rio, false, Some(reply_to)), Ok(RouterCtlMsg::GuardedUnlock) => handle_lock(rio, false, None), diff --git a/routing/src/router/mod.rs b/routing/src/router/mod.rs index a4bffdcb16..fa300c4151 100644 --- a/routing/src/router/mod.rs +++ b/routing/src/router/mod.rs @@ -144,6 +144,9 @@ impl Router { /// Start a `Router` #[allow(clippy::new_without_default)] pub fn new( + mgmt: &lifecycle::Subsystem, + mgmt_handle: &tokio::runtime::Handle, + router: &lifecycle::Subsystem, params: RouterParams, cli_sources: Option, ) -> Result { @@ -163,15 +166,16 @@ impl Router { let rioconf = Self::build_rio_config(¶ms)?; debug!("{name}: Starting router IO..."); - let rio_handle = start_rio(&rioconf, fibtw, iftw, atabler, cli_sources)?; + let rio_handle = start_rio(router, &rioconf, fibtw, iftw, atabler, cli_sources)?; - // Start BMP server in background if configured, always with mandatory dp_status let bmp_handle = if let Some(bmp_params) = ¶ms.bmp { debug!( "{name}: Starting BMP server on {} (interval={:?})", bmp_params.bind_addr, bmp_params.stats_interval ); Some(bmp::spawn_background( + mgmt, + mgmt_handle, bmp_params.bind_addr, params.dp_status.clone(), )) @@ -200,7 +204,9 @@ impl Router { } self.resolver.stop(); - // Abort BMP server task if running (Tokio handle). + // BMP is also tracked under the mgmt subsystem, which normally + // drains it cleanly via `Shutdown::drain_in_order`. This abort is + // a safety net for the case where the mgmt drain hit its deadline. if let Some(handle) = self.bmp_handle.take() { handle.abort(); } diff --git a/routing/src/router/rio.rs b/routing/src/router/rio.rs index 6916adb8f1..be2e3f2587 100644 --- a/routing/src/router/rio.rs +++ b/routing/src/router/rio.rs @@ -22,17 +22,18 @@ use cli::IoCache; use cli::cliproto::{CLI_RX_BUFF_SIZE, CliRequest}; use config::{GwConfigMeta, ValidatedGwConfig}; use dplane_rpc::socks::RpcCachedSock; +use lifecycle::{CancellationToken, Subsystem}; use mio::unix::SourceFd; use mio::{Events, Interest, Poll, Token}; use concurrency::sync::Arc; +use concurrency::thread::{self, JoinHandle}; use nix::sys::socket::{getsockopt, setsockopt, sockopt::SndBuf}; use std::fs; use std::os::fd::AsRawFd; use std::os::unix::fs::PermissionsExt; use std::os::unix::net::UnixDatagram; -use std::thread::{self, JoinHandle}; use std::time::{Duration, Instant}; use tokio::sync::mpsc::{Receiver, Sender, channel}; @@ -44,31 +45,28 @@ const CTL_CHANNEL_CAPACITY: usize = 100; /// An object to control a router IO, [`Rio`] pub(crate) struct RioHandle { + cancel: CancellationToken, ctl: Sender, handle: Option>, } impl RioHandle { - /// Terminate the router IO loop / thread + /// Trip the router cancel and join the RIO thread. Idempotent — a + /// second call after the thread has been joined returns `Ok(())`. + /// Worst-case exit latency is one poll timeout (1 second). /// /// # Errors - /// Fails if the channel has been dropped or the thread cannot be joined + /// Fails if the thread panicked during join. pub(crate) fn finish(&mut self) -> Result<(), RouterError> { debug!("Requesting router IO to stop.."); - self.ctl - .try_send(RouterCtlMsg::Finish) - .map_err(|_| RouterError::Internal("Error sending over ctl channel"))?; - - let handle = self.handle.take(); - if let Some(handle) = handle { - debug!("Waiting for the router IO to terminate.."); - handle - .join() - .map_err(|_| RouterError::Internal("Error joining thread"))?; - debug!("Router IO ended successfully"); - Ok(()) - } else { - Err(RouterError::Internal("No handle")) - } + self.cancel.cancel(); + + let Some(handle) = self.handle.take() else { + return Ok(()); + }; + handle + .join() + .map_err(|_| RouterError::Internal("Error joining thread"))?; + Ok(()) } #[must_use] pub(crate) fn get_ctl_tx(&self) -> RouterCtlSender { @@ -113,7 +111,6 @@ pub(crate) const FRRMISOCK: Token = Token(2); pub(crate) struct Rio { #[allow(unused)] pub(crate) name: String, - pub(crate) run: bool, pub(crate) frozen: bool, pub(crate) cp_sock_path: String, pub(crate) cli_sock_path: String, @@ -185,7 +182,6 @@ impl Rio { Ok(Rio { name: conf.name.clone(), - run: true, frozen: false, cp_sock_path, cli_sock_path, @@ -342,6 +338,7 @@ impl Rio { #[allow(clippy::missing_errors_doc, clippy::too_many_lines)] pub(crate) fn start_rio( + router: &Subsystem, conf: &RioConf, fibtw: FibTableWriter, iftw: IfTableWriter, @@ -351,9 +348,34 @@ pub(crate) fn start_rio( let mut rio = Rio::new(conf)?; let ctl_tx = rio.ctl_tx.clone(); let cli_sources = cli_sources.unwrap_or_default(); + let cancel = router.cancel_token(); + let loop_cancel = cancel.clone(); + let guard_subsystem = router.clone(); /* router IO loop */ let rio_loop = move || { + // Drop-guard so panic-unwind or unexpected loop exit trips + // report_fatal. + struct ExitGuard { + subsystem: Subsystem, + } + impl Drop for ExitGuard { + fn drop(&mut self) { + if self.subsystem.is_cancelled() { + return; + } + let reason = if std::thread::panicking() { + "RIO thread panicked" + } else { + "RIO thread exited unexpectedly" + }; + self.subsystem.report_fatal(reason); + } + } + let _guard = ExitGuard { + subsystem: guard_subsystem, + }; + info!("CPI: Listening at {}.", &rio.cp_sock_path); info!("CLI: Listening at {}.", &rio.cli_sock_path); info!("FRRMI: will connect to {}.", &rio.frrmi.get_remote()); @@ -366,7 +388,9 @@ pub(crate) fn start_rio( revent!(RouterEvent::Started); info!("Entering router IO loop...."); - while rio.run { + // Observe the router subsystem cancellation between poll cycles. + // Worst-case exit latency is the poll timeout (1 second). + while !loop_cancel.is_cancelled() { if let Err(e) = rio.poller.poll(&mut events, Some(Duration::from_secs(1))) { error!("Poller error!: {e}"); continue; @@ -467,6 +491,7 @@ pub(crate) fn start_rio( .map_err(|_| RouterError::Internal("Failure spawning thread"))?; Ok(RioHandle { + cancel, ctl: ctl_tx, handle: Some(handle), }) @@ -479,9 +504,14 @@ mod tests { use crate::fib::fibtable::FibTableWriter; use crate::interfaces::iftablerw::IfTableWriter; use crate::router::rio::{RioConf, start_rio}; - use std::thread; + use concurrency::thread; + use lifecycle::{CancellationToken, Subsystem}; use std::time::Duration; + fn test_router_subsystem() -> Subsystem { + Subsystem::new("router", CancellationToken::new()) + } + #[test] #[cfg_attr(emulated, ignore = "binds Unix domain sockets at /tmp/hh_*.sock")] fn test_rio_ctl() { @@ -508,7 +538,9 @@ mod tests { let (_atablew, atabler) = AtableWriter::new(); /* start CPI */ - let mut cpi = start_rio(&conf, fibtw, iftw, atabler, None).expect("Should succeed"); + let router = test_router_subsystem(); + let mut cpi = + start_rio(&router, &conf, fibtw, iftw, atabler, None).expect("Should succeed"); thread::sleep(Duration::from_secs(3)); assert_eq!(cpi.finish(), Ok(())); } @@ -533,7 +565,8 @@ mod tests { let (_atablew, atabler) = AtableWriter::new(); /* start router IO */ - let rio = start_rio(&conf, fibtw, iftw, atabler, None); + let router = test_router_subsystem(); + let rio = start_rio(&router, &conf, fibtw, iftw, atabler, None); assert!(rio.is_err_and(|e| matches!(e, RouterError::InvalidPath(_)))); } }