Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Add `Hash` derive similar to `std`'s one, but considering generics correctly,
and supporting custom hash functions per field or skipping fields.
([#532](https://github.com/JelteF/derive_more/pull/532))
- Add support for custom eq functions in `PartialEq`/`Eq` derive.
([#535](https://github.com/JelteF/derive_more/pull/535))

### Fixed

Expand Down
56 changes: 56 additions & 0 deletions impl/doc/eq.md
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,59 @@ impl PartialEq for Enum {
}
}
```

### Custom comparison with `with`

The `#[partial_eq(with(...))]` attribute allows specifying a custom comparison function for a field.
The function must have the signature `fn(&T, &T) -> bool` where `T` is the field type. `derive(Eq)` honors
`#[partial_eq(with(...))]` by not requiring an `Eq` bound on that field's type.

```rust
# use derive_more::{Eq, PartialEq};

mod custom_eq {
#[derive(Debug)]
pub struct NotPartialEq(pub i32);

pub fn compare(a: &NotPartialEq, b: &NotPartialEq) -> bool {
a.0 == b.0
}
}
use custom_eq::NotPartialEq;

#[derive(Debug, PartialEq, Eq)]
struct Foo(#[partial_eq(with(custom_eq::compare))] NotPartialEq);

assert_eq!(Foo(NotPartialEq(42)), Foo(NotPartialEq(42)));
assert_ne!(Foo(NotPartialEq(42)), Foo(NotPartialEq(73)));
```

This generates code equivalent to:

```rust
#
# mod custom_eq {
# #[derive(Debug)]
# pub struct NotPartialEq(i32);
#
# pub fn compare(a: &NotPartialEq, b: &NotPartialEq) -> bool {
# a.0 == b.0
# }
# }
# use custom_eq::NotPartialEq;
#
# struct Foo(NotPartialEq);
#
impl PartialEq for Foo {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self(self_0), Self(other_0)) => custom_eq::compare(self_0, other_0)
}
}
fn ne(&self, other: &Self) -> bool {
match (self, other) {
(Self(self_0), Self(other_0)) => !custom_eq::compare(self_0, other_0)
}
}
}
```
4 changes: 4 additions & 0 deletions impl/doc/hash.md
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,7 @@ impl Hash for Foo {

This is useful for types that don't implement `Hash` but can be hashed in a custom way, or when you need different
hashing behavior than the default.

Note: if a field carries `#[partial_eq(with(...))]`, then `#[derive(Hash)]` requires either
`#[hash(with(...))]` or `#[hash(skip)]` on the same field. Otherwise the default per-field hashing
may disagree with the custom equality and break the `k1 == k2 -> hash(k1) == hash(k2)` invariant.
32 changes: 24 additions & 8 deletions impl/src/cmp/eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,18 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result<TokenStr
}
if !is_skipped {
'fields: for field in &data.fields {
for attr_name in [&attr_name, &secondary_attr_name] {
if attr::Skip::parse_attrs(&field.attrs, attr_name)?.is_some() {
continue 'fields;
}
if attr::Skip::parse_attrs(&field.attrs, &attr_name)?.is_some()
|| attr::WithOrSkip::parse_attrs(
&field.attrs,
&secondary_attr_name,
)?
.is_some()
{
// If skipped, we don't want to add a bound on the field type.
// If a `with` function is provided for `PartialEq`, then adding a bound is
// counterproductive as the `with` function may handle types that do not
// implement `Eq`.
continue 'fields;
}
_ = fields_types.insert(&field.ty);
}
Expand All @@ -44,10 +52,18 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result<TokenStr
}
}
'fields: for field in &variant.fields {
for attr_name in [&attr_name, &secondary_attr_name] {
if attr::Skip::parse_attrs(&field.attrs, attr_name)?.is_some() {
continue 'fields;
}
if attr::Skip::parse_attrs(&field.attrs, &attr_name)?.is_some()
|| attr::WithOrSkip::parse_attrs(
&field.attrs,
&secondary_attr_name,
)?
.is_some()
{
// If skipped, we don't want to add a bound on the field type.
// If a `with` function is provided for `PartialEq`, then adding a bound is
// counterproductive as the `with` function may handle types that do not
// implement `Eq`.
continue 'fields;
}
_ = fields_types.insert(&field.ty);
}
Expand Down
91 changes: 77 additions & 14 deletions impl/src/cmp/partial_eq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::utils::{
attr::{self, ParseMultiple as _},
pattern_matching::FieldsExt as _,
structural_inclusion::TypeExt as _,
GenericsSearch, HashSet,
GenericsSearch, HashMap, HashSet, Spanning,
};

/// Expands a [`PartialEq`] derive macro.
Expand All @@ -33,15 +33,38 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result<TokenStr
}
if !has_skipped_variants {
let mut skipped_fields = SkippedFields::default();
let mut custom_eq_functions = FieldsWithCustomEqFunction::default();
'fields: for (n, field) in data.fields.iter().enumerate() {
for attr_name in [&attr_name, &secondary_attr_name] {
if attr::Skip::parse_attrs(&field.attrs, attr_name)?.is_some() {
match attr::WithOrSkip::parse_attrs(&field.attrs, &attr_name)? {
Some(Spanning {
item: attr::WithOrSkip::Skip,
..
}) => {
_ = skipped_fields.insert(n);
continue 'fields;
}
Some(Spanning {
item: attr::WithOrSkip::With(with),
..
}) => {
custom_eq_functions.insert(n, with.func);
continue 'fields;
}
None => {}
}
if attr::Skip::parse_attrs(&field.attrs, &secondary_attr_name)?
.is_some()
{
_ = skipped_fields.insert(n);
continue 'fields;
}
}
variants.push((None, &data.fields, skipped_fields));
variants.push((
None,
&data.fields,
skipped_fields,
custom_eq_functions,
));
}
}
syn::Data::Enum(data) => {
Expand All @@ -53,15 +76,38 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result<TokenStr
}
}
let mut skipped_fields = SkippedFields::default();
let mut custom_eq_functions = FieldsWithCustomEqFunction::default();
'fields: for (n, field) in variant.fields.iter().enumerate() {
for attr_name in [&attr_name, &secondary_attr_name] {
if attr::Skip::parse_attrs(&field.attrs, attr_name)?.is_some() {
match attr::WithOrSkip::parse_attrs(&field.attrs, &attr_name)? {
Some(Spanning {
item: attr::WithOrSkip::Skip,
..
}) => {
_ = skipped_fields.insert(n);
continue 'fields;
}
Some(Spanning {
item: attr::WithOrSkip::With(with),
..
}) => {
custom_eq_functions.insert(n, with.func);
continue 'fields;
}
None => {}
}
if attr::Skip::parse_attrs(&field.attrs, &secondary_attr_name)?
.is_some()
{
_ = skipped_fields.insert(n);
continue 'fields;
}
}
variants.push((Some(&variant.ident), &variant.fields, skipped_fields));
variants.push((
Some(&variant.ident),
&variant.fields,
skipped_fields,
custom_eq_functions,
));
}
}
syn::Data::Union(data) => {
Expand All @@ -84,6 +130,10 @@ pub fn expand(input: &syn::DeriveInput, _: &'static str) -> syn::Result<TokenStr
/// Indices of [`syn::Field`]s marked with an [`attr::Skip`].
type SkippedFields = HashSet<usize>;

/// Mapping from [`syn::Field`] marked with an [`attr::With`] to the [`syn::Path`] of the custom
/// eq function.
type FieldsWithCustomEqFunction = HashMap<usize, syn::Path>;

/// Expansion of a macro for generating a structural [`PartialEq`] implementation of an enum or a
/// struct.
struct StructuralExpansion<'i> {
Expand All @@ -93,7 +143,12 @@ struct StructuralExpansion<'i> {
self_ty: (&'i syn::Ident, &'i syn::Generics),

/// [`syn::Fields`] of the enum/struct to be compared in this [`StructuralExpansion`].
variants: Vec<(Option<&'i syn::Ident>, &'i syn::Fields, SkippedFields)>,
variants: Vec<(
Option<&'i syn::Ident>,
&'i syn::Fields,
SkippedFields,
FieldsWithCustomEqFunction,
)>,

/// Indicator whether some original enum variants where skipped with an [`attr::Skip`].
has_skipped_variants: bool,
Expand Down Expand Up @@ -144,7 +199,7 @@ impl StructuralExpansion<'_> {
let match_arms = self
.variants
.iter()
.filter_map(|(variant, all_fields, skipped_fields)| {
.filter_map(|(variant, all_fields, skipped_fields, custom_eq_functions)| {
if all_fields.is_empty() || skipped_fields.len() == all_fields.len() {
return None;
}
Expand All @@ -160,7 +215,15 @@ impl StructuralExpansion<'_> {
.map(|num| {
let self_val = format_ident!("__self_{num}");
let other_val = format_ident!("__other_{num}");
punctuated::Pair::Punctuated(quote! { #self_val #cmp #other_val }, &chain)
let equality = custom_eq_functions
.get(&num)
.map(|eq_fn| {
let maybe_not = (!eq).then(|| quote! {!});
quote! { #maybe_not #eq_fn(#self_val, #other_val) }
}
).unwrap_or_else(|| quote! { #self_val #cmp #other_val }
);
punctuated::Pair::Punctuated(equality, &chain)
})
.collect::<Punctuated<TokenStream, _>>();
_ = val_eqs.pop_punct();
Expand All @@ -178,14 +241,14 @@ impl StructuralExpansion<'_> {
});
let unreachable_arm = (self.variants.len() > 1
&& no_fields_arm.is_none())
.then(|| {
quote! {
.then(|| {
quote! {
// SAFETY: This arm is never reachable due to `mem::discriminant()` comparison
// preceding the expanded `match (self, other)` expression, but is
// required by it when there is more than one variant.
_ => unsafe { derive_more::core::hint::unreachable_unchecked() },
}
});
});

quote! {
match (self, __other) {
Expand Down Expand Up @@ -221,7 +284,7 @@ impl ToTokens for StructuralExpansion<'_> {
{
let self_ty: syn::Type = parse_quote! { Self };
let implementor_ty: syn::Type = parse_quote! { #ty #ty_generics };
for (_, all_fields, skipped_fields) in &self.variants {
for (_, all_fields, skipped_fields, _) in &self.variants {
for field_ty in
all_fields.iter().enumerate().filter_map(|(n, field)| {
(!skipped_fields.contains(&n)).then_some(&field.ty)
Expand Down
Loading
Loading