From 14eb60ee44fdb6847d4461ad3b5e1a2f1bd247b2 Mon Sep 17 00:00:00 2001 From: leon-xd Date: Tue, 28 Apr 2026 16:17:07 -0700 Subject: [PATCH 1/4] initial implementation of DerivesMap --- Cargo.lock | 7 +- Cargo.toml | 3 +- crates/wdk-build/Cargo.toml | 2 + crates/wdk-build/src/bindgen.rs | 3 + crates/wdk-build/src/derives.rs | 738 ++++++++++++++++++++++++++++++++ crates/wdk-build/src/lib.rs | 2 + 6 files changed, 752 insertions(+), 3 deletions(-) create mode 100644 crates/wdk-build/src/derives.rs diff --git a/Cargo.lock b/Cargo.lock index d38b627fc..85d36a2e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -106,9 +106,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "bindgen" -version = "0.71.1" +version = "0.72.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" +checksum = "993776b509cfb49c750f11b8f07a46fa23e0a1386ffc01fb1e7d343efc387895" dependencies = [ "bitflags", "cexpr", @@ -1146,6 +1146,7 @@ dependencies = [ "anyhow", "assert_fs", "bindgen", + "bitflags", "camino", "cargo_metadata", "cfg-if", @@ -1157,6 +1158,7 @@ dependencies = [ "semver", "serde", "serde_json", + "syn", "thiserror", "tracing", "windows", @@ -1191,6 +1193,7 @@ dependencies = [ "cargo_metadata", "cc", "cfg-if", + "regex", "rustversion", "serde_json", "thiserror", diff --git a/Cargo.toml b/Cargo.toml index ee152973f..fd4eb116c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,7 +38,8 @@ wdk-sys = { path = "crates/wdk-sys", version = "0.5.1" } anyhow = "1.0.97" assert_cmd = "2.0.17" assert_fs = "1.1.3" -bindgen = "0.71.0" +bindgen = "0.72.1" +bitflags = "2.6.0" camino = "1.1.9" cargo_metadata = "0.19.2" cc = "1.2.39" diff --git a/crates/wdk-build/Cargo.toml b/crates/wdk-build/Cargo.toml index 221d74b34..fd09b82d9 100644 --- a/crates/wdk-build/Cargo.toml +++ b/crates/wdk-build/Cargo.toml @@ -21,6 +21,7 @@ nightly = [] [dependencies] anyhow.workspace = true bindgen.workspace = true +bitflags.workspace = true camino.workspace = true cargo_metadata.workspace = true cfg-if.workspace = true @@ -32,6 +33,7 @@ rustversion.workspace = true semver.workspace = true serde = { features = ["derive"], workspace = true } serde_json.workspace = true +syn = { features = ["parsing"], workspace = true } thiserror.workspace = true tracing.workspace = true windows = { features = [ diff --git a/crates/wdk-build/src/bindgen.rs b/crates/wdk-build/src/bindgen.rs index f815cd5c6..8b6ab76a8 100644 --- a/crates/wdk-build/src/bindgen.rs +++ b/crates/wdk-build/src/bindgen.rs @@ -138,6 +138,9 @@ impl BuilderExt for Builder { // Defaults enums to generate as a set of constants contained in a module (default value // is EnumVariation::Consts which generates enums as global constants) .default_enum_style(bindgen::EnumVariation::ModuleConsts) + // `size_t`/`ssize_t` are pointer-width on every supported Windows + // driver target (x64, ARM64, x86), matching Rust's `usize`/`isize`. + .size_t_is_usize(true) .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) .parse_callbacks(Box::new(WdkCallbacks::new(config))) .formatter(bindgen::Formatter::Prettyplease) diff --git a/crates/wdk-build/src/derives.rs b/crates/wdk-build/src/derives.rs new file mode 100644 index 000000000..1af3add5b --- /dev/null +++ b/crates/wdk-build/src/derives.rs @@ -0,0 +1,738 @@ +// Copyright (c) Microsoft Corporation +// License: MIT OR Apache-2.0 + +//! Parses bindgen-emitted Rust source to recover the set of derives bindgen +//! applied to each generated type. Used by the per-subsystem bindgen pipeline +//! to answer `blocklisted_type_implements_trait` for base types. + +use std::{ + collections::HashMap, + path::{Path as FsPath, PathBuf}, + sync::Arc, +}; + +use bindgen::callbacks::{DeriveTrait, ImplementsTrait, ParseCallbacks}; +use syn::{Attribute, Item, ItemUse, Path, PathArguments, Type, UseTree}; +use thiserror::Error; + +/// Rust language primitives that can appear as a bare identifier in a `pub type +/// X = Y;` target. +const PRIMITIVES: &[&str] = &[ + "bool", "char", "f32", "f64", "i8", "i16", "i32", "i64", "i128", "isize", "u8", "u16", "u32", + "u64", "u128", "usize", +]; + +/// C stdint names that bindgen lowers to Rust integer primitives internally. +/// Bindgen never emits these as `pub type` aliases, so they have to be seeded +/// into the map directly. Mirrors bindgen 0.72.1's `is_stdint_type` allowlist — +/// re-verify on bindgen upgrades. +const STDINT_NAMES: &[&str] = &[ + "int8_t", + "uint8_t", + "int16_t", + "uint16_t", + "int32_t", + "uint32_t", + "int64_t", + "uint64_t", + "uintptr_t", + "intptr_t", + "ptrdiff_t", + "size_t", + "ssize_t", +]; + +/// Errors returned when parsing a bindgen-emitted source file into a +/// [`DerivesMap`]. +#[derive(Debug, Error)] +pub enum DerivesError { + #[error("failed to read {path}", path = path.display())] + Io { + path: PathBuf, + #[source] + source: std::io::Error, + }, + + #[error("failed to parse source as Rust")] + Parse(#[source] syn::Error), + + #[error("unhandled syn node: {node}")] + UnhandledSynCase { node: String }, + + #[error("malformed shape: {reason}: {node}")] + MalformedShape { reason: String, node: String }, + + #[error("alias cycle among: {names:?}")] + AliasCycle { names: Vec }, + + #[error("alias targets not found: {names:?}")] + UnresolvedAlias { names: Vec }, +} + +bitflags::bitflags! { + #[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] + struct DerivesSet: u8 { + const COPY = 1 << 0; + const DEBUG = 1 << 1; + const DEFAULT = 1 << 2; + const HASH = 1 << 3; + const PARTIAL_EQ_OR_PARTIAL_ORD = 1 << 4; + } +} + +impl DerivesSet { + const fn implements(self, derive_trait: DeriveTrait) -> bool { + let flag = match derive_trait { + DeriveTrait::Copy => Self::COPY, + DeriveTrait::Debug => Self::DEBUG, + DeriveTrait::Default => Self::DEFAULT, + DeriveTrait::Hash => Self::HASH, + DeriveTrait::PartialEqOrPartialOrd => Self::PARTIAL_EQ_OR_PARTIAL_ORD, + }; + self.contains(flag) + } +} + +impl From> for DerivesSet { + /// Build a `DerivesSet` from a list of derive trait names. + fn from(derives: Vec) -> Self { + let mut set = Self::empty(); + for derive in &derives { + set |= match derive.as_str() { + "Copy" => Self::COPY, + "Debug" => Self::DEBUG, + "Default" => Self::DEFAULT, + "Hash" => Self::HASH, + "PartialEq" | "PartialOrd" => Self::PARTIAL_EQ_OR_PARTIAL_ORD, + _ => Self::empty(), + }; + } + set + } +} + +enum DerivesSource { + Direct(DerivesSet), + Alias(String), +} + +/// Bindgen parse callback for `blocklisted_type_implements_trait` from a +/// pre-built [`DerivesMap`]. +#[derive(Debug)] +pub struct BaseDerivesCallback { + map: Arc, +} + +impl BaseDerivesCallback { + #[must_use] + pub const fn new(map: Arc) -> Self { + Self { map } + } +} + +impl ParseCallbacks for BaseDerivesCallback { + fn blocklisted_type_implements_trait( + &self, + name: &str, + derive_trait: DeriveTrait, + ) -> Option { + Some(if self.map.satisfies(name, derive_trait) { + ImplementsTrait::Yes + } else { + ImplementsTrait::No + }) + } +} + +/// Map storing Rust source type names to the set of derives the type +/// implements. +#[derive(Debug)] +pub struct DerivesMap { + types: HashMap, +} + +impl DerivesMap { + /// Reads a Rust source file from disk and parses its derive + /// information. See [`DerivesMap::from_source`] for the parsing behavior. + /// + /// # Errors + /// + /// Returns: + /// - [`DerivesError::Io`] if the file cannot be read + /// - any variant returned by [`DerivesMap::from_source`] if the contents + /// cannot be parsed + pub fn from_file(path: &FsPath) -> Result { + let source = std::fs::read_to_string(path).map_err(|source| DerivesError::Io { + path: path.to_path_buf(), + source, + })?; + Self::from_source(&source) + } + + /// Returns whether `name`'s recorded derive set contains `derive_trait`. + /// Returns `false` if `name` is not recorded. + #[must_use] + pub fn satisfies(&self, name: &str, derive_trait: DeriveTrait) -> bool { + self.types + .get(name) + .is_some_and(|&set| set.implements(derive_trait)) + } + + /// Parses a Rust source file and records the derive set for every + /// top-level `struct`, `union`, `enum`, and type alias. Unknown derive + /// idents are ignored. + /// + /// # Errors + /// + /// Returns: + /// - [`DerivesError::Parse`] if `source` is not valid Rust + /// - [`DerivesError::UnhandledSynCase`] or [`DerivesError::MalformedShape`] + /// if a classified construct does not match any recognized bindgen output + /// shape + /// - [`DerivesError::UnresolvedAlias`] or [`DerivesError::AliasCycle`] if + /// an alias cannot be resolved to a recorded type + fn from_source(source: &str) -> Result { + let file = syn::parse_str::(source).map_err(DerivesError::Parse)?; + let mut derives_map = Self::with_std_types(); + + let mut aliases: HashMap = HashMap::default(); + for (key, source) in idents_and_derives_for_items(&file.items)? { + match source { + DerivesSource::Direct(derives_set) => { + derives_map.types.insert(key, derives_set); + } + DerivesSource::Alias(aliased_to) => { + aliases.insert(key, aliased_to); + } + } + } + + derives_map.resolve_aliases(&aliases)?; + + Ok(derives_map) + } + + fn with_std_types() -> Self { + Self { + types: STDINT_NAMES + .iter() + .map(|&n| (n.to_owned(), DerivesSet::all())) + .collect(), + } + } + + /// Resolve every alias in `aliases` by walking its chain to a recorded + /// type and copying that type's derive set onto each alias along the way. + /// + /// # Errors + /// + /// Returns: + /// - [`DerivesError::UnresolvedAlias`] if a chain terminates at a name that + /// is neither a recorded type nor a queued alias + /// - [`DerivesError::AliasCycle`] if a chain revisits a name it has already + /// walked through + fn resolve_aliases(&mut self, aliases: &HashMap) -> Result<(), DerivesError> { + for key in aliases.keys() { + if self.types.contains_key(key) { + continue; + } + + let mut curr = key; + let mut walked = vec![curr]; + while !self.types.contains_key(curr) { + let Some(next) = aliases.get(curr) else { + return Err(DerivesError::UnresolvedAlias { + names: walked.into_iter().cloned().collect(), + }); + }; + if walked.contains(&next) { + return Err(DerivesError::AliasCycle { + names: walked.into_iter().cloned().collect(), + }); + } + walked.push(next); + curr = next; + } + + let target_derive_set = *self + .types + .get(curr) + .expect("`self.types.contains_key(curr)` just returned true"); + + for new_derive_key in walked { + self.types.insert(new_derive_key.clone(), target_derive_set); + } + } + + Ok(()) + } +} + +/// Classify the type-defining [`syn::Item`]s in `items`, returning their +/// type names and [`DerivesSource`]s. +/// +/// # Bindgen shapes +/// +/// Struct / Union / Enum: derives come from the `#[derive(...)]` attrs: +/// +/// ```ignore +/// #[derive(Debug, Default, Copy, Clone, Hash, PartialOrd, Ord, PartialEq, Eq)] +/// pub struct _DMF_MODULE_DESCRIPTOR { pub Size: u32, /* ... */ } +/// ``` +/// +/// Type alias / Module / Use: dispatched to the corresponding classifier. +/// +/// Impl / Const: bindgen helper blocks and anonymous layout assertions. +/// Neither contributes derive information; both are ignored. +/// +/// # Errors +/// +/// Returns: +/// - [`DerivesError::UnhandledSynCase`] for `Item` variants other than +/// Struct/Union/Enum/Type/Mod/Use/Impl/Const +/// - any error propagated from the per-shape classifiers +fn idents_and_derives_for_items( + items: &[Item], +) -> Result, DerivesError> { + let mut derives: Vec<(String, DerivesSource)> = vec![]; + + for item in items { + match item { + Item::Struct(s) => derives.push(( + s.ident.to_string(), + DerivesSource::Direct(derives_from_attrs(&s.attrs).into()), + )), + Item::Union(u) => derives.push(( + u.ident.to_string(), + DerivesSource::Direct(derives_from_attrs(&u.attrs).into()), + )), + Item::Enum(e) => derives.push(( + e.ident.to_string(), + DerivesSource::Direct(derives_from_attrs(&e.attrs).into()), + )), + Item::Type(t) => derives.push((t.ident.to_string(), derives_for_type(&t.ty)?)), + Item::Mod(m) => derives.extend(idents_and_derives_for_mod(m)?), + Item::Use(u) => derives.push(ident_and_derives_for_use(u)?), + Item::Impl(_) | Item::Const(_) => {} + other => { + return Err(DerivesError::UnhandledSynCase { + node: format!("{other:?}"), + }); + } + } + } + Ok(derives) +} + +/// Collects the derive trait names from a `#[derive(...)]` attribute list. +fn derives_from_attrs(attrs: &[Attribute]) -> Vec { + attrs + .iter() + .filter(|attr| attr.path().is_ident("derive")) + .filter_map(|attr| { + attr.parse_args_with( + syn::punctuated::Punctuated::::parse_terminated, + ) + .ok() + }) + .flatten() + .filter_map(|path| { + path.segments + .into_iter() + .next_back() + .map(|seg| seg.ident.to_string()) + }) + .collect() +} + +/// Classify a [`syn::Type`] into the [`DerivesSource`] it represents. +/// +/// # Bindgen shapes +/// +/// ```ignore +/// pub type DMFMODULE = *mut DMFMODULE__; // Type::Ptr +/// +/// pub type __C_ASSERT__ = [::core::ffi::c_char; 1usize]; // Type::Array +/// +/// +/// pub type EVT_DMF_CALLBACK = ::core::option::Option< // Type::Path (Option) +/// unsafe extern "C" fn(/* ... */) -> NTSTATUS, +/// >; +/// +/// pub type DMF_TIME_FIELDS = _DMF_TIME_FIELDS; // Type::Path (named) +/// +/// pub type WCHAR = u16; // Type::Path (primitive) +/// ``` +/// +/// # Errors +/// +/// Returns: +/// - [`DerivesError::UnhandledSynCase`] if `ty` is a `syn::Type` variant other +/// than Ptr/Path/Array +/// - [`DerivesError::MalformedShape`] if the path has no segments +/// - [`DerivesError::UnhandledSynCase`] if the path has generic arguments +fn derives_for_type(ty: &Type) -> Result { + match ty { + Type::Ptr(_) => Ok(DerivesSource::Direct(DerivesSet::all())), + Type::Array(arr) => derives_for_type(&arr.elem), + Type::Path(tp) => { + if path_is_option(&tp.path) && inner_is_bare_fn(&tp.path) { + return Ok(DerivesSource::Direct(DerivesSet::all())); + } + + let Some(last) = tp.path.segments.last() else { + return Err(DerivesError::MalformedShape { + reason: "alias path has no segments".to_owned(), + node: format!("{tp:?}"), + }); + }; + + let PathArguments::None = last.arguments else { + return Err(DerivesError::UnhandledSynCase { + node: format!("{:?}", last.arguments), + }); + }; + + if PRIMITIVES.iter().any(|&p| last.ident == p) || path_is_core_ffi_type(&tp.path) { + return Ok(DerivesSource::Direct(DerivesSet::all())); + } + + Ok(DerivesSource::Alias(last.ident.to_string())) + } + other => Err(DerivesError::UnhandledSynCase { + node: format!("{other:?}"), + }), + } +} + +/// Classify the type-defining items inside a [`syn::ItemMod`] (bindgen's +/// C-enum-as-module pattern), returning their prefixed type names and +/// [`DerivesSource`]s. +/// +/// Registers the inner `Type` under a compound key like +/// `_INTERFACE_TYPE::Type` so other types can link to it via an alias. +/// +/// # Bindgen shapes +/// +/// ```ignore +/// pub mod _INTERFACE_TYPE { +/// pub type Type = ::core::ffi::c_int; +/// pub const Isa: Type = 1; +/// pub const Eisa: Type = 2; +/// // ... +/// } +/// pub use self::_INTERFACE_TYPE::Type as INTERFACE_TYPE; +/// ``` +/// +/// # Errors +/// +/// Returns any error propagated from [`idents_and_derives_for_items`] on +/// the module's inner items. +fn idents_and_derives_for_mod( + m: &syn::ItemMod, +) -> Result, DerivesError> { + let Some((_, mod_items)) = &m.content else { + return Ok(vec![]); + }; + let prefix = format!("{}::", m.ident); + + let mut mod_items_derives = idents_and_derives_for_items(mod_items)?; + + for (key, _) in &mut mod_items_derives { + key.insert_str(0, &prefix); + } + Ok(mod_items_derives) +} + +/// Classify a [`syn::ItemUse`] (bindgen's `pub use self::_FOO::Type as +/// FOO;` rename), returning the type name and the classified +/// [`DerivesSource`]. +/// +/// # Bindgen shapes +/// +/// ```ignore +/// pub use self::_INTERFACE_TYPE::Type as INTERFACE_TYPE; +/// pub use self::_POWER_STATE_TYPE::Type as POWER_STATE_TYPE; +/// pub use self::_DEVICE_POWER_STATE::Type as DEVICE_POWER_STATE; +/// ``` +/// +/// # Errors +/// +/// Returns: +/// - [`DerivesError::UnhandledSynCase`] for `UseTree` variants other than +/// `Path`/`Rename` +fn ident_and_derives_for_use(item_use: &ItemUse) -> Result<(String, DerivesSource), DerivesError> { + let mut segments: Vec = Vec::new(); + let mut use_tree = &item_use.tree; + + while let UseTree::Path(path) = use_tree { + let seg = path.ident.to_string(); + if seg != "self" { + segments.push(seg); + } + use_tree = &path.tree; + } + + let UseTree::Rename(use_rename) = use_tree else { + return Err(DerivesError::UnhandledSynCase { + node: format!("{use_tree:?}"), + }); + }; + + segments.push(use_rename.ident.to_string()); + Ok(( + use_rename.rename.to_string(), + DerivesSource::Alias(segments.join("::")), + )) +} + +/// True when the last segment of `path` has a bare-fn type as its first +/// generic argument. +fn inner_is_bare_fn(path: &Path) -> bool { + let Some(last) = path.segments.last() else { + return false; + }; + let PathArguments::AngleBracketed(args) = &last.arguments else { + return false; + }; + matches!( + args.args.first(), + Some(syn::GenericArgument::Type(Type::BareFn(_))) + ) +} + +/// True when `path` ends in `core::option::Option`. +fn path_is_option(path: &Path) -> bool { + let segs = &path.segments; + segs.len() >= 3 + && segs[segs.len() - 3].ident == "core" + && segs[segs.len() - 2].ident == "option" + && segs[segs.len() - 1].ident == "Option" +} + +/// True when `path` ends in `core::ffi::*`. +fn path_is_core_ffi_type(path: &Path) -> bool { + let segs = &path.segments; + segs.len() >= 3 && segs[segs.len() - 3].ident == "core" && segs[segs.len() - 2].ident == "ffi" +} + +#[cfg(test)] +mod tests { + use super::*; + + fn parse(src: &str) -> DerivesMap { + DerivesMap::from_source(src).expect("parses") + } + + #[test] + #[allow(clippy::too_many_lines)] + fn parses_representative_bindgen_output() { + // Shapes observed in real bindgen output for wdk-sys: + // - POD struct with the common four-trait derive + // - Union with only Copy/Clone (Rust unions can't auto-derive Debug/Default) + // - Bindgen's `__BindgenUnionField` wrapper — PartialEq without PartialOrd + // - Bindgen's `__IncompleteArrayField` wrapper — the full nine-trait derive + // - Type alias chain: `PodAliasChain = PodAlias = Pod` should inherit Pod's + // derives. + let src = r#" + #[repr(C)] + #[derive(Debug, Default, Copy, Clone)] + pub struct Pod { pub x: u32 } + + #[repr(C)] + #[derive(Copy, Clone)] + pub union Uni { pub a: u32, pub b: u64 } + + #[derive(PartialEq, Copy, Clone, Debug, Hash)] + pub struct UnionField; + + #[derive(Copy, Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] + pub struct ArrayField; + + pub type PodAlias = Pod; + pub type PodAliasChain = PodAlias; + + pub type UCHAR = ::core::ffi::c_uchar; + pub type ULONG = ::core::ffi::c_ulong; + pub type PVOID = *mut ::core::ffi::c_void; + pub type PULONG = *mut ULONG; + + // Option: fn contributes all-except-Default, Option adds Default back — ends up with all 5. + pub type OptFn = ::core::option::Option u32>; + + // Bindgen module-enum pattern: inner `Type` aliases a primitive, and a use-rename re-exports it under a friendly name. The re-export must resolve to the inner `Type`'s derive set. + pub mod _INTERFACE_TYPE { + pub type Type = ::core::ffi::c_int; + pub const Isa: Type = 1; + } + pub use self::_INTERFACE_TYPE::Type as INTERFACE_TYPE; + "#; + let map = parse(src); + + assert!(map.satisfies("Pod", DeriveTrait::Debug)); + assert!(map.satisfies("Pod", DeriveTrait::Default)); + assert!(map.satisfies("Pod", DeriveTrait::Copy)); + assert!(!map.satisfies("Pod", DeriveTrait::Hash)); + assert!(!map.satisfies("Pod", DeriveTrait::PartialEqOrPartialOrd)); + + assert!(map.satisfies("Uni", DeriveTrait::Copy)); + assert!(!map.satisfies("Uni", DeriveTrait::Debug)); + assert!(!map.satisfies("Uni", DeriveTrait::Default)); + + // PartialEq alone now satisfies the grouped bindgen query. + assert!(map.satisfies("UnionField", DeriveTrait::PartialEqOrPartialOrd)); + assert!(map.satisfies("UnionField", DeriveTrait::Hash)); + + assert!(map.satisfies("ArrayField", DeriveTrait::PartialEqOrPartialOrd)); + assert!(map.satisfies("ArrayField", DeriveTrait::Hash)); + + // Alias chain resolves through to Pod's derives. + assert!(map.satisfies("PodAlias", DeriveTrait::Debug)); + assert!(map.satisfies("PodAlias", DeriveTrait::Default)); + assert!(map.satisfies("PodAliasChain", DeriveTrait::Debug)); + assert!(map.satisfies("PodAliasChain", DeriveTrait::Default)); + + // Primitive-target aliases: terminal shapes get the full standard derive set + // directly, without chain resolution. + for trait_ in [ + DeriveTrait::Copy, + DeriveTrait::Debug, + DeriveTrait::Default, + DeriveTrait::Hash, + DeriveTrait::PartialEqOrPartialOrd, + ] { + assert!(map.satisfies("UCHAR", trait_)); + assert!(map.satisfies("ULONG", trait_)); + assert!(map.satisfies("PVOID", trait_)); + assert!(map.satisfies("PULONG", trait_)); + } + + // Unknown type name: returns false, does not panic. + assert!(!map.satisfies("Nonexistent", DeriveTrait::Debug)); + + // Option — fn gives 4, Option adds Default → all 5. + for trait_ in [ + DeriveTrait::Copy, + DeriveTrait::Debug, + DeriveTrait::Default, + DeriveTrait::Hash, + DeriveTrait::PartialEqOrPartialOrd, + ] { + assert!(map.satisfies("OptFn", trait_)); + } + + // Module-enum pattern — both the compound key (`_INTERFACE_TYPE::Type`) and the + // re-exported friendly name (`INTERFACE_TYPE`) inherit the primitive's full + // derive set. + for trait_ in [ + DeriveTrait::Copy, + DeriveTrait::Debug, + DeriveTrait::Default, + DeriveTrait::Hash, + DeriveTrait::PartialEqOrPartialOrd, + ] { + assert!(map.satisfies("_INTERFACE_TYPE::Type", trait_)); + assert!(map.satisfies("INTERFACE_TYPE", trait_)); + } + } + + /// Every seeded stdint name derives the full standard set. Guards the + /// hand-maintained `STDINT_NAMES` list against accidental deletion and + /// keeps the `satisfies` result shape in sync with the seed. + #[test] + fn stdint_names_all_derive_standard_set() { + let map = parse(""); + for name in STDINT_NAMES { + for trait_ in [ + DeriveTrait::Copy, + DeriveTrait::Debug, + DeriveTrait::Default, + DeriveTrait::Hash, + DeriveTrait::PartialEqOrPartialOrd, + ] { + assert!( + map.satisfies(name, trait_), + "stdint {name} missing {trait_:?}" + ); + } + } + } + + /// A cyclic alias pair (`A = B; B = A;`) must surface as `AliasCycle` — + /// the chain-walking loop detects it when a step revisits a name already + /// in the walked set. + #[test] + fn alias_cycle_terminates() { + let src = r" + pub type A = B; + pub type B = A; + "; + let err = DerivesMap::from_source(src).expect_err("cycle must error"); + match err { + DerivesError::AliasCycle { mut names } => { + names.sort(); + assert_eq!(names, vec!["A".to_owned(), "B".to_owned()]); + } + other => panic!("expected AliasCycle, got {other:?}"), + } + } + + /// An alias whose target is neither a recorded type nor another pending + /// alias must surface as `UnresolvedAlias`. + #[test] + fn unresolvable_alias_errors() { + let src = r" + pub type UnknownAlias = SomeUnparsedType; + "; + let err = DerivesMap::from_source(src).expect_err("unresolvable must error"); + match err { + DerivesError::UnresolvedAlias { names } => { + assert_eq!( + names, + vec!["UnknownAlias".to_owned(), "SomeUnparsedType".to_owned()] + ); + } + other => panic!("expected UnresolvedAlias, got {other:?}"), + } + } + + /// `BaseDerivesCallback` must translate `bool` into the bindgen + /// `Some(Yes)` / `Some(No)` answers expected for blocklisted types. + #[test] + fn base_callback_known_positive_returns_yes() { + let src = r" + #[derive(Copy, Clone, Debug)] + pub struct Pod; + "; + let map = Arc::new(parse(src)); + let cb = BaseDerivesCallback::new(map); + assert!(matches!( + cb.blocklisted_type_implements_trait("Pod", DeriveTrait::Debug), + Some(ImplementsTrait::Yes) + )); + } + + #[test] + fn base_callback_known_negative_returns_no() { + let src = r" + #[derive(Copy, Clone)] + pub struct Pod; + "; + let map = Arc::new(parse(src)); + let cb = BaseDerivesCallback::new(map); + assert!(matches!( + cb.blocklisted_type_implements_trait("Pod", DeriveTrait::Debug), + Some(ImplementsTrait::No) + )); + } + + #[test] + fn base_callback_unknown_returns_no() { + let map = Arc::new(parse("")); + let cb = BaseDerivesCallback::new(map); + assert!(matches!( + cb.blocklisted_type_implements_trait("Nonexistent", DeriveTrait::Debug), + Some(ImplementsTrait::No) + )); + } +} diff --git a/crates/wdk-build/src/lib.rs b/crates/wdk-build/src/lib.rs index 710708e5b..470b41298 100644 --- a/crates/wdk-build/src/lib.rs +++ b/crates/wdk-build/src/lib.rs @@ -28,6 +28,8 @@ pub mod metadata; mod utils; mod bindgen; +#[doc(hidden)] +pub mod derives; use cargo_metadata::MetadataCommand; use serde::{Deserialize, Serialize}; From 6a5371cc8b9a86ff940f3276cb544d38200fc192 Mon Sep 17 00:00:00 2001 From: leon-xd Date: Tue, 28 Apr 2026 17:21:36 -0700 Subject: [PATCH 2/4] fix: drop regex from Cargo.lock to match wdk-sys manifest --- Cargo.lock | 1 - 1 file changed, 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 85d36a2e4..e1766700b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1193,7 +1193,6 @@ dependencies = [ "cargo_metadata", "cc", "cfg-if", - "regex", "rustversion", "serde_json", "thiserror", From d4aea8d33d7e5710fbd5f21599d10ee509eb53e2 Mon Sep 17 00:00:00 2001 From: leon-xd Date: Tue, 28 Apr 2026 17:48:35 -0700 Subject: [PATCH 3/4] clean up tests with assert_derives --- crates/wdk-build/src/derives.rs | 123 ++++++++++++++------------------ 1 file changed, 52 insertions(+), 71 deletions(-) diff --git a/crates/wdk-build/src/derives.rs b/crates/wdk-build/src/derives.rs index 1af3add5b..6cf6e1dad 100644 --- a/crates/wdk-build/src/derives.rs +++ b/crates/wdk-build/src/derives.rs @@ -524,9 +524,32 @@ mod tests { DerivesMap::from_source(src).expect("parses") } + const ALL_TRAITS: &[DeriveTrait] = &[ + DeriveTrait::Copy, + DeriveTrait::Debug, + DeriveTrait::Default, + DeriveTrait::Hash, + DeriveTrait::PartialEqOrPartialOrd, + ]; + + /// Assert that `map` reports `satisfies(name, t) == true` for exactly the + /// traits in `expected`, and `false` for every other trait in + /// [`ALL_TRAITS`]. + fn assert_derives(map: &DerivesMap, name: &str, expected: &[DeriveTrait]) { + for &t in ALL_TRAITS { + let want = expected.contains(&t); + let got = map.satisfies(name, t); + assert_eq!( + got, want, + "{name}: satisfies({t:?}) = {got}, expected {want}" + ); + } + } + #[test] - #[allow(clippy::too_many_lines)] fn parses_representative_bindgen_output() { + use DeriveTrait::{Copy, Debug, Default, Hash, PartialEqOrPartialOrd}; + // Shapes observed in real bindgen output for wdk-sys: // - POD struct with the common four-trait derive // - Union with only Copy/Clone (Rust unions can't auto-derive Debug/Default) @@ -569,71 +592,40 @@ mod tests { "#; let map = parse(src); - assert!(map.satisfies("Pod", DeriveTrait::Debug)); - assert!(map.satisfies("Pod", DeriveTrait::Default)); - assert!(map.satisfies("Pod", DeriveTrait::Copy)); - assert!(!map.satisfies("Pod", DeriveTrait::Hash)); - assert!(!map.satisfies("Pod", DeriveTrait::PartialEqOrPartialOrd)); - - assert!(map.satisfies("Uni", DeriveTrait::Copy)); - assert!(!map.satisfies("Uni", DeriveTrait::Debug)); - assert!(!map.satisfies("Uni", DeriveTrait::Default)); - - // PartialEq alone now satisfies the grouped bindgen query. - assert!(map.satisfies("UnionField", DeriveTrait::PartialEqOrPartialOrd)); - assert!(map.satisfies("UnionField", DeriveTrait::Hash)); - - assert!(map.satisfies("ArrayField", DeriveTrait::PartialEqOrPartialOrd)); - assert!(map.satisfies("ArrayField", DeriveTrait::Hash)); + assert_derives(&map, "Pod", &[Copy, Debug, Default]); + assert_derives(&map, "Uni", &[Copy]); + assert_derives( + &map, + "UnionField", + &[Copy, Debug, Hash, PartialEqOrPartialOrd], + ); + assert_derives( + &map, + "ArrayField", + &[Copy, Debug, Default, Hash, PartialEqOrPartialOrd], + ); // Alias chain resolves through to Pod's derives. - assert!(map.satisfies("PodAlias", DeriveTrait::Debug)); - assert!(map.satisfies("PodAlias", DeriveTrait::Default)); - assert!(map.satisfies("PodAliasChain", DeriveTrait::Debug)); - assert!(map.satisfies("PodAliasChain", DeriveTrait::Default)); - - // Primitive-target aliases: terminal shapes get the full standard derive set - // directly, without chain resolution. - for trait_ in [ - DeriveTrait::Copy, - DeriveTrait::Debug, - DeriveTrait::Default, - DeriveTrait::Hash, - DeriveTrait::PartialEqOrPartialOrd, - ] { - assert!(map.satisfies("UCHAR", trait_)); - assert!(map.satisfies("ULONG", trait_)); - assert!(map.satisfies("PVOID", trait_)); - assert!(map.satisfies("PULONG", trait_)); + assert_derives(&map, "PodAlias", &[Copy, Debug, Default]); + assert_derives(&map, "PodAliasChain", &[Copy, Debug, Default]); + + // Primitive-target aliases: terminal shapes get the full standard derive + // set directly, without chain resolution. + for name in ["UCHAR", "ULONG", "PVOID", "PULONG"] { + assert_derives(&map, name, ALL_TRAITS); } - // Unknown type name: returns false, does not panic. - assert!(!map.satisfies("Nonexistent", DeriveTrait::Debug)); + // Unknown type name: returns false for every trait, does not panic. + assert_derives(&map, "Nonexistent", &[]); // Option — fn gives 4, Option adds Default → all 5. - for trait_ in [ - DeriveTrait::Copy, - DeriveTrait::Debug, - DeriveTrait::Default, - DeriveTrait::Hash, - DeriveTrait::PartialEqOrPartialOrd, - ] { - assert!(map.satisfies("OptFn", trait_)); - } + assert_derives(&map, "OptFn", ALL_TRAITS); - // Module-enum pattern — both the compound key (`_INTERFACE_TYPE::Type`) and the - // re-exported friendly name (`INTERFACE_TYPE`) inherit the primitive's full - // derive set. - for trait_ in [ - DeriveTrait::Copy, - DeriveTrait::Debug, - DeriveTrait::Default, - DeriveTrait::Hash, - DeriveTrait::PartialEqOrPartialOrd, - ] { - assert!(map.satisfies("_INTERFACE_TYPE::Type", trait_)); - assert!(map.satisfies("INTERFACE_TYPE", trait_)); - } + // Module-enum pattern — both the compound key (`_INTERFACE_TYPE::Type`) and + // the re-exported friendly name (`INTERFACE_TYPE`) inherit the primitive's + // full derive set. + assert_derives(&map, "_INTERFACE_TYPE::Type", ALL_TRAITS); + assert_derives(&map, "INTERFACE_TYPE", ALL_TRAITS); } /// Every seeded stdint name derives the full standard set. Guards the @@ -643,18 +635,7 @@ mod tests { fn stdint_names_all_derive_standard_set() { let map = parse(""); for name in STDINT_NAMES { - for trait_ in [ - DeriveTrait::Copy, - DeriveTrait::Debug, - DeriveTrait::Default, - DeriveTrait::Hash, - DeriveTrait::PartialEqOrPartialOrd, - ] { - assert!( - map.satisfies(name, trait_), - "stdint {name} missing {trait_:?}" - ); - } + assert_derives(&map, name, ALL_TRAITS); } } From ff10314c2bb7536114770027ce43ce7ce2d4ac3c Mon Sep 17 00:00:00 2001 From: leon-xd Date: Thu, 30 Apr 2026 11:18:57 -0700 Subject: [PATCH 4/4] added robust unit tests, integration test, added doc comments --- crates/wdk-build/Cargo.toml | 2 +- crates/wdk-build/src/derives.rs | 600 +++++++++++++++++++++--------- crates/wdk-build/tests/derives.rs | 146 ++++++++ 3 files changed, 565 insertions(+), 183 deletions(-) create mode 100644 crates/wdk-build/tests/derives.rs diff --git a/crates/wdk-build/Cargo.toml b/crates/wdk-build/Cargo.toml index fd09b82d9..d8b2d6183 100644 --- a/crates/wdk-build/Cargo.toml +++ b/crates/wdk-build/Cargo.toml @@ -33,7 +33,7 @@ rustversion.workspace = true semver.workspace = true serde = { features = ["derive"], workspace = true } serde_json.workspace = true -syn = { features = ["parsing"], workspace = true } +syn = { features = ["extra-traits", "full", "parsing"], workspace = true } thiserror.workspace = true tracing.workspace = true windows = { features = [ diff --git a/crates/wdk-build/src/derives.rs b/crates/wdk-build/src/derives.rs index 6cf6e1dad..0c58d2988 100644 --- a/crates/wdk-build/src/derives.rs +++ b/crates/wdk-build/src/derives.rs @@ -46,27 +46,53 @@ const STDINT_NAMES: &[&str] = &[ /// [`DerivesMap`]. #[derive(Debug, Error)] pub enum DerivesError { + /// Reading the bindgen-emitted source file from disk failed. #[error("failed to read {path}", path = path.display())] Io { + /// Path to the file that could not be read. path: PathBuf, + /// Underlying I/O error from the filesystem operation. #[source] source: std::io::Error, }, + /// `syn` failed to parse the source as Rust. #[error("failed to parse source as Rust")] Parse(#[source] syn::Error), + /// Encountered a top-level [`syn::Item`] variant this parser does not + /// handle. #[error("unhandled syn node: {node}")] - UnhandledSynCase { node: String }, + UnhandledSynCase { + /// Debug-formatted representation of the unhandled node. + node: String, + }, + /// A recognized item kind whose internal shape did not match what the + /// parser expects from bindgen output. #[error("malformed shape: {reason}: {node}")] - MalformedShape { reason: String, node: String }, + MalformedShape { + /// Why the node shape is considered malformed. + reason: String, + /// Debug-formatted representation of the malformed node. + node: String, + }, + /// Alias chain visited the same name twice while walking aliases to + /// their target type. #[error("alias cycle among: {names:?}")] - AliasCycle { names: Vec }, + AliasCycle { + /// Names participating in the detected cycle, in walk order. + names: Vec, + }, - #[error("alias targets not found: {names:?}")] - UnresolvedAlias { names: Vec }, + /// Alias chain terminated at a name that is neither a recorded type nor + /// another pending alias. + #[error("alias target not found: {target}")] + UnresolvedAlias { + /// The unresolved target name. + target: String, + }, } bitflags::bitflags! { @@ -111,6 +137,7 @@ impl From> for DerivesSet { } } +#[derive(Debug)] enum DerivesSource { Direct(DerivesSet), Alias(String), @@ -124,6 +151,7 @@ pub struct BaseDerivesCallback { } impl BaseDerivesCallback { + /// Wrap a shared [`DerivesMap`] for use as a `bindgen` [`ParseCallbacks`]. #[must_use] pub const fn new(map: Arc) -> Self { Self { map } @@ -242,7 +270,7 @@ impl DerivesMap { while !self.types.contains_key(curr) { let Some(next) = aliases.get(curr) else { return Err(DerivesError::UnresolvedAlias { - names: walked.into_iter().cloned().collect(), + target: curr.clone(), }); }; if walked.contains(&next) { @@ -520,200 +548,408 @@ fn path_is_core_ffi_type(path: &Path) -> bool { mod tests { use super::*; - fn parse(src: &str) -> DerivesMap { - DerivesMap::from_source(src).expect("parses") - } - - const ALL_TRAITS: &[DeriveTrait] = &[ - DeriveTrait::Copy, - DeriveTrait::Debug, - DeriveTrait::Default, - DeriveTrait::Hash, - DeriveTrait::PartialEqOrPartialOrd, - ]; - - /// Assert that `map` reports `satisfies(name, t) == true` for exactly the - /// traits in `expected`, and `false` for every other trait in - /// [`ALL_TRAITS`]. - fn assert_derives(map: &DerivesMap, name: &str, expected: &[DeriveTrait]) { - for &t in ALL_TRAITS { - let want = expected.contains(&t); - let got = map.satisfies(name, t); - assert_eq!( - got, want, - "{name}: satisfies({t:?}) = {got}, expected {want}" - ); + #[track_caller] + fn assert_direct_full(source: DerivesSource) { + match source { + DerivesSource::Direct(set) => assert_eq!(set, DerivesSet::all()), + DerivesSource::Alias(name) => panic!("expected Direct(all), got Alias({name:?})"), } } - #[test] - fn parses_representative_bindgen_output() { - use DeriveTrait::{Copy, Debug, Default, Hash, PartialEqOrPartialOrd}; - - // Shapes observed in real bindgen output for wdk-sys: - // - POD struct with the common four-trait derive - // - Union with only Copy/Clone (Rust unions can't auto-derive Debug/Default) - // - Bindgen's `__BindgenUnionField` wrapper — PartialEq without PartialOrd - // - Bindgen's `__IncompleteArrayField` wrapper — the full nine-trait derive - // - Type alias chain: `PodAliasChain = PodAlias = Pod` should inherit Pod's - // derives. - let src = r#" - #[repr(C)] - #[derive(Debug, Default, Copy, Clone)] - pub struct Pod { pub x: u32 } - - #[repr(C)] - #[derive(Copy, Clone)] - pub union Uni { pub a: u32, pub b: u64 } - - #[derive(PartialEq, Copy, Clone, Debug, Hash)] - pub struct UnionField; - - #[derive(Copy, Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] - pub struct ArrayField; - - pub type PodAlias = Pod; - pub type PodAliasChain = PodAlias; - - pub type UCHAR = ::core::ffi::c_uchar; - pub type ULONG = ::core::ffi::c_ulong; - pub type PVOID = *mut ::core::ffi::c_void; - pub type PULONG = *mut ULONG; - - // Option: fn contributes all-except-Default, Option adds Default back — ends up with all 5. - pub type OptFn = ::core::option::Option u32>; - - // Bindgen module-enum pattern: inner `Type` aliases a primitive, and a use-rename re-exports it under a friendly name. The re-export must resolve to the inner `Type`'s derive set. - pub mod _INTERFACE_TYPE { - pub type Type = ::core::ffi::c_int; - pub const Isa: Type = 1; + #[track_caller] + fn assert_alias(source: DerivesSource, expected: &str) { + match source { + DerivesSource::Alias(s) => assert_eq!(s, expected), + DerivesSource::Direct(set) => { + panic!("expected Alias({expected:?}), got Direct({set:?})") } - pub use self::_INTERFACE_TYPE::Type as INTERFACE_TYPE; - "#; - let map = parse(src); - - assert_derives(&map, "Pod", &[Copy, Debug, Default]); - assert_derives(&map, "Uni", &[Copy]); - assert_derives( - &map, - "UnionField", - &[Copy, Debug, Hash, PartialEqOrPartialOrd], - ); - assert_derives( - &map, - "ArrayField", - &[Copy, Debug, Default, Hash, PartialEqOrPartialOrd], - ); - - // Alias chain resolves through to Pod's derives. - assert_derives(&map, "PodAlias", &[Copy, Debug, Default]); - assert_derives(&map, "PodAliasChain", &[Copy, Debug, Default]); - - // Primitive-target aliases: terminal shapes get the full standard derive - // set directly, without chain resolution. - for name in ["UCHAR", "ULONG", "PVOID", "PULONG"] { - assert_derives(&map, name, ALL_TRAITS); - } - - // Unknown type name: returns false for every trait, does not panic. - assert_derives(&map, "Nonexistent", &[]); - - // Option — fn gives 4, Option adds Default → all 5. - assert_derives(&map, "OptFn", ALL_TRAITS); - - // Module-enum pattern — both the compound key (`_INTERFACE_TYPE::Type`) and - // the re-exported friendly name (`INTERFACE_TYPE`) inherit the primitive's - // full derive set. - assert_derives(&map, "_INTERFACE_TYPE::Type", ALL_TRAITS); - assert_derives(&map, "INTERFACE_TYPE", ALL_TRAITS); + } } - /// Every seeded stdint name derives the full standard set. Guards the - /// hand-maintained `STDINT_NAMES` list against accidental deletion and - /// keeps the `satisfies` result shape in sync with the seed. - #[test] - fn stdint_names_all_derive_standard_set() { - let map = parse(""); - for name in STDINT_NAMES { - assert_derives(&map, name, ALL_TRAITS); + mod path_checks { + use syn::parse_str; + + use super::*; + + #[test] + fn path_is_option_recognizes_full_path() { + let p: Path = parse_str("::core::option::Option").unwrap(); + assert!(path_is_option(&p)); + let p: Path = parse_str("core::option::Option").unwrap(); + assert!(path_is_option(&p)); + } + + #[test] + fn path_is_option_rejects_short_or_wrong_paths() { + let p: Path = parse_str("Option").unwrap(); + assert!(!path_is_option(&p)); + let p: Path = parse_str("std::option::Option").unwrap(); + assert!(!path_is_option(&p)); + let p: Path = parse_str("core::ffi::c_void").unwrap(); + assert!(!path_is_option(&p)); + } + + #[test] + fn path_is_core_ffi_type_recognizes() { + let p: Path = parse_str("::core::ffi::c_void").unwrap(); + assert!(path_is_core_ffi_type(&p)); + let p: Path = parse_str("core::ffi::c_int").unwrap(); + assert!(path_is_core_ffi_type(&p)); + } + + #[test] + fn path_is_core_ffi_type_rejects_non_ffi() { + let p: Path = parse_str("core::option::Option").unwrap(); + assert!(!path_is_core_ffi_type(&p)); + let p: Path = parse_str("std::ffi::CStr").unwrap(); + assert!(!path_is_core_ffi_type(&p)); + let p: Path = parse_str("c_int").unwrap(); + assert!(!path_is_core_ffi_type(&p)); + } + + #[test] + fn inner_is_bare_fn_true_for_option_fn() { + let p: Path = + parse_str("::core::option::Option u32>").unwrap(); + assert!(inner_is_bare_fn(&p)); + } + + #[test] + fn inner_is_bare_fn_false_for_other_generics() { + let p: Path = parse_str("Option").unwrap(); + assert!(!inner_is_bare_fn(&p)); + let p: Path = parse_str("Vec").unwrap(); + assert!(!inner_is_bare_fn(&p)); + } + + #[test] + fn inner_is_bare_fn_false_for_no_generics() { + let p: Path = parse_str("u32").unwrap(); + assert!(!inner_is_bare_fn(&p)); } } - /// A cyclic alias pair (`A = B; B = A;`) must surface as `AliasCycle` — - /// the chain-walking loop detects it when a step revisits a name already - /// in the walked set. - #[test] - fn alias_cycle_terminates() { - let src = r" - pub type A = B; - pub type B = A; - "; - let err = DerivesMap::from_source(src).expect_err("cycle must error"); - match err { - DerivesError::AliasCycle { mut names } => { - names.sort(); - assert_eq!(names, vec!["A".to_owned(), "B".to_owned()]); + mod classifiers { + use syn::parse_str; + + use super::*; + + #[test] + fn derives_from_attrs_extracts_idents() { + let item: syn::ItemStruct = + parse_str("#[derive(Copy, Clone, Debug)] pub struct S;").unwrap(); + let derives = derives_from_attrs(&item.attrs); + assert_eq!(derives, vec!["Copy", "Clone", "Debug"]); + } + + #[test] + fn derives_from_attrs_ignores_non_derive_attrs() { + let item: syn::ItemStruct = + parse_str("#[repr(C)] #[derive(Copy)] #[allow(dead_code)] pub struct S;").unwrap(); + let derives = derives_from_attrs(&item.attrs); + assert_eq!(derives, vec!["Copy"]); + } + + #[test] + fn derives_from_attrs_uses_last_path_segment() { + let item: syn::ItemStruct = + parse_str("#[derive(::core::marker::Copy)] pub struct S;").unwrap(); + let derives = derives_from_attrs(&item.attrs); + assert_eq!(derives, vec!["Copy"]); + } + + #[test] + fn derives_from_attrs_no_derives_returns_empty() { + let item: syn::ItemStruct = parse_str("#[repr(C)] pub struct S;").unwrap(); + assert!(derives_from_attrs(&item.attrs).is_empty()); + } + + #[test] + fn derives_for_type_pointer_gets_all() { + let ty: Type = parse_str("*mut u32").unwrap(); + assert_direct_full(derives_for_type(&ty).unwrap()); + let ty: Type = parse_str("*const ::core::ffi::c_void").unwrap(); + assert_direct_full(derives_for_type(&ty).unwrap()); + } + + #[test] + fn derives_for_type_array_recurses_into_element() { + let ty: Type = parse_str("[u32; 4]").unwrap(); + assert_direct_full(derives_for_type(&ty).unwrap()); + + let ty: Type = parse_str("[SomeAlias; 8]").unwrap(); + assert_alias(derives_for_type(&ty).unwrap(), "SomeAlias"); + } + + #[test] + fn derives_for_type_primitive_path_gets_all() { + let ty: Type = parse_str("u32").unwrap(); + assert_direct_full(derives_for_type(&ty).unwrap()); + } + + #[test] + fn derives_for_type_core_ffi_path_gets_all() { + let ty: Type = parse_str("::core::ffi::c_int").unwrap(); + assert_direct_full(derives_for_type(&ty).unwrap()); + } + + #[test] + fn derives_for_type_option_fn_gets_all() { + let ty: Type = + parse_str("::core::option::Option u32>").unwrap(); + assert_direct_full(derives_for_type(&ty).unwrap()); + } + + #[test] + fn derives_for_type_named_alias_returns_alias_source() { + let ty: Type = parse_str("SomeAlias").unwrap(); + assert_alias(derives_for_type(&ty).unwrap(), "SomeAlias"); + } + + #[test] + fn derives_for_type_path_with_unsupported_generics_is_unhandled() { + // Vec is not the Option shape, so the `PathArguments::None` + // check fires and surfaces UnhandledSynCase. + let ty: Type = parse_str("Vec").unwrap(); + match derives_for_type(&ty).unwrap_err() { + DerivesError::UnhandledSynCase { .. } => {} + other => panic!("expected UnhandledSynCase, got {other:?}"), } - other => panic!("expected AliasCycle, got {other:?}"), + } + + #[test] + fn derives_for_type_unsupported_variant_is_unhandled() { + let ty: Type = parse_str("(u32, u64)").unwrap(); + assert!(matches!( + derives_for_type(&ty), + Err(DerivesError::UnhandledSynCase { .. }) + )); + + let ty: Type = parse_str("&u32").unwrap(); + assert!(matches!( + derives_for_type(&ty), + Err(DerivesError::UnhandledSynCase { .. }) + )); + + let ty: Type = parse_str("dyn Send").unwrap(); + assert!(matches!( + derives_for_type(&ty), + Err(DerivesError::UnhandledSynCase { .. }) + )); + } + + #[test] + fn ident_and_derives_for_use_self_path_rename() { + let item: ItemUse = parse_str("pub use self::_FOO::Type as FOO;").unwrap(); + let (key, source) = ident_and_derives_for_use(&item).unwrap(); + assert_eq!(key, "FOO"); + assert_alias(source, "_FOO::Type"); + } + + #[test] + fn ident_and_derives_for_use_no_self_segment() { + let item: ItemUse = parse_str("pub use _FOO::Type as FOO;").unwrap(); + let (key, source) = ident_and_derives_for_use(&item).unwrap(); + assert_eq!(key, "FOO"); + assert_alias(source, "_FOO::Type"); + } + + #[test] + fn ident_and_derives_for_use_glob_is_unhandled() { + let item: ItemUse = parse_str("pub use foo::*;").unwrap(); + assert!(matches!( + ident_and_derives_for_use(&item), + Err(DerivesError::UnhandledSynCase { .. }) + )); + } + + #[test] + fn ident_and_derives_for_use_no_rename_is_unhandled() { + let item: ItemUse = parse_str("pub use foo::Bar;").unwrap(); + assert!(matches!( + ident_and_derives_for_use(&item), + Err(DerivesError::UnhandledSynCase { .. }) + )); + } + + #[test] + fn ident_and_derives_for_use_group_is_unhandled() { + let item: ItemUse = parse_str("pub use foo::{Bar, Baz};").unwrap(); + assert!(matches!( + ident_and_derives_for_use(&item), + Err(DerivesError::UnhandledSynCase { .. }) + )); + } + + #[test] + fn idents_and_derives_for_mod_prefixes_inner_idents() { + let m: syn::ItemMod = + parse_str("pub mod _OUTER { pub type Type = ::core::ffi::c_int; }").unwrap(); + let mut result = idents_and_derives_for_mod(&m).unwrap(); + assert_eq!(result.len(), 1); + let (key, source) = result.remove(0); + assert_eq!(key, "_OUTER::Type"); + assert_direct_full(source); + } + + #[test] + fn idents_and_derives_for_mod_empty_content_returns_empty() { + // External mod declaration (no inline body) — `m.content` is `None`. + let m: syn::ItemMod = parse_str("pub mod foo;").unwrap(); + assert!(idents_and_derives_for_mod(&m).unwrap().is_empty()); + } + + #[test] + fn unsupported_item_kind_surfaces_unhandled_syn_case() { + // Item::Trait is not part of the supported Struct/Union/Enum/Type/Mod/ + // Use/Impl/Const set, so the catch-all arm fires. + assert!(matches!( + DerivesMap::from_source("pub trait T {}"), + Err(DerivesError::UnhandledSynCase { .. }) + )); } } - /// An alias whose target is neither a recorded type nor another pending - /// alias must surface as `UnresolvedAlias`. - #[test] - fn unresolvable_alias_errors() { - let src = r" - pub type UnknownAlias = SomeUnparsedType; - "; - let err = DerivesMap::from_source(src).expect_err("unresolvable must error"); - match err { - DerivesError::UnresolvedAlias { names } => { - assert_eq!( - names, - vec!["UnknownAlias".to_owned(), "SomeUnparsedType".to_owned()] - ); + mod alias_resolution { + use super::*; + + #[test] + fn resolve_aliases_chain_of_three_inherits_target_set() { + // A → B → C, where C is the only recorded type. + let mut map = DerivesMap::with_std_types(); + map.types.insert("C".into(), DerivesSet::all()); + let mut aliases = HashMap::new(); + aliases.insert("A".into(), "B".into()); + aliases.insert("B".into(), "C".into()); + map.resolve_aliases(&aliases).unwrap(); + assert_eq!(map.types.get("A"), Some(&DerivesSet::all())); + assert_eq!(map.types.get("B"), Some(&DerivesSet::all())); + } + + #[test] + fn resolve_aliases_skips_already_recorded_keys() { + let mut map = DerivesMap::with_std_types(); + map.types.insert("A".into(), DerivesSet::COPY); + let mut aliases = HashMap::new(); + // A is already recorded; the alias entry must be skipped (no overwrite). + aliases.insert("A".into(), "NeverResolved".into()); + map.resolve_aliases(&aliases).unwrap(); + assert_eq!(map.types.get("A"), Some(&DerivesSet::COPY)); + } + + #[test] + fn resolve_aliases_empty_input_is_noop() { + let mut map = DerivesMap::with_std_types(); + let snapshot = map.types.clone(); + map.resolve_aliases(&HashMap::new()).unwrap(); + assert_eq!(map.types, snapshot); + } + + /// Every seeded stdint name derives the full standard set. Guards the + /// hand-maintained `STDINT_NAMES` list against accidental deletion and + /// keeps the `satisfies` result shape in sync with the seed. + #[test] + fn stdint_names_all_derive_standard_set() { + let map = DerivesMap::from_source("").expect("parses"); + for name in STDINT_NAMES { + for trait_ in [ + DeriveTrait::Copy, + DeriveTrait::Debug, + DeriveTrait::Default, + DeriveTrait::Hash, + DeriveTrait::PartialEqOrPartialOrd, + ] { + assert!( + map.satisfies(name, trait_), + "stdint {name} missing {trait_:?}" + ); + } } - other => panic!("expected UnresolvedAlias, got {other:?}"), } - } - /// `BaseDerivesCallback` must translate `bool` into the bindgen - /// `Some(Yes)` / `Some(No)` answers expected for blocklisted types. - #[test] - fn base_callback_known_positive_returns_yes() { - let src = r" - #[derive(Copy, Clone, Debug)] - pub struct Pod; - "; - let map = Arc::new(parse(src)); - let cb = BaseDerivesCallback::new(map); - assert!(matches!( - cb.blocklisted_type_implements_trait("Pod", DeriveTrait::Debug), - Some(ImplementsTrait::Yes) - )); - } + /// A cyclic alias pair (`A = B; B = A;`) must surface as `AliasCycle` — + /// the chain-walking loop detects it when a step revisits a name + /// already in the walked set. + #[test] + fn alias_cycle_terminates() { + let src = r" + pub type A = B; + pub type B = A; + "; + let err = DerivesMap::from_source(src).expect_err("cycle must error"); + match err { + DerivesError::AliasCycle { mut names } => { + names.sort(); + assert_eq!(names, vec!["A".to_owned(), "B".to_owned()]); + } + other => panic!("expected AliasCycle, got {other:?}"), + } + } - #[test] - fn base_callback_known_negative_returns_no() { - let src = r" - #[derive(Copy, Clone)] - pub struct Pod; - "; - let map = Arc::new(parse(src)); - let cb = BaseDerivesCallback::new(map); - assert!(matches!( - cb.blocklisted_type_implements_trait("Pod", DeriveTrait::Debug), - Some(ImplementsTrait::No) - )); + /// An alias whose target is neither a recorded type nor another pending + /// alias must surface as `UnresolvedAlias`. + #[test] + fn unresolvable_alias_errors() { + let src = r" + pub type UnknownAlias = SomeUnparsedType; + "; + let err = DerivesMap::from_source(src).expect_err("unresolvable must error"); + match err { + DerivesError::UnresolvedAlias { target } => { + assert_eq!(target, "SomeUnparsedType"); + } + other => panic!("expected UnresolvedAlias, got {other:?}"), + } + } } - #[test] - fn base_callback_unknown_returns_no() { - let map = Arc::new(parse("")); - let cb = BaseDerivesCallback::new(map); - assert!(matches!( - cb.blocklisted_type_implements_trait("Nonexistent", DeriveTrait::Debug), - Some(ImplementsTrait::No) - )); + mod base_callback { + use super::*; + + /// `BaseDerivesCallback` must translate `bool` into the bindgen + /// `Some(Yes)` / `Some(No)` answers expected for blocklisted types. + #[test] + fn base_callback_known_positive_returns_yes() { + let src = r" + #[derive(Copy, Clone, Debug)] + pub struct Pod; + "; + let map = Arc::new(DerivesMap::from_source(src).expect("parses")); + let cb = BaseDerivesCallback::new(map); + + assert!(matches!( + cb.blocklisted_type_implements_trait("Pod", DeriveTrait::Copy), + Some(ImplementsTrait::Yes) + )); + + assert!(matches!( + cb.blocklisted_type_implements_trait("Pod", DeriveTrait::Debug), + Some(ImplementsTrait::Yes) + )); + } + + #[test] + fn base_callback_known_negative_returns_no() { + let src = r" + #[derive(Copy, Clone)] + pub struct Pod; + "; + let map = Arc::new(DerivesMap::from_source(src).expect("parses")); + let cb = BaseDerivesCallback::new(map); + assert!(matches!( + cb.blocklisted_type_implements_trait("Pod", DeriveTrait::Debug), + Some(ImplementsTrait::No) + )); + } + + #[test] + fn base_callback_unknown_returns_no() { + let map = Arc::new(DerivesMap::from_source("").expect("parses")); + let cb = BaseDerivesCallback::new(map); + assert!(matches!( + cb.blocklisted_type_implements_trait("Nonexistent", DeriveTrait::Debug), + Some(ImplementsTrait::No) + )); + } } } diff --git a/crates/wdk-build/tests/derives.rs b/crates/wdk-build/tests/derives.rs new file mode 100644 index 000000000..57114bb07 --- /dev/null +++ b/crates/wdk-build/tests/derives.rs @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft Corporation +// License: MIT OR Apache-2.0 + +//! Integration tests for [`wdk_build::derives::DerivesMap`] driven through +//! [`DerivesMap::from_file`]: writes a representative bindgen source snippet +//! to a temp file and asserts the recovered derive sets match each documented +//! bindgen output shape. + +use assert_fs::{NamedTempFile, fixture::FileWriteStr}; +use bindgen::callbacks::DeriveTrait; +use wdk_build::derives::{DerivesError, DerivesMap}; + +const ALL_TRAITS: &[DeriveTrait] = &[ + DeriveTrait::Copy, + DeriveTrait::Debug, + DeriveTrait::Default, + DeriveTrait::Hash, + DeriveTrait::PartialEqOrPartialOrd, +]; + +/// Writes `src` to a temp file and parses it through the public +/// [`DerivesMap::from_file`] entry point. +fn parse(src: &str) -> DerivesMap { + let tmp = NamedTempFile::new("bindgen_output.rs").expect("create temp file"); + tmp.write_str(src).expect("write temp file"); + DerivesMap::from_file(tmp.path()).expect("parses") +} + +/// Assert that `map` reports `satisfies(name, t) == true` for exactly the +/// traits in `expected`, and `false` for every other trait in [`ALL_TRAITS`]. +fn assert_derives(map: &DerivesMap, name: &str, expected: &[DeriveTrait]) { + for &t in ALL_TRAITS { + let want = expected.contains(&t); + let got = map.satisfies(name, t); + assert_eq!( + got, want, + "{name}: satisfies({t:?}) = {got}, expected {want}" + ); + } +} + +#[test] +fn parses_representative_bindgen_output() { + use DeriveTrait::{Copy, Debug, Default, Hash, PartialEqOrPartialOrd}; + + // Shapes observed in real bindgen output for wdk-sys: + // - POD struct with the common four-trait derive + // - Union with only Copy/Clone (Rust unions can't auto-derive Debug/Default) + // - Bindgen's `__BindgenUnionField` wrapper — PartialEq without PartialOrd + // - Bindgen's `__IncompleteArrayField` wrapper — the full nine-trait derive + // - Type alias chain: `PodAliasChain = PodAlias = Pod` should inherit Pod's + // derives. + let src = r#" + #[repr(C)] + #[derive(Debug, Default, Copy, Clone)] + pub struct Pod { pub x: u32 } + + #[repr(C)] + #[derive(Copy, Clone)] + pub union Uni { pub a: u32, pub b: u64 } + + #[derive(PartialEq, Copy, Clone, Debug, Hash)] + pub struct UnionField; + + #[derive(Copy, Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] + pub struct ArrayField; + + pub type PodAlias = Pod; + pub type PodAliasChain = PodAlias; + + pub type UCHAR = ::core::ffi::c_uchar; + pub type ULONG = ::core::ffi::c_ulong; + pub type PVOID = *mut ::core::ffi::c_void; + pub type PULONG = *mut ULONG; + + // Option: fn contributes all-except-Default, Option adds Default back — ends up with all 5. + pub type OptFn = ::core::option::Option u32>; + + // Bindgen module-enum pattern: inner `Type` aliases a primitive, and a use-rename re-exports it under a friendly name. The re-export must resolve to the inner `Type`'s derive set. + pub mod _INTERFACE_TYPE { + pub type Type = ::core::ffi::c_int; + pub const Isa: Type = 1; + } + pub use self::_INTERFACE_TYPE::Type as INTERFACE_TYPE; + "#; + let map = parse(src); + + assert_derives(&map, "Pod", &[Copy, Debug, Default]); + assert_derives(&map, "Uni", &[Copy]); + assert_derives( + &map, + "UnionField", + &[Copy, Debug, Hash, PartialEqOrPartialOrd], + ); + assert_derives( + &map, + "ArrayField", + &[Copy, Debug, Default, Hash, PartialEqOrPartialOrd], + ); + + // Alias chain resolves through to Pod's derives. + assert_derives(&map, "PodAlias", &[Copy, Debug, Default]); + assert_derives(&map, "PodAliasChain", &[Copy, Debug, Default]); + + // Primitive-target aliases: terminal shapes get the full standard derive + // set directly, without chain resolution. + for name in ["UCHAR", "ULONG", "PVOID", "PULONG"] { + assert_derives(&map, name, ALL_TRAITS); + } + + // Unknown type name: returns false for every trait, does not panic. + assert_derives(&map, "Nonexistent", &[]); + + // Option — fn gives 4, Option adds Default → all 5. + assert_derives(&map, "OptFn", ALL_TRAITS); + + // Module-enum pattern — both the compound key (`_INTERFACE_TYPE::Type`) and + // the re-exported friendly name (`INTERFACE_TYPE`) inherit the primitive's + // full derive set. + assert_derives(&map, "_INTERFACE_TYPE::Type", ALL_TRAITS); + assert_derives(&map, "INTERFACE_TYPE", ALL_TRAITS); +} + +#[test] +fn from_file_missing_path_returns_io_error() { + let err = DerivesMap::from_file(std::path::Path::new( + "/this/path/does/not/exist/bindgen_output.rs", + )) + .expect_err("missing file must error"); + assert!( + matches!(err, DerivesError::Io { .. }), + "expected Io, got {err:?}" + ); +} + +#[test] +fn from_file_invalid_rust_returns_parse_error() { + let tmp = NamedTempFile::new("bad.rs").expect("create temp file"); + tmp.write_str("not @ valid @ rust @@@") + .expect("write temp file"); + let err = DerivesMap::from_file(tmp.path()).expect_err("invalid syntax must error"); + assert!( + matches!(err, DerivesError::Parse(_)), + "expected Parse, got {err:?}" + ); +}