Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1036,10 +1036,13 @@ impl<'hir> LoweringContext<'_, 'hir> {
hir::LangItem::PinNewUnchecked,
arena_vec![self; ref_mut_awaitee],
);
let get_context = self.expr_call_lang_item_fn_mut(
let get_context = self.expr(
gen_future_span,
hir::LangItem::GetContext,
arena_vec![self; task_context],
hir::ExprKind::UnsafeBinderCast(
UnsafeBinderCastKind::Unwrap,
self.arena.alloc(task_context),
None,
),
);
let call = match await_kind {
FutureKind::Future => self.expr_call_lang_item_fn(
Expand Down
10 changes: 6 additions & 4 deletions compiler/rustc_borrowck/src/places_conflict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,11 @@ fn place_projection_conflict<'tcx>(
debug!("place_element_conflict: DISJOINT-OR-EQ-OPAQUE");
Overlap::EqualOrDisjoint
}
(ProjectionElem::UnwrapUnsafeBinder(_), ProjectionElem::UnwrapUnsafeBinder(_)) => {
// casts to other types may always conflict irrespective of the type being cast to.
debug!("place_element_conflict: DISJOINT-OR-EQ-OPAQUE");
Overlap::EqualOrDisjoint
}
(ProjectionElem::Field(f1, _), ProjectionElem::Field(f2, _)) => {
if f1 == f2 {
// same field (e.g., `a.y` vs. `a.y`) - recur.
Expand Down Expand Up @@ -510,6 +515,7 @@ fn place_projection_conflict<'tcx>(
| ProjectionElem::Index(..)
| ProjectionElem::ConstantIndex { .. }
| ProjectionElem::OpaqueCast { .. }
| ProjectionElem::UnwrapUnsafeBinder { .. }
| ProjectionElem::Subslice { .. }
| ProjectionElem::Downcast(..),
_,
Expand All @@ -518,9 +524,5 @@ fn place_projection_conflict<'tcx>(
pi1_elem,
pi2_elem
),

(ProjectionElem::UnwrapUnsafeBinder(_), _) => {
todo!()
}
}
}
8 changes: 7 additions & 1 deletion compiler/rustc_const_eval/src/interpret/validity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,6 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValidityVisitor<'rt, 'tcx, M> {
// Nothing to check.
interp_ok(true)
}
ty::UnsafeBinder(_) => todo!("FIXME(unsafe_binder)"),
// The above should be all the primitive types. The rest is compound, we
// check them by visiting their fields/variants.
ty::Adt(..)
Expand All @@ -1002,6 +1001,7 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValidityVisitor<'rt, 'tcx, M> {
| ty::Dynamic(..)
| ty::Closure(..)
| ty::Pat(..)
| ty::UnsafeBinder(_)
| ty::CoroutineClosure(..)
| ty::Coroutine(..) => interp_ok(false),
// Some types only occur during typechecking, they have no layout.
Expand Down Expand Up @@ -1536,6 +1536,12 @@ impl<'rt, 'tcx, M: Machine<'tcx>> ValueVisitor<'tcx, M> for ValidityVisitor<'rt,
BackendRepr::Memory { .. } => unreachable!()
}
}
ty::UnsafeBinder(base) => {
// First check that the base type is valid
let base = self.ecx.tcx.instantiate_bound_regions_with_erased((*base).into());
let inner_layout = self.ecx.layout_of(base)?;
self.visit_value(&val.transmute(inner_layout, self.ecx)?)?;
}
ty::Adt(adt, _) if adt.is_maybe_dangling() => {
let old_may_dangle = mem::replace(&mut self.may_dangle, true);

Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_hir/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2327,6 +2327,13 @@ impl CoroutineKind {
matches!(self, CoroutineKind::Desugared(_, CoroutineSource::Fn))
}

pub fn is_async_desugaring(self) -> bool {
matches!(
self,
CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _)
)
}

