From ea280298e396a94108c55bb58c5a19e64b962c09 Mon Sep 17 00:00:00 2001 From: Thibaut Lorrain Date: Mon, 12 Jan 2026 12:01:18 +0100 Subject: [PATCH 1/5] move WithOrSkip to utils --- impl/src/hash.rs | 51 ++++++---------------------------------------- impl/src/utils.rs | 52 +++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 47 deletions(-) diff --git a/impl/src/hash.rs b/impl/src/hash.rs index c78ae7b3..c5ceae47 100644 --- a/impl/src/hash.rs +++ b/impl/src/hash.rs @@ -8,7 +8,6 @@ use crate::utils::{ }; use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens}; -use syn::parse::{Parse, ParseStream}; use syn::{ parse_quote, punctuated::{self, Punctuated}, @@ -39,7 +38,7 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result { for attr_name in &secondary_attr_names { if attr::Skip::parse_attrs(&field.attrs, attr_name)? @@ -51,14 +50,14 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result { skipped_fields.insert(n); } Some(Spanning { - item: FieldAttributes::With(with), + item: attr::WithOrSkip::With(with), .. }) => { alternate_hash_functions.insert(n, with.func.clone()); @@ -85,7 +84,7 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result { for attr_name in &secondary_attr_names { if attr::Skip::parse_attrs(&field.attrs, attr_name)? @@ -97,14 +96,14 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result { skipped_fields.insert(n); } Some(Spanning { - item: FieldAttributes::With(with), + item: attr::WithOrSkip::With(with), .. }) => { alternate_hash_functions.insert(n, with.func.clone()); @@ -136,44 +135,6 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result) -> syn::Result { - mod ident { - use syn::custom_keyword; - - custom_keyword!(with); - custom_keyword!(skip); - custom_keyword!(ignore); - } - - // `.lookahead1()` with all possible idents forms a nice error message including all the - // possible variants. - let ahead = input.lookahead1(); - - if ahead.peek(ident::with) { - Ok(Self::With(input.parse()?)) - } else if ahead.peek(ident::skip) || ahead.peek(ident::ignore) { - _ = input.parse::()?; - Ok(Self::Skip) - } else { - Err(ahead.error()) - } - } -} - -impl ParseMultiple for FieldAttributes {} - /// Indices of [`syn::Field`]s marked with an [`attr::Skip`]. type SkippedFields = HashSet; diff --git a/impl/src/utils.rs b/impl/src/utils.rs index 82178bb1..f0752d15 100644 --- a/impl/src/utils.rs +++ b/impl/src/utils.rs @@ -1566,8 +1566,10 @@ pub(crate) mod attr { pub(crate) use self::skip::Skip; #[cfg(any(feature = "as_ref", feature = "from", feature = "try_from"))] pub(crate) use self::types::Types; - #[cfg(feature = "hash")] + #[cfg(any(feature = "hash", feature = "eq"))] pub(crate) use self::with::With; + #[cfg(any(feature = "hash", feature = "eq"))] + pub(crate) use self::with_or_skip::WithOrSkip; #[cfg(any(feature = "as_ref", feature = "from"))] pub(crate) use self::{conversion::Conversion, field_conversion::FieldConversion}; #[cfg(feature = "try_from")] @@ -2404,7 +2406,7 @@ pub(crate) mod attr { impl ParseMultiple for RenameAll {} } - #[cfg(feature = "hash")] + #[cfg(any(feature = "hash", feature = "eq"))] mod with { use syn::parenthesized; use syn::parse::{Parse, ParseStream}; @@ -2439,6 +2441,52 @@ pub(crate) mod attr { impl ParseMultiple for With {} } + + + #[cfg(any(feature = "hash", feature = "eq"))] + mod with_or_skip { + use syn::parse::{Parse, ParseStream}; + use crate::utils::attr; + use crate::utils::attr::ParseMultiple; + + /// Custom combination of an [`attr::Skip`] and [`attr::With`] used for a better error message + /// including all the possible variants. + pub enum WithOrSkip { + /// Parsed [`attr::Skip`]. + Skip, + /// Parsed [`attr::With`]. + With(attr::With), + } + + // TODO: Try generalize in `Either`. + impl Parse for WithOrSkip { + fn parse(input: ParseStream<'_>) -> syn::Result { + mod ident { + use syn::custom_keyword; + + custom_keyword!(with); + custom_keyword!(skip); + custom_keyword!(ignore); + } + + // `.lookahead1()` with all possible idents forms a nice error message including all the + // possible variants. + let ahead = input.lookahead1(); + + if ahead.peek(ident::with) { + Ok(Self::With(input.parse()?)) + } else if ahead.peek(ident::skip) || ahead.peek(ident::ignore) { + _ = input.parse::()?; + Ok(Self::Skip) + } else { + Err(ahead.error()) + } + } + } + + impl ParseMultiple for WithOrSkip {} + } + } #[cfg(any(feature = "from", feature = "into"))] From d090b1a26b09c52750789b50fbc82f08d64e2dbe Mon Sep 17 00:00:00 2001 From: Thibaut Lorrain Date: Tue, 13 Jan 2026 16:22:48 +0100 Subject: [PATCH 2/5] add support for `with` attribute in Eq/PartialEq --- CHANGELOG.md | 2 + impl/doc/eq.md | 55 +++++ impl/src/cmp/eq.rs | 14 +- impl/src/cmp/partial_eq.rs | 90 ++++++-- impl/src/utils.rs | 6 +- .../eq/unknown_field_attribute.stderr | 4 +- .../partial_eq/unknown_field_attribute.stderr | 4 +- .../partial_eq/unknown_with_function.rs | 18 ++ .../partial_eq/unknown_with_function.stderr | 80 +++++++ tests/eq.rs | 146 +++++++++++++ tests/partial_eq.rs | 202 ++++++++++++++++++ 11 files changed, 595 insertions(+), 26 deletions(-) create mode 100644 tests/compile_fail/partial_eq/unknown_with_function.rs create mode 100644 tests/compile_fail/partial_eq/unknown_with_function.stderr diff --git a/CHANGELOG.md b/CHANGELOG.md index 833502c2..ae9724bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Add `Hash` derive similar to `std`'s one, but considering generics correctly, and supporting custom hash functions per field or skipping fields. ([#532](https://github.com/JelteF/derive_more/pull/532)) +- Add support for custom eq functions in `PartialEq` derive. + ([#535](https://github.com/JelteF/derive_more/pull/535)) ### Fixed diff --git a/impl/doc/eq.md b/impl/doc/eq.md index 111ea03f..51211e0f 100644 --- a/impl/doc/eq.md +++ b/impl/doc/eq.md @@ -286,3 +286,58 @@ impl PartialEq for Enum { } } ``` + +### Custom comparison with `with` + +Both `#[eq(with(...))]` and `#[partial_eq(with(...))]` attribute allows specifying a custom comparison function +for a field. The function must have the signature `fn(&T, &T) -> bool` where `T` is the field type. + +```rust +# use derive_more::{Eq, PartialEq}; + +mod custom_eq { + #[derive(Debug)] + pub struct NotPartialEq(pub i32); + + pub fn compare(a: &NotPartialEq, b: &NotPartialEq) -> bool { + a.0 == b.0 + } +} +use custom_eq::NotPartialEq; + +#[derive(Debug, PartialEq, Eq)] +struct Foo(#[partial_eq(with(custom_eq::compare))] NotPartialEq); + +assert_eq!(Foo(NotPartialEq(42)), Foo(NotPartialEq(42))); +assert_ne!(Foo(NotPartialEq(42)), Foo(NotPartialEq(73))); +``` + +This generates code equivalent to: + +```rust +# +# mod custom_eq { +# #[derive(Debug)] +# pub struct NotPartialEq(i32); +# +# pub fn compare(a: &NotPartialEq, b: &NotPartialEq) -> bool { +# a.0 == b.0 +# } +# } +# use custom_eq::NotPartialEq; +# +# struct Foo(NotPartialEq); +# +impl PartialEq for Foo { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self(self_0), Self(other_0)) => custom_eq::compare(self_0, other_0) + } + } + fn ne(&self, other: &Self) -> bool { + match (self, other) { + (Self(self_0), Self(other_0)) => !custom_eq::compare(self_0, other_0) + } + } +} +``` \ No newline at end of file diff --git a/impl/src/cmp/eq.rs b/impl/src/cmp/eq.rs index c18b7790..69232636 100644 --- a/impl/src/cmp/eq.rs +++ b/impl/src/cmp/eq.rs @@ -28,7 +28,12 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result syn::Result syn::Result { + _ = skipped_fields.insert(n); + continue 'fields; + } + Some(Spanning { + item: attr::WithOrSkip::With(with), + .. + }) => { + alternate_eq_functions.insert(n, with.func); + continue 'fields; + } + None => {} } } } - variants.push((None, &data.fields, skipped_fields)); + variants.push(( + None, + &data.fields, + skipped_fields, + alternate_eq_functions, + )); } } syn::Data::Enum(data) => { @@ -53,15 +73,35 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result { + _ = skipped_fields.insert(n); + continue 'fields; + } + Some(Spanning { + item: attr::WithOrSkip::With(with), + .. + }) => { + alternate_eq_functions.insert(n, with.func); + continue 'fields; + } + None => {} } } } - variants.push((Some(&variant.ident), &variant.fields, skipped_fields)); + variants.push(( + Some(&variant.ident), + &variant.fields, + skipped_fields, + alternate_eq_functions, + )); } } syn::Data::Union(data) => { @@ -84,6 +124,10 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result; +/// Mapping from [`syn::Field`] marked with an [`attr::With`] to the [`syn::Path`] of the alternate +/// eq function. +type FieldsWithAlternateEqFunction = HashMap; + /// Expansion of a macro for generating a structural [`PartialEq`] implementation of an enum or a /// struct. struct StructuralExpansion<'i> { @@ -93,7 +137,12 @@ struct StructuralExpansion<'i> { self_ty: (&'i syn::Ident, &'i syn::Generics), /// [`syn::Fields`] of the enum/struct to be compared in this [`StructuralExpansion`]. - variants: Vec<(Option<&'i syn::Ident>, &'i syn::Fields, SkippedFields)>, + variants: Vec<( + Option<&'i syn::Ident>, + &'i syn::Fields, + SkippedFields, + FieldsWithAlternateEqFunction, + )>, /// Indicator whether some original enum variants where skipped with an [`attr::Skip`]. has_skipped_variants: bool, @@ -144,7 +193,7 @@ impl StructuralExpansion<'_> { let match_arms = self .variants .iter() - .filter_map(|(variant, all_fields, skipped_fields)| { + .filter_map(|(variant, all_fields, skipped_fields, alternate_eq_functions)| { if all_fields.is_empty() || skipped_fields.len() == all_fields.len() { return None; } @@ -160,7 +209,16 @@ impl StructuralExpansion<'_> { .map(|num| { let self_val = format_ident!("__self_{num}"); let other_val = format_ident!("__other_{num}"); - punctuated::Pair::Punctuated(quote! { #self_val #cmp #other_val }, &chain) + let equality = alternate_eq_functions + .get(&num) + .map(|eq_fn| { + let maybe_not = (!eq).then(|| quote! {!}); + quote! { #maybe_not #eq_fn(#self_val, #other_val) } + } + ).unwrap_or_else(|| + quote! { #self_val #cmp #other_val } + ); + punctuated::Pair::Punctuated(equality, &chain) }) .collect::>(); _ = val_eqs.pop_punct(); @@ -178,14 +236,14 @@ impl StructuralExpansion<'_> { }); let unreachable_arm = (self.variants.len() > 1 && no_fields_arm.is_none()) - .then(|| { - quote! { + .then(|| { + quote! { // SAFETY: This arm is never reachable due to `mem::discriminant()` comparison // preceding the expanded `match (self, other)` expression, but is // required by it when there is more than one variant. _ => unsafe { derive_more::core::hint::unreachable_unchecked() }, } - }); + }); quote! { match (self, __other) { @@ -221,7 +279,7 @@ impl ToTokens for StructuralExpansion<'_> { { let self_ty: syn::Type = parse_quote! { Self }; let implementor_ty: syn::Type = parse_quote! { #ty #ty_generics }; - for (_, all_fields, skipped_fields) in &self.variants { + for (_, all_fields, skipped_fields, _) in &self.variants { for field_ty in all_fields.iter().enumerate().filter_map(|(n, field)| { (!skipped_fields.contains(&n)).then_some(&field.ty) diff --git a/impl/src/utils.rs b/impl/src/utils.rs index f0752d15..589a8e24 100644 --- a/impl/src/utils.rs +++ b/impl/src/utils.rs @@ -2442,12 +2442,11 @@ pub(crate) mod attr { impl ParseMultiple for With {} } - #[cfg(any(feature = "hash", feature = "eq"))] mod with_or_skip { - use syn::parse::{Parse, ParseStream}; use crate::utils::attr; use crate::utils::attr::ParseMultiple; + use syn::parse::{Parse, ParseStream}; /// Custom combination of an [`attr::Skip`] and [`attr::With`] used for a better error message /// including all the possible variants. @@ -2485,8 +2484,7 @@ pub(crate) mod attr { } impl ParseMultiple for WithOrSkip {} - } - + } } #[cfg(any(feature = "from", feature = "into"))] diff --git a/tests/compile_fail/eq/unknown_field_attribute.stderr b/tests/compile_fail/eq/unknown_field_attribute.stderr index ccd3198a..c93ad77a 100644 --- a/tests/compile_fail/eq/unknown_field_attribute.stderr +++ b/tests/compile_fail/eq/unknown_field_attribute.stderr @@ -1,10 +1,10 @@ -error: only `skip`/`ignore` allowed here +error: expected one of: `with`, `skip`, `ignore` --> tests/compile_fail/eq/unknown_field_attribute.rs:2:17 | 2 | struct Foo(#[eq(unknown)] i32); | ^^^^^^^ -error: only `skip`/`ignore` allowed here +error: expected one of: `with`, `skip`, `ignore` --> tests/compile_fail/eq/unknown_field_attribute.rs:12:16 | 12 | Bar { #[eq(unknown)] i: i32 }, diff --git a/tests/compile_fail/partial_eq/unknown_field_attribute.stderr b/tests/compile_fail/partial_eq/unknown_field_attribute.stderr index fdfed3b8..05c88369 100644 --- a/tests/compile_fail/partial_eq/unknown_field_attribute.stderr +++ b/tests/compile_fail/partial_eq/unknown_field_attribute.stderr @@ -1,10 +1,10 @@ -error: only `skip`/`ignore` allowed here +error: expected one of: `with`, `skip`, `ignore` --> tests/compile_fail/partial_eq/unknown_field_attribute.rs:2:25 | 2 | struct Foo(#[partial_eq(unknown)] i32); | ^^^^^^^ -error: only `skip`/`ignore` allowed here +error: expected one of: `with`, `skip`, `ignore` --> tests/compile_fail/partial_eq/unknown_field_attribute.rs:6:24 | 6 | Bar { #[partial_eq(unknown)] i: i32 }, diff --git a/tests/compile_fail/partial_eq/unknown_with_function.rs b/tests/compile_fail/partial_eq/unknown_with_function.rs new file mode 100644 index 00000000..64db6416 --- /dev/null +++ b/tests/compile_fail/partial_eq/unknown_with_function.rs @@ -0,0 +1,18 @@ +#[derive(derive_more::PartialEq)] +struct Foo(#[partial_eq(with(unknown))] i32); + +fn incompatible_types(a:& str) ->i32 {0} + +#[derive(derive_more::PartialEq)] +struct Bar(#[partial_eq(with(incompatible_types))] i32); + + +#[derive(derive_more::PartialEq)] +enum Enum { + Bar { #[partial_eq(with(unknown))] i: i32 }, + Baz { #[partial_eq(with(incompatible_types))] i: i32 }, +} + + + +fn main() {} diff --git a/tests/compile_fail/partial_eq/unknown_with_function.stderr b/tests/compile_fail/partial_eq/unknown_with_function.stderr new file mode 100644 index 00000000..13546586 --- /dev/null +++ b/tests/compile_fail/partial_eq/unknown_with_function.stderr @@ -0,0 +1,80 @@ +error[E0425]: cannot find function `unknown` in this scope + --> tests/compile_fail/partial_eq/unknown_with_function.rs:2:30 + | +2 | struct Foo(#[partial_eq(with(unknown))] i32); + | ^^^^^^^ not found in this scope + +error[E0061]: this function takes 1 argument but 2 arguments were supplied + --> tests/compile_fail/partial_eq/unknown_with_function.rs:7:30 + | +6 | #[derive(derive_more::PartialEq)] + | ---------------------- + | | + | expected `&str`, found `&i32` + | unexpected argument #2 of type `&i32` +7 | struct Bar(#[partial_eq(with(incompatible_types))] i32); + | ^^^^^^^^^^^^^^^^^^ + | + = note: expected reference `&str` + found reference `&i32` +note: function defined here + --> tests/compile_fail/partial_eq/unknown_with_function.rs:4:4 + | +4 | fn incompatible_types(a:& str) ->i32 {0} + | ^^^^^^^^^^^^^^^^^^ ------- + +error[E0308]: mismatched types + --> tests/compile_fail/partial_eq/unknown_with_function.rs:6:10 + | +6 | #[derive(derive_more::PartialEq)] + | ^^^^^^^^^^^^^^^^^^^^^^ + | | + | expected `bool`, found `i32` + | expected `bool` because of return type + | + = note: this error originates in the derive macro `derive_more::PartialEq` (in Nightly builds, run with -Z macro-backtrace for more info) + +error[E0425]: cannot find function `unknown` in this scope + --> tests/compile_fail/partial_eq/unknown_with_function.rs:12:29 + | +12 | Bar { #[partial_eq(with(unknown))] i: i32 }, + | ^^^^^^^ not found in this scope + +error[E0061]: this function takes 1 argument but 2 arguments were supplied + --> tests/compile_fail/partial_eq/unknown_with_function.rs:13:29 + | +10 | #[derive(derive_more::PartialEq)] + | ---------------------- + | | + | expected `&str`, found `&i32` + | unexpected argument #2 of type `&i32` +... +13 | Baz { #[partial_eq(with(incompatible_types))] i: i32 }, + | ^^^^^^^^^^^^^^^^^^ + | + = note: expected reference `&str` + found reference `&i32` +note: function defined here + --> tests/compile_fail/partial_eq/unknown_with_function.rs:4:4 + | + 4 | fn incompatible_types(a:& str) ->i32 {0} + | ^^^^^^^^^^^^^^^^^^ ------- + +error[E0308]: mismatched types + --> tests/compile_fail/partial_eq/unknown_with_function.rs:10:10 + | +10 | #[derive(derive_more::PartialEq)] + | ^^^^^^^^^^^^^^^^^^^^^^ + | | + | expected `bool`, found `i32` + | expected `bool` because of return type + | + = note: this error originates in the derive macro `derive_more::PartialEq` (in Nightly builds, run with -Z macro-backtrace for more info) + +warning: unused variable: `a` + --> tests/compile_fail/partial_eq/unknown_with_function.rs:4:23 + | +4 | fn incompatible_types(a:& str) ->i32 {0} + | ^ help: if this is intentional, prefix it with an underscore: `_a` + | + = note: `#[warn(unused_variables)]` (part of `#[warn(unused)]`) on by default diff --git a/tests/eq.rs b/tests/eq.rs index e9203bb1..441dfbee 100644 --- a/tests/eq.rs +++ b/tests/eq.rs @@ -162,6 +162,64 @@ mod structs { } } + mod with { + use derive_more::{__private::AssertParamIsEq, Eq, PartialEq}; + + fn eq_special(_: &NotPartialEq, _: &NotPartialEq) -> bool { + true + } + + struct NotPartialEq(i32); + + #[test] + fn single_field() { + #[derive(Eq, PartialEq)] + struct Foo(#[partial_eq(with(eq_special))] NotPartialEq); + #[derive(Eq, PartialEq)] + struct Bar(#[eq(with(eq_special))] NotPartialEq); + + let _: AssertParamIsEq; + let _: AssertParamIsEq; + } + + #[test] + fn multiple_fields() { + #[derive(Eq, PartialEq)] + struct Foo { + #[partial_eq(with(eq_special))] + a: NotPartialEq, + b: i32, + } + + #[derive(Eq, PartialEq)] + struct Bar { + #[eq(with(eq_special))] + a: NotPartialEq, + b: i32, + } + + let _: AssertParamIsEq; + let _: AssertParamIsEq; + } + + #[test] + fn tuple_all() { + #[derive(Eq, PartialEq)] + struct Foo( + #[partial_eq(with(eq_special))] NotPartialEq, + #[partial_eq(with(eq_special))] NotPartialEq, + ); + + #[derive(Eq, PartialEq)] + struct Bar( + #[eq(with(eq_special))] NotPartialEq, + #[eq(with(eq_special))] NotPartialEq, + ); + + let _: AssertParamIsEq; + } + } + mod generic { #[cfg(not(feature = "std"))] use ::alloc::{boxed::Box, vec::Vec}; @@ -547,6 +605,94 @@ mod enums { } } + mod with { + use derive_more::{__private::AssertParamIsEq, Eq, PartialEq}; + + fn eq_special(_: &NotPartialEq, _: &NotPartialEq) -> bool { + true + } + + struct NotPartialEq(i32); + + #[test] + fn single_field() { + #[derive(Eq, PartialEq)] + enum E { + Foo(#[partial_eq(with(eq_special))] NotPartialEq), + Bar(#[eq(with(eq_special))] NotPartialEq), + Baz, + } + + let _: AssertParamIsEq; + } + + #[test] + fn multiple_fields() { + #[derive(Eq, PartialEq)] + enum E { + Foo { + #[partial_eq(with(eq_special))] + a: NotPartialEq, + b: i32, + }, + Bar { + #[eq(with(eq_special))] + a: NotPartialEq, + b: i32, + }, + Baz, + } + + let _: AssertParamIsEq; + } + + #[test] + fn tuple_all() { + #[derive(Eq, PartialEq)] + enum E { + Foo( + #[partial_eq(with(eq_special))] NotPartialEq, + #[partial_eq(with(eq_special))] NotPartialEq, + ), + Bar( + #[eq(with(eq_special))] NotPartialEq, + #[eq(with(eq_special))] NotPartialEq, + ), + Baz, + } + + let _: AssertParamIsEq; + } + + #[test] + fn multi_variant() { + #[derive(Eq, PartialEq)] + enum E { + Foo(#[partial_eq(with(eq_special))] NotPartialEq), + Bar { + #[partial_eq(with(eq_special))] + val: NotPartialEq, + }, + Baz, + } + + let _: AssertParamIsEq; + } + + #[test] + fn with_skip_combined() { + #[derive(Eq, PartialEq)] + enum E { + Foo( + #[partial_eq(with(eq_special))] NotPartialEq, + #[partial_eq(skip)] NotPartialEq, + ), + } + + let _: AssertParamIsEq; + } + } + mod generic { #[cfg(not(feature = "std"))] use ::alloc::{boxed::Box, vec::Vec}; diff --git a/tests/partial_eq.rs b/tests/partial_eq.rs index ff940539..a1b31dce 100644 --- a/tests/partial_eq.rs +++ b/tests/partial_eq.rs @@ -248,6 +248,69 @@ mod structs { } } + mod with { + use derive_more::PartialEq; + + #[test] + fn single_field() { + #[derive(Debug)] + struct NotPartialEq(i32); + fn eq_special(a: &NotPartialEq, b: &NotPartialEq) -> bool { + a.0 == b.0 + } + + #[derive(Debug, PartialEq)] + struct Foo(#[partial_eq(with(eq_special))] NotPartialEq); + + // assert both using == and != as both are overloaded + assert_eq!(Foo(NotPartialEq(42)), Foo(NotPartialEq(42))); + assert!(!(Foo(NotPartialEq(42)) != Foo(NotPartialEq(42)))); + + assert!(!(Foo(NotPartialEq(42)) == Foo(NotPartialEq(73)))); + assert_ne!(Foo(NotPartialEq(42)), Foo(NotPartialEq(73))); + } + + fn eq_special_always_equal(_: &u32, _: &u32) -> bool { + true + } + + #[test] + fn multiple_fields() { + #[derive(Debug, PartialEq)] + struct Foo { + #[partial_eq(with(eq_special_always_equal))] + a: u32, + b: u32, + } + + // assert both using == and != as both are overloaded + assert_eq!(Foo { a: 73, b: 1 }, Foo { a: 42, b: 1 }); + assert!(!(Foo { a: 73, b: 1 } != Foo { a: 42, b: 1 })); + + assert!(!(Foo { a: 73, b: 1 } == Foo { a: 42, b: 2 })); + assert_ne!(Foo { a: 73, b: 1 }, Foo { a: 42, b: 2 }); + } + + #[test] + fn tuple_all() { + mod eq { + pub fn always_eq(_: &i32, _: &i32) -> bool { + true + } + } + + #[derive(Debug, PartialEq)] + struct Foo( + #[partial_eq(with(eq::always_eq))] i32, + #[partial_eq(with(eq::always_eq))] i32, + ); + + // assert both using == and != as both are overloaded + assert_eq!(Foo(12, 13), Foo(14, 15)); + assert!(!(Foo(12, 13) != Foo(14, 15))); + } + } + mod generic { #[cfg(not(feature = "std"))] use ::alloc::{boxed::Box, vec, vec::Vec}; @@ -927,6 +990,145 @@ mod enums { } } + mod with { + use derive_more::PartialEq; + + #[test] + fn single_field() { + #[derive(Debug)] + struct NotPartialEq(i32); + fn eq_special(a: &NotPartialEq, b: &NotPartialEq) -> bool { + a.0 == b.0 + } + + #[derive(Debug, PartialEq)] + enum E { + Foo(#[partial_eq(with(eq_special))] NotPartialEq), + Bar, + } + + // assert both using == and != as both are overloaded + assert_eq!(E::Foo(NotPartialEq(42)), E::Foo(NotPartialEq(42))); + assert!(!(E::Foo(NotPartialEq(42)) != E::Foo(NotPartialEq(42)))); + + assert!(!(E::Foo(NotPartialEq(42)) == E::Foo(NotPartialEq(73)))); + assert_ne!(E::Foo(NotPartialEq(42)), E::Foo(NotPartialEq(73))); + + assert!(!(E::Foo(NotPartialEq(42)) == E::Bar)); + assert_ne!(E::Foo(NotPartialEq(42)), E::Bar); + } + + fn eq_special_always_equal(_: &u32, _: &u32) -> bool { + true + } + + #[test] + fn multiple_fields() { + #[derive(Debug, PartialEq)] + enum E { + Foo { + #[partial_eq(with(eq_special_always_equal))] + a: u32, + b: u32, + }, + Bar, + } + + // assert both using == and != as both are overloaded + assert_eq!(E::Foo { a: 73, b: 1 }, E::Foo { a: 42, b: 1 }); + assert!(!(E::Foo { a: 73, b: 1 } != E::Foo { a: 42, b: 1 })); + + assert!(!(E::Foo { a: 73, b: 1 } == E::Foo { a: 42, b: 2 })); + assert_ne!(E::Foo { a: 73, b: 1 }, E::Foo { a: 42, b: 2 }); + + assert!(!(E::Foo { a: 73, b: 1 } == E::Bar)); + assert_ne!(E::Foo { a: 73, b: 1 }, E::Bar); + } + + #[test] + fn tuple_all() { + mod eq { + pub fn always_eq(_: &i32, _: &i32) -> bool { + true + } + } + + #[derive(Debug, PartialEq)] + enum E { + Foo( + #[partial_eq(with(eq::always_eq))] i32, + #[partial_eq(with(eq::always_eq))] i32, + ), + Bar, + } + + // assert both using == and != as both are overloaded + assert_eq!(E::Foo(12, 13), E::Foo(14, 15)); + assert!(!(E::Foo(12, 13) != E::Foo(14, 15))); + + assert!(!(E::Foo(73, 1) == E::Bar)); + assert_ne!(E::Foo(73, 1), E::Bar); + } + + #[test] + fn multi_variant() { + fn eq_mod_10(a: &i32, b: &i32) -> bool { + a % 10 == b % 10 + } + + #[derive(Debug, PartialEq)] + enum E { + Foo(#[partial_eq(with(eq_mod_10))] i32), + Bar { + #[partial_eq(with(eq_mod_10))] + val: i32, + }, + Baz, + } + + // assert both using == and != as both are overloaded + assert_eq!(E::Foo(13), E::Foo(23)); + assert!(!(E::Foo(13) != E::Foo(23))); + + assert_ne!(E::Foo(13), E::Foo(24)); + assert!(!(E::Foo(13) == E::Foo(24))); + + assert_eq!(E::Bar { val: 15 }, E::Bar { val: 25 }); + assert!(!(E::Bar { val: 15 } != E::Bar { val: 25 })); + + assert_ne!(E::Bar { val: 15 }, E::Bar { val: 26 }); + assert!(!(E::Bar { val: 15 } == E::Bar { val: 26 })); + + assert_eq!(E::Baz, E::Baz); + assert!(!(E::Baz != E::Baz)); + + assert_ne!(E::Foo(13), E::Bar { val: 13 }); + assert!(!(E::Foo(13) == E::Bar { val: 13 })); + + assert_ne!(E::Foo(13), E::Baz); + assert!(!(E::Foo(13) == E::Baz)); + + assert_ne!(E::Bar { val: 13 }, E::Baz); + assert!(!(E::Bar { val: 13 } == E::Baz)); + } + + #[test] + fn with_skip_combined() { + fn eq_always(_: &i32, _: &i32) -> bool { + true + } + + #[derive(Debug, PartialEq)] + enum E { + Foo(#[partial_eq(with(eq_always))] i32, #[partial_eq(skip)] i32), + } + + // assert both using == and != as both are overloaded + assert_eq!(E::Foo(1, 2), E::Foo(3, 4)); + assert!(!(E::Foo(1, 2) != E::Foo(3, 4))); + } + } + mod generic { #[cfg(not(feature = "std"))] use ::alloc::{boxed::Box, vec, vec::Vec}; From e8ac838f0457a170ef4fedcbb0b6c9f2ad91ae81 Mon Sep 17 00:00:00 2001 From: Thibaut Lorrain Date: Mon, 11 May 2026 11:41:53 +0200 Subject: [PATCH 3/5] small nits --- CHANGELOG.md | 2 +- impl/doc/eq.md | 4 +- impl/src/cmp/partial_eq.rs | 3 +- .../partial_eq/unknown_with_function.rs | 5 +-- .../partial_eq/unknown_with_function.stderr | 40 ++++++++----------- 5 files changed, 21 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ae9724bc..542d95fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Add `Hash` derive similar to `std`'s one, but considering generics correctly, and supporting custom hash functions per field or skipping fields. ([#532](https://github.com/JelteF/derive_more/pull/532)) -- Add support for custom eq functions in `PartialEq` derive. +- Add support for custom eq functions in `PartialEq`/`Eq` derive. ([#535](https://github.com/JelteF/derive_more/pull/535)) ### Fixed diff --git a/impl/doc/eq.md b/impl/doc/eq.md index 51211e0f..f970b5de 100644 --- a/impl/doc/eq.md +++ b/impl/doc/eq.md @@ -289,7 +289,7 @@ impl PartialEq for Enum { ### Custom comparison with `with` -Both `#[eq(with(...))]` and `#[partial_eq(with(...))]` attribute allows specifying a custom comparison function +Both `#[eq(with(...))]` and `#[partial_eq(with(...))]` attributes allows specifying a custom comparison function for a field. The function must have the signature `fn(&T, &T) -> bool` where `T` is the field type. ```rust @@ -340,4 +340,4 @@ impl PartialEq for Foo { } } } -``` \ No newline at end of file +``` diff --git a/impl/src/cmp/partial_eq.rs b/impl/src/cmp/partial_eq.rs index c8111270..947ba393 100644 --- a/impl/src/cmp/partial_eq.rs +++ b/impl/src/cmp/partial_eq.rs @@ -215,8 +215,7 @@ impl StructuralExpansion<'_> { let maybe_not = (!eq).then(|| quote! {!}); quote! { #maybe_not #eq_fn(#self_val, #other_val) } } - ).unwrap_or_else(|| - quote! { #self_val #cmp #other_val } + ).unwrap_or_else(|| quote! { #self_val #cmp #other_val } ); punctuated::Pair::Punctuated(equality, &chain) }) diff --git a/tests/compile_fail/partial_eq/unknown_with_function.rs b/tests/compile_fail/partial_eq/unknown_with_function.rs index 64db6416..9577b619 100644 --- a/tests/compile_fail/partial_eq/unknown_with_function.rs +++ b/tests/compile_fail/partial_eq/unknown_with_function.rs @@ -1,18 +1,15 @@ #[derive(derive_more::PartialEq)] struct Foo(#[partial_eq(with(unknown))] i32); -fn incompatible_types(a:& str) ->i32 {0} +fn incompatible_types(_a: &str) ->i32 {0} #[derive(derive_more::PartialEq)] struct Bar(#[partial_eq(with(incompatible_types))] i32); - #[derive(derive_more::PartialEq)] enum Enum { Bar { #[partial_eq(with(unknown))] i: i32 }, Baz { #[partial_eq(with(incompatible_types))] i: i32 }, } - - fn main() {} diff --git a/tests/compile_fail/partial_eq/unknown_with_function.stderr b/tests/compile_fail/partial_eq/unknown_with_function.stderr index 13546586..7b2d6ede 100644 --- a/tests/compile_fail/partial_eq/unknown_with_function.stderr +++ b/tests/compile_fail/partial_eq/unknown_with_function.stderr @@ -20,8 +20,8 @@ error[E0061]: this function takes 1 argument but 2 arguments were supplied note: function defined here --> tests/compile_fail/partial_eq/unknown_with_function.rs:4:4 | -4 | fn incompatible_types(a:& str) ->i32 {0} - | ^^^^^^^^^^^^^^^^^^ ------- +4 | fn incompatible_types(_a: &str) ->i32 {0} + | ^^^^^^^^^^^^^^^^^^ -------- error[E0308]: mismatched types --> tests/compile_fail/partial_eq/unknown_with_function.rs:6:10 @@ -35,21 +35,21 @@ error[E0308]: mismatched types = note: this error originates in the derive macro `derive_more::PartialEq` (in Nightly builds, run with -Z macro-backtrace for more info) error[E0425]: cannot find function `unknown` in this scope - --> tests/compile_fail/partial_eq/unknown_with_function.rs:12:29 + --> tests/compile_fail/partial_eq/unknown_with_function.rs:11:29 | -12 | Bar { #[partial_eq(with(unknown))] i: i32 }, +11 | Bar { #[partial_eq(with(unknown))] i: i32 }, | ^^^^^^^ not found in this scope error[E0061]: this function takes 1 argument but 2 arguments were supplied - --> tests/compile_fail/partial_eq/unknown_with_function.rs:13:29 + --> tests/compile_fail/partial_eq/unknown_with_function.rs:12:29 | -10 | #[derive(derive_more::PartialEq)] + 9 | #[derive(derive_more::PartialEq)] | ---------------------- | | | expected `&str`, found `&i32` | unexpected argument #2 of type `&i32` ... -13 | Baz { #[partial_eq(with(incompatible_types))] i: i32 }, +12 | Baz { #[partial_eq(with(incompatible_types))] i: i32 }, | ^^^^^^^^^^^^^^^^^^ | = note: expected reference `&str` @@ -57,24 +57,16 @@ error[E0061]: this function takes 1 argument but 2 arguments were supplied note: function defined here --> tests/compile_fail/partial_eq/unknown_with_function.rs:4:4 | - 4 | fn incompatible_types(a:& str) ->i32 {0} - | ^^^^^^^^^^^^^^^^^^ ------- + 4 | fn incompatible_types(_a: &str) ->i32 {0} + | ^^^^^^^^^^^^^^^^^^ -------- error[E0308]: mismatched types - --> tests/compile_fail/partial_eq/unknown_with_function.rs:10:10 - | -10 | #[derive(derive_more::PartialEq)] - | ^^^^^^^^^^^^^^^^^^^^^^ - | | - | expected `bool`, found `i32` - | expected `bool` because of return type - | - = note: this error originates in the derive macro `derive_more::PartialEq` (in Nightly builds, run with -Z macro-backtrace for more info) - -warning: unused variable: `a` - --> tests/compile_fail/partial_eq/unknown_with_function.rs:4:23 + --> tests/compile_fail/partial_eq/unknown_with_function.rs:9:10 | -4 | fn incompatible_types(a:& str) ->i32 {0} - | ^ help: if this is intentional, prefix it with an underscore: `_a` +9 | #[derive(derive_more::PartialEq)] + | ^^^^^^^^^^^^^^^^^^^^^^ + | | + | expected `bool`, found `i32` + | expected `bool` because of return type | - = note: `#[warn(unused_variables)]` (part of `#[warn(unused)]`) on by default + = note: this error originates in the derive macro `derive_more::PartialEq` (in Nightly builds, run with -Z macro-backtrace for more info) From 0353b9c5b0297ff235e377202f73aa98cf2b4b6f Mon Sep 17 00:00:00 2001 From: Thibaut Lorrain Date: Mon, 18 May 2026 15:21:08 +0200 Subject: [PATCH 4/5] no eq(with) and error on derive(Hash) / partial_eq(with) and no hash(with/skip) --- impl/doc/eq.md | 5 +- impl/doc/hash.md | 4 + impl/src/cmp/eq.rs | 42 +++++---- impl/src/cmp/partial_eq.rs | 76 +++++++++------- impl/src/hash.rs | 88 +++++++++++++------ .../eq/unknown_field_attribute.stderr | 4 +- tests/compile_fail/eq/with_attribute.rs | 28 ++++++ tests/compile_fail/eq/with_attribute.stderr | 11 +++ .../hash/partial_eq_with_without_hash.rs | 16 ++++ .../hash/partial_eq_with_without_hash.stderr | 11 +++ tests/eq.rs | 35 +------- tests/hash.rs | 38 ++++++++ 12 files changed, 245 insertions(+), 113 deletions(-) create mode 100644 tests/compile_fail/eq/with_attribute.rs create mode 100644 tests/compile_fail/eq/with_attribute.stderr create mode 100644 tests/compile_fail/hash/partial_eq_with_without_hash.rs create mode 100644 tests/compile_fail/hash/partial_eq_with_without_hash.stderr diff --git a/impl/doc/eq.md b/impl/doc/eq.md index f970b5de..8d83dc38 100644 --- a/impl/doc/eq.md +++ b/impl/doc/eq.md @@ -289,8 +289,9 @@ impl PartialEq for Enum { ### Custom comparison with `with` -Both `#[eq(with(...))]` and `#[partial_eq(with(...))]` attributes allows specifying a custom comparison function -for a field. The function must have the signature `fn(&T, &T) -> bool` where `T` is the field type. +The `#[partial_eq(with(...))]` attribute allows specifying a custom comparison function for a field. +The function must have the signature `fn(&T, &T) -> bool` where `T` is the field type. `derive(Eq)` honors +`#[partial_eq(with(...))]` by not requiring an `Eq` bound on that field's type. ```rust # use derive_more::{Eq, PartialEq}; diff --git a/impl/doc/hash.md b/impl/doc/hash.md index b19731a2..82ecc417 100644 --- a/impl/doc/hash.md +++ b/impl/doc/hash.md @@ -272,3 +272,7 @@ impl Hash for Foo { This is useful for types that don't implement `Hash` but can be hashed in a custom way, or when you need different hashing behavior than the default. + +Note: if a field carries `#[partial_eq(with(...))]`, then `#[derive(Hash)]` requires either +`#[hash(with(...))]` or `#[hash(skip)]` on the same field. Otherwise the default per-field hashing +may disagree with the custom equality and break the `k1 == k2 -> hash(k1) == hash(k2)` invariant. diff --git a/impl/src/cmp/eq.rs b/impl/src/cmp/eq.rs index 69232636..52036025 100644 --- a/impl/src/cmp/eq.rs +++ b/impl/src/cmp/eq.rs @@ -27,15 +27,18 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result syn::Result syn::Result { - _ = skipped_fields.insert(n); - continue 'fields; - } - Some(Spanning { - item: attr::WithOrSkip::With(with), - .. - }) => { - alternate_eq_functions.insert(n, with.func); - continue 'fields; - } - None => {} + match attr::WithOrSkip::parse_attrs(&field.attrs, &attr_name)? { + Some(Spanning { + item: attr::WithOrSkip::Skip, + .. + }) => { + _ = skipped_fields.insert(n); + continue 'fields; + } + Some(Spanning { + item: attr::WithOrSkip::With(with), + .. + }) => { + alternate_eq_functions.insert(n, with.func); + continue 'fields; } + None => {} + } + if attr::Skip::parse_attrs(&field.attrs, &secondary_attr_name)? + .is_some() + { + _ = skipped_fields.insert(n); + continue 'fields; } } variants.push(( @@ -76,24 +80,28 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result { - _ = skipped_fields.insert(n); - continue 'fields; - } - Some(Spanning { - item: attr::WithOrSkip::With(with), - .. - }) => { - alternate_eq_functions.insert(n, with.func); - continue 'fields; - } - None => {} + match attr::WithOrSkip::parse_attrs(&field.attrs, &attr_name)? { + Some(Spanning { + item: attr::WithOrSkip::Skip, + .. + }) => { + _ = skipped_fields.insert(n); + continue 'fields; + } + Some(Spanning { + item: attr::WithOrSkip::With(with), + .. + }) => { + alternate_eq_functions.insert(n, with.func); + continue 'fields; } + None => {} + } + if attr::Skip::parse_attrs(&field.attrs, &secondary_attr_name)? + .is_some() + { + _ = skipped_fields.insert(n); + continue 'fields; } } variants.push(( diff --git a/impl/src/hash.rs b/impl/src/hash.rs index c5ceae47..2cad6a2b 100644 --- a/impl/src/hash.rs +++ b/impl/src/hash.rs @@ -14,20 +14,24 @@ use syn::{ spanned::Spanned as _, }; +const PARTIAL_EQ_WITH_WITHOUT_HASH_ERROR: &str = + "field has `#[partial_eq(with(...))]` but no `#[hash(with(...))]` or `#[hash(skip)]`: a custom \ + equality function requires a consistent `Hash` implementation to uphold the `Hash`/`Eq` \ + invariant (`a == b` implies `hash(a) == hash(b)`)"; + /// Expands a [`Hash`] derive macro. pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result { let attr_name = format_ident!("hash"); - let secondary_attr_name = format_ident!("eq"); - let tertiary_attr_name = format_ident!("partial_eq"); - let attr_names = [&attr_name, &secondary_attr_name, &tertiary_attr_name]; - let secondary_attr_names = [&secondary_attr_name, &tertiary_attr_name]; + let partial_eq_attr_name = format_ident!("partial_eq"); + let eq_attr_name = format_ident!("eq"); + let skip_attr_names = [&attr_name, &partial_eq_attr_name, &eq_attr_name]; let mut has_skipped_variants = false; let mut variants = vec![]; match &input.data { syn::Data::Struct(data) => { - for attr_name in &attr_names { + for attr_name in skip_attr_names { if attr::Skip::parse_attrs(&input.attrs, attr_name)?.is_some() { has_skipped_variants = true; break; @@ -37,30 +41,46 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result { - for attr_name in &secondary_attr_names { - if attr::Skip::parse_attrs(&field.attrs, attr_name)? - .is_some() - { - _ = skipped_fields.insert(n); - continue 'fields; + if let Some(Spanning { item, span, .. }) = + attr::WithOrSkip::parse_attrs( + &field.attrs, + &partial_eq_attr_name, + )? + { + match item { + attr::WithOrSkip::Skip => { + _ = skipped_fields.insert(n); + } + attr::WithOrSkip::With(_) => { + return Err(syn::Error::new( + span, + PARTIAL_EQ_WITH_WITHOUT_HASH_ERROR, + )); + } } + } else if attr::Skip::parse_attrs( + &field.attrs, + &eq_attr_name, + )? + .is_some() + { + _ = skipped_fields.insert(n); } } Some(Spanning { item: attr::WithOrSkip::Skip, .. }) => { - skipped_fields.insert(n); + _ = skipped_fields.insert(n); } - Some(Spanning { item: attr::WithOrSkip::With(with), .. }) => { - alternate_hash_functions.insert(n, with.func.clone()); + alternate_hash_functions.insert(n, with.func); } } } @@ -74,7 +94,7 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result { 'variants: for variant in &data.variants { - for attr_name in &attr_names { + for attr_name in skip_attr_names { if attr::Skip::parse_attrs(&variant.attrs, attr_name)?.is_some() { has_skipped_variants = true; continue 'variants; @@ -83,30 +103,46 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result { - for attr_name in &secondary_attr_names { - if attr::Skip::parse_attrs(&field.attrs, attr_name)? - .is_some() - { - _ = skipped_fields.insert(n); - continue 'fields; + if let Some(Spanning { item, span, .. }) = + attr::WithOrSkip::parse_attrs( + &field.attrs, + &partial_eq_attr_name, + )? + { + match item { + attr::WithOrSkip::Skip => { + _ = skipped_fields.insert(n); + } + attr::WithOrSkip::With(_) => { + return Err(syn::Error::new( + span, + PARTIAL_EQ_WITH_WITHOUT_HASH_ERROR, + )); + } } + } else if attr::Skip::parse_attrs( + &field.attrs, + &eq_attr_name, + )? + .is_some() + { + _ = skipped_fields.insert(n); } } Some(Spanning { item: attr::WithOrSkip::Skip, .. }) => { - skipped_fields.insert(n); + _ = skipped_fields.insert(n); } - Some(Spanning { item: attr::WithOrSkip::With(with), .. }) => { - alternate_hash_functions.insert(n, with.func.clone()); + alternate_hash_functions.insert(n, with.func); } } } diff --git a/tests/compile_fail/eq/unknown_field_attribute.stderr b/tests/compile_fail/eq/unknown_field_attribute.stderr index c93ad77a..ccd3198a 100644 --- a/tests/compile_fail/eq/unknown_field_attribute.stderr +++ b/tests/compile_fail/eq/unknown_field_attribute.stderr @@ -1,10 +1,10 @@ -error: expected one of: `with`, `skip`, `ignore` +error: only `skip`/`ignore` allowed here --> tests/compile_fail/eq/unknown_field_attribute.rs:2:17 | 2 | struct Foo(#[eq(unknown)] i32); | ^^^^^^^ -error: expected one of: `with`, `skip`, `ignore` +error: only `skip`/`ignore` allowed here --> tests/compile_fail/eq/unknown_field_attribute.rs:12:16 | 12 | Bar { #[eq(unknown)] i: i32 }, diff --git a/tests/compile_fail/eq/with_attribute.rs b/tests/compile_fail/eq/with_attribute.rs new file mode 100644 index 00000000..09fb7e36 --- /dev/null +++ b/tests/compile_fail/eq/with_attribute.rs @@ -0,0 +1,28 @@ +fn always_eq(_: &T, _: &T) -> bool { + true +} + +#[derive(derive_more::Eq)] +struct Foo(#[eq(with(always_eq))] i32); + +impl PartialEq for Foo { + fn eq(&self, _: &Self) -> bool { + unimplemented!() + } +} + +#[derive(derive_more::Eq)] +enum Enum { + Bar { + #[eq(with(always_eq))] + i: i32, + }, +} + +impl PartialEq for Enum { + fn eq(&self, _: &Self) -> bool { + unimplemented!() + } +} + +fn main() {} diff --git a/tests/compile_fail/eq/with_attribute.stderr b/tests/compile_fail/eq/with_attribute.stderr new file mode 100644 index 00000000..eec7a886 --- /dev/null +++ b/tests/compile_fail/eq/with_attribute.stderr @@ -0,0 +1,11 @@ +error: only `skip`/`ignore` allowed here + --> tests/compile_fail/eq/with_attribute.rs:6:17 + | +6 | struct Foo(#[eq(with(always_eq))] i32); + | ^^^^ + +error: only `skip`/`ignore` allowed here + --> tests/compile_fail/eq/with_attribute.rs:17:14 + | +17 | #[eq(with(always_eq))] + | ^^^^ diff --git a/tests/compile_fail/hash/partial_eq_with_without_hash.rs b/tests/compile_fail/hash/partial_eq_with_without_hash.rs new file mode 100644 index 00000000..da641733 --- /dev/null +++ b/tests/compile_fail/hash/partial_eq_with_without_hash.rs @@ -0,0 +1,16 @@ +fn always_eq(_: &i32, _: &i32) -> bool { + true +} + +#[derive(derive_more::Hash, derive_more::PartialEq, derive_more::Eq)] +struct Foo(#[partial_eq(with(always_eq))] i32); + +#[derive(derive_more::Hash, derive_more::PartialEq, derive_more::Eq)] +enum Enum { + Bar { + #[partial_eq(with(always_eq))] + i: i32, + }, +} + +fn main() {} diff --git a/tests/compile_fail/hash/partial_eq_with_without_hash.stderr b/tests/compile_fail/hash/partial_eq_with_without_hash.stderr new file mode 100644 index 00000000..408cf2dc --- /dev/null +++ b/tests/compile_fail/hash/partial_eq_with_without_hash.stderr @@ -0,0 +1,11 @@ +error: field has `#[partial_eq(with(...))]` but no `#[hash(with(...))]` or `#[hash(skip)]`: a custom equality function requires a consistent `Hash` implementation to uphold the `Hash`/`Eq` invariant (`a == b` implies `hash(a) == hash(b)`) + --> tests/compile_fail/hash/partial_eq_with_without_hash.rs:6:12 + | +6 | struct Foo(#[partial_eq(with(always_eq))] i32); + | ^ + +error: field has `#[partial_eq(with(...))]` but no `#[hash(with(...))]` or `#[hash(skip)]`: a custom equality function requires a consistent `Hash` implementation to uphold the `Hash`/`Eq` invariant (`a == b` implies `hash(a) == hash(b)`) + --> tests/compile_fail/hash/partial_eq_with_without_hash.rs:11:9 + | +11 | #[partial_eq(with(always_eq))] + | ^ diff --git a/tests/eq.rs b/tests/eq.rs index 441dfbee..9c7c7b6e 100644 --- a/tests/eq.rs +++ b/tests/eq.rs @@ -175,11 +175,8 @@ mod structs { fn single_field() { #[derive(Eq, PartialEq)] struct Foo(#[partial_eq(with(eq_special))] NotPartialEq); - #[derive(Eq, PartialEq)] - struct Bar(#[eq(with(eq_special))] NotPartialEq); let _: AssertParamIsEq; - let _: AssertParamIsEq; } #[test] @@ -191,15 +188,7 @@ mod structs { b: i32, } - #[derive(Eq, PartialEq)] - struct Bar { - #[eq(with(eq_special))] - a: NotPartialEq, - b: i32, - } - let _: AssertParamIsEq; - let _: AssertParamIsEq; } #[test] @@ -210,13 +199,7 @@ mod structs { #[partial_eq(with(eq_special))] NotPartialEq, ); - #[derive(Eq, PartialEq)] - struct Bar( - #[eq(with(eq_special))] NotPartialEq, - #[eq(with(eq_special))] NotPartialEq, - ); - - let _: AssertParamIsEq; + let _: AssertParamIsEq; } } @@ -619,8 +602,7 @@ mod enums { #[derive(Eq, PartialEq)] enum E { Foo(#[partial_eq(with(eq_special))] NotPartialEq), - Bar(#[eq(with(eq_special))] NotPartialEq), - Baz, + Bar, } let _: AssertParamIsEq; @@ -635,12 +617,7 @@ mod enums { a: NotPartialEq, b: i32, }, - Bar { - #[eq(with(eq_special))] - a: NotPartialEq, - b: i32, - }, - Baz, + Bar, } let _: AssertParamIsEq; @@ -654,11 +631,7 @@ mod enums { #[partial_eq(with(eq_special))] NotPartialEq, #[partial_eq(with(eq_special))] NotPartialEq, ), - Bar( - #[eq(with(eq_special))] NotPartialEq, - #[eq(with(eq_special))] NotPartialEq, - ), - Baz, + Bar, } let _: AssertParamIsEq; diff --git a/tests/hash.rs b/tests/hash.rs index f9854280..76918fae 100644 --- a/tests/hash.rs +++ b/tests/hash.rs @@ -297,6 +297,44 @@ mod enums { } } +#[cfg(feature = "eq")] +mod partial_eq_with_requires_hash_attr { + use derive_more::{Eq, Hash, PartialEq}; + + use super::{do_hash, utils}; + + fn eq_mod_10(a: &u32, b: &u32) -> bool { + a % 10 == b % 10 + } + + // Field uses a custom equality, so a matching custom hash function is provided. + #[derive(Hash, Eq, PartialEq)] + struct WithBoth { + #[partial_eq(with(eq_mod_10))] + #[hash(with(utils::alternate_u32_hash_function))] + a: u32, + b: i32, + } + + // Field uses a custom equality and is skipped from hashing — also consistent. + #[derive(Hash, Eq, PartialEq)] + struct WithSkip { + #[partial_eq(with(eq_mod_10))] + #[hash(skip)] + a: u32, + b: i32, + } + + #[test] + fn assert() { + assert_eq!( + do_hash(&WithBoth { a: 42, b: 7 }), + do_hash(&(42u32, 42u32, 7i32)), + ); + assert_eq!(do_hash(&WithSkip { a: 42, b: 7 }), do_hash(&7i32)); + } +} + #[cfg(feature = "eq")] mod hash_respects_eq_skip { use derive_more::{Eq, Hash, PartialEq}; From 6f7e5d6478cf5f82a5ecf18b60cfcc67bb57134b Mon Sep 17 00:00:00 2001 From: Thibaut Lorrain Date: Mon, 18 May 2026 15:26:34 +0200 Subject: [PATCH 5/5] alternate -> custom --- impl/src/cmp/partial_eq.rs | 24 +++++++++++------------- impl/src/hash.rs | 24 +++++++++++------------- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/impl/src/cmp/partial_eq.rs b/impl/src/cmp/partial_eq.rs index 2df36b8f..809ffc14 100644 --- a/impl/src/cmp/partial_eq.rs +++ b/impl/src/cmp/partial_eq.rs @@ -33,8 +33,7 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result syn::Result { - alternate_eq_functions.insert(n, with.func); + custom_eq_functions.insert(n, with.func); continue 'fields; } None => {} @@ -64,7 +63,7 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result syn::Result syn::Result { - alternate_eq_functions.insert(n, with.func); + custom_eq_functions.insert(n, with.func); continue 'fields; } None => {} @@ -108,7 +106,7 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result syn::Result; -/// Mapping from [`syn::Field`] marked with an [`attr::With`] to the [`syn::Path`] of the alternate +/// Mapping from [`syn::Field`] marked with an [`attr::With`] to the [`syn::Path`] of the custom /// eq function. -type FieldsWithAlternateEqFunction = HashMap; +type FieldsWithCustomEqFunction = HashMap; /// Expansion of a macro for generating a structural [`PartialEq`] implementation of an enum or a /// struct. @@ -149,7 +147,7 @@ struct StructuralExpansion<'i> { Option<&'i syn::Ident>, &'i syn::Fields, SkippedFields, - FieldsWithAlternateEqFunction, + FieldsWithCustomEqFunction, )>, /// Indicator whether some original enum variants where skipped with an [`attr::Skip`]. @@ -201,7 +199,7 @@ impl StructuralExpansion<'_> { let match_arms = self .variants .iter() - .filter_map(|(variant, all_fields, skipped_fields, alternate_eq_functions)| { + .filter_map(|(variant, all_fields, skipped_fields, custom_eq_functions)| { if all_fields.is_empty() || skipped_fields.len() == all_fields.len() { return None; } @@ -217,7 +215,7 @@ impl StructuralExpansion<'_> { .map(|num| { let self_val = format_ident!("__self_{num}"); let other_val = format_ident!("__other_{num}"); - let equality = alternate_eq_functions + let equality = custom_eq_functions .get(&num) .map(|eq_fn| { let maybe_not = (!eq).then(|| quote! {!}); diff --git a/impl/src/hash.rs b/impl/src/hash.rs index 2cad6a2b..b79a703d 100644 --- a/impl/src/hash.rs +++ b/impl/src/hash.rs @@ -39,8 +39,7 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result { @@ -80,7 +79,7 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result { - alternate_hash_functions.insert(n, with.func); + custom_hash_functions.insert(n, with.func); } } } @@ -88,7 +87,7 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result syn::Result { @@ -142,7 +140,7 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result { - alternate_hash_functions.insert(n, with.func); + custom_hash_functions.insert(n, with.func); } } } @@ -150,7 +148,7 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result syn::Result; -/// Mapping from [`syn::Field`] marked with an [`attr::With`] to the [`syn::Path`] of the alternate +/// Mapping from [`syn::Field`] marked with an [`attr::With`] to the [`syn::Path`] of the custom /// hash function. -type FieldsWithAlternateHashFunction = HashMap; +type FieldsWithCustomHashFunction = HashMap; /// Expansion of a macro for generating a structural [`Hash`] implementation of an enum or a struct. struct StructuralExpansion<'i> { @@ -190,7 +188,7 @@ struct StructuralExpansion<'i> { Option<&'i syn::Ident>, &'i syn::Fields, SkippedFields, - FieldsWithAlternateHashFunction, + FieldsWithCustomHashFunction, )>, /// Indicator whether some original enum variants where skipped with an [`attr::Skip`]. @@ -230,7 +228,7 @@ impl StructuralExpansion<'_> { .variants .iter() .map( - |(variant, all_fields, skipped_fields, alternate_hash_functions)| { + |(variant, all_fields, skipped_fields, custom_hash_functions)| { let variant = variant.map(|variant| quote! { :: #variant }); let self_pattern = all_fields .non_exhaustive_arm_pattern("__self_", skipped_fields); @@ -239,7 +237,7 @@ impl StructuralExpansion<'_> { .filter(|num| !skipped_fields.contains(num)) .map(|num| { let self_val = format_ident!("__self_{num}"); - let hash_function = alternate_hash_functions + let hash_function = custom_hash_functions .get(&num) .map(|it| quote! {#it}) .unwrap_or_else(