Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion crates/hir-def/src/expr_store/body.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,15 @@ impl Body {
let mut params = None;

let mut is_async_fn = false;
let mut is_gen_fn = false;
let InFile { file_id, value: body } = {
match def {
DefWithBodyId::FunctionId(f) => {
let f = f.lookup(db);
let src = f.source(db);
params = src.value.param_list();
is_async_fn = src.value.async_token().is_some();
is_gen_fn = src.value.gen_token().is_some();
src.map(|it| it.body().map(ast::Expr::from))
}
DefWithBodyId::ConstId(c) => {
Expand All @@ -101,7 +103,8 @@ impl Body {
}
};
let module = def.module(db);
let (body, source_map) = lower_body(db, def, file_id, module, params, body, is_async_fn);
let (body, source_map) =
lower_body(db, def, file_id, module, params, body, is_async_fn, is_gen_fn);

(Arc::new(body), source_map)
}
Expand Down
125 changes: 94 additions & 31 deletions crates/hir-def/src/expr_store/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ use crate::{
},
hir::{
Array, Binding, BindingAnnotation, BindingId, BindingProblems, CaptureBy, ClosureKind,
CoroutineSource, Expr, ExprId, Item, Label, LabelId, Literal, MatchArm, Movability,
OffsetOf, Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, Statement,
CoroutineKind, CoroutineSource, Expr, ExprId, Item, Label, LabelId, Literal, MatchArm,
Movability, OffsetOf, Pat, PatId, RecordFieldPat, RecordLitField, RecordSpread, Statement,
generics::GenericParams,
},
item_scope::BuiltinShadowMode,
Expand All @@ -72,6 +72,7 @@ pub(super) fn lower_body(
parameters: Option<ast::ParamList>,
body: Option<ast::Expr>,
is_async_fn: bool,
is_gen_fn: bool,
) -> (Body, BodySourceMap) {
// We cannot leave the root span map empty and let any identifier from it be treated as root,
// because when inside nested macros `SyntaxContextId`s from the outer macro will be interleaved
Expand Down Expand Up @@ -176,6 +177,8 @@ pub(super) fn lower_body(
DefWithBodyId::VariantId(..) => Awaitable::No("enum variant"),
}
},
is_async_fn,
is_gen_fn,
);
collector.store.inference_roots = Some(smallvec![(body_expr, RootExprOrigin::BodyRoot)]);

Expand Down Expand Up @@ -376,12 +379,20 @@ pub(crate) fn lower_function(
expr_collector.lower_type_ref_opt(ret_type.ty(), &mut ExprCollector::impl_trait_allocator)
});

