diff --git a/source/builtin/src/lib.rs b/source/builtin/src/lib.rs index 56966a312b..9154917595 100644 --- a/source/builtin/src/lib.rs +++ b/source/builtin/src/lib.rs @@ -23,6 +23,29 @@ pub fn admit() { unimplemented!(); } +/// Pass tracked or ghost values to the immediately following external function call. +/// Used with `external_fn_specification` functions that have extra tracked/ghost parameters. +/// Pass ghost/tracked arguments to a function call with proper borrow checking. +/// `_a` is a tuple of `Tracked` or `Ghost` values. +/// `_b` is the function call expression whose result is returned. +/// Usage: `proof_with((Tracked(&mut x), Ghost(y)), f(a))` +#[cfg(verus_keep_ghost)] +#[rustc_diagnostic_item = "verus::verus_builtin::proof_with"] +#[verifier::proof] +pub fn proof_with(_a: A, _b: B) -> B { + unimplemented!(); +} + +/// Zero-arg version of proof_with for output parameters. +/// The type parameter indicates the output container type. + +#[cfg(verus_keep_ghost)] +#[rustc_diagnostic_item = "verus::verus_builtin::declare_with"] +#[verifier::proof] +pub fn declare_with() -> A { + unimplemented!(); +} + // Can only appear at beginning of function body #[cfg(verus_keep_ghost)] #[rustc_diagnostic_item = "verus::verus_builtin::no_method_body"] diff --git a/source/rust_verify/src/context.rs b/source/rust_verify/src/context.rs index b4e49dc7bd..76df710da6 100644 --- a/source/rust_verify/src/context.rs +++ b/source/rust_verify/src/context.rs @@ -8,7 +8,7 @@ use rustc_mir_build_verus::verus::BodyErasure; use rustc_span::SpanData; use rustc_span::def_id::DefId; use std::cell::RefCell; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::ops::DerefMut; use std::rc::Rc; use std::sync::Arc; @@ -16,6 +16,17 @@ use std::sync::atomic::AtomicU64; use vir::ast::{CrateId, Mode, Path, Pattern, VirErr}; use vir::messages::AstId; +/// A pending tracked/ghost argument from a `proof_with()` call, waiting to be +/// appended to the next function call's argument list. +pub(crate) struct PendingTrackedArg { + /// The VIR expression for the argument value + pub expr: vir::ast::Expr, + /// true for Tracked, false for Ghost + pub is_tracked: bool, + /// HirId of the proof_with argument expression (for type lookup) + pub arg_hir_id: HirId, +} + pub struct ErasureInfo { pub(crate) hir_vir_ids: Vec<(HirId, AstId)>, pub(crate) resolved_calls: Vec<(HirId, SpanData, ResolvedCall)>, @@ -48,6 +59,10 @@ pub struct ContextX<'tcx> { pub(crate) crate_name: CrateId, pub(crate) name_def_id_map: Rc>>, pub(crate) next_read_kind_id: AtomicU64, + /// For functions with extra ghost/tracked params (from declare_with() stmts): + /// maps the function's DefId to a Vec of (is_tracked, expected_ty) pairs + pub(crate) declare_with_params: + Rc)>>>>, } /// The context in which a given header node might be interpretted @@ -92,6 +107,11 @@ pub(crate) struct BodyCtxt<'tcx> { /// Assume specification defines a new opaque type for each opaque type in the external function. /// We use this map to resolve them later. pub(crate) external_opaque_type_map: Option>, + /// Pending tracked/ghost args from proof_with() calls, to be consumed by the next function call. + /// Set to Some(...) by ProofWith handler, taken (consumed) by fn_call_to_vir. + pub(crate) pending_tracked_args: Rc>>>, + /// HirIds of declare_with() let-stmts to skip during body conversion + pub(crate) declare_with_hir_ids: Rc>, } impl<'tcx> ContextX<'tcx> { @@ -117,6 +137,7 @@ impl<'tcx> ContextX<'tcx> { crate_name, name_def_id_map: Rc::new(RefCell::new(HashMap::new())), next_read_kind_id: AtomicU64::new(0), + declare_with_params: Rc::new(RefCell::new(HashMap::new())), } } diff --git a/source/rust_verify/src/fn_call_to_vir.rs b/source/rust_verify/src/fn_call_to_vir.rs index 60a8b8d348..9d3bb81a5c 100644 --- a/source/rust_verify/src/fn_call_to_vir.rs +++ b/source/rust_verify/src/fn_call_to_vir.rs @@ -42,6 +42,8 @@ use vir::ast_util::{ }; use vir::def::field_ident_from_rust; +use crate::proof_with_lifetime::check_proof_with_lifetime; + pub(crate) fn fn_call_to_vir<'tcx>( bctx: &BodyCtxt<'tcx>, expr: &Expr<'tcx>, @@ -304,6 +306,95 @@ fn fn_call_or_assoc_const_to_vir<'tcx>( let vir_args = if let Some(args) = args { mk_vir_args(bctx, &args)? } else { vec![] }; + // Consume pending tracked args from proof_with() if present. + // Only consume when callee expects extra params (has declare_with entries). + let vir_args = { + let mut args = vir_args; + let extra_params = bctx.ctxt.declare_with_params.borrow().get(&f).cloned(); + if let Some(ref expected_params) = extra_params { + let extra_count = expected_params.len(); + let mut pending_opt = bctx.pending_tracked_args.borrow_mut(); + let pending = match pending_opt.take() { + Some(p) => p, + None => { + return err_span( + expr.span, + format!( + "this external function requires {} extra tracked/ghost argument(s) via proof_with()", + extra_count + ), + ); + } + }; + if pending.len() != extra_count { + return err_span( + expr.span, + format!( + "expected {} tracked/ghost argument(s) via proof_with(), got {}", + extra_count, + pending.len() + ), + ); + } + // Check mode and type for each pending arg + for (i, (pending_arg, (expected_is_tracked, expected_ty))) in + pending.iter().zip(expected_params.iter()).enumerate() + { + // Mode check + if pending_arg.is_tracked != *expected_is_tracked { + let expected_mode = if *expected_is_tracked { "Tracked" } else { "Ghost" }; + let actual_mode = if pending_arg.is_tracked { "Tracked" } else { "Ghost" }; + return err_span( + expr.span, + format!( + "proof_with argument {} has wrong mode: expected {}, got {}", + i + 1, + expected_mode, + actual_mode, + ), + ); + } + // Type check: compare rustc types with regions erased. + { + let tcx = bctx.ctxt.tcx; + let expected_ty_instantiated = + rustc_middle::ty::EarlyBinder::bind(*expected_ty) + .instantiate(tcx, node_substs); + let actual_ty = bctx.types.node_type(pending_arg.arg_hir_id); + use rustc_middle::ty::TypeFoldable; + let expected_erased = expected_ty_instantiated.fold_with( + &mut rustc_middle::ty::RegionFolder::new(tcx, &mut |_, _| { + tcx.lifetimes.re_erased + }), + ); + if actual_ty != expected_erased { + return err_span( + expr.span, + format!( + "proof_with argument {} has wrong type: expected `{}`, got `{}`", + i + 1, + expected_ty_instantiated, + actual_ty, + ), + ); + } + check_proof_with_lifetime( + bctx, + f, + expr.hir_id, + pending_arg.arg_hir_id, + *expected_ty, + expr.span, + i, + )?; + } + } + let exprs: Vec<_> = pending.into_iter().map(|a| a.expr).collect(); + args.extend(exprs); + } + args + }; + let typ_args = mk_typ_args(bctx, node_substs, f, expr.span)?; let impl_paths = get_impl_paths(bctx, f, node_substs, None, const_var, expr.span)?; let target = @@ -2109,6 +2200,65 @@ fn verus_item_to_vir<'tcx, 'a>( let p = crate::rust_to_vir_expr::simplify_place_by_cancelling(&p); mk_expr(ExprX::BorrowMutTracked(p)) } + VerusItem::ProofWith => { + // proof_with(ghost_args, call) — ghost args and call in one expression. + // First arg: single Tracked/Ghost or tuple of them. + // Second arg: the actual function call, whose result is returned. + unsupported_err_unless!( + args_len == 2, + expr.span, + "expected proof_with(ghost_args, call)", + &args + ); + + // Extract individual ghost/tracked items from the first arg. + let ghost_arg = &args[0]; + let ghost_items: Vec<&rustc_hir::Expr> = match &ghost_arg.kind { + rustc_hir::ExprKind::Tup(elems) => elems.iter().collect(), + _ => vec![ghost_arg], + }; + + let bctx_ghost = &BodyCtxt { in_ghost: true, ..bctx.clone() }; + let mut pending_args = Vec::new(); + for item in &ghost_items { + let arg_typ = typ_of_expr_adjusted(bctx, item.span, &item.hir_id)?; + let is_tracked = match &*arg_typ { + TypX::Decorate(TypDecoration::Tracked, _, _) => true, + TypX::Decorate(TypDecoration::Ghost, _, _) => false, + _ => { + return err_span( + item.span, + "proof_with expects arguments of type Tracked or Ghost", + ); + } + }; + let arg_expr = expr_to_vir_consume(bctx_ghost, item)?; + pending_args.push(crate::context::PendingTrackedArg { + expr: arg_expr, + is_tracked, + arg_hir_id: item.hir_id, + }); + } + + // Set pending args for consumption by the inner fn_call_to_vir + *bctx.pending_tracked_args.borrow_mut() = Some(pending_args); + + // Record erasure: replace proof_with(A, B) with just B (arg index 1) + record_call(bctx, expr, ResolvedCall::SpecAllowProofArgs); + + // Process second arg (the function call) — this will consume pending args + let call_expr = expr_to_vir_consume(bctx, &args[1])?; + + // Assert pending args were consumed + if bctx.pending_tracked_args.borrow().is_some() { + return err_span( + expr.span, + "proof_with second argument must be a function call that accepts extra tracked/ghost arguments", + ); + } + + Ok(call_expr) + } VerusItem::BuiltinDeref(d) => { // This would be easy to support (similar to handling borrow_mut etc.) but their usage // would be very rare so I'm skipping for now @@ -2126,6 +2276,11 @@ fn verus_item_to_vir<'tcx, 'a>( ) .help("you can implicitly dereference this type using `*`")); } + VerusItem::DeclareWith => { + // declare_with() is handled at let-stmt level + // in rust_to_vir_expr.rs. If we reach here, it's used outside a let-stmt. + return err_span(expr.span, "declare_with() must be used as a let initializer"); + } VerusItem::Vstd(_, _) | VerusItem::Marker(_) | VerusItem::BuiltinType(_) diff --git a/source/rust_verify/src/lib.rs b/source/rust_verify/src/lib.rs index 2a55714363..5ce7189ded 100644 --- a/source/rust_verify/src/lib.rs +++ b/source/rust_verify/src/lib.rs @@ -53,6 +53,7 @@ mod fn_call_to_vir; mod hir_hide_reveal_rewrite; mod import_export; pub mod profiler; +mod proof_with_lifetime; mod resolve_traits; pub mod reveal_hide; mod rust_intrinsics_to_vir; diff --git a/source/rust_verify/src/proof_with_lifetime.rs b/source/rust_verify/src/proof_with_lifetime.rs new file mode 100644 index 0000000000..7aeb03803f --- /dev/null +++ b/source/rust_verify/src/proof_with_lifetime.rs @@ -0,0 +1,287 @@ +//! Lifetime checking for proof_with / declare_with. +//! +//! These functions erase tracked/ghost parameters from fn signatures before the borrow +//! checker runs, so Rust's NLL cannot enforce lifetime constraints on those parameters. +//! This module performs a first-pass lifetime check during VIR lowering to catch +//! incompatible lifetimes early. +//! +//! Strategy: use rustc's `InferCtxt` to do a subtype check between the actual arg type +//! (with real caller regions) and the expected type (with caller regions derived from +//! exec arg matching). This reuses Rust's region constraint solver rather than +//! reimplementing outlives checking. + +use crate::context::BodyCtxt; +use crate::util::err_span; +use rustc_hir::def::Res; +use rustc_hir::{ExprKind, QPath}; +use rustc_middle::ty::error::TypeErrorToStringExt; +use rustc_span::Span; +use rustc_span::def_id::DefId; +use rustc_trait_selection::regions::InferCtxtRegionExt; +use vir::ast::VirErr; + +/// Check that a proof_with argument's lifetime is compatible with the callee's expected lifetime. +/// +/// The expected type (from `lower_ty` on the callee's `declare_with` HIR type) has +/// `ReLateParam(callee, 'a)` regions. We need to: +/// 1. Map callee's late-bound `'a` to the caller's corresponding lifetime via the call args +/// 2. Get the proof_with arg's lifetime from its declaration +/// 3. Check that the arg's lifetime outlives the expected lifetime in the caller's scope +pub(crate) fn check_proof_with_lifetime<'tcx>( + bctx: &BodyCtxt<'tcx>, + callee_def_id: DefId, + call_hir_id: rustc_hir::HirId, + arg_hir_id: rustc_hir::HirId, + expected_ty_raw: rustc_middle::ty::Ty<'tcx>, + call_span: Span, + arg_index: usize, +) -> Result<(), VirErr> { + use rustc_infer::infer::TyCtxtInferExt; + use rustc_middle::ty::{Region, RegionKind, TypeFoldable, TypeVisitable, TypeVisitor}; + use std::ops::ControlFlow; + + let tcx = bctx.ctxt.tcx; + + // --- Helpers --- + + struct RegionCollector<'tcx> { + regions: Vec>, + } + impl<'tcx> TypeVisitor> for RegionCollector<'tcx> { + type Result = ControlFlow<()>; + fn visit_region(&mut self, r: Region<'tcx>) -> Self::Result { + self.regions.push(r); + ControlFlow::Continue(()) + } + } + + fn collect_regions<'tcx>(ty: rustc_middle::ty::Ty<'tcx>) -> Vec> { + let mut collector = RegionCollector { regions: vec![] }; + let _ = ty.visit_with(&mut collector); + collector.regions + } + + /// Get DefId for a named region in a given function's scope. + fn region_def_id<'tcx>( + tcx: rustc_middle::ty::TyCtxt<'tcx>, + owner_def_id: DefId, + region: Region<'tcx>, + ) -> Option { + match region.kind() { + RegionKind::ReLateParam(lp) => match lp.kind { + rustc_middle::ty::LateParamRegionKind::Named(def_id) => Some(def_id), + _ => None, + }, + RegionKind::ReEarlyParam(ep) => { + let generics = tcx.generics_of(owner_def_id); + let param = generics.param_at(ep.index as usize, tcx); + Some(param.def_id) + } + RegionKind::ReBound(_, br) => { + if let rustc_middle::ty::BoundRegionKind::Named(def_id) = br.kind { + Some(def_id) + } else { + None + } + } + _ => None, + } + } + + // --- Step 1: Check if expected type has any regions worth checking --- + + let raw_regions = collect_regions(expected_ty_raw); + let callee_region_def_ids: Vec<_> = + raw_regions.iter().filter_map(|r| region_def_id(tcx, callee_def_id, *r)).collect(); + + if callee_region_def_ids.is_empty() { + return Ok(()); // No lifetime params in the ghost type + } + + // --- Step 2: Get the actual arg type with real regions --- + + let caller_def_id = bctx.fun_id; + let caller_poly_sig = tcx.fn_sig(caller_def_id).instantiate_identity(); + let caller_liberated = tcx.liberate_late_bound_regions(caller_def_id, caller_poly_sig); + let caller_inputs = caller_liberated.inputs(); + + let param_idx = resolve_expr_to_param_index(tcx, caller_def_id, arg_hir_id); + let actual_ty_with_regions = match param_idx { + Some(idx) if idx < caller_inputs.len() => caller_inputs[idx], + _ => return Ok(()), // Can't determine actual type with regions + }; + + // --- Step 3: Build callee→caller region mapping from exec args --- + + let callee_poly_sig = tcx.fn_sig(callee_def_id).instantiate_identity(); + let callee_sig_inputs = callee_poly_sig.skip_binder().inputs(); + + // Extract call arg HirIds from the HIR + let call_arg_hir_ids: Vec = { + let hir_node = tcx.hir_node(call_hir_id); + match hir_node { + rustc_hir::Node::Expr(expr) => match &expr.kind { + ExprKind::Call(_, hir_args) => hir_args.iter().map(|a| a.hir_id).collect(), + ExprKind::MethodCall(_, receiver, hir_args, _) => { + let mut ids = vec![receiver.hir_id]; + ids.extend(hir_args.iter().map(|a| a.hir_id)); + ids + } + _ => return Ok(()), + }, + _ => return Ok(()), + } + }; + + // Map callee region DefId → caller Region by matching callee fn_sig params with call args + let mut callee_to_caller: std::collections::HashMap> = + std::collections::HashMap::new(); + + for (param_idx, callee_param_ty) in callee_sig_inputs.iter().enumerate() { + let param_regions: Vec<_> = collect_regions(*callee_param_ty) + .into_iter() + .filter_map(|r| { + let did = region_def_id(tcx, callee_def_id, r)?; + Some(did) + }) + .collect(); + + if param_regions.is_empty() || param_idx >= call_arg_hir_ids.len() { + continue; + } + + let arg_hir = call_arg_hir_ids[param_idx]; + if let Some(caller_param_idx) = resolve_expr_to_param_index(tcx, caller_def_id, arg_hir) { + if caller_param_idx < caller_inputs.len() { + let caller_param_ty = caller_inputs[caller_param_idx]; + let caller_regions: Vec<_> = collect_regions(caller_param_ty) + .into_iter() + .filter(|r| { + matches!(r.kind(), RegionKind::ReLateParam(_) | RegionKind::ReEarlyParam(_)) + }) + .collect(); + + for (callee_did, caller_region) in param_regions.iter().zip(caller_regions.iter()) { + callee_to_caller.insert(*callee_did, *caller_region); + } + } + } + } + + // --- Step 4: Build expected type with caller regions using InferCtxt --- + + let infcx = tcx.infer_ctxt().build(rustc_type_ir::TypingMode::PostAnalysis); + + // For ghost-only regions not found in exec args, create fresh region variables + for callee_did in &callee_region_def_ids { + if !callee_to_caller.contains_key(callee_did) { + let fresh = + infcx.next_region_var(rustc_infer::infer::RegionVariableOrigin::Misc(call_span)); + callee_to_caller.insert(*callee_did, fresh); + } + } + + // Fold expected_ty_raw, replacing callee regions with mapped caller regions + let expected_ty_caller = + expected_ty_raw.fold_with(&mut rustc_middle::ty::RegionFolder::new(tcx, &mut |r, _| { + if let Some(callee_did) = region_def_id(tcx, callee_def_id, r) { + if let Some(caller_region) = callee_to_caller.get(&callee_did) { + return *caller_region; + } + } + r + })); + + // --- Step 5: Add callee's where-clause region bounds as constraints --- + + let callee_predicates = tcx.predicates_of(callee_def_id); + for (pred, _) in callee_predicates.predicates { + if let Some(rustc_middle::ty::ClauseKind::RegionOutlives(outlives)) = + pred.kind().no_bound_vars() + { + let longer_did = region_def_id(tcx, callee_def_id, outlives.0); + let shorter_did = region_def_id(tcx, callee_def_id, outlives.1); + + if let (Some(longer_did), Some(shorter_did)) = (longer_did, shorter_did) { + if let (Some(&caller_longer), Some(&caller_shorter)) = + (callee_to_caller.get(&longer_did), callee_to_caller.get(&shorter_did)) + { + // Add constraint: caller_longer: caller_shorter + // (sub_regions(a, b) means a <= b, i.e., b outlives a) + infcx.sub_regions( + rustc_infer::infer::SubregionOrigin::RelateParamBound( + call_span, + actual_ty_with_regions, + None, + ), + caller_shorter, + caller_longer, + ); + } + } + } + } + + // --- Step 6: Subtype check: actual_ty <: expected_ty --- + + let param_env = tcx.param_env(caller_def_id); + let cause = rustc_infer::traits::ObligationCause::dummy(); + + if let Err(error) = infcx.at(&cause, param_env).sub( + rustc_infer::infer::DefineOpaqueTypes::Yes, + actual_ty_with_regions, + expected_ty_caller, + ) { + return err_span( + call_span, + format!( + "proof_with argument {} has incompatible lifetime: {}", + arg_index + 1, + error.to_string(tcx) + ), + ); + } + + // --- Step 7: Resolve region constraints and check for errors --- + + let assumed_wf: Vec> = vec![]; + let errors = infcx.resolve_regions(caller_def_id.expect_local(), param_env, assumed_wf); + + if !errors.is_empty() { + return err_span( + call_span, + format!( + "proof_with argument {} has incompatible lifetime: {:?}", + arg_index + 1, + errors + ), + ); + } + + Ok(()) +} + +/// Try to resolve an expression (by HirId) to a function parameter index. +fn resolve_expr_to_param_index<'tcx>( + tcx: rustc_middle::ty::TyCtxt<'tcx>, + fn_def_id: DefId, + arg_hir_id: rustc_hir::HirId, +) -> Option { + let hir_node = tcx.hir_node(arg_hir_id); + let expr = match hir_node { + rustc_hir::Node::Expr(e) => e, + _ => return None, + }; + if let ExprKind::Path(QPath::Resolved(None, path)) = &expr.kind { + if let Res::Local(local_hir_id) = path.res { + let local_def_id = fn_def_id.as_local()?; + let body = tcx.hir_body_owned_by(local_def_id); + for (i, param) in body.params.iter().enumerate() { + if param.pat.hir_id == local_hir_id { + return Some(i); + } + } + } + } + None +} diff --git a/source/rust_verify/src/rust_to_vir_expr.rs b/source/rust_verify/src/rust_to_vir_expr.rs index b20cc1710a..57f498d27e 100644 --- a/source/rust_verify/src/rust_to_vir_expr.rs +++ b/source/rust_verify/src/rust_to_vir_expr.rs @@ -3572,6 +3572,10 @@ pub(crate) fn stmt_to_vir<'tcx>( return Ok(vec![]); } } + // Skip declare_with() stmts (handled as extra params) + if bctx.declare_with_hir_ids.contains(&stmt.hir_id) { + return Ok(vec![]); + } let_stmt_to_vir(bctx, pat, init, els, bctx.ctxt.tcx.hir_attrs(stmt.hir_id)) } diff --git a/source/rust_verify/src/rust_to_vir_func.rs b/source/rust_verify/src/rust_to_vir_func.rs index 41ebedd5f0..3ed7a2a614 100644 --- a/source/rust_verify/src/rust_to_vir_func.rs +++ b/source/rust_verify/src/rust_to_vir_func.rs @@ -291,6 +291,7 @@ fn mk_bctx<'tcx>( migrate_postcondition_vars: Option>, param_names: Vec, external_opaque_type_map: Option>, + declare_with_hir_ids: HashSet, ) -> BodyCtxt<'tcx> { BodyCtxt { ctxt: ctxt.clone(), @@ -310,9 +311,135 @@ fn mk_bctx<'tcx>( header_setting: HeaderSetting::Fn, unwrap_param_map: std::rc::Rc::new(std::cell::RefCell::new(HashMap::new())), external_opaque_type_map, + pending_tracked_args: std::rc::Rc::new(std::cell::RefCell::new(None)), + declare_with_hir_ids: std::rc::Rc::new(declare_with_hir_ids), } } +/// Pre-scan the HIR body for `declare_with()` let-stmts. +/// Returns (extra_vir_params, hir_ids_to_skip) so that: +/// 1. Extra params can be appended to `vir_params` before body conversion +/// 2. `stmt_to_vir` can skip these let-stmts during body conversion +fn pre_scan_declare_with_params<'tcx>( + ctxt: &Context<'tcx>, + id: DefId, + body: &Body<'tcx>, + body_id: &BodyId, +) -> Result< + (Vec<(vir::ast::Param, Option, rustc_middle::ty::Ty<'tcx>)>, HashSet), + VirErr, +> { + let mut extra_params = Vec::new(); + let mut hir_ids = HashSet::new(); + let types = body_id_to_types(ctxt.tcx, body_id); + + // Navigate into the body's top-level block + let stmts = match &body.value.kind { + ExprKind::Block(block, _) => block.stmts, + _ => return Ok((extra_params, hir_ids)), + }; + + for stmt in stmts { + if let rustc_hir::StmtKind::Let(rustc_hir::LetStmt { + pat, + init: Some(init), + ty: hir_ty, + .. + }) = &stmt.kind + { + // Check if init is a call to declare_with() + let verus_item = match &init.kind { + ExprKind::Call(fun, _) => match &fun.kind { + ExprKind::Path(rustc_hir::QPath::Resolved( + None, + rustc_hir::Path { res: rustc_hir::def::Res::Def(_, fun_id), .. }, + )) => ctxt.get_verus_item(*fun_id).cloned(), + _ => None, + }, + _ => None, + }; + + let is_declare_with = match verus_item { + Some(VerusItem::DeclareWith) => true, + _ => false, + }; + if !is_declare_with { + continue; + } + + // Require simple binding pattern + let (is_mut_var, name) = pat_to_mut_var(pat)?; + + // Get the resolved type. Use lower_ty on the HIR type annotation if available, + // because it preserves early-bound regions (ReEarlyParam) which are needed for + // lifetime checking. typeck's node_type() returns types with erased regions. + let ty = if let Some(hir_ty) = hir_ty { + rustc_hir_analysis::lower_ty(ctxt.tcx, hir_ty) + } else { + types.node_type(init.hir_id) + }; + + // Derive is_tracked from the type (Tracked vs Ghost) via ADT DefId + let is_tracked = match ty.kind() { + rustc_middle::ty::TyKind::Adt(adt_def, _) + if matches!( + ctxt.get_verus_item(adt_def.did()), + Some(VerusItem::BuiltinType(crate::verus_items::BuiltinTypeItem::Tracked)) + ) => + { + true + } + rustc_middle::ty::TyKind::Adt(adt_def, _) + if matches!( + ctxt.get_verus_item(adt_def.did()), + Some(VerusItem::BuiltinType(crate::verus_items::BuiltinTypeItem::Ghost)) + ) => + { + false + } + _ => { + return err_span( + init.span, + "declare_with() must be assigned to a Tracked or Ghost type", + ); + } + }; + let inner_mode = if is_tracked { Mode::Proof } else { Mode::Spec }; + + // Handle &mut types - check if the type contains a mutable reference + let is_ref_mut = is_mut_ty(ctxt, ty); + let is_mut = is_ref_mut.is_some(); + // Use mid_ty_to_vir on the FULL type (matching normal param processing in + // check_fn_decl). Don't manually unwrap the MutRef — it must be preserved + // in the typ so that VIR can properly generate mut_ref_current% / update% AIR. + let typ = ctxt.mid_ty_to_vir(id, pat.span, &ty, None)?; + + // All declare_with params are inputs — unwrap Ghost/Tracked as usual + let outer_name = vir::ast_util::air_unique_var(&format!( + "declare_with_{}", + vir::def::user_local_name(&name) + )); + let unwrapped_info = Some((inner_mode, outer_name)); + let param_mode = Mode::Exec; + let vir_param = ctxt.spanned_new( + pat.span, + ParamX { + name: name.clone(), + typ, + mode: param_mode, + unwrapped_info, + user_mut: is_mut_var || is_mut, + }, + ); + + extra_params.push((vir_param, None, ty)); + hir_ids.insert(stmt.hir_id); + } + } + + Ok((extra_params, hir_ids)) +} + fn body_to_vir<'tcx>( ctxt: &Context<'tcx>, fun_id: DefId, @@ -325,6 +452,7 @@ fn body_to_vir<'tcx>( param_names: Vec, external_opaque_type_map: Option>, is_async: bool, + declare_with_hir_ids: HashSet, ) -> Result { let bctx = mk_bctx( ctxt, @@ -336,6 +464,7 @@ fn body_to_vir<'tcx>( migrate_postcondition_vars, param_names, external_opaque_type_map, + declare_with_hir_ids, ); let body_expr = if is_async { extract_desugared_async_body(&bctx.ctxt, body)? } else { &body.value }; @@ -1088,7 +1217,9 @@ fn equalize_substs<'tcx>( let mut l1 = vec![]; let mut l2 = vec![]; - if substs1_early.len() + num_late1 != substs2_early.len() + num_late2 { + // Allow proxy (sig1) to have more total generics than external (sig2), + // since extra ghost/tracked params may introduce additional late-bound lifetimes. + if substs1_early.len() + num_late1 < substs2_early.len() + num_late2 { return None; } @@ -1724,6 +1855,57 @@ pub(crate) fn check_item_fn<'tcx>( rustc_hir::IsAsync::Async(..) => true, }; let body = find_body(ctxt, body_id); + + // Pre-scan for declare_with() calls + let (declare_with_extra_params, declare_with_hir_ids) = + pre_scan_declare_with_params(ctxt, id, body, body_id)?; + let declare_with_modes: Vec<(bool, rustc_middle::ty::Ty<'tcx>)> = + declare_with_extra_params + .iter() + .map(|(p, _, ty)| { + // unwrapped_info mode: Proof = Tracked, Spec = Ghost + match p.x.unwrapped_info { + Some((Mode::Proof, _)) => (true, *ty), + Some((Mode::Spec, _)) => (false, *ty), + _ => unreachable!(), + } + }) + .collect(); + for (p, _mode, _) in declare_with_extra_params { + vir_params.push(p); + } + + // Register declare_with extra param modes early so callers can find them + // even if body_to_vir fails for this function. + if !declare_with_modes.is_empty() { + let target_id = proxy_id.unwrap_or(id); + ctxt.declare_with_params.borrow_mut().insert(target_id, declare_with_modes.clone()); + + // For unerased_proxy functions (const fn proxies), also register under + // the original function's DefId, since call sites resolve to the original. + if vattrs.unerased_proxy { + let proxy_name = ctxt.tcx.item_name(id).as_str().to_string(); + let prefix = "VERUS_UNERASED_PROXY__"; + if let Some(original_name) = proxy_name.strip_prefix(prefix) { + let parent = ctxt.tcx.parent(id); + // Find sibling with matching name by iterating HIR items in the parent + for item_id in ctxt.tcx.hir_free_items() { + let child_def_id = item_id.owner_id.to_def_id(); + if let Some(name) = ctxt.tcx.opt_item_name(child_def_id) { + if name.as_str() == original_name + && child_def_id != id + && ctxt.tcx.parent(child_def_id) == parent + { + ctxt.declare_with_params + .borrow_mut() + .insert(child_def_id, declare_with_modes.clone()); + } + } + } + } + } + } + let external_body = vattrs.external_body || vattrs.external_fn_specification; let param_names = vir_params.iter().map(|p| p.x.name.clone()).collect::>(); let mut vir_body = body_to_vir( @@ -1738,6 +1920,7 @@ pub(crate) fn check_item_fn<'tcx>( param_names, assume_specification_opaque_type_map.clone(), is_async, + declare_with_hir_ids, )?; let header = vir::headers::read_header(&mut vir_body, &vir::headers::HeaderAllows::All)?; @@ -2821,6 +3004,7 @@ pub(crate) fn check_item_const_or_static<'tcx>( vec![], None, false, + HashSet::new(), )?; let header = vir::headers::read_header( &mut vir_body, diff --git a/source/rust_verify/src/verus_items.rs b/source/rust_verify/src/verus_items.rs index c2e666216a..138ae50219 100644 --- a/source/rust_verify/src/verus_items.rs +++ b/source/rust_verify/src/verus_items.rs @@ -448,6 +448,8 @@ pub(crate) enum VerusItem { GetFirst, DummyCapture(DummyCaptureItem), MutRefTracked, + ProofWith, + DeclareWith, } #[derive(PartialEq, Eq, Debug, Clone, Hash)] @@ -460,6 +462,8 @@ pub(crate) enum DummyCaptureItem { #[rustfmt::skip] fn verus_items_map() -> Vec<(&'static str, VerusItem)> { vec![ + ("verus::verus_builtin::proof_with", VerusItem::ProofWith), + ("verus::verus_builtin::declare_with", VerusItem::DeclareWith), ("verus::verus_builtin::admit", VerusItem::Spec(SpecItem::Admit)), ("verus::verus_builtin::assume_", VerusItem::Spec(SpecItem::Assume)), ("verus::verus_builtin::no_method_body", VerusItem::Spec(SpecItem::NoMethodBody)), diff --git a/source/rust_verify_test/tests/proof_with.rs b/source/rust_verify_test/tests/proof_with.rs new file mode 100644 index 0000000000..b9e267c835 --- /dev/null +++ b/source/rust_verify_test/tests/proof_with.rs @@ -0,0 +1,329 @@ +#![feature(rustc_private)] +#[macro_use] +mod common; +use common::*; + +test_verify_one_file! { + #[test] test_proof_with verus_code!{ + use vstd::prelude::*; + fn test(a: u64) + { + let b: Tracked = declare_with(); + let c: Ghost = declare_with(); + requires(a == 0 && b@ == 1 && c@ == 2); + } + + fn call_test() { + proof_with((Tracked(1u64), Ghost(2u32)), test(0)); + } + + #[verifier(external)] + fn unverified_call_test() { + test(0); + } + } => Ok(()) +} + +test_verify_one_file! { + #[test] test_proof_with_impl verus_code!{ + use vstd::prelude::*; + struct A { + a: u64 + } + impl A { + fn test(&self) + { + let b: Tracked = declare_with(); + let c: Ghost = declare_with(); + requires(self.a == 0 && b@ == 1 && c@ == 2); + } + } + fn call_test() { + let a = A { a: 0 }; + proof_with((Tracked(1u64), Ghost(2u32)), a.test()); + } + + #[verifier(external)] + fn unverified_call_test() { + let a = A { a: 0 }; + a.test(); + } + } => Ok(()) +} + +test_verify_one_file! { + #[test] test_proof_with_trait verus_code!{ + use vstd::prelude::*; + struct A { + a: u64 + } + + trait AOp { + fn test(&self) { + let b: Tracked = declare_with(); + let c: Ghost = declare_with(); + requires(b@ == 1 && c@ == 2); + } + } + impl AOp for A { + fn test(&self) + { + let b: Tracked = declare_with(); + let c: Ghost = declare_with(); + assert(b@ == 1); + assert(c@ == 0); // FAILS + } + } + fn call_test() { + let a = A { a: 0 }; + proof_with((Tracked(1u64), Ghost(2u32)), a.test()); + } + + #[verifier(external)] + fn unverified_call_test() { + let a = A { a: 0 }; + a.test(); + } + } => Err(e) => assert_one_fails(e) +} + +test_verify_one_file! { + #[test] test_proof_with_external verus_code!{ + use vstd::prelude::*; + #[verifier(external)] + fn negate_bool(b: bool, x: u8) -> bool { + !b + } + + #[verifier(external_fn_specification)] + fn negate_bool_requires_ensures(b: bool, x: u8) -> bool + { + let extra: Tracked = declare_with(); + requires(x == extra@); + ensures(|ret_b: bool| ret_b == !b); + negate_bool(b, x) + } + + fn call_test() { + proof_with(Tracked(1u8), negate_bool(true, 1)); + } + + #[verifier(external)] + fn unverified_call_test() { + negate_bool(true, 1); + } + } => Ok(()) +} + +test_verify_one_file! { + #[test] test_proof_with_external_failed verus_code!{ + use vstd::prelude::*; + #[verifier(external)] + fn negate_bool(b: bool, x: u8) -> bool { + !b + } + + #[verifier(external_fn_specification)] + fn negate_bool_requires_ensures(b: bool, x: u8) -> bool + { + let extra: Tracked = declare_with(); + requires(x == extra@); + ensures(|ret_b: bool| ret_b == !b); + negate_bool(b, x) + } + + fn call_test() { + negate_bool(true, 1); + } + } => Err(e) => assert_vir_error_msg(e, "this external function requires 1 extra tracked/ghost argument(s) via proof_with()") +} + +test_verify_one_file! { + #[test] test_proof_with_failed_requires verus_code!{ + use vstd::prelude::*; + fn test(a: u64) + { + let b: Tracked = declare_with(); + let c: Ghost = declare_with(); + requires(a == 0 && b@ == 1 && c@ == 2); + } + + fn call_test() { + proof_with((Tracked(0u64), Ghost(2u32)), test(0)); // FAILS + } + } => Err(e) => assert_one_fails(e) +} + +test_verify_one_file! { + #[test] test_proof_with_invalid_type verus_code!{ + use vstd::prelude::*; + fn test(a: u64) + { + let b: Tracked = declare_with(); + requires(a == 0 && b@ == 1); + } + + fn call_test() { + proof_with(0u64, test(0)); + } + } => Err(e) => assert_vir_error_msg(e, "proof_with expects arguments of type Tracked or Ghost") +} + +test_verify_one_file! { + #[test] test_proof_with_wrong_mode_type verus_code!{ + use vstd::prelude::*; + fn test(a: u64) + { + let b: Tracked = declare_with(); + requires(a == 0 && b@ == 1); + } + + fn call_test() { + proof_with(Ghost(0u64), test(0)); + } + } => Err(e) => assert_vir_error_msg(e, "proof_with argument 1 has wrong mode: expected Tracked, got Ghost") +} + +// ---- Lifetime soundness tests ---- +// These tests verify that lifetime constraints on tracked/ghost params are properly checked. + +test_verify_one_file! { + #[test] test_proof_with_lifetime_mismatch verus_code!{ + use vstd::prelude::*; + fn test<'a>(a: &'a u64, b: u64) -> u64 + { + let c: Tracked<&'a u64> = declare_with(); + 1 + } + + fn test2<'a, 'b>(a: &'a u64, b: u64, c: Tracked<&'b u64>) -> u64 + { + proof_with(c, test(a, b)) + } + } => Err(err) => assert_vir_error_msg(err, "proof_with argument 1 has incompatible lifetime") +} + +test_verify_one_file! { + #[test] test_proof_with_lifetime_compatible verus_code!{ + use vstd::prelude::*; + fn test<'a>(a: &'a u64, b: u64) -> u64 + { + let c: Tracked<&'a u64> = declare_with(); + 1 + } + + fn test2<'a, 'b: 'a>(a: &'a u64, b: u64, c: Tracked<&'b u64>) -> u64 + { + proof_with(c, test(a, b)) + } + } => Ok(()) +} + +test_verify_one_file! { + #[test] test_proof_with_lifetime_bound_mismatch verus_code!{ + use vstd::prelude::*; + fn test<'a, 'b: 'a>(a: &'a u64, b: u64) -> &'a u64 + { + let c: Tracked<&'b u64> = declare_with(); + a + } + + fn test2<'a, 'b>(a: &'a u64, b: u64, c: Tracked<&'b u64>) -> &'a u64 + { + proof_with(c, test(a, b)) + } + } => Err(err) => assert_vir_error_msg(err, "proof_with argument 1 has incompatible lifetime") +} + +// Same as test_proof_with_lifetime_mismatch but for Ghost. +test_verify_one_file! { + #[test] test_declare_with_ghost_lifetime_mismatch verus_code!{ + use vstd::prelude::*; + fn test<'a>(a: &'a u64) -> u64 + { + let g: Ghost<&'a u64> = declare_with(); + 1 + } + + fn test2<'a, 'b>(a: &'a u64, c: Ghost<&'b u64>) -> u64 + { + proof_with(c, test(a)) + } + } => Err(err) => assert_vir_error_msg(err, "proof_with argument 1 has incompatible lifetime") +} + +test_verify_one_file! { + #[test] test_proof_with_generic_type verus_code!{ + use vstd::prelude::*; + fn test(a: T) + { + let b: Tracked = declare_with(); + let c: Ghost = declare_with(); + requires(a === b@ && c@ == 2); + } + + fn call_test() { + proof_with((Tracked(0u64), Ghost(2u32)), test(0u64)); + } + + #[verifier(external)] + fn unverified_call_test() { + test(0u64); + } + } => Ok(()) +} + +test_verify_one_file! { + #[test] test_proof_with_generic_type2 verus_code!{ + use vstd::prelude::*; + trait X {} + fn test(a: T1, b: T2) + { + let c: Tracked = declare_with(); + let d: Ghost = declare_with(); + } + + fn call_test(a: T1, b: T2, c: Tracked, d: Ghost) { + proof_with((c, d), test(a, b)); + } + } => Ok(()) +} + +test_verify_one_file! { + #[test] test_proof_with_generic_type_wrong_type verus_code!{ + use vstd::prelude::*; + fn test(a: T) + { + let b: Tracked = declare_with(); + let c: Ghost = declare_with(); + requires(a === b@ && c@ == 2); + } + + fn call_test() { + proof_with((Tracked(0u8), Ghost(2u32)), test(0u64)); + } + + #[verifier(external)] + fn unverified_call_test() { + test(0u64); + } + } => Err(e) => assert_vir_error_msg(e, "proof_with argument 1 has wrong type") +} + +test_verify_one_file! { + #[test] test_proof_with_ownership verus_code!{ + use vstd::prelude::*; + + struct A; + + fn test<'a>(a: &'a mut A) + { + let b: Tracked<&'a mut A> = declare_with(); + let c: Ghost = declare_with(); + } + + fn call_test(mut a: A, mut b: A) { + proof_with((Tracked(&mut a), Ghost(2u32)), test(&mut a)); + } + } => Err(e) => assert_rust_error_msg_skip_spec_msgs(e, "cannot borrow `a` as mutable more than once at a time") +}