Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 168 additions & 9 deletions source/builtin_macros/src/struct_decl_inv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,17 @@ fn struct_decl_inv_main(sdi: SDI) -> parse::Result<TokenStream> {
sdi.item_struct.to_tokens(&mut stream);

let fields_filled_in = get_fields(&sdi.item_struct.fields)?;
let all_struct_params: HashSet<String> =
sdi.item_struct.generics.params.iter().map(generic_param_to_string).collect();
for field in fields_filled_in.iter() {
output_field_type_alias(
&main_name,
&sdi.item_struct.vis,
&mut stream,
field,
&used_type_params,
&all_struct_params,
&sdi.item_struct.generics.where_clause,
);
}

Expand Down Expand Up @@ -604,12 +608,19 @@ fn fill_in_item_struct(
invariant_decls: &Vec<InvariantDecl>,
used_type_params: &HashMap<String, Vec<GenericParam>>,
) {
let struct_type_params = item_struct.generics.params.clone();
match &mut item_struct.fields {
Fields::Named(fields_named) => {
for field in fields_named.named.iter_mut() {
let name = field.ident.as_ref().unwrap().to_string();
let invdecls = get_invariant_decls_by_name(invariant_decls, &name);
field.ty = fill_in_type(&field.ty, main_name, invdecls, used_type_params);
field.ty = fill_in_type(
&field.ty,
main_name,
invdecls,
used_type_params,
&struct_type_params,
);
}
}
_ => {
Expand Down Expand Up @@ -682,6 +693,41 @@ fn output_invariant(
let where_clause = &sdi.item_struct.generics.where_clause;
let vis = &sdi.item_struct.vis;

// Build the predicate struct as carrying a `PhantomData` over each
// struct generic param so the auto-impl can constrain those params.
let (pred_struct_decl, pred_self_ty) = if type_params.is_empty() {
(quote! { #vis struct #predname { } }, quote! { #predname })
} else {
let type_args = remove_bounds(type_params);
let phantom_tys: Vec<TokenStream> = type_params
.iter()
.map(|gp| match gp {
GenericParam::Type(tp) => {
let id = &tp.ident;
quote! { #id }
}
GenericParam::Lifetime(ld) => {
let lt = &ld.lifetime;
quote! { & #lt () }
}
GenericParam::Const(cp) => {
// Const generics can't appear in PhantomData
// directly; reference them via `[(); N]`.
let id = &cp.ident;
quote! { [(); #id] }
}
})
.collect();
(
quote! {
#vis struct #predname<#type_params> #where_clause {
_phantom: ::core::marker::PhantomData<(#(#phantom_tys,)*)>,
}
},
quote! { #predname<#type_args> },
)
};

let span = field_name.span();

let mut e_stream_conjuncts = vec![];
Expand All @@ -704,8 +750,8 @@ fn output_invariant(
let g_pat = &v_pats[1];

stream.extend(quote_spanned_vstd! { vstd, predicate.span() =>
#vis struct #predname { }
impl<#type_params> #vstd::atomic_ghost::AtomicInvariantPredicate<#k_type, #v_type, #g_type> for #predname #where_clause {
#pred_struct_decl
impl<#type_params> #vstd::atomic_ghost::AtomicInvariantPredicate<#k_type, #v_type, #g_type> for #pred_self_ty #where_clause {
#publish_kind spec fn atomic_inv(#tmp_k: #k_type, #tmp_v: #v_type, #tmp_g: #g_type) -> bool {
let #k_pat = #tmp_k;
let #v_pat = #tmp_v;
Expand All @@ -723,8 +769,8 @@ fn output_invariant(
let v_pat = maybe_tuple(&v_pats);

stream.extend(quote_spanned_vstd! { vstd, predicate.span() =>
#vis struct #predname { }
impl<#type_params> #vstd::invariant::InvariantPredicate<#k_type, #v_type> for #predname #where_clause {
#pred_struct_decl
impl<#type_params> #vstd::invariant::InvariantPredicate<#k_type, #v_type> for #pred_self_ty #where_clause {
#publish_kind spec fn inv(#tmp_k: #k_type, #tmp_v: #v_type) -> bool {
let #k_pat = #tmp_k;
let #v_pat = #tmp_v;
Expand Down Expand Up @@ -821,16 +867,80 @@ fn output_field_type_alias(
stream: &mut TokenStream,
field: &Field,
used_type_params: &HashMap<String, Vec<GenericParam>>,
all_struct_params: &HashSet<String>,
where_clause: &Option<verus_syn::WhereClause>,
) {
let field_ident = field.ident.as_ref().unwrap();
let alias = get_type_alias(main_name, field_ident, used_type_params);
let ident = Ident::new(&format!("FieldType_{main_name}_{field_ident}"), Span::call_site());
let utp = used_type_params.get(&field_ident.to_string()).unwrap();
let field_ty = &field.ty;

// Restrict the struct's where clause to predicates that only reference
// generic params that are present on this alias.
let used_param_names: HashSet<String> = utp.iter().map(generic_param_to_string).collect();
// Strip from each kept param any bounds that reference params not in scope.
let utp_restricted: Vec<GenericParam> = utp
.iter()
.map(|gp| restrict_param_bounds(gp, &used_param_names, all_struct_params))
.collect();
let filtered_where = where_clause.as_ref().and_then(|wc| {
let preds: Punctuated<verus_syn::WherePredicate, Comma> = wc
.predicates
.iter()
.filter(|p| where_predicate_only_uses(p, &used_param_names, all_struct_params))
.cloned()
.collect();
if preds.is_empty() {
None
} else {
Some(verus_syn::WhereClause { where_token: wc.where_token, predicates: preds })
}
});

let generics_tokens = if utp_restricted.is_empty() {
quote! {}
} else {
quote! { <#(#utp_restricted),*> }
};
stream.extend(quote! {
#vis type #alias = #field_ty;
#[allow(type_alias_bounds)]
#vis type #ident #generics_tokens #filtered_where = #field_ty;
});
}

fn where_predicate_only_uses(
pred: &verus_syn::WherePredicate,
allowed: &HashSet<String>,
all_struct_params: &HashSet<String>,
) -> bool {
let mut params = HashSet::new();
let mut visitor = CollectIdentsVisitor { result: &mut params };
visit::visit_where_predicate(&mut visitor, pred);
params.iter().all(|p| !all_struct_params.contains(p) || allowed.contains(p))
}

struct CollectIdentsVisitor<'a> {
result: &'a mut HashSet<String>,
}

impl<'ast, 'a> Visit<'ast> for CollectIdentsVisitor<'a> {
fn visit_type_path(&mut self, type_path: &TypePath) {
let TypePath { qself, path } = type_path;
if qself.is_none()
&& path.leading_colon.is_none()
&& !path.segments.is_empty()
&& path.segments[0].arguments == PathArguments::None
{
self.result.insert(path.segments[0].ident.to_string());
}
visit::visit_type_path(self, type_path);
}

fn visit_lifetime(&mut self, lt: &Lifetime) {
self.result.insert("'".to_string() + &lt.ident.to_string());
}
}

// Defs

fn get_pred_typename(main_name: &str, field_name: &Ident) -> Ident {
Expand Down Expand Up @@ -990,6 +1100,7 @@ fn fill_in_type(
main_name: &str,
inv_decls: Vec<&InvariantDecl>,
used_type_params: &HashMap<String, Vec<GenericParam>>,
struct_type_params: &Punctuated<GenericParam, Comma>,
) -> Type {
let mut typs = vec![];

Expand All @@ -1003,8 +1114,14 @@ fn fill_in_type(
}
};
let pred = get_pred_typename(main_name, field_name);
let pred_ty = if struct_type_params.is_empty() {
quote! { #pred }
} else {
let type_args = remove_bounds(struct_type_params);
quote! { #pred<#type_args> }
};
typs.push(get_constant_type(main_name, depends_on, quants, used_type_params));
typs.push(Type::Verbatim(quote! { #pred }));
typs.push(Type::Verbatim(pred_ty));
}

fill_in_infers(ty, typs)
Expand Down Expand Up @@ -1200,6 +1317,48 @@ fn get_params_used_in_type(params: &Punctuated<GenericParam, Comma>, ty: &Type)
upv.result
}

fn restrict_param_bounds(
gp: &GenericParam,
allowed: &HashSet<String>,
all_struct_params: &HashSet<String>,
) -> GenericParam {
let mut gp = gp.clone();
match &mut gp {
GenericParam::Type(tp) => {
tp.bounds = tp
.bounds
.iter()
.filter(|b| {
let mut found = HashSet::new();
let mut visitor = CollectIdentsVisitor { result: &mut found };
visit::visit_type_param_bound(&mut visitor, *b);
found.iter().all(|n| !all_struct_params.contains(n) || allowed.contains(n))
})
.cloned()
.collect();
if tp.bounds.is_empty() {
tp.colon_token = None;
}
}
GenericParam::Lifetime(ld) => {
ld.bounds = ld
.bounds
.iter()
.filter(|lt| {
let n = "'".to_string() + &lt.ident.to_string();
!all_struct_params.contains(&n) || allowed.contains(&n)
})
.cloned()
.collect();
if ld.bounds.is_empty() {
ld.colon_token = None;
}
}
GenericParam::Const(_) => {}
}
gp
}

struct UsedParamsVisitor {
params: HashSet<String>,
result: HashSet<String>,
Expand All @@ -1210,7 +1369,7 @@ impl<'ast> Visit<'ast> for UsedParamsVisitor {
let TypePath { qself, path } = type_path;
if qself.is_none()
&& path.leading_colon.is_none()
&& path.segments.len() == 1
&& !path.segments.is_empty()
&& path.segments[0].arguments == PathArguments::None
{
let id = path.segments[0].ident.to_string();
Expand Down
25 changes: 25 additions & 0 deletions source/rust_verify_test/tests/user_defined_type_invariants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2722,3 +2722,28 @@ test_verify_one_file_with_options! {
}
} => Err(err) => assert_fails_type_invariant_error(err, 2)
}

test_verify_one_file! {
Comment thread
rikosellic marked this conversation as resolved.
Outdated
#[test] struct_with_invariants_assoc_type_prefix verus_code! {
use core::marker::PhantomData;
use vstd::atomic_ghost::AtomicPtr;
use vstd::prelude::*;

pub trait HasTarget {
type Target;
}

struct_with_invariants!{
pub struct S<A: HasTarget> {
a: AtomicPtr<<A as HasTarget>::Target,_,(),_>,
_phantom: PhantomData<*const A::Target>,
}

closed spec fn wf(self) -> bool {
invariant on a with (_phantom) is (v:*mut <A as HasTarget>::Target,g:()) {
true
}
}
}
} => Ok(())
}
Loading