From 188e57c5bcff0324462eeba8e344c765a9221be2 Mon Sep 17 00:00:00 2001 From: Ziqiao Zhou Date: Tue, 12 May 2026 00:02:21 +0000 Subject: [PATCH 1/4] Add proof_with() and declare_with() builtin functions for ghost/tracked args Why: to allow with spec in verus_spec for traits and external functions. - Add builtin API: proof_with(), declare_with() (plus aliases declare_with_tracked/declare_with_ghost for backward compatibility) - Infer Tracked vs Ghost mode from the type annotation via ADT DefId - Add VIR lowering in fn_call_to_vir.rs with type and mode checking - Add pre-scan for declare_with params in rust_to_vir_func.rs - Add proof_with test suite - Add first-pass lifetime checking to ensure that tracked/ghost arguments passed via proof_with() satisfy the lifetime constraints declared by declare_with() in the callee. - Add proof_with_lifetime.rs module with region-aware lifetime checking - Use rustc_hir_analysis::lower_ty() to preserve real regions instead of erased regions from typeck writeback - Check that argument regions outlive expected parameter regions using the function's where-clause predicates - Add test cases for lifetime mismatch detection (both Tracked and Ghost) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Ziqiao Zhou --- source/builtin/src/lib.rs | 20 + source/rust_verify/src/context.rs | 24 +- source/rust_verify/src/fn_call_to_vir.rs | 145 ++++++- source/rust_verify/src/lib.rs | 1 + source/rust_verify/src/proof_with_lifetime.rs | 355 ++++++++++++++++++ source/rust_verify/src/rust_to_vir_expr.rs | 4 + source/rust_verify/src/rust_to_vir_func.rs | 185 ++++++++- source/rust_verify/src/verus_items.rs | 4 + source/rust_verify_test/tests/proof_with.rs | 260 +++++++++++++ 9 files changed, 995 insertions(+), 3 deletions(-) create mode 100644 source/rust_verify/src/proof_with_lifetime.rs create mode 100644 source/rust_verify_test/tests/proof_with.rs diff --git a/source/builtin/src/lib.rs b/source/builtin/src/lib.rs index 56966a312b..0a3ee74282 100644 --- a/source/builtin/src/lib.rs +++ b/source/builtin/src/lib.rs @@ -23,6 +23,26 @@ pub fn admit() { unimplemented!(); } +/// Pass tracked or ghost values to the immediately following external function call. +/// Used with `external_fn_specification` functions that have extra tracked/ghost parameters. +/// The argument must be of type `Tracked` or `Ghost`. +#[cfg(verus_keep_ghost)] +#[rustc_diagnostic_item = "verus::verus_builtin::proof_with"] +#[verifier::proof] +pub fn proof_with(_a: A) { + unimplemented!(); +} + +/// Zero-arg version of proof_with for output parameters. +/// The type parameter indicates the output container type. + +#[cfg(verus_keep_ghost)] +#[rustc_diagnostic_item = "verus::verus_builtin::declare_with"] +#[verifier::proof] +pub fn declare_with() -> A { + unimplemented!(); +} + // Can only appear at beginning of function body #[cfg(verus_keep_ghost)] #[rustc_diagnostic_item = "verus::verus_builtin::no_method_body"] diff --git a/source/rust_verify/src/context.rs b/source/rust_verify/src/context.rs index b4e49dc7bd..a8446a0abc 100644 --- a/source/rust_verify/src/context.rs +++ b/source/rust_verify/src/context.rs @@ -8,7 +8,7 @@ use rustc_mir_build_verus::verus::BodyErasure; use rustc_span::SpanData; use rustc_span::def_id::DefId; use std::cell::RefCell; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::ops::DerefMut; use std::rc::Rc; use std::sync::Arc; @@ -16,6 +16,17 @@ use std::sync::atomic::AtomicU64; use vir::ast::{CrateId, Mode, Path, Pattern, VirErr}; use vir::messages::AstId; +/// A pending tracked/ghost argument from a `proof_with()` call, waiting to be +/// appended to the next function call's argument list. +pub(crate) struct PendingTrackedArg { + /// The VIR expression for the argument value + pub expr: vir::ast::Expr, + /// true for Tracked, false for Ghost + pub is_tracked: bool, + /// HirId of the proof_with argument expression (for type lookup) + pub arg_hir_id: HirId, +} + pub struct ErasureInfo { pub(crate) hir_vir_ids: Vec<(HirId, AstId)>, pub(crate) resolved_calls: Vec<(HirId, SpanData, ResolvedCall)>, @@ -48,6 +59,10 @@ pub struct ContextX<'tcx> { pub(crate) crate_name: CrateId, pub(crate) name_def_id_map: Rc>>, pub(crate) next_read_kind_id: AtomicU64, + /// For external_fn_specification functions with extra ghost/tracked params: + /// maps the external function's DefId to a Vec of (is_tracked, expected_ty) pairs + pub(crate) external_fn_extra_tracked_params: + Rc)>>>>, } /// The context in which a given header node might be interpretted @@ -92,6 +107,12 @@ pub(crate) struct BodyCtxt<'tcx> { /// Assume specification defines a new opaque type for each opaque type in the external function. /// We use this map to resolve them later. pub(crate) external_opaque_type_map: Option>, + /// Pending tracked/ghost args from proof_with() calls, to be appended to the next function call. + pub(crate) pending_tracked_args: Rc>>, + /// Depth counter: >0 means we're inside argument processing, so don't consume pending_tracked_args. + pub(crate) in_args_depth: Rc>, + /// HirIds of declare_with_tracked()/declare_with_ghost() let-stmts to skip during body conversion + pub(crate) declare_with_hir_ids: Rc>, } impl<'tcx> ContextX<'tcx> { @@ -117,6 +138,7 @@ impl<'tcx> ContextX<'tcx> { crate_name, name_def_id_map: Rc::new(RefCell::new(HashMap::new())), next_read_kind_id: AtomicU64::new(0), + external_fn_extra_tracked_params: Rc::new(RefCell::new(HashMap::new())), } } diff --git a/source/rust_verify/src/fn_call_to_vir.rs b/source/rust_verify/src/fn_call_to_vir.rs index 60a8b8d348..90d325b5bb 100644 --- a/source/rust_verify/src/fn_call_to_vir.rs +++ b/source/rust_verify/src/fn_call_to_vir.rs @@ -42,6 +42,8 @@ use vir::ast_util::{ }; use vir::def::field_ident_from_rust; +use crate::proof_with_lifetime::check_proof_with_lifetime; + pub(crate) fn fn_call_to_vir<'tcx>( bctx: &BodyCtxt<'tcx>, expr: &Expr<'tcx>, @@ -302,7 +304,115 @@ fn fn_call_or_assoc_const_to_vir<'tcx>( record_call(bctx, expr, ResolvedCall::Call(name.clone(), record_name, bctx.in_ghost)); - let vir_args = if let Some(args) = args { mk_vir_args(bctx, &args)? } else { vec![] }; + let vir_args = if let Some(args) = args { + *bctx.in_args_depth.borrow_mut() += 1; + let result = mk_vir_args(bctx, &args); + *bctx.in_args_depth.borrow_mut() -= 1; + result? + } else { + vec![] + }; + + // Append any pending tracked args from proof_with() calls. + // If the function has extra tracked params but no proof_with was used, error. + // Skip consumption if we're inside argument processing (nested call) AND the callee + // doesn't expect extra params. If the callee does expect extra params (e.g., f() in + // f()? where ? desugars to Try::branch(f())), we must consume them here. + let vir_args = { + let mut args = vir_args; + let in_args = *bctx.in_args_depth.borrow() > 0; + let extra_params = bctx.ctxt.external_fn_extra_tracked_params.borrow().get(&f).cloned(); + if !in_args || extra_params.is_some() { + let mut pending = bctx.pending_tracked_args.borrow_mut(); + if let Some(ref expected_params) = extra_params { + let extra_count = expected_params.len(); + if pending.is_empty() { + return err_span( + expr.span, + format!( + "this external function requires {} extra tracked/ghost argument(s) via proof_with()", + extra_count + ), + ); + } + if pending.len() != extra_count { + return err_span( + expr.span, + format!( + "expected {} tracked/ghost argument(s) via proof_with(), got {}", + extra_count, + pending.len() + ), + ); + } + // Check mode and type for each pending arg + for (i, (pending_arg, (expected_is_tracked, expected_ty))) in + pending.iter().zip(expected_params.iter()).enumerate() + { + // Mode check + if pending_arg.is_tracked != *expected_is_tracked { + let expected_mode = if *expected_is_tracked { "Tracked" } else { "Ghost" }; + let actual_mode = if pending_arg.is_tracked { "Tracked" } else { "Ghost" }; + return err_span( + expr.span, + format!( + "proof_with argument {} has wrong mode: expected {}, got {}", + i + 1, + expected_mode, + actual_mode, + ), + ); + } + // Type check: compare rustc types with regions erased. + { + let tcx = bctx.ctxt.tcx; + let expected_ty_instantiated = + rustc_middle::ty::EarlyBinder::bind(*expected_ty) + .instantiate(tcx, node_substs); + let actual_ty = bctx.types.node_type(pending_arg.arg_hir_id); + use rustc_middle::ty::TypeFoldable; + let expected_erased = expected_ty_instantiated.fold_with( + &mut rustc_middle::ty::RegionFolder::new(tcx, &mut |_, _| { + tcx.lifetimes.re_erased + }), + ); + if actual_ty != expected_erased { + return err_span( + expr.span, + format!( + "proof_with argument {} has wrong type: expected `{}`, got `{}`", + i + 1, + expected_ty_instantiated, + actual_ty, + ), + ); + } + check_proof_with_lifetime( + bctx, + tcx, + f, + expr.hir_id, + node_substs, + pending_arg.arg_hir_id, + *expected_ty, + expected_ty_instantiated, + expr.span, + i, + )?; + } + } + let exprs: Vec<_> = pending.drain(..).map(|a| a.expr).collect(); + args.extend(exprs); + } else if !pending.is_empty() { + pending.drain(..); + return err_span( + expr.span, + "proof_with was used but this function does not expect extra tracked/ghost arguments", + ); + } + } + args + }; let typ_args = mk_typ_args(bctx, node_substs, f, expr.span)?; let impl_paths = get_impl_paths(bctx, f, node_substs, None, const_var, expr.span)?; @@ -2109,6 +2219,34 @@ fn verus_item_to_vir<'tcx, 'a>( let p = crate::rust_to_vir_expr::simplify_place_by_cancelling(&p); mk_expr(ExprX::BorrowMutTracked(p)) } + VerusItem::ProofWith => { + // proof_with(expr) stores the tracked/ghost arg to be appended to the next call + unsupported_err_unless!(args_len == 1, expr.span, "expected proof_with(expr)", &args); + let arg_typ = typ_of_expr_adjusted(bctx, args[0].span, &args[0].hir_id)?; + let is_tracked = match &*arg_typ { + TypX::Decorate(TypDecoration::Tracked, _, _) => true, + TypX::Decorate(TypDecoration::Ghost, _, _) => false, + _ => { + return err_span( + args[0].span, + "proof_with expects an argument of type Tracked or Ghost", + ); + } + }; + let bctx_ghost = &BodyCtxt { in_ghost: true, ..bctx.clone() }; + // Increment in_args_depth so nested function calls within the proof_with + // argument don't try to consume pending_tracked_args + *bctx.in_args_depth.borrow_mut() += 1; + let arg_expr = expr_to_vir_consume(bctx_ghost, &args[0])?; + *bctx.in_args_depth.borrow_mut() -= 1; + bctx.pending_tracked_args.borrow_mut().push(crate::context::PendingTrackedArg { + expr: arg_expr, + is_tracked, + arg_hir_id: args[0].hir_id, + }); + // Return a unit expression (no-op) + mk_expr(ExprX::Block(Arc::new(vec![]), None)) + } VerusItem::BuiltinDeref(d) => { // This would be easy to support (similar to handling borrow_mut etc.) but their usage // would be very rare so I'm skipping for now @@ -2126,6 +2264,11 @@ fn verus_item_to_vir<'tcx, 'a>( ) .help("you can implicitly dereference this type using `*`")); } + VerusItem::DeclareWith => { + // declare_with() is handled at let-stmt level + // in rust_to_vir_expr.rs. If we reach here, it's used outside a let-stmt. + return err_span(expr.span, "declare_with() must be used as a let initializer"); + } VerusItem::Vstd(_, _) | VerusItem::Marker(_) | VerusItem::BuiltinType(_) diff --git a/source/rust_verify/src/lib.rs b/source/rust_verify/src/lib.rs index 2a55714363..5ce7189ded 100644 --- a/source/rust_verify/src/lib.rs +++ b/source/rust_verify/src/lib.rs @@ -53,6 +53,7 @@ mod fn_call_to_vir; mod hir_hide_reveal_rewrite; mod import_export; pub mod profiler; +mod proof_with_lifetime; mod resolve_traits; pub mod reveal_hide; mod rust_intrinsics_to_vir; diff --git a/source/rust_verify/src/proof_with_lifetime.rs b/source/rust_verify/src/proof_with_lifetime.rs new file mode 100644 index 0000000000..2ec1297a2c --- /dev/null +++ b/source/rust_verify/src/proof_with_lifetime.rs @@ -0,0 +1,355 @@ +//! Lifetime checking for proof_with / declare_with_tracked / declare_with_ghost. +//! +//! These functions erase tracked/ghost parameters from fn signatures before the borrow +//! checker runs, so Rust's NLL cannot enforce lifetime constraints on those parameters. +//! This module performs a first-pass lifetime check during VIR lowering to catch +//! incompatible lifetimes early. + +use crate::context::BodyCtxt; +use crate::util::err_span; +use rustc_hir::def::Res; +use rustc_hir::{ExprKind, QPath}; +use rustc_span::Span; +use rustc_span::def_id::DefId; +use vir::ast::VirErr; + +/// Check that a proof_with argument's lifetime is compatible with the callee's expected lifetime. +/// +/// The expected type (from `lower_ty` on the callee's `declare_with_tracked` HIR type) has +/// `ReLateParam(callee, 'a)` regions. We need to: +/// 1. Map callee's late-bound `'a` to the caller's corresponding lifetime via the call args +/// 2. Get the proof_with arg's lifetime from its declaration +/// 3. Check that the arg's lifetime outlives the expected lifetime in the caller's scope +pub(crate) fn check_proof_with_lifetime<'tcx>( + bctx: &BodyCtxt<'tcx>, + tcx: rustc_middle::ty::TyCtxt<'tcx>, + callee_def_id: DefId, + call_hir_id: rustc_hir::HirId, + _node_substs: &'tcx rustc_middle::ty::List>, + arg_hir_id: rustc_hir::HirId, + _expected_ty_raw: rustc_middle::ty::Ty<'tcx>, + expected_ty_instantiated: rustc_middle::ty::Ty<'tcx>, + call_span: Span, + arg_index: usize, +) -> Result<(), VirErr> { + use rustc_middle::ty::{Region, RegionKind, TypeVisitable, TypeVisitor}; + use std::ops::ControlFlow; + + struct RegionCollector<'tcx> { + regions: Vec>, + } + impl<'tcx> TypeVisitor> for RegionCollector<'tcx> { + type Result = ControlFlow<()>; + fn visit_region(&mut self, r: Region<'tcx>) -> Self::Result { + self.regions.push(r); + ControlFlow::Continue(()) + } + } + + fn collect_regions<'tcx>(ty: rustc_middle::ty::Ty<'tcx>) -> Vec> { + let mut collector = RegionCollector { regions: vec![] }; + let _ = ty.visit_with(&mut collector); + collector.regions + } + + // Extract named regions from the expected type (could be ReLateParam or ReEarlyParam) + let expected_regions = collect_regions(expected_ty_instantiated); + + // Get DefId for a named region + fn region_def_id<'tcx>( + tcx: rustc_middle::ty::TyCtxt<'tcx>, + owner_def_id: DefId, + region: &rustc_middle::ty::Region<'tcx>, + ) -> Option { + match region.kind() { + RegionKind::ReLateParam(lp) => match lp.kind { + rustc_middle::ty::LateParamRegionKind::Named(def_id) => Some(def_id), + _ => None, + }, + RegionKind::ReEarlyParam(ep) => { + let generics = tcx.generics_of(owner_def_id); + let param = generics.param_at(ep.index as usize, tcx); + Some(param.def_id) + } + _ => None, + } + } + + let named_regions: Vec<_> = expected_regions + .iter() + .filter(|r| region_def_id(tcx, callee_def_id, r).is_some()) + .collect(); + + if named_regions.is_empty() { + return Ok(()); // No named lifetimes to check + } + + // Get the callee's poly fn_sig to find which params use each late-bound region + let callee_poly_sig = tcx.fn_sig(callee_def_id).instantiate_identity(); + let callee_sig_inputs = callee_poly_sig.skip_binder().inputs(); + + // Get the caller's liberated fn_sig to find real lifetime regions for caller params + let caller_def_id = bctx.fun_id; + let caller_poly_sig = tcx.fn_sig(caller_def_id).instantiate_identity(); + let caller_liberated = tcx.liberate_late_bound_regions(caller_def_id, caller_poly_sig); + let caller_inputs = caller_liberated.inputs(); + + // Build mapping: callee late-bound region DefId → caller ReLateParam region + // by matching callee fn_sig params with call args + let mut callee_to_caller_region: std::collections::HashMap< + rustc_span::def_id::DefId, + rustc_middle::ty::Region<'tcx>, + > = std::collections::HashMap::new(); + + // Extract call arg HirIds from the HIR expression + let call_arg_hir_ids: Vec = { + let hir_node = tcx.hir_node(call_hir_id); + match hir_node { + rustc_hir::Node::Expr(expr) => match &expr.kind { + ExprKind::Call(_, hir_args) => hir_args.iter().map(|a| a.hir_id).collect(), + rustc_hir::ExprKind::MethodCall(_, receiver, hir_args, _) => { + let mut ids = vec![receiver.hir_id]; + ids.extend(hir_args.iter().map(|a| a.hir_id)); + ids + } + _ => return Ok(()), // Can't extract args + }, + _ => return Ok(()), + } + }; + + // For each callee fn_sig param, find bound regions and match with call arg + for (param_idx, callee_param_ty) in callee_sig_inputs.iter().enumerate() { + // Find named regions (ReBound or ReEarlyParam) in this callee param + let named_param_regions: Vec<_> = collect_regions(*callee_param_ty) + .into_iter() + .filter_map(|r| match r.kind() { + RegionKind::ReBound(rustc_middle::ty::BoundVarIndexKind::Bound(debruijn), br) + if debruijn == rustc_middle::ty::INNERMOST => + { + if let rustc_middle::ty::BoundRegionKind::Named(def_id) = br.kind { + Some(def_id) + } else { + None + } + } + RegionKind::ReEarlyParam(ep) => { + let generics = tcx.generics_of(callee_def_id); + Some(generics.param_at(ep.index as usize, tcx).def_id) + } + _ => None, + }) + .collect(); + + if named_param_regions.is_empty() || param_idx >= call_arg_hir_ids.len() { + continue; + } + + // Try to resolve the call arg to a caller parameter to get its real lifetime + let arg_hir_id_for_param = call_arg_hir_ids[param_idx]; + if let Some(caller_param_idx) = + resolve_expr_to_param_index(tcx, caller_def_id, arg_hir_id_for_param) + { + if caller_param_idx < caller_inputs.len() { + let caller_param_ty = caller_inputs[caller_param_idx]; + let caller_regions = collect_regions(caller_param_ty); + let caller_named_regions: Vec<_> = caller_regions + .into_iter() + .filter(|r| { + matches!(r.kind(), RegionKind::ReLateParam(_) | RegionKind::ReEarlyParam(_)) + }) + .collect(); + + // Map each bound region to the corresponding caller region + // (simple 1:1 mapping by position within the type) + for (region_def_id, caller_region) in + named_param_regions.iter().zip(caller_named_regions.iter()) + { + callee_to_caller_region.insert(*region_def_id, *caller_region); + } + } + } + } + + if callee_to_caller_region.is_empty() { + // Couldn't determine the mapping — skip the check + return Ok(()); + } + + // Get the proof_with arg's type with real regions + let actual_arg_region = + get_arg_late_param_region(tcx, caller_def_id, caller_inputs, arg_hir_id); + + // For each named region in the expected type, check compatibility + for expected_region in &named_regions { + let callee_lt_def_id = match region_def_id(tcx, callee_def_id, expected_region) { + Some(did) => did, + None => continue, + }; + + let Some(expected_caller_region) = callee_to_caller_region.get(&callee_lt_def_id) else { + continue; // Can't check this region + }; + + let Some(actual_region) = actual_arg_region else { + continue; // Can't determine actual region + }; + + // Compare: does actual_region outlive expected_caller_region? + let actual_did = region_def_id(tcx, caller_def_id, &actual_region); + let expected_did = region_def_id(tcx, caller_def_id, expected_caller_region); + + if let (Some(actual_did), Some(expected_did)) = (actual_did, expected_did) { + if actual_did == expected_did { + continue; // Same lifetime — OK + } + + // Check if actual outlives expected in the caller's where-clause bounds + if check_region_outlives(tcx, caller_def_id, actual_region, *expected_caller_region) { + continue; // Outlives relationship declared — OK + } + + // Lifetime mismatch! + let actual_name = tcx.item_name(actual_did); + let expected_name = tcx.item_name(expected_did); + return err_span( + call_span, + format!( + "proof_with argument {} has incompatible lifetime: \ + expected lifetime `'{}`, got `'{}`\n\ + help: consider adding `'{}: '{}`", + arg_index + 1, + expected_name, + actual_name, + actual_name, + expected_name, + ), + ); + } + } + + Ok(()) +} + +/// Try to resolve an expression (by HirId) to a function parameter index. +/// Returns Some(index) if the expression is a simple reference to a function parameter. +fn resolve_expr_to_param_index<'tcx>( + tcx: rustc_middle::ty::TyCtxt<'tcx>, + fn_def_id: DefId, + arg_hir_id: rustc_hir::HirId, +) -> Option { + let hir_node = tcx.hir_node(arg_hir_id); + let expr = match hir_node { + rustc_hir::Node::Expr(e) => e, + _ => return None, + }; + if let ExprKind::Path(QPath::Resolved(None, path)) = &expr.kind { + if let Res::Local(local_hir_id) = path.res { + let local_def_id = fn_def_id.as_local()?; + let body = tcx.hir_body_owned_by(local_def_id); + for (i, param) in body.params.iter().enumerate() { + if param.pat.hir_id == local_hir_id { + return Some(i); + } + } + } + } + None +} + +/// Get the first named region (ReLateParam or ReEarlyParam) from a proof_with arg's type. +/// The arg is typically a function parameter of the caller. +fn get_arg_late_param_region<'tcx>( + tcx: rustc_middle::ty::TyCtxt<'tcx>, + caller_def_id: DefId, + caller_inputs: &[rustc_middle::ty::Ty<'tcx>], + arg_hir_id: rustc_hir::HirId, +) -> Option> { + use rustc_middle::ty::RegionKind; + + if let Some(param_idx) = resolve_expr_to_param_index(tcx, caller_def_id, arg_hir_id) { + if param_idx < caller_inputs.len() { + let param_ty = caller_inputs[param_idx]; + let regions: Vec<_> = { + use rustc_middle::ty::{Region, TypeVisitable, TypeVisitor}; + use std::ops::ControlFlow; + struct Collector<'tcx> { + regions: Vec>, + } + impl<'tcx> TypeVisitor> for Collector<'tcx> { + type Result = ControlFlow<()>; + fn visit_region(&mut self, r: Region<'tcx>) -> Self::Result { + if matches!( + r.kind(), + RegionKind::ReLateParam(_) | RegionKind::ReEarlyParam(_) + ) { + self.regions.push(r); + } + ControlFlow::Continue(()) + } + } + let mut c = Collector { regions: vec![] }; + let _ = param_ty.visit_with(&mut c); + c.regions + }; + return regions.into_iter().next(); + } + } + + None +} + +/// Check if region `a` outlives region `b` based on the caller's where-clause bounds. +/// Compares by DefId since regions may be in different representations (ReLateParam vs ReEarlyParam). +fn check_region_outlives<'tcx>( + tcx: rustc_middle::ty::TyCtxt<'tcx>, + caller_def_id: DefId, + actual: rustc_middle::ty::Region<'tcx>, + expected: rustc_middle::ty::Region<'tcx>, +) -> bool { + use rustc_middle::ty::{ClauseKind, RegionKind}; + + fn local_region_def_id<'tcx>( + tcx: rustc_middle::ty::TyCtxt<'tcx>, + owner_def_id: DefId, + r: rustc_middle::ty::Region<'tcx>, + ) -> Option { + match r.kind() { + RegionKind::ReLateParam(lp) => match lp.kind { + rustc_middle::ty::LateParamRegionKind::Named(def_id) => Some(def_id), + _ => None, + }, + RegionKind::ReEarlyParam(ep) => { + let generics = tcx.generics_of(owner_def_id); + Some(generics.param_at(ep.index as usize, tcx).def_id) + } + _ => None, + } + } + + let actual_did = local_region_def_id(tcx, caller_def_id, actual); + let expected_did = local_region_def_id(tcx, caller_def_id, expected); + + if actual_did.is_none() || expected_did.is_none() { + return false; + } + + // Check the caller's predicates for 'actual: 'expected (by DefId) + let predicates = tcx.predicates_of(caller_def_id); + for (pred, _) in predicates.predicates { + if let Some(ClauseKind::RegionOutlives(outlives)) = pred.kind().no_bound_vars() { + let pred_longer = local_region_def_id(tcx, caller_def_id, outlives.0); + let pred_shorter = local_region_def_id(tcx, caller_def_id, outlives.1); + if pred_longer == actual_did && pred_shorter == expected_did { + return true; + } + } + } + + // 'static outlives anything + if actual.is_static() { + return true; + } + + false +} diff --git a/source/rust_verify/src/rust_to_vir_expr.rs b/source/rust_verify/src/rust_to_vir_expr.rs index b20cc1710a..57f498d27e 100644 --- a/source/rust_verify/src/rust_to_vir_expr.rs +++ b/source/rust_verify/src/rust_to_vir_expr.rs @@ -3572,6 +3572,10 @@ pub(crate) fn stmt_to_vir<'tcx>( return Ok(vec![]); } } + // Skip declare_with() stmts (handled as extra params) + if bctx.declare_with_hir_ids.contains(&stmt.hir_id) { + return Ok(vec![]); + } let_stmt_to_vir(bctx, pat, init, els, bctx.ctxt.tcx.hir_attrs(stmt.hir_id)) } diff --git a/source/rust_verify/src/rust_to_vir_func.rs b/source/rust_verify/src/rust_to_vir_func.rs index 41ebedd5f0..dc6d63a282 100644 --- a/source/rust_verify/src/rust_to_vir_func.rs +++ b/source/rust_verify/src/rust_to_vir_func.rs @@ -291,6 +291,7 @@ fn mk_bctx<'tcx>( migrate_postcondition_vars: Option>, param_names: Vec, external_opaque_type_map: Option>, + declare_with_hir_ids: HashSet, ) -> BodyCtxt<'tcx> { BodyCtxt { ctxt: ctxt.clone(), @@ -310,9 +311,136 @@ fn mk_bctx<'tcx>( header_setting: HeaderSetting::Fn, unwrap_param_map: std::rc::Rc::new(std::cell::RefCell::new(HashMap::new())), external_opaque_type_map, + pending_tracked_args: std::rc::Rc::new(std::cell::RefCell::new(Vec::new())), + in_args_depth: std::rc::Rc::new(std::cell::RefCell::new(0)), + declare_with_hir_ids: std::rc::Rc::new(declare_with_hir_ids), } } +/// Pre-scan the HIR body for `declare_with_tracked()`/`declare_with_ghost()` let-stmts. +/// Returns (extra_vir_params, hir_ids_to_skip) so that: +/// 1. Extra params can be appended to `vir_params` before body conversion +/// 2. `stmt_to_vir` can skip these let-stmts during body conversion +fn pre_scan_declare_with_params<'tcx>( + ctxt: &Context<'tcx>, + id: DefId, + body: &Body<'tcx>, + body_id: &BodyId, +) -> Result< + (Vec<(vir::ast::Param, Option, rustc_middle::ty::Ty<'tcx>)>, HashSet), + VirErr, +> { + let mut extra_params = Vec::new(); + let mut hir_ids = HashSet::new(); + let types = body_id_to_types(ctxt.tcx, body_id); + + // Navigate into the body's top-level block + let stmts = match &body.value.kind { + ExprKind::Block(block, _) => block.stmts, + _ => return Ok((extra_params, hir_ids)), + }; + + for stmt in stmts { + if let rustc_hir::StmtKind::Let(rustc_hir::LetStmt { + pat, + init: Some(init), + ty: hir_ty, + .. + }) = &stmt.kind + { + // Check if init is a call to declare_with_tracked() or declare_with_ghost() + let verus_item = match &init.kind { + ExprKind::Call(fun, _) => match &fun.kind { + ExprKind::Path(rustc_hir::QPath::Resolved( + None, + rustc_hir::Path { res: rustc_hir::def::Res::Def(_, fun_id), .. }, + )) => ctxt.get_verus_item(*fun_id).cloned(), + _ => None, + }, + _ => None, + }; + + let is_declare_with = match verus_item { + Some(VerusItem::DeclareWith) => true, + _ => false, + }; + if !is_declare_with { + continue; + } + + // Require simple binding pattern + let (is_mut_var, name) = pat_to_mut_var(pat)?; + + // Get the resolved type. Use lower_ty on the HIR type annotation if available, + // because it preserves early-bound regions (ReEarlyParam) which are needed for + // lifetime checking. typeck's node_type() returns types with erased regions. + let ty = if let Some(hir_ty) = hir_ty { + rustc_hir_analysis::lower_ty(ctxt.tcx, hir_ty) + } else { + types.node_type(init.hir_id) + }; + + // Derive is_tracked from the type (Tracked vs Ghost) via ADT DefId + let is_tracked = match ty.kind() { + rustc_middle::ty::TyKind::Adt(adt_def, _) => { + match ctxt.get_verus_item(adt_def.did()) { + Some(VerusItem::BuiltinType( + crate::verus_items::BuiltinTypeItem::Tracked, + )) => true, + Some(VerusItem::BuiltinType( + crate::verus_items::BuiltinTypeItem::Ghost, + )) => false, + _ => { + return err_span( + init.span, + "declare_with() must be assigned to a Tracked or Ghost type", + ); + } + } + } + _ => { + return err_span( + init.span, + "declare_with() must be assigned to a Tracked or Ghost type", + ); + } + }; + let inner_mode = if is_tracked { Mode::Proof } else { Mode::Spec }; + + // Handle &mut types - check if the type contains a mutable reference + let is_ref_mut = is_mut_ty(ctxt, ty); + let is_mut = is_ref_mut.is_some(); + // Use mid_ty_to_vir on the FULL type (matching normal param processing in + // check_fn_decl). Don't manually unwrap the MutRef — it must be preserved + // in the typ so that VIR can properly generate mut_ref_current% / update% AIR. + let typ = ctxt.mid_ty_to_vir(id, pat.span, &ty, None)?; + + // All declare_with params are inputs — unwrap Ghost/Tracked as usual + let outer_name = vir::ast_util::air_unique_var(&format!( + "declare_with_{}", + vir::def::user_local_name(&name) + )); + let unwrapped_info = Some((inner_mode, outer_name)); + let param_mode = Mode::Exec; + let vir_param = ctxt.spanned_new( + pat.span, + ParamX { + name: name.clone(), + typ, + mode: param_mode, + unwrapped_info, + user_mut: is_mut_var || is_mut, + }, + ); + + extra_params.push((vir_param, None, ty)); + hir_ids.insert(stmt.hir_id); + } + } + + Ok((extra_params, hir_ids)) +} + fn body_to_vir<'tcx>( ctxt: &Context<'tcx>, fun_id: DefId, @@ -325,6 +453,7 @@ fn body_to_vir<'tcx>( param_names: Vec, external_opaque_type_map: Option>, is_async: bool, + declare_with_hir_ids: HashSet, ) -> Result { let bctx = mk_bctx( ctxt, @@ -336,6 +465,7 @@ fn body_to_vir<'tcx>( migrate_postcondition_vars, param_names, external_opaque_type_map, + declare_with_hir_ids, ); let body_expr = if is_async { extract_desugared_async_body(&bctx.ctxt, body)? } else { &body.value }; @@ -1088,7 +1218,9 @@ fn equalize_substs<'tcx>( let mut l1 = vec![]; let mut l2 = vec![]; - if substs1_early.len() + num_late1 != substs2_early.len() + num_late2 { + // Allow proxy (sig1) to have more total generics than external (sig2), + // since extra ghost/tracked params may introduce additional late-bound lifetimes. + if substs1_early.len() + num_late1 < substs2_early.len() + num_late2 { return None; } @@ -1724,6 +1856,55 @@ pub(crate) fn check_item_fn<'tcx>( rustc_hir::IsAsync::Async(..) => true, }; let body = find_body(ctxt, body_id); + + // Pre-scan for declare_with_tracked()/declare_with_ghost() calls + let (declare_with_extra_params, declare_with_hir_ids) = + pre_scan_declare_with_params(ctxt, id, body, body_id)?; + let declare_with_modes: Vec<(bool, rustc_middle::ty::Ty<'tcx>)> = + declare_with_extra_params + .iter() + .map(|(p, _, ty)| { + // unwrapped_info mode: Proof = Tracked, Spec = Ghost + (matches!(p.x.unwrapped_info, Some((Mode::Proof, _))), *ty) + }) + .collect(); + for (p, _mode, _) in declare_with_extra_params { + vir_params.push(p); + } + + // Register declare_with extra param modes early so callers can find them + // even if body_to_vir fails for this function. + if !declare_with_modes.is_empty() { + let target_id = proxy_id.unwrap_or(id); + ctxt.external_fn_extra_tracked_params + .borrow_mut() + .insert(target_id, declare_with_modes.clone()); + + // For unerased_proxy functions (const fn proxies), also register under + // the original function's DefId, since call sites resolve to the original. + if vattrs.unerased_proxy { + let proxy_name = ctxt.tcx.item_name(id).as_str().to_string(); + let prefix = "VERUS_UNERASED_PROXY__"; + if let Some(original_name) = proxy_name.strip_prefix(prefix) { + let parent = ctxt.tcx.parent(id); + // Find sibling with matching name by iterating HIR items in the parent + for item_id in ctxt.tcx.hir_free_items() { + let child_def_id = item_id.owner_id.to_def_id(); + if let Some(name) = ctxt.tcx.opt_item_name(child_def_id) { + if name.as_str() == original_name + && child_def_id != id + && ctxt.tcx.parent(child_def_id) == parent + { + ctxt.external_fn_extra_tracked_params + .borrow_mut() + .insert(child_def_id, declare_with_modes.clone()); + } + } + } + } + } + } + let external_body = vattrs.external_body || vattrs.external_fn_specification; let param_names = vir_params.iter().map(|p| p.x.name.clone()).collect::>(); let mut vir_body = body_to_vir( @@ -1738,6 +1919,7 @@ pub(crate) fn check_item_fn<'tcx>( param_names, assume_specification_opaque_type_map.clone(), is_async, + declare_with_hir_ids, )?; let header = vir::headers::read_header(&mut vir_body, &vir::headers::HeaderAllows::All)?; @@ -2821,6 +3003,7 @@ pub(crate) fn check_item_const_or_static<'tcx>( vec![], None, false, + HashSet::new(), )?; let header = vir::headers::read_header( &mut vir_body, diff --git a/source/rust_verify/src/verus_items.rs b/source/rust_verify/src/verus_items.rs index c2e666216a..138ae50219 100644 --- a/source/rust_verify/src/verus_items.rs +++ b/source/rust_verify/src/verus_items.rs @@ -448,6 +448,8 @@ pub(crate) enum VerusItem { GetFirst, DummyCapture(DummyCaptureItem), MutRefTracked, + ProofWith, + DeclareWith, } #[derive(PartialEq, Eq, Debug, Clone, Hash)] @@ -460,6 +462,8 @@ pub(crate) enum DummyCaptureItem { #[rustfmt::skip] fn verus_items_map() -> Vec<(&'static str, VerusItem)> { vec![ + ("verus::verus_builtin::proof_with", VerusItem::ProofWith), + ("verus::verus_builtin::declare_with", VerusItem::DeclareWith), ("verus::verus_builtin::admit", VerusItem::Spec(SpecItem::Admit)), ("verus::verus_builtin::assume_", VerusItem::Spec(SpecItem::Assume)), ("verus::verus_builtin::no_method_body", VerusItem::Spec(SpecItem::NoMethodBody)), diff --git a/source/rust_verify_test/tests/proof_with.rs b/source/rust_verify_test/tests/proof_with.rs new file mode 100644 index 0000000000..0cca0a62e6 --- /dev/null +++ b/source/rust_verify_test/tests/proof_with.rs @@ -0,0 +1,260 @@ +#![feature(rustc_private)] +#[macro_use] +mod common; +use common::*; + +test_verify_one_file! { + #[test] test_proof_with verus_code!{ + use vstd::prelude::*; + fn test(a: u64) + { + let b: Tracked = declare_with(); + let c: Ghost = declare_with(); + requires(a == 0 && b@ == 1 && c@ == 2); + } + + fn call_test() { + proof_with(Tracked(1u64)); + proof_with(Ghost(2u32)); + test(0); + } + + #[verifier(external)] + fn unverified_call_test() { + test(0); + } + } => Ok(()) +} + +test_verify_one_file! { + #[test] test_proof_with_impl verus_code!{ + use vstd::prelude::*; + struct A { + a: u64 + } + impl A { + fn test(&self) + { + let b: Tracked = declare_with(); + let c: Ghost = declare_with(); + requires(self.a == 0 && b@ == 1 && c@ == 2); + } + } + fn call_test() { + let a = A { a: 0 }; + proof_with(Tracked(1u64)); + proof_with(Ghost(2u32)); + a.test(); + } + + #[verifier(external)] + fn unverified_call_test() { + let a = A { a: 0 }; + a.test(); + } + } => Ok(()) +} + +test_verify_one_file! { + #[test] test_proof_with_trait verus_code!{ + use vstd::prelude::*; + struct A { + a: u64 + } + + trait AOp { + fn test(&self) { + let b: Tracked = declare_with(); + let c: Ghost = declare_with(); + requires(b@ == 1 && c@ == 2); + } + } + impl AOp for A { + fn test(&self) + { + let b: Tracked = declare_with(); + let c: Ghost = declare_with(); + assert(b@ == 1); + assert(c@ == 0); // FAILS + } + } + fn call_test() { + let a = A { a: 0 }; + proof_with(Tracked(1u64)); + proof_with(Ghost(2u32)); + a.test(); + } + + #[verifier(external)] + fn unverified_call_test() { + let a = A { a: 0 }; + a.test(); + } + } => Err(e) => assert_one_fails(e) +} + +test_verify_one_file! { + #[test] test_proof_with_external verus_code!{ + use vstd::prelude::*; + #[verifier(external)] + fn negate_bool(b: bool, x: u8) -> bool { + !b + } + + #[verifier(external_fn_specification)] + fn negate_bool_requires_ensures(b: bool, x: u8) -> bool + { + let extra: Tracked = declare_with(); + requires(x == extra@); + ensures(|ret_b: bool| ret_b == !b); + negate_bool(b, x) + } + + fn call_test() { + proof_with(Tracked(1u8)); + negate_bool(true, 1); + } + + #[verifier(external)] + fn unverified_call_test() { + negate_bool(true, 1); + } + } => Ok(()) +} + +test_verify_one_file! { + #[test] test_proof_with_external_failed verus_code!{ + use vstd::prelude::*; + #[verifier(external)] + fn negate_bool(b: bool, x: u8) -> bool { + !b + } + + #[verifier(external_fn_specification)] + fn negate_bool_requires_ensures(b: bool, x: u8) -> bool + { + let extra: Tracked = declare_with(); + requires(x == extra@); + ensures(|ret_b: bool| ret_b == !b); + negate_bool(b, x) + } + + fn call_test() { + negate_bool(true, 1); + } + } => Err(e) => assert_vir_error_msg(e, "this external function requires 1 extra tracked/ghost argument(s) via proof_with()") +} + +test_verify_one_file! { + #[test] test_proof_with_failed_requires verus_code!{ + use vstd::prelude::*; + fn test(a: u64) + { + let b: Tracked = declare_with(); + let c: Ghost = declare_with(); + requires(a == 0 && b@ == 1 && c@ == 2); + } + + fn call_test() { + proof_with(Tracked(0u64)); + proof_with(Ghost(2u32)); + test(0); // FAILS + } + } => Err(e) => assert_one_fails(e) +} + +test_verify_one_file! { + #[test] test_proof_with_invalid_type verus_code!{ + use vstd::prelude::*; + fn test(a: u64) + { + let b: Tracked = declare_with(); + requires(a == 0 && b@ == 1); + } + + fn call_test() { + proof_with(0u64); + test(0); + } + } => Err(e) => assert_vir_error_msg(e, "proof_with expects an argument of type Tracked or Ghost") +} + +test_verify_one_file! { + #[test] test_proof_with_wrong_mode_type verus_code!{ + use vstd::prelude::*; + fn test(a: u64) + { + let b: Tracked = declare_with(); + requires(a == 0 && b@ == 1); + } + + fn call_test() { + proof_with(Ghost(0u64)); + test(0); + } + } => Err(e) => assert_vir_error_msg(e, "proof_with argument 1 has wrong mode: expected Tracked, got Ghost") +} + +// ---- Lifetime soundness tests ---- +// These tests verify that lifetime constraints on tracked/ghost params are properly checked. + +// BUG: This should fail with a lifetime error but currently passes. +// The tracked param Tracked<&'b u64> is passed where Tracked<&'a u64> is expected, +// but 'b may not outlive 'a. This is unsound. +// Once the fix is in place, change this to: +// } => Err(err) => assert_rust_error_msg(err, "lifetime may not live long enough") +test_verify_one_file! { + #[test] test_proof_with_lifetime_mismatch verus_code!{ + use vstd::prelude::*; + // test expects a Tracked<&'a u64> where 'a is tied to param `a` + fn test<'a>(a: &'a u64, b: u64) -> u64 + { + let c: Tracked<&'a u64> = declare_with(); + 1 + } + + // test2 has independent lifetimes 'a and 'b + // Passing Tracked<&'b u64> where Tracked<&'a u64> is expected should fail + fn test2<'a, 'b>(a: &'a u64, b: u64, c: Tracked<&'b u64>) -> u64 + { + proof_with(c); + test(a, b) + } + } => Err(err) => assert_vir_error_msg(err, "proof_with argument 1 has incompatible lifetime") +} + +test_verify_one_file! { + #[test] test_proof_with_lifetime_compatible verus_code!{ + use vstd::prelude::*; + // Same as above, but 'b: 'a so the lifetime is compatible + fn test<'a>(a: &'a u64, b: u64) -> u64 + { + let c: Tracked<&'a u64> = declare_with(); + 1 + } + + fn test2<'a, 'b: 'a>(a: &'a u64, b: u64, c: Tracked<&'b u64>) -> u64 + { + proof_with(c); + test(a, b) + } + } => Ok(()) +} + +// Same as test_proof_with_lifetime_mismatch but for Ghost. +test_verify_one_file! { + #[test] test_declare_with_ghost_lifetime_mismatch verus_code!{ + use vstd::prelude::*; + fn test<'a>(a: &'a u64) -> u64 + { + let g: Ghost<&'a u64> = declare_with(); + 1 + } + + fn test2<'a, 'b>(a: &'a u64, c: Ghost<&'b u64>) -> u64 + { + proof_with(c); + test(a) + } + } => Err(err) => assert_vir_error_msg(err, "proof_with argument 1 has incompatible lifetime") +} From b6d0d4c55a7eade3a53c025ba462d99cd87051da Mon Sep 17 00:00:00 2001 From: Ziqiao Zhou Date: Tue, 12 May 2026 22:04:03 +0000 Subject: [PATCH 2/4] Add proof_with tests for generic types --- source/rust_verify_test/tests/proof_with.rs | 46 +++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/source/rust_verify_test/tests/proof_with.rs b/source/rust_verify_test/tests/proof_with.rs index 0cca0a62e6..96b342beb2 100644 --- a/source/rust_verify_test/tests/proof_with.rs +++ b/source/rust_verify_test/tests/proof_with.rs @@ -258,3 +258,49 @@ test_verify_one_file! { } } => Err(err) => assert_vir_error_msg(err, "proof_with argument 1 has incompatible lifetime") } + +test_verify_one_file! { + #[test] test_proof_with_generic_type verus_code!{ + use vstd::prelude::*; + fn test(a: T) + { + let b: Tracked = declare_with(); + let c: Ghost = declare_with(); + requires(a === b@ && c@ == 2); + } + + fn call_test() { + proof_with(Tracked(0u64)); + proof_with(Ghost(2u32)); + test(0u64); + } + + #[verifier(external)] + fn unverified_call_test() { + test(0u64); + } + } => Ok(()) +} + +test_verify_one_file! { + #[test] test_proof_with_generic_type_wrong_type verus_code!{ + use vstd::prelude::*; + fn test(a: T) + { + let b: Tracked = declare_with(); + let c: Ghost = declare_with(); + requires(a === b@ && c@ == 2); + } + + fn call_test() { + proof_with(Tracked(0u8)); + proof_with(Ghost(2u32)); + test(0u64); + } + + #[verifier(external)] + fn unverified_call_test() { + test(0u64); + } + } => Err(e) => assert_vir_error_msg(e, "proof_with argument 1 has wrong type") +} From 792e1b00dd22a451edb8733986691652fa94d423 Mon Sep 17 00:00:00 2001 From: Ziqiao Zhou Date: Tue, 12 May 2026 23:15:12 +0000 Subject: [PATCH 3/4] Change proof_with(a: A> to proof_with(a: A, b: B) where a is tuple of args, and b should be call to a function that takes extra args --- source/builtin/src/lib.rs | 7 +- source/rust_verify/src/context.rs | 9 +- source/rust_verify/src/fn_call_to_vir.rs | 215 ++++++++++-------- source/rust_verify/src/proof_with_lifetime.rs | 4 +- source/rust_verify/src/rust_to_vir_func.rs | 9 +- source/rust_verify_test/tests/proof_with.rs | 71 +++--- 6 files changed, 163 insertions(+), 152 deletions(-) diff --git a/source/builtin/src/lib.rs b/source/builtin/src/lib.rs index 0a3ee74282..9154917595 100644 --- a/source/builtin/src/lib.rs +++ b/source/builtin/src/lib.rs @@ -25,11 +25,14 @@ pub fn admit() { /// Pass tracked or ghost values to the immediately following external function call. /// Used with `external_fn_specification` functions that have extra tracked/ghost parameters. -/// The argument must be of type `Tracked` or `Ghost`. +/// Pass ghost/tracked arguments to a function call with proper borrow checking. +/// `_a` is a tuple of `Tracked` or `Ghost` values. +/// `_b` is the function call expression whose result is returned. +/// Usage: `proof_with((Tracked(&mut x), Ghost(y)), f(a))` #[cfg(verus_keep_ghost)] #[rustc_diagnostic_item = "verus::verus_builtin::proof_with"] #[verifier::proof] -pub fn proof_with(_a: A) { +pub fn proof_with(_a: A, _b: B) -> B { unimplemented!(); } diff --git a/source/rust_verify/src/context.rs b/source/rust_verify/src/context.rs index a8446a0abc..2dc8096d71 100644 --- a/source/rust_verify/src/context.rs +++ b/source/rust_verify/src/context.rs @@ -107,11 +107,10 @@ pub(crate) struct BodyCtxt<'tcx> { /// Assume specification defines a new opaque type for each opaque type in the external function. /// We use this map to resolve them later. pub(crate) external_opaque_type_map: Option>, - /// Pending tracked/ghost args from proof_with() calls, to be appended to the next function call. - pub(crate) pending_tracked_args: Rc>>, - /// Depth counter: >0 means we're inside argument processing, so don't consume pending_tracked_args. - pub(crate) in_args_depth: Rc>, - /// HirIds of declare_with_tracked()/declare_with_ghost() let-stmts to skip during body conversion + /// Pending tracked/ghost args from proof_with() calls, to be consumed by the next function call. + /// Set to Some(...) by ProofWith handler, taken (consumed) by fn_call_to_vir. + pub(crate) pending_tracked_args: Rc>>>, + /// HirIds of declare_with() let-stmts to skip during body conversion pub(crate) declare_with_hir_ids: Rc>, } diff --git a/source/rust_verify/src/fn_call_to_vir.rs b/source/rust_verify/src/fn_call_to_vir.rs index 90d325b5bb..a72799010f 100644 --- a/source/rust_verify/src/fn_call_to_vir.rs +++ b/source/rust_verify/src/fn_call_to_vir.rs @@ -305,28 +305,22 @@ fn fn_call_or_assoc_const_to_vir<'tcx>( record_call(bctx, expr, ResolvedCall::Call(name.clone(), record_name, bctx.in_ghost)); let vir_args = if let Some(args) = args { - *bctx.in_args_depth.borrow_mut() += 1; - let result = mk_vir_args(bctx, &args); - *bctx.in_args_depth.borrow_mut() -= 1; - result? + mk_vir_args(bctx, &args)? } else { vec![] }; - // Append any pending tracked args from proof_with() calls. - // If the function has extra tracked params but no proof_with was used, error. - // Skip consumption if we're inside argument processing (nested call) AND the callee - // doesn't expect extra params. If the callee does expect extra params (e.g., f() in - // f()? where ? desugars to Try::branch(f())), we must consume them here. + // Consume pending tracked args from proof_with() if present. + // Only consume when callee expects extra params (has declare_with entries). let vir_args = { let mut args = vir_args; - let in_args = *bctx.in_args_depth.borrow() > 0; let extra_params = bctx.ctxt.external_fn_extra_tracked_params.borrow().get(&f).cloned(); - if !in_args || extra_params.is_some() { - let mut pending = bctx.pending_tracked_args.borrow_mut(); - if let Some(ref expected_params) = extra_params { - let extra_count = expected_params.len(); - if pending.is_empty() { + if let Some(ref expected_params) = extra_params { + let extra_count = expected_params.len(); + let mut pending_opt = bctx.pending_tracked_args.borrow_mut(); + let pending = match pending_opt.take() { + Some(p) => p, + None => { return err_span( expr.span, format!( @@ -335,81 +329,75 @@ fn fn_call_or_assoc_const_to_vir<'tcx>( ), ); } - if pending.len() != extra_count { + }; + if pending.len() != extra_count { + return err_span( + expr.span, + format!( + "expected {} tracked/ghost argument(s) via proof_with(), got {}", + extra_count, + pending.len() + ), + ); + } + // Check mode and type for each pending arg + for (i, (pending_arg, (expected_is_tracked, expected_ty))) in + pending.iter().zip(expected_params.iter()).enumerate() + { + // Mode check + if pending_arg.is_tracked != *expected_is_tracked { + let expected_mode = if *expected_is_tracked { "Tracked" } else { "Ghost" }; + let actual_mode = if pending_arg.is_tracked { "Tracked" } else { "Ghost" }; return err_span( expr.span, format!( - "expected {} tracked/ghost argument(s) via proof_with(), got {}", - extra_count, - pending.len() + "proof_with argument {} has wrong mode: expected {}, got {}", + i + 1, + expected_mode, + actual_mode, ), ); } - // Check mode and type for each pending arg - for (i, (pending_arg, (expected_is_tracked, expected_ty))) in - pending.iter().zip(expected_params.iter()).enumerate() + // Type check: compare rustc types with regions erased. { - // Mode check - if pending_arg.is_tracked != *expected_is_tracked { - let expected_mode = if *expected_is_tracked { "Tracked" } else { "Ghost" }; - let actual_mode = if pending_arg.is_tracked { "Tracked" } else { "Ghost" }; + let tcx = bctx.ctxt.tcx; + let expected_ty_instantiated = + rustc_middle::ty::EarlyBinder::bind(*expected_ty) + .instantiate(tcx, node_substs); + let actual_ty = bctx.types.node_type(pending_arg.arg_hir_id); + use rustc_middle::ty::TypeFoldable; + let expected_erased = expected_ty_instantiated.fold_with( + &mut rustc_middle::ty::RegionFolder::new(tcx, &mut |_, _| { + tcx.lifetimes.re_erased + }), + ); + if actual_ty != expected_erased { return err_span( expr.span, format!( - "proof_with argument {} has wrong mode: expected {}, got {}", + "proof_with argument {} has wrong type: expected `{}`, got `{}`", i + 1, - expected_mode, - actual_mode, + expected_ty_instantiated, + actual_ty, ), ); } - // Type check: compare rustc types with regions erased. - { - let tcx = bctx.ctxt.tcx; - let expected_ty_instantiated = - rustc_middle::ty::EarlyBinder::bind(*expected_ty) - .instantiate(tcx, node_substs); - let actual_ty = bctx.types.node_type(pending_arg.arg_hir_id); - use rustc_middle::ty::TypeFoldable; - let expected_erased = expected_ty_instantiated.fold_with( - &mut rustc_middle::ty::RegionFolder::new(tcx, &mut |_, _| { - tcx.lifetimes.re_erased - }), - ); - if actual_ty != expected_erased { - return err_span( - expr.span, - format!( - "proof_with argument {} has wrong type: expected `{}`, got `{}`", - i + 1, - expected_ty_instantiated, - actual_ty, - ), - ); - } - check_proof_with_lifetime( - bctx, - tcx, - f, - expr.hir_id, - node_substs, - pending_arg.arg_hir_id, - *expected_ty, - expected_ty_instantiated, - expr.span, - i, - )?; - } + check_proof_with_lifetime( + bctx, + tcx, + f, + expr.hir_id, + node_substs, + pending_arg.arg_hir_id, + *expected_ty, + expected_ty_instantiated, + expr.span, + i, + )?; } - let exprs: Vec<_> = pending.drain(..).map(|a| a.expr).collect(); - args.extend(exprs); - } else if !pending.is_empty() { - pending.drain(..); - return err_span( - expr.span, - "proof_with was used but this function does not expect extra tracked/ghost arguments", - ); } + let exprs: Vec<_> = pending.into_iter().map(|a| a.expr).collect(); + args.extend(exprs); } args }; @@ -2220,32 +2208,63 @@ fn verus_item_to_vir<'tcx, 'a>( mk_expr(ExprX::BorrowMutTracked(p)) } VerusItem::ProofWith => { - // proof_with(expr) stores the tracked/ghost arg to be appended to the next call - unsupported_err_unless!(args_len == 1, expr.span, "expected proof_with(expr)", &args); - let arg_typ = typ_of_expr_adjusted(bctx, args[0].span, &args[0].hir_id)?; - let is_tracked = match &*arg_typ { - TypX::Decorate(TypDecoration::Tracked, _, _) => true, - TypX::Decorate(TypDecoration::Ghost, _, _) => false, - _ => { - return err_span( - args[0].span, - "proof_with expects an argument of type Tracked or Ghost", - ); - } + // proof_with(ghost_args, call) — ghost args and call in one expression. + // First arg: single Tracked/Ghost or tuple of them. + // Second arg: the actual function call, whose result is returned. + unsupported_err_unless!( + args_len == 2, + expr.span, + "expected proof_with(ghost_args, call)", + &args + ); + + // Extract individual ghost/tracked items from the first arg. + let ghost_arg = &args[0]; + let ghost_items: Vec<&rustc_hir::Expr> = match &ghost_arg.kind { + rustc_hir::ExprKind::Tup(elems) => elems.iter().collect(), + _ => vec![ghost_arg], }; + let bctx_ghost = &BodyCtxt { in_ghost: true, ..bctx.clone() }; - // Increment in_args_depth so nested function calls within the proof_with - // argument don't try to consume pending_tracked_args - *bctx.in_args_depth.borrow_mut() += 1; - let arg_expr = expr_to_vir_consume(bctx_ghost, &args[0])?; - *bctx.in_args_depth.borrow_mut() -= 1; - bctx.pending_tracked_args.borrow_mut().push(crate::context::PendingTrackedArg { - expr: arg_expr, - is_tracked, - arg_hir_id: args[0].hir_id, - }); - // Return a unit expression (no-op) - mk_expr(ExprX::Block(Arc::new(vec![]), None)) + let mut pending_args = Vec::new(); + for item in &ghost_items { + let arg_typ = typ_of_expr_adjusted(bctx, item.span, &item.hir_id)?; + let is_tracked = match &*arg_typ { + TypX::Decorate(TypDecoration::Tracked, _, _) => true, + TypX::Decorate(TypDecoration::Ghost, _, _) => false, + _ => { + return err_span( + item.span, + "proof_with expects arguments of type Tracked or Ghost", + ); + } + }; + let arg_expr = expr_to_vir_consume(bctx_ghost, item)?; + pending_args.push(crate::context::PendingTrackedArg { + expr: arg_expr, + is_tracked, + arg_hir_id: item.hir_id, + }); + } + + // Set pending args for consumption by the inner fn_call_to_vir + *bctx.pending_tracked_args.borrow_mut() = Some(pending_args); + + // Record erasure: replace proof_with(A, B) with just B (arg index 1) + record_call(bctx, expr, ResolvedCall::SpecAllowProofArgs); + + // Process second arg (the function call) — this will consume pending args + let call_expr = expr_to_vir_consume(bctx, &args[1])?; + + // Assert pending args were consumed + if bctx.pending_tracked_args.borrow().is_some() { + return err_span( + expr.span, + "proof_with second argument must be a function call that accepts extra tracked/ghost arguments", + ); + } + + Ok(call_expr) } VerusItem::BuiltinDeref(d) => { // This would be easy to support (similar to handling borrow_mut etc.) but their usage diff --git a/source/rust_verify/src/proof_with_lifetime.rs b/source/rust_verify/src/proof_with_lifetime.rs index 2ec1297a2c..887e69c6eb 100644 --- a/source/rust_verify/src/proof_with_lifetime.rs +++ b/source/rust_verify/src/proof_with_lifetime.rs @@ -1,4 +1,4 @@ -//! Lifetime checking for proof_with / declare_with_tracked / declare_with_ghost. +//! Lifetime checking for proof_with / declare_with. //! //! These functions erase tracked/ghost parameters from fn signatures before the borrow //! checker runs, so Rust's NLL cannot enforce lifetime constraints on those parameters. @@ -15,7 +15,7 @@ use vir::ast::VirErr; /// Check that a proof_with argument's lifetime is compatible with the callee's expected lifetime. /// -/// The expected type (from `lower_ty` on the callee's `declare_with_tracked` HIR type) has +/// The expected type (from `lower_ty` on the callee's `declare_with` HIR type) has /// `ReLateParam(callee, 'a)` regions. We need to: /// 1. Map callee's late-bound `'a` to the caller's corresponding lifetime via the call args /// 2. Get the proof_with arg's lifetime from its declaration diff --git a/source/rust_verify/src/rust_to_vir_func.rs b/source/rust_verify/src/rust_to_vir_func.rs index dc6d63a282..d4ff3f2dbb 100644 --- a/source/rust_verify/src/rust_to_vir_func.rs +++ b/source/rust_verify/src/rust_to_vir_func.rs @@ -311,13 +311,12 @@ fn mk_bctx<'tcx>( header_setting: HeaderSetting::Fn, unwrap_param_map: std::rc::Rc::new(std::cell::RefCell::new(HashMap::new())), external_opaque_type_map, - pending_tracked_args: std::rc::Rc::new(std::cell::RefCell::new(Vec::new())), - in_args_depth: std::rc::Rc::new(std::cell::RefCell::new(0)), + pending_tracked_args: std::rc::Rc::new(std::cell::RefCell::new(None)), declare_with_hir_ids: std::rc::Rc::new(declare_with_hir_ids), } } -/// Pre-scan the HIR body for `declare_with_tracked()`/`declare_with_ghost()` let-stmts. +/// Pre-scan the HIR body for `declare_with()` let-stmts. /// Returns (extra_vir_params, hir_ids_to_skip) so that: /// 1. Extra params can be appended to `vir_params` before body conversion /// 2. `stmt_to_vir` can skip these let-stmts during body conversion @@ -348,7 +347,7 @@ fn pre_scan_declare_with_params<'tcx>( .. }) = &stmt.kind { - // Check if init is a call to declare_with_tracked() or declare_with_ghost() + // Check if init is a call to declare_with() let verus_item = match &init.kind { ExprKind::Call(fun, _) => match &fun.kind { ExprKind::Path(rustc_hir::QPath::Resolved( @@ -1857,7 +1856,7 @@ pub(crate) fn check_item_fn<'tcx>( }; let body = find_body(ctxt, body_id); - // Pre-scan for declare_with_tracked()/declare_with_ghost() calls + // Pre-scan for declare_with() calls let (declare_with_extra_params, declare_with_hir_ids) = pre_scan_declare_with_params(ctxt, id, body, body_id)?; let declare_with_modes: Vec<(bool, rustc_middle::ty::Ty<'tcx>)> = diff --git a/source/rust_verify_test/tests/proof_with.rs b/source/rust_verify_test/tests/proof_with.rs index 96b342beb2..6ab9ae36b8 100644 --- a/source/rust_verify_test/tests/proof_with.rs +++ b/source/rust_verify_test/tests/proof_with.rs @@ -14,9 +14,7 @@ test_verify_one_file! { } fn call_test() { - proof_with(Tracked(1u64)); - proof_with(Ghost(2u32)); - test(0); + proof_with((Tracked(1u64), Ghost(2u32)), test(0)); } #[verifier(external)] @@ -42,9 +40,7 @@ test_verify_one_file! { } fn call_test() { let a = A { a: 0 }; - proof_with(Tracked(1u64)); - proof_with(Ghost(2u32)); - a.test(); + proof_with((Tracked(1u64), Ghost(2u32)), a.test()); } #[verifier(external)] @@ -80,9 +76,7 @@ test_verify_one_file! { } fn call_test() { let a = A { a: 0 }; - proof_with(Tracked(1u64)); - proof_with(Ghost(2u32)); - a.test(); + proof_with((Tracked(1u64), Ghost(2u32)), a.test()); } #[verifier(external)] @@ -111,8 +105,7 @@ test_verify_one_file! { } fn call_test() { - proof_with(Tracked(1u8)); - negate_bool(true, 1); + proof_with(Tracked(1u8), negate_bool(true, 1)); } #[verifier(external)] @@ -156,9 +149,7 @@ test_verify_one_file! { } fn call_test() { - proof_with(Tracked(0u64)); - proof_with(Ghost(2u32)); - test(0); // FAILS + proof_with((Tracked(0u64), Ghost(2u32)), test(0)); // FAILS } } => Err(e) => assert_one_fails(e) } @@ -173,10 +164,9 @@ test_verify_one_file! { } fn call_test() { - proof_with(0u64); - test(0); + proof_with(0u64, test(0)); } - } => Err(e) => assert_vir_error_msg(e, "proof_with expects an argument of type Tracked or Ghost") + } => Err(e) => assert_vir_error_msg(e, "proof_with expects arguments of type Tracked or Ghost") } test_verify_one_file! { @@ -189,8 +179,7 @@ test_verify_one_file! { } fn call_test() { - proof_with(Ghost(0u64)); - test(0); + proof_with(Ghost(0u64), test(0)); } } => Err(e) => assert_vir_error_msg(e, "proof_with argument 1 has wrong mode: expected Tracked, got Ghost") } @@ -198,27 +187,18 @@ test_verify_one_file! { // ---- Lifetime soundness tests ---- // These tests verify that lifetime constraints on tracked/ghost params are properly checked. -// BUG: This should fail with a lifetime error but currently passes. -// The tracked param Tracked<&'b u64> is passed where Tracked<&'a u64> is expected, -// but 'b may not outlive 'a. This is unsound. -// Once the fix is in place, change this to: -// } => Err(err) => assert_rust_error_msg(err, "lifetime may not live long enough") test_verify_one_file! { #[test] test_proof_with_lifetime_mismatch verus_code!{ use vstd::prelude::*; - // test expects a Tracked<&'a u64> where 'a is tied to param `a` fn test<'a>(a: &'a u64, b: u64) -> u64 { let c: Tracked<&'a u64> = declare_with(); 1 } - // test2 has independent lifetimes 'a and 'b - // Passing Tracked<&'b u64> where Tracked<&'a u64> is expected should fail fn test2<'a, 'b>(a: &'a u64, b: u64, c: Tracked<&'b u64>) -> u64 { - proof_with(c); - test(a, b) + proof_with(c, test(a, b)) } } => Err(err) => assert_vir_error_msg(err, "proof_with argument 1 has incompatible lifetime") } @@ -226,7 +206,6 @@ test_verify_one_file! { test_verify_one_file! { #[test] test_proof_with_lifetime_compatible verus_code!{ use vstd::prelude::*; - // Same as above, but 'b: 'a so the lifetime is compatible fn test<'a>(a: &'a u64, b: u64) -> u64 { let c: Tracked<&'a u64> = declare_with(); @@ -235,8 +214,7 @@ test_verify_one_file! { fn test2<'a, 'b: 'a>(a: &'a u64, b: u64, c: Tracked<&'b u64>) -> u64 { - proof_with(c); - test(a, b) + proof_with(c, test(a, b)) } } => Ok(()) } @@ -253,8 +231,7 @@ test_verify_one_file! { fn test2<'a, 'b>(a: &'a u64, c: Ghost<&'b u64>) -> u64 { - proof_with(c); - test(a) + proof_with(c, test(a)) } } => Err(err) => assert_vir_error_msg(err, "proof_with argument 1 has incompatible lifetime") } @@ -270,9 +247,7 @@ test_verify_one_file! { } fn call_test() { - proof_with(Tracked(0u64)); - proof_with(Ghost(2u32)); - test(0u64); + proof_with((Tracked(0u64), Ghost(2u32)), test(0u64)); } #[verifier(external)] @@ -293,9 +268,7 @@ test_verify_one_file! { } fn call_test() { - proof_with(Tracked(0u8)); - proof_with(Ghost(2u32)); - test(0u64); + proof_with((Tracked(0u8), Ghost(2u32)), test(0u64)); } #[verifier(external)] @@ -304,3 +277,21 @@ test_verify_one_file! { } } => Err(e) => assert_vir_error_msg(e, "proof_with argument 1 has wrong type") } + +test_verify_one_file! { + #[test] test_proof_with_ownership verus_code!{ + use vstd::prelude::*; + + struct A; + + fn test<'a>(a: &'a mut A) + { + let b: Tracked<&'a mut A> = declare_with(); + let c: Ghost = declare_with(); + } + + fn call_test(mut a: A, mut b: A) { + proof_with((Tracked(&mut a), Ghost(2u32)), test(&mut a)); + } + } => Err(e) => assert_rust_error_msg_skip_spec_msgs(e, "cannot borrow `a` as mutable more than once at a time") +} From 8fb7da4405ed8e00fe9f0810c2bc02c57ac07edc Mon Sep 17 00:00:00 2001 From: Ziqiao Zhou Date: Wed, 13 May 2026 07:02:17 +0000 Subject: [PATCH 4/4] improve lifetime checking for proof_with MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rewrite proof_with_lifetime.rs using InferCtxt region solver - Rename external_fn_extra_tracked_params → declare_with_params - Clean up check_proof_with_lifetime signature (10 → 7 args) - Simplify is_tracked match in pre_scan_declare_with_params - Refactor pending_tracked_args to Option with take() semantics - Remove in_args_depth field from BodyCtxt Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- source/rust_verify/src/context.rs | 8 +- source/rust_verify/src/fn_call_to_vir.rs | 11 +- source/rust_verify/src/proof_with_lifetime.rs | 342 +++++++----------- source/rust_verify/src/rust_to_vir_func.rs | 42 ++- source/rust_verify_test/tests/proof_with.rs | 32 ++ 5 files changed, 197 insertions(+), 238 deletions(-) diff --git a/source/rust_verify/src/context.rs b/source/rust_verify/src/context.rs index 2dc8096d71..76df710da6 100644 --- a/source/rust_verify/src/context.rs +++ b/source/rust_verify/src/context.rs @@ -59,9 +59,9 @@ pub struct ContextX<'tcx> { pub(crate) crate_name: CrateId, pub(crate) name_def_id_map: Rc>>, pub(crate) next_read_kind_id: AtomicU64, - /// For external_fn_specification functions with extra ghost/tracked params: - /// maps the external function's DefId to a Vec of (is_tracked, expected_ty) pairs - pub(crate) external_fn_extra_tracked_params: + /// For functions with extra ghost/tracked params (from declare_with() stmts): + /// maps the function's DefId to a Vec of (is_tracked, expected_ty) pairs + pub(crate) declare_with_params: Rc)>>>>, } @@ -137,7 +137,7 @@ impl<'tcx> ContextX<'tcx> { crate_name, name_def_id_map: Rc::new(RefCell::new(HashMap::new())), next_read_kind_id: AtomicU64::new(0), - external_fn_extra_tracked_params: Rc::new(RefCell::new(HashMap::new())), + declare_with_params: Rc::new(RefCell::new(HashMap::new())), } } diff --git a/source/rust_verify/src/fn_call_to_vir.rs b/source/rust_verify/src/fn_call_to_vir.rs index a72799010f..9d3bb81a5c 100644 --- a/source/rust_verify/src/fn_call_to_vir.rs +++ b/source/rust_verify/src/fn_call_to_vir.rs @@ -304,17 +304,13 @@ fn fn_call_or_assoc_const_to_vir<'tcx>( record_call(bctx, expr, ResolvedCall::Call(name.clone(), record_name, bctx.in_ghost)); - let vir_args = if let Some(args) = args { - mk_vir_args(bctx, &args)? - } else { - vec![] - }; + let vir_args = if let Some(args) = args { mk_vir_args(bctx, &args)? } else { vec![] }; // Consume pending tracked args from proof_with() if present. // Only consume when callee expects extra params (has declare_with entries). let vir_args = { let mut args = vir_args; - let extra_params = bctx.ctxt.external_fn_extra_tracked_params.borrow().get(&f).cloned(); + let extra_params = bctx.ctxt.declare_with_params.borrow().get(&f).cloned(); if let Some(ref expected_params) = extra_params { let extra_count = expected_params.len(); let mut pending_opt = bctx.pending_tracked_args.borrow_mut(); @@ -384,13 +380,10 @@ fn fn_call_or_assoc_const_to_vir<'tcx>( } check_proof_with_lifetime( bctx, - tcx, f, expr.hir_id, - node_substs, pending_arg.arg_hir_id, *expected_ty, - expected_ty_instantiated, expr.span, i, )?; diff --git a/source/rust_verify/src/proof_with_lifetime.rs b/source/rust_verify/src/proof_with_lifetime.rs index 887e69c6eb..7aeb03803f 100644 --- a/source/rust_verify/src/proof_with_lifetime.rs +++ b/source/rust_verify/src/proof_with_lifetime.rs @@ -4,13 +4,20 @@ //! checker runs, so Rust's NLL cannot enforce lifetime constraints on those parameters. //! This module performs a first-pass lifetime check during VIR lowering to catch //! incompatible lifetimes early. +//! +//! Strategy: use rustc's `InferCtxt` to do a subtype check between the actual arg type +//! (with real caller regions) and the expected type (with caller regions derived from +//! exec arg matching). This reuses Rust's region constraint solver rather than +//! reimplementing outlives checking. use crate::context::BodyCtxt; use crate::util::err_span; use rustc_hir::def::Res; use rustc_hir::{ExprKind, QPath}; +use rustc_middle::ty::error::TypeErrorToStringExt; use rustc_span::Span; use rustc_span::def_id::DefId; +use rustc_trait_selection::regions::InferCtxtRegionExt; use vir::ast::VirErr; /// Check that a proof_with argument's lifetime is compatible with the callee's expected lifetime. @@ -22,19 +29,21 @@ use vir::ast::VirErr; /// 3. Check that the arg's lifetime outlives the expected lifetime in the caller's scope pub(crate) fn check_proof_with_lifetime<'tcx>( bctx: &BodyCtxt<'tcx>, - tcx: rustc_middle::ty::TyCtxt<'tcx>, callee_def_id: DefId, call_hir_id: rustc_hir::HirId, - _node_substs: &'tcx rustc_middle::ty::List>, arg_hir_id: rustc_hir::HirId, - _expected_ty_raw: rustc_middle::ty::Ty<'tcx>, - expected_ty_instantiated: rustc_middle::ty::Ty<'tcx>, + expected_ty_raw: rustc_middle::ty::Ty<'tcx>, call_span: Span, arg_index: usize, ) -> Result<(), VirErr> { - use rustc_middle::ty::{Region, RegionKind, TypeVisitable, TypeVisitor}; + use rustc_infer::infer::TyCtxtInferExt; + use rustc_middle::ty::{Region, RegionKind, TypeFoldable, TypeVisitable, TypeVisitor}; use std::ops::ControlFlow; + let tcx = bctx.ctxt.tcx; + + // --- Helpers --- + struct RegionCollector<'tcx> { regions: Vec>, } @@ -52,14 +61,11 @@ pub(crate) fn check_proof_with_lifetime<'tcx>( collector.regions } - // Extract named regions from the expected type (could be ReLateParam or ReEarlyParam) - let expected_regions = collect_regions(expected_ty_instantiated); - - // Get DefId for a named region + /// Get DefId for a named region in a given function's scope. fn region_def_id<'tcx>( tcx: rustc_middle::ty::TyCtxt<'tcx>, owner_def_id: DefId, - region: &rustc_middle::ty::Region<'tcx>, + region: Region<'tcx>, ) -> Option { match region.kind() { RegionKind::ReLateParam(lp) => match lp.kind { @@ -71,168 +77,191 @@ pub(crate) fn check_proof_with_lifetime<'tcx>( let param = generics.param_at(ep.index as usize, tcx); Some(param.def_id) } + RegionKind::ReBound(_, br) => { + if let rustc_middle::ty::BoundRegionKind::Named(def_id) = br.kind { + Some(def_id) + } else { + None + } + } _ => None, } } - let named_regions: Vec<_> = expected_regions - .iter() - .filter(|r| region_def_id(tcx, callee_def_id, r).is_some()) - .collect(); + // --- Step 1: Check if expected type has any regions worth checking --- + + let raw_regions = collect_regions(expected_ty_raw); + let callee_region_def_ids: Vec<_> = + raw_regions.iter().filter_map(|r| region_def_id(tcx, callee_def_id, *r)).collect(); - if named_regions.is_empty() { - return Ok(()); // No named lifetimes to check + if callee_region_def_ids.is_empty() { + return Ok(()); // No lifetime params in the ghost type } - // Get the callee's poly fn_sig to find which params use each late-bound region - let callee_poly_sig = tcx.fn_sig(callee_def_id).instantiate_identity(); - let callee_sig_inputs = callee_poly_sig.skip_binder().inputs(); + // --- Step 2: Get the actual arg type with real regions --- - // Get the caller's liberated fn_sig to find real lifetime regions for caller params let caller_def_id = bctx.fun_id; let caller_poly_sig = tcx.fn_sig(caller_def_id).instantiate_identity(); let caller_liberated = tcx.liberate_late_bound_regions(caller_def_id, caller_poly_sig); let caller_inputs = caller_liberated.inputs(); - // Build mapping: callee late-bound region DefId → caller ReLateParam region - // by matching callee fn_sig params with call args - let mut callee_to_caller_region: std::collections::HashMap< - rustc_span::def_id::DefId, - rustc_middle::ty::Region<'tcx>, - > = std::collections::HashMap::new(); + let param_idx = resolve_expr_to_param_index(tcx, caller_def_id, arg_hir_id); + let actual_ty_with_regions = match param_idx { + Some(idx) if idx < caller_inputs.len() => caller_inputs[idx], + _ => return Ok(()), // Can't determine actual type with regions + }; + + // --- Step 3: Build callee→caller region mapping from exec args --- + + let callee_poly_sig = tcx.fn_sig(callee_def_id).instantiate_identity(); + let callee_sig_inputs = callee_poly_sig.skip_binder().inputs(); - // Extract call arg HirIds from the HIR expression + // Extract call arg HirIds from the HIR let call_arg_hir_ids: Vec = { let hir_node = tcx.hir_node(call_hir_id); match hir_node { rustc_hir::Node::Expr(expr) => match &expr.kind { ExprKind::Call(_, hir_args) => hir_args.iter().map(|a| a.hir_id).collect(), - rustc_hir::ExprKind::MethodCall(_, receiver, hir_args, _) => { + ExprKind::MethodCall(_, receiver, hir_args, _) => { let mut ids = vec![receiver.hir_id]; ids.extend(hir_args.iter().map(|a| a.hir_id)); ids } - _ => return Ok(()), // Can't extract args + _ => return Ok(()), }, _ => return Ok(()), } }; - // For each callee fn_sig param, find bound regions and match with call arg + // Map callee region DefId → caller Region by matching callee fn_sig params with call args + let mut callee_to_caller: std::collections::HashMap> = + std::collections::HashMap::new(); + for (param_idx, callee_param_ty) in callee_sig_inputs.iter().enumerate() { - // Find named regions (ReBound or ReEarlyParam) in this callee param - let named_param_regions: Vec<_> = collect_regions(*callee_param_ty) + let param_regions: Vec<_> = collect_regions(*callee_param_ty) .into_iter() - .filter_map(|r| match r.kind() { - RegionKind::ReBound(rustc_middle::ty::BoundVarIndexKind::Bound(debruijn), br) - if debruijn == rustc_middle::ty::INNERMOST => - { - if let rustc_middle::ty::BoundRegionKind::Named(def_id) = br.kind { - Some(def_id) - } else { - None - } - } - RegionKind::ReEarlyParam(ep) => { - let generics = tcx.generics_of(callee_def_id); - Some(generics.param_at(ep.index as usize, tcx).def_id) - } - _ => None, + .filter_map(|r| { + let did = region_def_id(tcx, callee_def_id, r)?; + Some(did) }) .collect(); - if named_param_regions.is_empty() || param_idx >= call_arg_hir_ids.len() { + if param_regions.is_empty() || param_idx >= call_arg_hir_ids.len() { continue; } - // Try to resolve the call arg to a caller parameter to get its real lifetime - let arg_hir_id_for_param = call_arg_hir_ids[param_idx]; - if let Some(caller_param_idx) = - resolve_expr_to_param_index(tcx, caller_def_id, arg_hir_id_for_param) - { + let arg_hir = call_arg_hir_ids[param_idx]; + if let Some(caller_param_idx) = resolve_expr_to_param_index(tcx, caller_def_id, arg_hir) { if caller_param_idx < caller_inputs.len() { let caller_param_ty = caller_inputs[caller_param_idx]; - let caller_regions = collect_regions(caller_param_ty); - let caller_named_regions: Vec<_> = caller_regions + let caller_regions: Vec<_> = collect_regions(caller_param_ty) .into_iter() .filter(|r| { matches!(r.kind(), RegionKind::ReLateParam(_) | RegionKind::ReEarlyParam(_)) }) .collect(); - // Map each bound region to the corresponding caller region - // (simple 1:1 mapping by position within the type) - for (region_def_id, caller_region) in - named_param_regions.iter().zip(caller_named_regions.iter()) - { - callee_to_caller_region.insert(*region_def_id, *caller_region); + for (callee_did, caller_region) in param_regions.iter().zip(caller_regions.iter()) { + callee_to_caller.insert(*callee_did, *caller_region); } } } } - if callee_to_caller_region.is_empty() { - // Couldn't determine the mapping — skip the check - return Ok(()); - } + // --- Step 4: Build expected type with caller regions using InferCtxt --- - // Get the proof_with arg's type with real regions - let actual_arg_region = - get_arg_late_param_region(tcx, caller_def_id, caller_inputs, arg_hir_id); + let infcx = tcx.infer_ctxt().build(rustc_type_ir::TypingMode::PostAnalysis); - // For each named region in the expected type, check compatibility - for expected_region in &named_regions { - let callee_lt_def_id = match region_def_id(tcx, callee_def_id, expected_region) { - Some(did) => did, - None => continue, - }; + // For ghost-only regions not found in exec args, create fresh region variables + for callee_did in &callee_region_def_ids { + if !callee_to_caller.contains_key(callee_did) { + let fresh = + infcx.next_region_var(rustc_infer::infer::RegionVariableOrigin::Misc(call_span)); + callee_to_caller.insert(*callee_did, fresh); + } + } - let Some(expected_caller_region) = callee_to_caller_region.get(&callee_lt_def_id) else { - continue; // Can't check this region - }; + // Fold expected_ty_raw, replacing callee regions with mapped caller regions + let expected_ty_caller = + expected_ty_raw.fold_with(&mut rustc_middle::ty::RegionFolder::new(tcx, &mut |r, _| { + if let Some(callee_did) = region_def_id(tcx, callee_def_id, r) { + if let Some(caller_region) = callee_to_caller.get(&callee_did) { + return *caller_region; + } + } + r + })); - let Some(actual_region) = actual_arg_region else { - continue; // Can't determine actual region - }; + // --- Step 5: Add callee's where-clause region bounds as constraints --- - // Compare: does actual_region outlive expected_caller_region? - let actual_did = region_def_id(tcx, caller_def_id, &actual_region); - let expected_did = region_def_id(tcx, caller_def_id, expected_caller_region); + let callee_predicates = tcx.predicates_of(callee_def_id); + for (pred, _) in callee_predicates.predicates { + if let Some(rustc_middle::ty::ClauseKind::RegionOutlives(outlives)) = + pred.kind().no_bound_vars() + { + let longer_did = region_def_id(tcx, callee_def_id, outlives.0); + let shorter_did = region_def_id(tcx, callee_def_id, outlives.1); - if let (Some(actual_did), Some(expected_did)) = (actual_did, expected_did) { - if actual_did == expected_did { - continue; // Same lifetime — OK + if let (Some(longer_did), Some(shorter_did)) = (longer_did, shorter_did) { + if let (Some(&caller_longer), Some(&caller_shorter)) = + (callee_to_caller.get(&longer_did), callee_to_caller.get(&shorter_did)) + { + // Add constraint: caller_longer: caller_shorter + // (sub_regions(a, b) means a <= b, i.e., b outlives a) + infcx.sub_regions( + rustc_infer::infer::SubregionOrigin::RelateParamBound( + call_span, + actual_ty_with_regions, + None, + ), + caller_shorter, + caller_longer, + ); + } } + } + } - // Check if actual outlives expected in the caller's where-clause bounds - if check_region_outlives(tcx, caller_def_id, actual_region, *expected_caller_region) { - continue; // Outlives relationship declared — OK - } + // --- Step 6: Subtype check: actual_ty <: expected_ty --- + + let param_env = tcx.param_env(caller_def_id); + let cause = rustc_infer::traits::ObligationCause::dummy(); + + if let Err(error) = infcx.at(&cause, param_env).sub( + rustc_infer::infer::DefineOpaqueTypes::Yes, + actual_ty_with_regions, + expected_ty_caller, + ) { + return err_span( + call_span, + format!( + "proof_with argument {} has incompatible lifetime: {}", + arg_index + 1, + error.to_string(tcx) + ), + ); + } - // Lifetime mismatch! - let actual_name = tcx.item_name(actual_did); - let expected_name = tcx.item_name(expected_did); - return err_span( - call_span, - format!( - "proof_with argument {} has incompatible lifetime: \ - expected lifetime `'{}`, got `'{}`\n\ - help: consider adding `'{}: '{}`", - arg_index + 1, - expected_name, - actual_name, - actual_name, - expected_name, - ), - ); - } + // --- Step 7: Resolve region constraints and check for errors --- + + let assumed_wf: Vec> = vec![]; + let errors = infcx.resolve_regions(caller_def_id.expect_local(), param_env, assumed_wf); + + if !errors.is_empty() { + return err_span( + call_span, + format!( + "proof_with argument {} has incompatible lifetime: {:?}", + arg_index + 1, + errors + ), + ); } Ok(()) } /// Try to resolve an expression (by HirId) to a function parameter index. -/// Returns Some(index) if the expression is a simple reference to a function parameter. fn resolve_expr_to_param_index<'tcx>( tcx: rustc_middle::ty::TyCtxt<'tcx>, fn_def_id: DefId, @@ -256,100 +285,3 @@ fn resolve_expr_to_param_index<'tcx>( } None } - -/// Get the first named region (ReLateParam or ReEarlyParam) from a proof_with arg's type. -/// The arg is typically a function parameter of the caller. -fn get_arg_late_param_region<'tcx>( - tcx: rustc_middle::ty::TyCtxt<'tcx>, - caller_def_id: DefId, - caller_inputs: &[rustc_middle::ty::Ty<'tcx>], - arg_hir_id: rustc_hir::HirId, -) -> Option> { - use rustc_middle::ty::RegionKind; - - if let Some(param_idx) = resolve_expr_to_param_index(tcx, caller_def_id, arg_hir_id) { - if param_idx < caller_inputs.len() { - let param_ty = caller_inputs[param_idx]; - let regions: Vec<_> = { - use rustc_middle::ty::{Region, TypeVisitable, TypeVisitor}; - use std::ops::ControlFlow; - struct Collector<'tcx> { - regions: Vec>, - } - impl<'tcx> TypeVisitor> for Collector<'tcx> { - type Result = ControlFlow<()>; - fn visit_region(&mut self, r: Region<'tcx>) -> Self::Result { - if matches!( - r.kind(), - RegionKind::ReLateParam(_) | RegionKind::ReEarlyParam(_) - ) { - self.regions.push(r); - } - ControlFlow::Continue(()) - } - } - let mut c = Collector { regions: vec![] }; - let _ = param_ty.visit_with(&mut c); - c.regions - }; - return regions.into_iter().next(); - } - } - - None -} - -/// Check if region `a` outlives region `b` based on the caller's where-clause bounds. -/// Compares by DefId since regions may be in different representations (ReLateParam vs ReEarlyParam). -fn check_region_outlives<'tcx>( - tcx: rustc_middle::ty::TyCtxt<'tcx>, - caller_def_id: DefId, - actual: rustc_middle::ty::Region<'tcx>, - expected: rustc_middle::ty::Region<'tcx>, -) -> bool { - use rustc_middle::ty::{ClauseKind, RegionKind}; - - fn local_region_def_id<'tcx>( - tcx: rustc_middle::ty::TyCtxt<'tcx>, - owner_def_id: DefId, - r: rustc_middle::ty::Region<'tcx>, - ) -> Option { - match r.kind() { - RegionKind::ReLateParam(lp) => match lp.kind { - rustc_middle::ty::LateParamRegionKind::Named(def_id) => Some(def_id), - _ => None, - }, - RegionKind::ReEarlyParam(ep) => { - let generics = tcx.generics_of(owner_def_id); - Some(generics.param_at(ep.index as usize, tcx).def_id) - } - _ => None, - } - } - - let actual_did = local_region_def_id(tcx, caller_def_id, actual); - let expected_did = local_region_def_id(tcx, caller_def_id, expected); - - if actual_did.is_none() || expected_did.is_none() { - return false; - } - - // Check the caller's predicates for 'actual: 'expected (by DefId) - let predicates = tcx.predicates_of(caller_def_id); - for (pred, _) in predicates.predicates { - if let Some(ClauseKind::RegionOutlives(outlives)) = pred.kind().no_bound_vars() { - let pred_longer = local_region_def_id(tcx, caller_def_id, outlives.0); - let pred_shorter = local_region_def_id(tcx, caller_def_id, outlives.1); - if pred_longer == actual_did && pred_shorter == expected_did { - return true; - } - } - } - - // 'static outlives anything - if actual.is_static() { - return true; - } - - false -} diff --git a/source/rust_verify/src/rust_to_vir_func.rs b/source/rust_verify/src/rust_to_vir_func.rs index d4ff3f2dbb..3ed7a2a614 100644 --- a/source/rust_verify/src/rust_to_vir_func.rs +++ b/source/rust_verify/src/rust_to_vir_func.rs @@ -381,21 +381,21 @@ fn pre_scan_declare_with_params<'tcx>( // Derive is_tracked from the type (Tracked vs Ghost) via ADT DefId let is_tracked = match ty.kind() { - rustc_middle::ty::TyKind::Adt(adt_def, _) => { - match ctxt.get_verus_item(adt_def.did()) { - Some(VerusItem::BuiltinType( - crate::verus_items::BuiltinTypeItem::Tracked, - )) => true, - Some(VerusItem::BuiltinType( - crate::verus_items::BuiltinTypeItem::Ghost, - )) => false, - _ => { - return err_span( - init.span, - "declare_with() must be assigned to a Tracked or Ghost type", - ); - } - } + rustc_middle::ty::TyKind::Adt(adt_def, _) + if matches!( + ctxt.get_verus_item(adt_def.did()), + Some(VerusItem::BuiltinType(crate::verus_items::BuiltinTypeItem::Tracked)) + ) => + { + true + } + rustc_middle::ty::TyKind::Adt(adt_def, _) + if matches!( + ctxt.get_verus_item(adt_def.did()), + Some(VerusItem::BuiltinType(crate::verus_items::BuiltinTypeItem::Ghost)) + ) => + { + false } _ => { return err_span( @@ -1864,7 +1864,11 @@ pub(crate) fn check_item_fn<'tcx>( .iter() .map(|(p, _, ty)| { // unwrapped_info mode: Proof = Tracked, Spec = Ghost - (matches!(p.x.unwrapped_info, Some((Mode::Proof, _))), *ty) + match p.x.unwrapped_info { + Some((Mode::Proof, _)) => (true, *ty), + Some((Mode::Spec, _)) => (false, *ty), + _ => unreachable!(), + } }) .collect(); for (p, _mode, _) in declare_with_extra_params { @@ -1875,9 +1879,7 @@ pub(crate) fn check_item_fn<'tcx>( // even if body_to_vir fails for this function. if !declare_with_modes.is_empty() { let target_id = proxy_id.unwrap_or(id); - ctxt.external_fn_extra_tracked_params - .borrow_mut() - .insert(target_id, declare_with_modes.clone()); + ctxt.declare_with_params.borrow_mut().insert(target_id, declare_with_modes.clone()); // For unerased_proxy functions (const fn proxies), also register under // the original function's DefId, since call sites resolve to the original. @@ -1894,7 +1896,7 @@ pub(crate) fn check_item_fn<'tcx>( && child_def_id != id && ctxt.tcx.parent(child_def_id) == parent { - ctxt.external_fn_extra_tracked_params + ctxt.declare_with_params .borrow_mut() .insert(child_def_id, declare_with_modes.clone()); } diff --git a/source/rust_verify_test/tests/proof_with.rs b/source/rust_verify_test/tests/proof_with.rs index 6ab9ae36b8..b9e267c835 100644 --- a/source/rust_verify_test/tests/proof_with.rs +++ b/source/rust_verify_test/tests/proof_with.rs @@ -219,6 +219,22 @@ test_verify_one_file! { } => Ok(()) } +test_verify_one_file! { + #[test] test_proof_with_lifetime_bound_mismatch verus_code!{ + use vstd::prelude::*; + fn test<'a, 'b: 'a>(a: &'a u64, b: u64) -> &'a u64 + { + let c: Tracked<&'b u64> = declare_with(); + a + } + + fn test2<'a, 'b>(a: &'a u64, b: u64, c: Tracked<&'b u64>) -> &'a u64 + { + proof_with(c, test(a, b)) + } + } => Err(err) => assert_vir_error_msg(err, "proof_with argument 1 has incompatible lifetime") +} + // Same as test_proof_with_lifetime_mismatch but for Ghost. test_verify_one_file! { #[test] test_declare_with_ghost_lifetime_mismatch verus_code!{ @@ -257,6 +273,22 @@ test_verify_one_file! { } => Ok(()) } +test_verify_one_file! { + #[test] test_proof_with_generic_type2 verus_code!{ + use vstd::prelude::*; + trait X {} + fn test(a: T1, b: T2) + { + let c: Tracked = declare_with(); + let d: Ghost = declare_with(); + } + + fn call_test(a: T1, b: T2, c: Tracked, d: Ghost) { + proof_with((c, d), test(a, b)); + } + } => Ok(()) +} + test_verify_one_file! { #[test] test_proof_with_generic_type_wrong_type verus_code!{ use vstd::prelude::*;