diff --git a/source/builtin_macros/src/struct_decl_inv.rs b/source/builtin_macros/src/struct_decl_inv.rs index 7c6034888d..0ccc8adba2 100644 --- a/source/builtin_macros/src/struct_decl_inv.rs +++ b/source/builtin_macros/src/struct_decl_inv.rs @@ -52,6 +52,8 @@ fn struct_decl_inv_main(sdi: SDI) -> parse::Result { sdi.item_struct.to_tokens(&mut stream); let fields_filled_in = get_fields(&sdi.item_struct.fields)?; + let all_struct_params: HashSet = + sdi.item_struct.generics.params.iter().map(generic_param_to_string).collect(); for field in fields_filled_in.iter() { output_field_type_alias( &main_name, @@ -59,6 +61,8 @@ fn struct_decl_inv_main(sdi: SDI) -> parse::Result { &mut stream, field, &used_type_params, + &all_struct_params, + &sdi.item_struct.generics.where_clause, ); } @@ -604,12 +608,19 @@ fn fill_in_item_struct( invariant_decls: &Vec, used_type_params: &HashMap>, ) { + let struct_type_params = item_struct.generics.params.clone(); match &mut item_struct.fields { Fields::Named(fields_named) => { for field in fields_named.named.iter_mut() { let name = field.ident.as_ref().unwrap().to_string(); let invdecls = get_invariant_decls_by_name(invariant_decls, &name); - field.ty = fill_in_type(&field.ty, main_name, invdecls, used_type_params); + field.ty = fill_in_type( + &field.ty, + main_name, + invdecls, + used_type_params, + &struct_type_params, + ); } } _ => { @@ -682,6 +693,41 @@ fn output_invariant( let where_clause = &sdi.item_struct.generics.where_clause; let vis = &sdi.item_struct.vis; + // Build the predicate struct as carrying a `PhantomData` over each + // struct generic param so the auto-impl can constrain those params. + let (pred_struct_decl, pred_self_ty) = if type_params.is_empty() { + (quote! { #vis struct #predname { } }, quote! { #predname }) + } else { + let type_args = remove_bounds(type_params); + let phantom_tys: Vec = type_params + .iter() + .map(|gp| match gp { + GenericParam::Type(tp) => { + let id = &tp.ident; + quote! { #id } + } + GenericParam::Lifetime(ld) => { + let lt = &ld.lifetime; + quote! { & #lt () } + } + GenericParam::Const(cp) => { + // Const generics can't appear in PhantomData + // directly; reference them via `[(); N]`. + let id = &cp.ident; + quote! { [(); #id] } + } + }) + .collect(); + ( + quote! { + #vis struct #predname<#type_params> #where_clause { + _phantom: ::core::marker::PhantomData<(#(#phantom_tys,)*)>, + } + }, + quote! { #predname<#type_args> }, + ) + }; + let span = field_name.span(); let mut e_stream_conjuncts = vec![]; @@ -704,8 +750,8 @@ fn output_invariant( let g_pat = &v_pats[1]; stream.extend(quote_spanned_vstd! { vstd, predicate.span() => - #vis struct #predname { } - impl<#type_params> #vstd::atomic_ghost::AtomicInvariantPredicate<#k_type, #v_type, #g_type> for #predname #where_clause { + #pred_struct_decl + impl<#type_params> #vstd::atomic_ghost::AtomicInvariantPredicate<#k_type, #v_type, #g_type> for #pred_self_ty #where_clause { #publish_kind spec fn atomic_inv(#tmp_k: #k_type, #tmp_v: #v_type, #tmp_g: #g_type) -> bool { let #k_pat = #tmp_k; let #v_pat = #tmp_v; @@ -723,8 +769,8 @@ fn output_invariant( let v_pat = maybe_tuple(&v_pats); stream.extend(quote_spanned_vstd! { vstd, predicate.span() => - #vis struct #predname { } - impl<#type_params> #vstd::invariant::InvariantPredicate<#k_type, #v_type> for #predname #where_clause { + #pred_struct_decl + impl<#type_params> #vstd::invariant::InvariantPredicate<#k_type, #v_type> for #pred_self_ty #where_clause { #publish_kind spec fn inv(#tmp_k: #k_type, #tmp_v: #v_type) -> bool { let #k_pat = #tmp_k; let #v_pat = #tmp_v; @@ -821,16 +867,80 @@ fn output_field_type_alias( stream: &mut TokenStream, field: &Field, used_type_params: &HashMap>, + all_struct_params: &HashSet, + where_clause: &Option, ) { let field_ident = field.ident.as_ref().unwrap(); - let alias = get_type_alias(main_name, field_ident, used_type_params); + let ident = Ident::new(&format!("FieldType_{main_name}_{field_ident}"), Span::call_site()); + let utp = used_type_params.get(&field_ident.to_string()).unwrap(); let field_ty = &field.ty; + // Restrict the struct's where clause to predicates that only reference + // generic params that are present on this alias. + let used_param_names: HashSet = utp.iter().map(generic_param_to_string).collect(); + // Strip from each kept param any bounds that reference params not in scope. + let utp_restricted: Vec = utp + .iter() + .map(|gp| restrict_param_bounds(gp, &used_param_names, all_struct_params)) + .collect(); + let filtered_where = where_clause.as_ref().and_then(|wc| { + let preds: Punctuated = wc + .predicates + .iter() + .filter(|p| where_predicate_only_uses(p, &used_param_names, all_struct_params)) + .cloned() + .collect(); + if preds.is_empty() { + None + } else { + Some(verus_syn::WhereClause { where_token: wc.where_token, predicates: preds }) + } + }); + + let generics_tokens = if utp_restricted.is_empty() { + quote! {} + } else { + quote! { <#(#utp_restricted),*> } + }; stream.extend(quote! { - #vis type #alias = #field_ty; + #[allow(type_alias_bounds)] + #vis type #ident #generics_tokens #filtered_where = #field_ty; }); } +fn where_predicate_only_uses( + pred: &verus_syn::WherePredicate, + allowed: &HashSet, + all_struct_params: &HashSet, +) -> bool { + let mut params = HashSet::new(); + let mut visitor = CollectIdentsVisitor { result: &mut params }; + visit::visit_where_predicate(&mut visitor, pred); + params.iter().all(|p| !all_struct_params.contains(p) || allowed.contains(p)) +} + +struct CollectIdentsVisitor<'a> { + result: &'a mut HashSet, +} + +impl<'ast, 'a> Visit<'ast> for CollectIdentsVisitor<'a> { + fn visit_type_path(&mut self, type_path: &TypePath) { + let TypePath { qself, path } = type_path; + if qself.is_none() + && path.leading_colon.is_none() + && !path.segments.is_empty() + && path.segments[0].arguments == PathArguments::None + { + self.result.insert(path.segments[0].ident.to_string()); + } + visit::visit_type_path(self, type_path); + } + + fn visit_lifetime(&mut self, lt: &Lifetime) { + self.result.insert("'".to_string() + <.ident.to_string()); + } +} + // Defs fn get_pred_typename(main_name: &str, field_name: &Ident) -> Ident { @@ -990,6 +1100,7 @@ fn fill_in_type( main_name: &str, inv_decls: Vec<&InvariantDecl>, used_type_params: &HashMap>, + struct_type_params: &Punctuated, ) -> Type { let mut typs = vec![]; @@ -1003,8 +1114,14 @@ fn fill_in_type( } }; let pred = get_pred_typename(main_name, field_name); + let pred_ty = if struct_type_params.is_empty() { + quote! { #pred } + } else { + let type_args = remove_bounds(struct_type_params); + quote! { #pred<#type_args> } + }; typs.push(get_constant_type(main_name, depends_on, quants, used_type_params)); - typs.push(Type::Verbatim(quote! { #pred })); + typs.push(Type::Verbatim(pred_ty)); } fill_in_infers(ty, typs) @@ -1200,6 +1317,48 @@ fn get_params_used_in_type(params: &Punctuated, ty: &Type) upv.result } +fn restrict_param_bounds( + gp: &GenericParam, + allowed: &HashSet, + all_struct_params: &HashSet, +) -> GenericParam { + let mut gp = gp.clone(); + match &mut gp { + GenericParam::Type(tp) => { + tp.bounds = tp + .bounds + .iter() + .filter(|b| { + let mut found = HashSet::new(); + let mut visitor = CollectIdentsVisitor { result: &mut found }; + visit::visit_type_param_bound(&mut visitor, *b); + found.iter().all(|n| !all_struct_params.contains(n) || allowed.contains(n)) + }) + .cloned() + .collect(); + if tp.bounds.is_empty() { + tp.colon_token = None; + } + } + GenericParam::Lifetime(ld) => { + ld.bounds = ld + .bounds + .iter() + .filter(|lt| { + let n = "'".to_string() + <.ident.to_string(); + !all_struct_params.contains(&n) || allowed.contains(&n) + }) + .cloned() + .collect(); + if ld.bounds.is_empty() { + ld.colon_token = None; + } + } + GenericParam::Const(_) => {} + } + gp +} + struct UsedParamsVisitor { params: HashSet, result: HashSet, @@ -1210,7 +1369,7 @@ impl<'ast> Visit<'ast> for UsedParamsVisitor { let TypePath { qself, path } = type_path; if qself.is_none() && path.leading_colon.is_none() - && path.segments.len() == 1 + && !path.segments.is_empty() && path.segments[0].arguments == PathArguments::None { let id = path.segments[0].ident.to_string(); diff --git a/source/rust_verify_test/tests/struct_with_invariants.rs b/source/rust_verify_test/tests/struct_with_invariants.rs new file mode 100644 index 0000000000..fc01b95f43 --- /dev/null +++ b/source/rust_verify_test/tests/struct_with_invariants.rs @@ -0,0 +1,50 @@ +#![feature(rustc_private)] +#[macro_use] +mod common; +use common::*; + +test_verify_one_file! { + #[test] struct_with_invariants_assoc_type_prefix verus_code! { + use core::marker::PhantomData; + use vstd::atomic_ghost::AtomicPtr; + use vstd::prelude::*; + + pub trait HasTarget { + type Target; + } + + struct_with_invariants!{ + pub struct S { + a: AtomicPtr<::Target,_,(),_>, + _phantom: PhantomData<*const A::Target>, + } + + closed spec fn wf(self) -> bool { + invariant on a with (_phantom) is (v:*mut ::Target,g:()) { + true + } + } + } + } => Ok(()) +} + +test_verify_one_file! { + #[test] struct_with_invariants_const_usage verus_code! { + use vstd::prelude::*; + use vstd::atomic_ghost::AtomicUsize; + + const ONE: usize = 1; + + struct_with_invariants!{ + pub struct S { + x: AtomicUsize<_,(),_>, + } + + closed spec fn wf(self) -> bool { + invariant on x is (v:usize,g:()) { + v == ONE + } + } + } + } => Ok(()) +} diff --git a/source/rust_verify_test/tests/user_defined_type_invariants.rs b/source/rust_verify_test/tests/user_defined_type_invariants.rs index 842c412ccc..d2c894bc7f 100644 --- a/source/rust_verify_test/tests/user_defined_type_invariants.rs +++ b/source/rust_verify_test/tests/user_defined_type_invariants.rs @@ -2272,27 +2272,6 @@ test_verify_one_file! { } => Err(err) => assert_vir_error_msg(err, "expected generics to match") } -test_verify_one_file! { - #[test] struct_with_invariants_const_usage verus_code! { - use vstd::prelude::*; - use vstd::atomic_ghost::AtomicUsize; - - const ONE: usize = 1; - - struct_with_invariants!{ - pub struct S { - x: AtomicUsize<_,(),_>, - } - - closed spec fn wf(self) -> bool { - invariant on x is (v:usize,g:()) { - v == ONE - } - } - } - } => Ok(()) -} - test_verify_one_file_with_options! { #[test] mut_ref_not_supported [] => verus_code! { struct A {