pub fn to_plural_string(&self) -> String {
match self {
CoroutineKind::Desugared(d, CoroutineSource::Fn) => format!("{d:#}fn bodies"),
Expand Down
3 changes: 1 addition & 2 deletions compiler/rustc_hir/src/lang_items.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,7 @@ language_item_table! {

// FIXME(swatinem): the following lang items are used for async lowering and
// should become obsolete eventually.
ResumeTy, sym::ResumeTy, resume_ty, Target::Struct, GenericRequirement::None;
GetContext, sym::get_context, get_context_fn, Target::Fn, GenericRequirement::None;
ResumeTy, sym::ResumeTy, resume_ty, Target::TyAlias, GenericRequirement::None;

Context, sym::Context, context, Target::Struct, GenericRequirement::None;
FuturePoll, sym::poll, future_poll_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
Expand Down
23 changes: 23 additions & 0 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,29 @@ impl<'tcx> Ty<'tcx> {
Ty::new_generic_adt(tcx, def_id, ty)
}

/// Creates a `unsafe<'a, 'b> &'a mut Context<'b>` [`Ty`].
pub fn new_resume_ty(tcx: TyCtxt<'tcx>) -> Ty<'tcx> {
let context_did = tcx.require_lang_item(LangItem::Context, DUMMY_SP);
let context_adt_ref = tcx.adt_def(context_did);

let lt = |n| {
ty::Region::new_bound(
tcx,
ty::INNERMOST,
ty::BoundRegion { var: ty::BoundVar::from_u32(n), kind: BoundRegionKind::Anon },
)
};

let context_args = tcx.mk_args(&[lt(1).into()]);
let context_ty = Ty::new_adt(tcx, context_adt_ref, context_args);
let context_mut_ref = Ty::new_mut_ref(tcx, lt(0), context_ty);
let bound_vars = tcx.mk_bound_variable_kinds(&[
BoundVariableKind::Region(BoundRegionKind::Anon),
BoundVariableKind::Region(BoundRegionKind::Anon),
]);
Ty::new_unsafe_binder(tcx, ty::Binder::bind_with_vars(context_mut_ref, bound_vars))
}

/// Creates a `&mut Context<'_>` [`Ty`] with erased lifetimes.
pub fn new_task_context(tcx: TyCtxt<'tcx>) -> Ty<'tcx> {
let context_did = tcx.require_lang_item(LangItem::Context, DUMMY_SP);
Expand Down
138 changes: 32 additions & 106 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -563,20 +563,13 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
);
}

/// Transforms the `body` of the coroutine applying the following transforms:
/// Transforms the `body` of the coroutine replacing `CTX_ARG: ResumeTy` types with
/// `CTX_ARG: &mut Context<'_>` (`context_mut_ref`), and making all users wrap/unwrap
/// into a `ResumeTy`.
///
/// - Eliminates all the `get_context` calls that async lowering created.
/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe binder.
///
/// The `Local`s that have their types replaced are:
/// - The `resume` argument itself.
/// - The argument to `get_context`.
/// - The yielded value of a `yield`.
///
/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
///
/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
/// Ideally the async lowering would not use the `ResumeTy` indirection,
/// but rather directly use `&mut Context<'_>`, however that would currently
/// lead to higher-kinded lifetime errors.
/// See <https://github.com/rust-lang/rust/issues/105501>.
Expand All @@ -585,78 +578,20 @@ fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body
/// still using the `ResumeTy` indirection for the time being, and that indirection
/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`.
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> Ty<'tcx> {
let context_mut_ref = Ty::new_task_context(tcx);

// replace the type of the `resume` argument
replace_resume_ty_local(tcx, body, CTX_ARG, context_mut_ref);
fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let resume_ty = body.local_decls[CTX_ARG].ty;
let resume_local = body.local_decls.push(LocalDecl::new(Ty::new_task_context(tcx), body.span));
body.local_decls.swap(CTX_ARG, resume_local);

let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
RenameLocalVisitor { from: CTX_ARG, to: resume_local, tcx }.visit_body(body);

for bb in body.basic_blocks.indices() {
let bb_data = &body[bb];
if bb_data.is_cleanup {
continue;
}

match &bb_data.terminator().kind {
TerminatorKind::Call { func, .. } => {
let func_ty = func.ty(body, tcx);
if let ty::FnDef(def_id, _) = *func_ty.kind()
&& def_id == get_context_def_id
{
let local = eliminate_get_context_call(&mut body[bb]);
replace_resume_ty_local(tcx, body, local, context_mut_ref);
}
}
TerminatorKind::Yield { resume_arg, .. } => {
replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
}
_ => {}
}
}
context_mut_ref
}

fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
let terminator = bb_data.terminator.take().unwrap();
let TerminatorKind::Call { args, destination, target, .. } = terminator.kind else {
bug!();
};
let [arg] = *Box::try_from(args).unwrap();
let local = arg.node.place().unwrap().local;

let arg = Rvalue::Use(arg.node, WithRetag::Yes);
let assign =
Statement::new(terminator.source_info, StatementKind::Assign(Box::new((destination, arg))));
bb_data.statements.push(assign);
bb_data.terminator = Some(Terminator {
source_info: terminator.source_info,
kind: TerminatorKind::Goto { target: target.unwrap() },
});
local
}

#[cfg_attr(not(debug_assertions), allow(unused))]
#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
fn replace_resume_ty_local<'tcx>(
tcx: TyCtxt<'tcx>,
body: &mut Body<'tcx>,
local: Local,
context_mut_ref: Ty<'tcx>,
) {
let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
// We have to replace the `ResumeTy` that is used for type and borrow checking
// with `&mut Context<'_>` in MIR.
#[cfg(debug_assertions)]
{
if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, body.span));
assert_eq!(*resume_ty_adt, expected_adt);
} else {
panic!("expected `ResumeTy`, found `{:?}`", local_ty);
};
}
// Now `CTX_ARG` is `&mut Context` and `resume_local` is a `unsafe<>`.
let source_info = SourceInfo::outermost(body.span);
let rhs = Rvalue::WrapUnsafeBinder(Operand::Move(CTX_ARG.into()), resume_ty);
let assign = StatementKind::Assign(Box::new((resume_local.into(), rhs)));
body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK]
.statements
.insert(0, Statement::new(source_info, assign));
}