let return_type = if fn_.value.async_token().is_some() {
let path = hir_expand::mod_path::path![core::future::Future];
let return_type = if fn_.value.async_token().is_some() || fn_.value.gen_token().is_some() {
let (path, assoc_name) =
match (fn_.value.async_token().is_some(), fn_.value.gen_token().is_some()) {
(true, true) => {
(hir_expand::mod_path::path![core::async_iter::AsyncIterator], sym::Item)
}
(true, false) => (hir_expand::mod_path::path![core::future::Future], sym::Output),
(false, true) => (hir_expand::mod_path::path![core::iter::Iterator], sym::Item),
(false, false) => unreachable!(),
};
let mut generic_args: Vec<_> =
std::iter::repeat_n(None, path.segments().len() - 1).collect();
let binding = AssociatedTypeBinding {
name: Name::new_symbol_root(sym::Output),
name: Name::new_symbol_root(assoc_name),
args: None,
type_ref: Some(
return_type
Expand Down Expand Up @@ -950,10 +961,11 @@ impl<'db> ExprCollector<'db> {
/// into the body. This is to make sure that the future actually owns the
/// arguments that are passed to the function, and to ensure things like
/// drop order are stable.
fn lower_async_block_with_moved_arguments(
fn lower_coroutine_with_moved_arguments(
&mut self,
params: &mut [PatId],
body: ExprId,
kind: CoroutineKind,
coroutine_source: CoroutineSource,
) -> ExprId {
let mut statements = Vec::new();
Expand Down Expand Up @@ -989,7 +1001,8 @@ impl<'db> ExprCollector<'db> {
*param = pat_id;
}

let async_ = self.async_block(
let coroutine = self.desugared_coroutine_expr(
kind,
coroutine_source,
// The default capture mode here is by-ref. Later on during upvar analysis,
// we will force the captured arguments to by-move, but for async closures,
Expand All @@ -1001,11 +1014,12 @@ impl<'db> ExprCollector<'db> {
Some(body),
);
// It's important that this comes last, see the lowering of async closures for why.
self.alloc_expr_desugared(async_)
self.alloc_expr_desugared(coroutine)
}

fn async_block(
fn desugared_coroutine_expr(
&mut self,
kind: CoroutineKind,
source: CoroutineSource,
capture_by: CaptureBy,
id: Option<BlockId>,
Expand All @@ -1018,7 +1032,7 @@ impl<'db> ExprCollector<'db> {
arg_types: Box::default(),
ret_type: None,
body: block,
closure_kind: ClosureKind::AsyncBlock { source },
closure_kind: ClosureKind::Coroutine { kind, source },
capture_by,
}
}
Expand All @@ -1028,12 +1042,20 @@ impl<'db> ExprCollector<'db> {
params: &mut [PatId],
expr: Option<ast::Expr>,
awaitable: Awaitable,
is_async_fn: bool,
is_gen_fn: bool,
) -> ExprId {
self.awaitable_context.replace(awaitable);
self.with_label_rib(RibKind::Closure, |this| {
let body = this.collect_expr_opt(expr);
if awaitable == Awaitable::Yes {
this.lower_async_block_with_moved_arguments(params, body, CoroutineSource::Fn)
if is_async_fn || is_gen_fn {
let kind = match (is_async_fn, is_gen_fn) {
(true, true) => CoroutineKind::AsyncGen,
(true, false) => CoroutineKind::Async,
(false, true) => CoroutineKind::Gen,
(false, false) => unreachable!(),
};
this.lower_coroutine_with_moved_arguments(params, body, kind, CoroutineSource::Fn)
} else {
body
}
Expand Down Expand Up @@ -1192,7 +1214,44 @@ impl<'db> ExprCollector<'db> {
self.with_label_rib(RibKind::Closure, |this| {
this.with_awaitable_block(Awaitable::Yes, |this| {
this.collect_block_(e, |this, id, statements, tail| {
this.async_block(
this.desugared_coroutine_expr(
CoroutineKind::Async,
CoroutineSource::Block,
capture_by,
id,
statements,
tail,
)
})
})
})
}
Some(ast::BlockModifier::Gen(_)) => {
let capture_by =
if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref };
self.with_label_rib(RibKind::Closure, |this| {
this.with_awaitable_block(Awaitable::No("non-async gen block"), |this| {
this.collect_block_(e, |this, id, statements, tail| {
this.desugared_coroutine_expr(
CoroutineKind::Gen,
CoroutineSource::Block,
capture_by,
id,
statements,
tail,
)
})
})
})
}
Some(ast::BlockModifier::AsyncGen(_)) => {
let capture_by =
if e.move_token().is_some() { CaptureBy::Value } else { CaptureBy::Ref };
self.with_label_rib(RibKind::Closure, |this| {
this.with_awaitable_block(Awaitable::Yes, |this| {
this.collect_block_(e, |this, id, statements, tail| {
this.desugared_coroutine_expr(
CoroutineKind::AsyncGen,
CoroutineSource::Block,
capture_by,
id,
Expand All @@ -1213,14 +1272,6 @@ impl<'db> ExprCollector<'db> {
})
})
}
// FIXME
Some(ast::BlockModifier::AsyncGen(_)) => {
self.with_awaitable_block(Awaitable::Yes, |this| this.collect_block(e))
}
Some(ast::BlockModifier::Gen(_)) => self
.with_awaitable_block(Awaitable::No("non-async gen block"), |this| {
this.collect_block(e)
}),
None => self.collect_block(e),
},
ast::Expr::LoopExpr(e) => {
Expand Down Expand Up @@ -1460,25 +1511,37 @@ impl<'db> ExprCollector<'db> {
};
let mut body = this
.with_awaitable_block(awaitable, |this| this.collect_expr_opt(e.body()));

let closure_kind = if this.is_lowering_coroutine {
let movability = if e.static_token().is_some() {
Movability::Static
let kind = {
if e.async_token().is_some() && e.gen_token().is_some() {
Some(CoroutineKind::AsyncGen)
} else if e.async_token().is_some() {
Some(CoroutineKind::Async)
} else if e.gen_token().is_some() {
Some(CoroutineKind::Gen)
} else {
Movability::Movable
};
ClosureKind::Coroutine(movability)
} else if e.async_token().is_some() {
None
}
};

let closure_kind = if let Some(kind) = kind {
// It's important that this expr is allocated immediately before the closure.
// We rely on it for `coroutine_for_closure()`.
body = this.lower_async_block_with_moved_arguments(
body = this.lower_coroutine_with_moved_arguments(
&mut args,
body,
kind,
CoroutineSource::Closure,
);
body_is_bindings_owner = true;

ClosureKind::AsyncClosure
ClosureKind::CoroutineClosure(kind)
} else if this.is_lowering_coroutine {
let movability = if e.static_token().is_some() {
Movability::Static
} else {
Movability::Movable
};
ClosureKind::OldCoroutine(movability)
} else {
ClosureKind::Closure
};
Expand Down
28 changes: 18 additions & 10 deletions crates/hir-def/src/expr_store/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use crate::{
attrs::AttrFlags,
expr_store::path::{GenericArg, GenericArgs},
hir::{
Array, BindingAnnotation, CaptureBy, ClosureKind, Literal, Movability, RecordSpread,
Statement,
Array, BindingAnnotation, CaptureBy, ClosureKind, CoroutineKind, Literal, Movability,
RecordSpread, Statement,
generics::{GenericParams, WherePredicate},
},
lang_item::LangItemTarget,
Expand Down Expand Up @@ -761,28 +761,36 @@ impl Printer<'_> {
let mut body = *body;
let mut print_pipes = true;
match closure_kind {
ClosureKind::Coroutine(Movability::Static) => {
ClosureKind::OldCoroutine(Movability::Static) => {
w!(self, "static ");
}
ClosureKind::AsyncClosure => {
ClosureKind::CoroutineClosure(kind) => {
if let Expr::Closure {
body: inner_body,
closure_kind: ClosureKind::AsyncBlock { .. },
closure_kind: ClosureKind::Coroutine { .. },
..
} = self.store[body]
{
body = inner_body;
} else {
never!("async closure should always have an async block body");
never!("coroutine closure should always have a coroutine body");
}

w!(self, "async ");
match kind {
CoroutineKind::Async => w!(self, "async "),
CoroutineKind::Gen => w!(self, "gen "),
CoroutineKind::AsyncGen => w!(self, "async gen "),
}
}
ClosureKind::AsyncBlock { .. } => {
w!(self, "async ");
ClosureKind::Coroutine { kind, .. } => {
match kind {
CoroutineKind::Async => w!(self, "async "),
CoroutineKind::Gen => w!(self, "gen "),
CoroutineKind::AsyncGen => w!(self, "async gen "),
}
print_pipes = false;
}
ClosureKind::Closure | ClosureKind::Coroutine(Movability::Movable) => (),
ClosureKind::Closure | ClosureKind::OldCoroutine(Movability::Movable) => (),
}
match capture_by {
CaptureBy::Value => {
Expand Down
13 changes: 10 additions & 3 deletions crates/hir-def/src/hir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,12 +524,19 @@ pub enum InlineAsmRegOrRegClass {
RegClass(Symbol),
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CoroutineKind {
Async,
Gen,
AsyncGen,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ClosureKind {
Closure,
Coroutine(Movability),
AsyncBlock { source: CoroutineSource },
AsyncClosure,
OldCoroutine(Movability),
Coroutine { kind: CoroutineKind, source: CoroutineSource },
CoroutineClosure(CoroutineKind),
}

/// In the case of a coroutine created as part of an async/gen construct,
Expand Down
2 changes: 1 addition & 1 deletion crates/hir-def/src/lang_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ language_item_table! { LangItems =>
FnOnceOutput, sym::fn_once_output, TypeAliasId;

Future, sym::future_trait, TraitId;
AsyncIterator, sym::async_iterator, TraitId;
CoroutineState, sym::coroutine_state, EnumId;
Coroutine, sym::coroutine, TraitId;
CoroutineReturn, sym::coroutine_return, TypeAliasId;
Expand Down Expand Up @@ -522,7 +523,6 @@ language_item_table! { LangItems =>
IteratorNext, sym::next, FunctionId;
Iterator, sym::iterator, TraitId;
FusedIterator, sym::fused_iterator, TraitId;
AsyncIterator, sym::async_iterator, TraitId;

PinNewUnchecked, sym::new_unchecked, FunctionId;

Expand Down
Loading