/// Transforms the `body` of the coroutine applying the following transform:
Expand Down Expand Up @@ -1314,6 +1249,12 @@ fn create_coroutine_resume_function<'tcx>(

pm::run_passes_no_validate(tcx, body, &[&abort_unwinding_calls::AbortUnwindingCalls], None);

// Run derefer to fix Derefs that are not in the first place
deref_finder(tcx, body, false);
if transform.coroutine_kind.is_async_desugaring() {
transform_async_context(tcx, body);
}

if let Some(dumper) = MirDumper::new(tcx, "coroutine_resume", body) {
dumper.dump_mir(body);
}
Expand Down Expand Up @@ -1361,13 +1302,13 @@ fn create_cases<'tcx>(
}
}

// Move the resume argument to the destination place of the `Yield` terminator
if operation == Operation::Resume && point.resume_arg != CTX_ARG.into() {
// Move the resume argument to the destination place of the `Yield` terminator
statements.push(Statement::new(
source_info,
StatementKind::Assign(Box::new((
point.resume_arg,
Rvalue::Use(Operand::Move(CTX_ARG.into()), WithRetag::Yes),
Rvalue::Use(Operand::Copy(CTX_ARG.into()), WithRetag::Yes),
))),
));
}
Expand Down Expand Up @@ -1519,18 +1460,11 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
// (finally in open_drop_for_tuple) before async drop expansion.
// Async drops, produced by this drop elaboration, will be expanded,
// and corresponding futures kept in layout.
let has_async_drops = matches!(
coroutine_kind,
CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _)
) && has_expandable_async_drops(tcx, body, coroutine_ty);
let has_async_drops = coroutine_kind.is_async_desugaring()
&& has_expandable_async_drops(tcx, body, coroutine_ty);

// Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
if matches!(
coroutine_kind,
CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _)
) {
let context_mut_ref = transform_async_context(tcx, body);
expand_async_drops(tcx, body, context_mut_ref, coroutine_kind, coroutine_ty);
if has_async_drops {
expand_async_drops(tcx, body, coroutine_kind, coroutine_ty);

if let Some(dumper) = MirDumper::new(tcx, "coroutine_async_drop_expand", body) {
dumper.dump_mir(body);
Expand Down Expand Up @@ -1650,30 +1584,22 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
// Create a copy of our MIR and use it to create the drop shim for the coroutine
if has_async_drops {
// If coroutine has async drops, generating async drop shim
let mut drop_shim =
let drop_shim =
create_coroutine_drop_shim_async(tcx, &transform, body, drop_clean, can_unwind);
// Run derefer to fix Derefs that are not in the first place
deref_finder(tcx, &mut drop_shim, false);
body.coroutine.as_mut().unwrap().coroutine_drop_async = Some(drop_shim);
} else {
// If coroutine has no async drops, generating sync drop shim
let mut drop_shim =
let drop_shim =
create_coroutine_drop_shim(tcx, &transform, coroutine_ty, body, drop_clean);
// Run derefer to fix Derefs that are not in the first place
deref_finder(tcx, &mut drop_shim, false);
body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim);

// For coroutine with sync drop, generating async proxy for `future_drop_poll` call
let mut proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body);
deref_finder(tcx, &mut proxy_shim, false);
let proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body, coroutine_kind);
body.coroutine.as_mut().unwrap().coroutine_drop_proxy_async = Some(proxy_shim);
}

// Create the Coroutine::resume / Future::poll function
create_coroutine_resume_function(tcx, transform, body, can_return, can_unwind);

// Run derefer to fix Derefs that are not in the first place
deref_finder(tcx, body, false);
}

fn is_required(&self) -> bool {
Expand Down
Loading
Loading