diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 8a4ff72eb..b8ff22c97 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -5,9 +5,11 @@ // https://opensource.org/licenses/MIT. use clap::Parser; +use futures::{SinkExt, future}; use service::{WorldClient, init_tracing}; use std::{net::SocketAddr, time::Duration}; -use tarpc::{client, context, tokio_serde::formats::Json}; +use tarpc::context::{SharedContext}; +use tarpc::{client, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; @@ -29,15 +31,20 @@ async fn main() -> anyhow::Result<()> { let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default); transport.config_mut().max_frame_length(usize::MAX); + let transport = transport.await?; + // WorldClient is generated by the service attribute. It has a constructor `new` that takes a // config and any Transport as input. - let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); + let client = WorldClient::new(client::Config::default(), transport).spawn(); let hello = async move { + let mut context = SharedContext::current(); + let mut context2 = SharedContext::current(); + // Send the request twice, just to be safe! ;) tokio::select! { - hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 } - hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 } + hello1 = client.hello(&mut context, format!("{}1", flags.name)) => { hello1 } + hello2 = client.hello(&mut context2, format!("{}2", flags.name)) => { hello2 } } } .instrument(tracing::info_span!("Two Hellos")) diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 896280c3d..019a2d7b1 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -11,12 +11,14 @@ use rand::{ thread_rng, }; use service::{World, init_tracing}; +use std::ops::Deref; use std::{ net::{IpAddr, Ipv6Addr, SocketAddr}, time::Duration, }; +use tarpc::context::{SharedContext}; use tarpc::{ - context, + ClientMessage, context, server::{self, Channel, incoming::Incoming}, tokio_serde::formats::Json, }; @@ -35,7 +37,8 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + type Context = SharedContext; + async fn hello(self, _: &mut Self::Context, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); time::sleep(sleep_time).await; @@ -64,11 +67,11 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())) .map(server::BaseChannel::with_defaults) // Limit channels to 1 per IP. - .max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip()) + .max_channels_per_key(1, |t| t.transport().get_ref().peer_addr().unwrap().ip()) // serve is generated by the service attribute. It takes as input any type implementing // the generated World trait. .map(|channel| { - let server = HelloServer(channel.transport().peer_addr().unwrap()); + let server = HelloServer(channel.transport().get_ref().peer_addr().unwrap()); channel.execute(server.serve()).for_each(spawn) }) // Max 10 channels. diff --git a/plugins/Cargo.toml b/plugins/Cargo.toml index 8be746c26..eeab84924 100644 --- a/plugins/Cargo.toml +++ b/plugins/Cargo.toml @@ -30,5 +30,6 @@ proc-macro = true [dev-dependencies] assert-type-eq = "0.1.0" futures = "0.3" +futures-util = "0.3.31" serde = { version = "1.0", features = ["derive"] } tarpc = { path = "../tarpc", features = ["serde1"] } diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index da6443edf..cf107d0ad 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -375,7 +375,10 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// # Example /// /// ```no_run -/// use tarpc::{client, transport, service, server::{self, Channel}, context::Context}; +/// use tarpc::{client, transport, service, server::{self, Channel}}; +/// use futures_util::{TryStreamExt, sink::SinkExt};/// +/// +/// use tarpc::context::SharedContext; /// /// #[service] /// pub trait Calculator { @@ -401,7 +404,8 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// async fn add(self, context: Context, a: i32, b: i32) -> i32 { +/// type Context = SharedContext; +/// async fn add(self, context: &mut Self::Context, a: i32, b: i32) -> i32 { /// a + b /// } /// } @@ -558,7 +562,7 @@ impl ServiceGenerator<'_> { )| { quote! { #( #attrs )* - async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output; + async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output; } }, ); @@ -567,6 +571,8 @@ impl ServiceGenerator<'_> { quote! { #( #attrs )* #vis trait #service_ident: ::core::marker::Sized { + type Context: ::tarpc::context::ExtractContext<::tarpc::context::SharedContext>; + #( #rpc_fns )* /// Returns a serving function to use with @@ -577,11 +583,11 @@ impl ServiceGenerator<'_> { } #[doc = #stub_doc] - #vis trait #client_stub_ident: ::tarpc::client::stub::Stub { + #vis trait #client_stub_ident: ::tarpc::client::stub::Stub { } - impl #client_stub_ident for S - where S: ::tarpc::client::stub::Stub + impl #client_stub_ident for S + where S: ::tarpc::client::stub::Stub { } } @@ -620,9 +626,9 @@ impl ServiceGenerator<'_> { { type Req = #request_ident; type Resp = #response_ident; + type ServerCtx = S::Context; - - async fn serve(self, ctx: ::tarpc::context::Context, req: #request_ident) + async fn serve(self, ctx: &mut Self::ServerCtx, req: #request_ident) -> ::core::result::Result<#response_ident, ::tarpc::ServerError> { match req { #( @@ -711,12 +717,19 @@ impl ServiceGenerator<'_> { quote! { #[allow(unused)] - #[derive(Clone, Debug)] + #[derive(Debug)] /// The client stub that makes RPC calls to the server. All request methods return /// [Futures](::core::future::Future). #vis struct #client_ident< - Stub = ::tarpc::client::Channel<#request_ident, #response_ident> - >(Stub); + ClientCtx, + Stub = ::tarpc::client::Channel<#request_ident, #response_ident, ClientCtx> + >(Stub, ::std::marker::PhantomData); + + impl ::std::clone::Clone for #client_ident { + fn clone(&self) -> Self { + Self(self.0.clone(), ::std::marker::PhantomData) + } + } } } @@ -730,32 +743,33 @@ impl ServiceGenerator<'_> { } = self; quote! { - impl #client_ident { + impl #client_ident { /// Returns a new client stub that sends requests over the given transport. #vis fn new(config: ::tarpc::client::Config, transport: T) -> ::tarpc::client::NewClient< Self, - ::tarpc::client::RequestDispatch<#request_ident, #response_ident, T> + ::tarpc::client::RequestDispatch<#request_ident, #response_ident, ClientCtx, T> > where - T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>> + T: ::tarpc::Transport<::tarpc::ClientMessage, ::tarpc::Response> { let new_client = ::tarpc::client::new(config, transport); ::tarpc::client::NewClient { - client: #client_ident(new_client.client), + client: #client_ident(new_client.client, ::std::marker::PhantomData), dispatch: new_client.dispatch, } } } - impl ::core::convert::From for #client_ident + impl ::core::convert::From for #client_ident where Stub: ::tarpc::client::stub::Stub< Req = #request_ident, - Resp = #response_ident> + Resp = #response_ident, + ClientCtx = ClientCtx> { /// Returns a new client stub that sends requests over the given transport. fn from(stub: Stub) -> Self { - #client_ident(stub) + #client_ident::(stub, ::std::marker::PhantomData) } } @@ -778,15 +792,16 @@ impl ServiceGenerator<'_> { } = self; quote! { - impl #client_ident + impl #client_ident where Stub: ::tarpc::client::stub::Stub< Req = #request_ident, - Resp = #response_ident> + Resp = #response_ident, + ClientCtx = ClientCtx> { #( #[allow(unused)] #( #method_attrs )* - #vis fn #method_idents(&self, ctx: ::tarpc::context::Context, #( #args ),*) + #vis fn #method_idents<'a>(&'a self, ctx: &'a mut Stub::ClientCtx, #( #args ),*) -> impl ::core::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let resp = self.0.call(ctx, request); diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 26ee1ec39..d8213f4d4 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize}; use std::hash::Hash; use tarpc::context; +use tarpc::context::SharedContext; #[test] fn att_service_trait() { @@ -12,15 +13,21 @@ fn att_service_trait() { } impl Foo for () { - async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { + type Context = SharedContext; + async fn two_part( + self, + _: &mut context::SharedContext, + s: String, + i: i32, + ) -> (String, i32) { (s, i) } - async fn bar(self, _: context::Context, s: String) -> String { + async fn bar(self, _: &mut Self::Context, s: String) -> String { s } - async fn baz(self, _: context::Context) {} + async fn baz(self, _: &mut Self::Context) {} } } @@ -37,20 +44,21 @@ fn raw_idents() { } impl r#trait for () { + type Context = SharedContext; async fn r#await( self, - _: context::Context, + _: &mut Self::Context, r#struct: r#yield, r#enum: i32, ) -> (r#yield, i32) { (r#struct, r#enum) } - async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { + async fn r#fn(self, _: &mut Self::Context, r#impl: r#yield) -> r#yield { r#impl } - async fn r#async(self, _: context::Context) {} + async fn r#async(self, _: &mut Self::Context) {} } } @@ -64,7 +72,8 @@ fn service_with_cfg_rpc() { } impl Foo for () { - async fn foo(self, _: context::Context) {} + type Context = SharedContext; + async fn foo(self, _: &mut Self::Context) {} } } diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 778eb0938..625c1f72f 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -61,6 +61,8 @@ tracing = { version = "0.1", default-features = false, features = [ tracing-opentelemetry = { version = "0.31.0", default-features = false } opentelemetry = { version = "0.30.0", default-features = false } opentelemetry-semantic-conventions = "0.30.0" +anymap3 = "1.0.1" +serde-value = "0.7" [dev-dependencies] assert_matches = "1.4" @@ -81,6 +83,7 @@ trybuild = "1.0" tokio-rustls = "0.26" rustls-pemfile = "2.0" + [package.metadata.docs.rs] all-features = true rustdoc-args = ["--cfg", "docsrs"] diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index d66261d19..f201521ad 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -15,6 +15,7 @@ use tarpc::{ server::{BaseChannel, Channel}, tokio_serde::formats::Bincode, }; +use tarpc::context::SharedContext; /// Type of compression that should be enabled on the request. The transport is free to ignore this. #[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)] @@ -108,7 +109,8 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + type Context = SharedContext; + async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hey, {name}!") } } @@ -120,21 +122,26 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { let mut incoming = tcp::listen("localhost:0", Bincode::default).await?; + let addr = incoming.local_addr(); tokio::spawn(async move { let transport = incoming.next().await.unwrap().unwrap(); - BaseChannel::with_defaults(add_compression(transport)) + let transport = add_compression(transport); + BaseChannel::with_defaults(transport) .execute(HelloServer.serve()) .for_each(spawn) .await; }); let transport = tcp::connect(addr, Bincode::default).await?; - let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn(); + let transport = add_compression(transport); + let client = WorldClient::new(client::Config::default(), transport).spawn(); println!( "{}", - client.hello(context::current(), "friend".into()).await? + client + .hello(&mut context::SharedContext::current(), "friend".into()) + .await? ); Ok(()) } diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 5f5386785..c9eb871ea 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -4,8 +4,9 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +use console_subscriber::Server; use futures::prelude::*; -use tarpc::context::Context; +use tarpc::context::{SharedContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; @@ -21,9 +22,9 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: Context) {} + type Context = SharedContext; + async fn ping(self, _: &mut Self::Context) {} } - #[tokio::main] async fn main() -> anyhow::Result<()> { let bind_addr = "/tmp/tarpc_on_unix_example.sock"; @@ -52,7 +53,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); PingServiceClient::new(Default::default(), transport) .spawn() - .ping(tarpc::context::current()) + .ping(&mut SharedContext::current()) .await?; Ok(()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index d61f68c48..07a93becf 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -47,9 +47,13 @@ use std::{ net::SocketAddr, sync::{Arc, Mutex, RwLock}, }; +use std::ops::Shl; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; use subscriber::Subscriber as _; +use tarpc::context::{ExtractContext, SharedContext}; use tarpc::{ - client, context, + ClientMessage, client, context, serde_transport::tcp, server::{self, Channel}, tokio_serde::formats::Json, @@ -80,11 +84,12 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { - async fn topics(self, _: context::Context) -> Vec { + type Context = SharedContext; + async fn topics(self, _: &mut Self::Context) -> Vec { self.topics.clone() } - async fn receive(self, _: context::Context, topic: String, message: String) { + async fn receive(self, _: &mut Self::Context, topic: String, message: String) { info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage") } } @@ -132,10 +137,19 @@ struct Subscription { topics: Vec, } -#[derive(Clone, Debug)] -struct Publisher { +#[derive(Debug)] +struct Publisher { clients: Arc>>, - subscriptions: Arc>>>, + subscriptions: Arc>>>>, +} + +impl Clone for Publisher { + fn clone(&self) -> Self { + Publisher { + clients: self.clients.clone(), + subscriptions: self.subscriptions.clone(), + } + } } struct PublisherAddrs { @@ -147,7 +161,7 @@ async fn spawn(fut: impl Future + Send + 'static) { tokio::spawn(fut); } -impl Publisher { +impl Publisher where ClientCtx: ExtractContext + From + Serialize + DeserializeOwned + Send + Sync + 'static { // TODO: Remove serde bounds here async fn start(self) -> io::Result { let mut connecting_publishers = tcp::listen("localhost:0", Json::default).await?; @@ -164,6 +178,7 @@ impl Publisher { let publisher = connecting_publishers.next().await.unwrap().unwrap(); info!(publisher.peer_addr = ?publisher.peer_addr(), "publisher connected."); + server::BaseChannel::with_defaults(publisher) .execute(self.serve()) .for_each(spawn) @@ -183,7 +198,6 @@ impl Publisher { tokio::spawn(async move { while let Some(conn) = connecting_subscribers.next().await { let subscriber_addr = conn.peer_addr().unwrap(); - let tarpc::client::NewClient { client: subscriber, dispatch, @@ -207,10 +221,13 @@ impl Publisher { async fn initialize_subscription( &mut self, subscriber_addr: SocketAddr, - subscriber: subscriber::SubscriberClient, + subscriber: subscriber::SubscriberClient, ) { // Populate the topics - if let Ok(topics) = subscriber.topics(context::current()).await { + if let Ok(topics) = subscriber + .topics(&mut ClientCtx::from(context::SharedContext::current())) + .await + { self.clients.lock().unwrap().insert( subscriber_addr, Subscription { @@ -262,16 +279,26 @@ impl Publisher { } } -impl publisher::Publisher for Publisher { - async fn publish(self, _: context::Context, topic: String, message: String) { +impl publisher::Publisher for Publisher where ClientCtx: ExtractContext + From + Send + Sync + 'static { + type Context = ClientCtx; + async fn publish(self, _: &mut Self::Context, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { None => return, Some(subscriptions) => subscriptions.clone(), }; let mut publications = Vec::new(); + for client in subscribers.values_mut() { - publications.push(client.receive(context::current(), topic.clone(), message.clone())); + publications.push(async { + client + .receive( + &mut ClientCtx::from(context::SharedContext::current()), + topic.clone(), + message.clone(), + ) + .await + }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until // subscribers ack. Of course, a lot would be different in a real pubsub :) @@ -316,7 +343,7 @@ pub fn init_tracing( async fn main() -> anyhow::Result<()> { let tracer_provider = init_tracing("Pub/Sub")?; - let addrs = Publisher { + let addrs = Publisher:: { clients: Arc::new(Mutex::new(HashMap::new())), subscriptions: Arc::new(RwLock::new(HashMap::new())), } @@ -342,26 +369,34 @@ async fn main() -> anyhow::Result<()> { .spawn(); publisher - .publish(context::current(), "calculus".into(), "sqrt(2)".into()) + .publish( + &mut SharedContext::current(), + "calculus".into(), + "sqrt(2)".into(), + ) .await?; publisher .publish( - context::current(), + &mut SharedContext::current(), "cool shorts".into(), "hello to all".into(), ) .await?; publisher - .publish(context::current(), "history".into(), "napoleon".to_string()) + .publish( + &mut SharedContext::current(), + "history".into(), + "napoleon".to_string(), + ) .await?; drop(_subscriber0); publisher .publish( - context::current(), + &mut SharedContext::current(), "cool shorts".into(), "hello to who?".into(), ) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index c328bd884..359b4af8b 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -5,9 +5,11 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; +use tarpc::context::{SharedContext}; use tarpc::{ - client, context, + ClientMessage, client, context, server::{self, Channel}, + transport, }; /// This is the service definition. It looks a lot like a trait definition. @@ -23,7 +25,8 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + type Context = SharedContext; + async fn hello(self, _: &mut Self::Context, name: String) -> String { format!("Hello, {name}!") } } @@ -34,7 +37,8 @@ async fn spawn(fut: impl Future + Send + 'static) { #[tokio::main] async fn main() -> anyhow::Result<()> { - let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); + let (client_transport, server_transport) = + transport::channel::unbounded(); let server = server::BaseChannel::with_defaults(server_transport); tokio::spawn(server.execute(HelloServer.serve()).for_each(spawn)); @@ -46,7 +50,9 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client.hello(context::current(), "Stim".to_string()).await?; + let hello = client + .hello(&mut context::SharedContext::current(), "Stim".to_string()) + .await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 968f76c17..c203bf0b8 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -10,6 +10,11 @@ use std::io::{self, BufReader, Cursor}; use std::net::{IpAddr, Ipv4Addr}; use std::sync::Arc; +use tarpc::context::{SharedContext}; +use tarpc::serde_transport as transport; +use tarpc::server::{BaseChannel, Channel}; +use tarpc::tokio_serde::formats::Bincode; +use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio_rustls::rustls::{ @@ -18,12 +23,6 @@ use tokio_rustls::rustls::{ }; use tokio_rustls::{TlsAcceptor, TlsConnector}; -use tarpc::context::Context; -use tarpc::serde_transport as transport; -use tarpc::server::{BaseChannel, Channel}; -use tarpc::tokio_serde::formats::Bincode; -use tarpc::tokio_util::codec::length_delimited::LengthDelimitedCodec; - #[tarpc::service] pub trait PingService { async fn ping() -> String; @@ -33,7 +32,8 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: Context) -> String { + type Context = SharedContext; + async fn ping(self, _: &mut Self::Context) -> String { "🔒".to_owned() } } @@ -146,7 +146,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); let answer = PingServiceClient::new(Default::default(), transport) .spawn() - .ping(tarpc::context::current()) + .ping(&mut SharedContext::current()) .await?; println!("ping answer: {answer}"); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 79a7026c0..525a16a47 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -19,6 +19,8 @@ use std::{ atomic::{AtomicBool, Ordering}, }, }; +use std::marker::PhantomData; +use tarpc::context::{ExtractContext, SharedContext}; use tarpc::{ ClientMessage, RequestName, Response, ServerError, Transport, client::{ @@ -56,23 +58,27 @@ pub mod double { struct AddServer; impl AddService for AddServer { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + type Context = SharedContext; + async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } } #[derive(Clone)] -struct DoubleServer { - add_client: add::AddClient, +struct DoubleServer { + add_client: add::AddClient, + ghost: PhantomData } -impl DoubleService for DoubleServer +impl DoubleService for DoubleServer where - Stub: AddStub + Clone + Send + Sync + 'static, + Stub: AddStub + Clone + Send + Sync + 'static, + ClientCtx: From + Send + Sync + 'static { - async fn double(self, _: context::Context, x: i32) -> Result { + type Context = SharedContext; + async fn double(self, _: &mut Self::Context, x: i32) -> Result { self.add_client - .add(context::current(), x, x) + .add(&mut ClientCtx::from(context::SharedContext::current()), x, x) .await .map_err(|e| e.to_string()) } @@ -123,15 +129,19 @@ where Ok((listener, addr)) } -fn make_stub( - backends: [impl Transport>, Response> + Send + Sync + 'static; N], +fn make_stub( + backends: [impl Transport>, Response> + + Send + + Sync + + 'static; N], ) -> retry::Retry< impl Fn(&Result, u32) -> bool + Clone, - load_balance::RoundRobin, Resp>>, + load_balance::RoundRobin, Resp, ClientCtx>>, > where Req: RequestName + Send + Sync + 'static, Resp: Send + Sync + 'static, + ClientCtx: ExtractContext + From + Send + Sync + 'static { let stub = load_balance::RoundRobin::new( backends @@ -186,16 +196,20 @@ async fn main() -> anyhow::Result<()> { .filter_map(|r| future::ready(r.ok())); let addr = double_listener.get_ref().local_addr(); let double_server = double_listener.map(BaseChannel::with_defaults).take(1); - let server = DoubleServer { add_client }.serve(); + let server = DoubleServer::<_, SharedContext> { add_client, ghost: PhantomData }.serve(); tokio::spawn(spawn_incoming(double_server.execute(server))); let to_double_server = tarpc::serde_transport::tcp::connect(addr, Json::default).await?; let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); - let ctx = context::current(); for _ in 1..=5 { - tracing::info!("{:?}", double_client.double(ctx, 1).await?); + tracing::info!( + "{:?}", + double_client + .double(&mut context::SharedContext::current(), 1) + .await? + ); } tracer_provider.shutdown()?; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index 3cf9ff07a..40ba7e461 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -9,6 +9,7 @@ mod in_flight_requests; pub mod stub; +use crate::context::{ExtractContext, SharedContext}; use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -29,6 +30,7 @@ use std::{ }, time::SystemTime, }; +use std::marker::PhantomData; use tokio::sync::{mpsc, oneshot}; use tracing::Span; @@ -95,27 +97,32 @@ const _CHECK_USIZE: () = assert!( /// Handles communication from the client to request dispatch. #[derive(Debug)] -pub struct Channel { +pub struct Channel { to_dispatch: mpsc::Sender>, /// Channel to send a cancel message to the dispatcher. cancellation: RequestCancellation, /// The ID to use for the next request to stage. next_request_id: Arc, + + ///TODO: Document + ghost: PhantomData } -impl Clone for Channel { +impl Clone for Channel { fn clone(&self) -> Self { Self { to_dispatch: self.to_dispatch.clone(), cancellation: self.cancellation.clone(), next_request_id: self.next_request_id.clone(), + ghost: PhantomData } } } -impl Channel +impl Channel where Req: RequestName, + ClientCtx: ExtractContext { /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that /// resolves to the response. @@ -124,19 +131,24 @@ where skip(self, ctx, request), fields( rpc.trace_id = tracing::field::Empty, - rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + ctx.deadline.time_until()), + rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + ctx.extract().deadline.time_until()), otel.kind = "client", otel.name = %request.name()) )] - pub async fn call(&self, mut ctx: context::Context, request: Req) -> Result { + pub async fn call( + &self, + ctx: &mut ClientCtx, + request: Req, + ) -> Result { let span = Span::current(); - ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + let mut shared_context = ctx.extract(); + shared_context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( "OpenTelemetry subscriber not installed; making unsampled child context." ); - ctx.trace_context.new_child() + shared_context.trace_context.new_child() }); - span.record("rpc.trace_id", tracing::field::display(ctx.trace_id())); + span.record("rpc.trace_id", tracing::field::display(shared_context.trace_id())); let (response_completion, mut response) = oneshot::channel(); let request_id = u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); @@ -153,7 +165,7 @@ where }; self.to_dispatch .send(DispatchRequest { - ctx, + ctx: shared_context, span, request_id, request, @@ -161,14 +173,19 @@ where }) .await .map_err(|mpsc::error::SendError(_)| RpcError::Shutdown)?; - response_guard.response().await + + let (response_ctx, r) = response_guard.response().await?; + + ctx.update(response_ctx); + + Ok(r) } } /// A server response that is completed by request dispatch when the corresponding response /// arrives off the wire. struct ResponseGuard<'a, Resp> { - response: &'a mut oneshot::Receiver>, + response: &'a mut oneshot::Receiver>, cancellation: &'a RequestCancellation, request_id: u64, cancel: bool, @@ -196,7 +213,7 @@ pub enum RpcError { } impl ResponseGuard<'_, Resp> { - async fn response(mut self) -> Result { + async fn response(mut self) -> Result<(SharedContext, Resp), RpcError> { let response = (&mut self.response).await; // Cancel drop logic once a response has been received. self.cancel = false; @@ -234,12 +251,12 @@ impl Drop for ResponseGuard<'_, Resp> { /// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the /// channel. -pub fn new( +pub fn new( config: Config, transport: C, -) -> NewClient, RequestDispatch> +) -> NewClient, RequestDispatch> where - C: Transport, Response>, + C: Transport, Response>, { let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer); let (cancellation, canceled_requests) = cancellations(); @@ -249,6 +266,7 @@ where to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), + ghost: PhantomData }, dispatch: RequestDispatch { config, @@ -257,6 +275,7 @@ where in_flight_requests: InFlightRequests::default(), pending_requests, terminal_error: None, + ghost: PhantomData }, } } @@ -266,7 +285,7 @@ where #[must_use] #[pin_project()] #[derive(Debug)] -pub struct RequestDispatch { +pub struct RequestDispatch { /// Writes requests to the wire and reads responses off the wire. #[pin] transport: Fuse, @@ -275,7 +294,7 @@ pub struct RequestDispatch { /// Requests that were dropped. canceled_requests: CanceledRequests, /// Requests already written to the wire that haven't yet received responses. - in_flight_requests: InFlightRequests>, + in_flight_requests: InFlightRequests, /// Configures limits to prevent unlimited resource usage. config: Config, /// Produces errors that can be sent in response to any unprocessed requests at the time @@ -283,15 +302,18 @@ pub struct RequestDispatch { /// RequestDispatch::poll, which relies on downcasting the Any to a concrete error type /// determined within the poll function. terminal_error: Option>, + + ghost: PhantomData, } -impl RequestDispatch +impl RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, + ClientCtx: ExtractContext + From { fn in_flight_requests<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut InFlightRequests> { + ) -> &'a mut InFlightRequests { self.as_mut().project().in_flight_requests } @@ -308,7 +330,10 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send(self: &mut Pin<&mut Self>, message: ClientMessage) -> Result<(), C::Error> { + fn start_send( + self: &mut Pin<&mut Self>, + message: ClientMessage, + ) -> Result<(), C::Error> { self.transport_pin_mut().start_send(message) } @@ -457,7 +482,7 @@ where fn poll_next_cancellation( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>>> { + ) -> Poll>>> { ready!(self.ensure_writeable(cx)?); loop { @@ -510,16 +535,24 @@ where // poll_next_request only returns Ready if there is room to buffer another request. // Therefore, we can call write_request without fear of erroring due to a full // buffer. + + let trace_context = ctx.trace_context; + let deadline = ctx.deadline; + let request = ClientMessage::Request(Request { id: request_id, message: request, - context: context::Context { - deadline: ctx.deadline, - trace_context: ctx.trace_context, - }, + context: ctx.into(), }); + self.in_flight_requests() - .insert_request(request_id, ctx, span.clone(), response_completion) + .insert_request( + request_id, + trace_context, + deadline, + span.clone(), + response_completion, + ) .expect("Request IDs should be unique"); match self.start_send(request) { Ok(()) => tracing::debug!("SendRequest"), @@ -541,14 +574,15 @@ where self: &mut Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>>> { - let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) { - Some(triple) => triple, - None => return Poll::Ready(None), - }; + let (trace_context, span, request_id) = + match ready!(self.as_mut().poll_next_cancellation(cx)?) { + Some(triple) => triple, + None => return Poll::Ready(None), + }; let _entered = span.enter(); let cancel = ClientMessage::Cancel { - trace_context: context.trace_context, + trace_context, request_id, }; self.start_send(cancel) @@ -558,10 +592,10 @@ where } /// Sends a server response to the client task that initiated the associated request. - fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { + fn complete(mut self: Pin<&mut Self>, response: Response) -> bool { if let Some(span) = self.in_flight_requests().complete_request( response.request_id, - response.message.map_err(RpcError::Server), + response.message.map_err(RpcError::Server).map(|m| (response.context.extract(), m)), ) { let _entered = span.enter(); tracing::debug!("ReceiveResponse"); @@ -636,9 +670,10 @@ where } } -impl Future for RequestDispatch +impl Future for RequestDispatch where - C: Transport, Response>, + C: Transport, Response>, + ClientCtx: ExtractContext + From { type Output = Result<(), ChannelError>; @@ -669,11 +704,11 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { - pub ctx: context::Context, + pub ctx: context::SharedContext, ///TODO: <-- this should be a &mut ClientContext pub span: Span, pub request_id: u64, pub request: Req, - pub response_completion: oneshot::Sender>, + pub response_completion: oneshot::Sender>, } #[cfg(test)] @@ -681,10 +716,10 @@ mod tests { use super::{ Channel, DispatchRequest, RequestDispatch, ResponseGuard, RpcError, cancellations, }; + use crate::context::{SharedContext}; use crate::{ ChannelError, ClientMessage, Response, client::{Config, in_flight_requests::InFlightRequests}, - context::{self, current}, transport::{self, channel::UnboundedChannel}, }; use assert_matches::assert_matches; @@ -708,23 +743,32 @@ mod tests { #[tokio::test] async fn response_completes_request_future() { - let (mut dispatch, mut _channel, mut server_channel) = set_up(); + let (mut dispatch, _channel, mut server_channel) = set_up(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); + let context = SharedContext::current(); + dispatch .in_flight_requests - .insert_request(0, context::current(), Span::current(), tx) + .insert_request( + 0, + context.trace_context, + context.deadline, + Span::current(), + tx, + ) .unwrap(); server_channel .send(Response { request_id: 0, + context: SharedContext::current(), message: Ok("Resp".into()), }) .await .unwrap(); assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending); - assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp"); + assert_matches!(rx.try_recv(), Ok(Ok((_, resp))) if resp == "Resp"); } #[tokio::test] @@ -746,11 +790,7 @@ mod tests { async fn dispatch_response_doesnt_cancel_after_complete() { let (cancellation, mut canceled_requests) = cancellations(); let (tx, mut response) = oneshot::channel(); - tx.send(Ok(Response { - request_id: 0, - message: Ok("well done"), - })) - .unwrap(); + tx.send(Ok((SharedContext::current(), "well done"))).unwrap(); // resp's drop() is run, but should not send a cancel message. ResponseGuard { response: &mut response, @@ -768,7 +808,7 @@ mod tests { #[tokio::test] async fn stage_request() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -798,6 +838,7 @@ mod tests { &mut server_channel, Response { request_id: 0, + context: SharedContext::current(), message: Ok("hello".into()), }, ) @@ -808,7 +849,7 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_before_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -824,7 +865,7 @@ mod tests { #[allow(unstable_name_collisions)] #[tokio::test] async fn stage_request_response_future_dropped_is_canceled_after_sending() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -845,7 +886,7 @@ mod tests { #[tokio::test] async fn stage_request_response_closed_skipped() { - let (mut dispatch, mut channel, _server_channel) = set_up(); + let (mut dispatch, mut channel, _server_channel) = set_up::(); let cx = &mut Context::from_waker(noop_waker_ref()); let (tx, mut rx) = oneshot::channel(); @@ -861,7 +902,7 @@ mod tests { #[tokio::test] async fn test_permit_before_transport_error() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (mut dispatch, mut channel, mut cx) = set_up_always_err(TransportError::Flush); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(TransportError::Flush); let (tx, mut rx) = oneshot::channel(); // reserve succeeds let permit = reserve_for_send(&mut channel, tx, &mut rx).await; @@ -878,17 +919,19 @@ mod tests { #[tokio::test] async fn test_shutdown() { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); - let (dispatch, channel, _server_channel) = set_up(); + let (dispatch, channel, _server_channel) = set_up::(); drop(dispatch); // error on send - let resp = channel.call(current(), "hi".to_string()).await; + let resp = channel + .call(&mut SharedContext::current(), "hi".to_string()) + .await; assert_matches!(resp, Err(RpcError::Shutdown)); } #[tokio::test] async fn test_transport_error_write() { let cause = TransportError::Write; - let (mut dispatch, mut channel, mut cx) = set_up_always_err(cause); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); let resp = send_request(&mut channel, "hi", tx, &mut rx).await; @@ -911,7 +954,7 @@ mod tests { #[tokio::test] async fn test_transport_error_read() { let cause = TransportError::Read; - let (mut dispatch, mut channel, mut cx) = set_up_always_err(cause); + let (mut dispatch, mut channel, mut cx) = set_up_always_err::(cause); let (tx, mut rx) = oneshot::channel(); let resp = send_request(&mut channel, "hi", tx, &mut rx).await; assert_eq!( @@ -928,7 +971,7 @@ mod tests { #[tokio::test] async fn test_transport_error_ready() { let cause = TransportError::Ready; - let (mut dispatch, _, mut cx) = set_up_always_err(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Ready(Arc::new(cause)))) @@ -938,7 +981,7 @@ mod tests { #[tokio::test] async fn test_transport_error_flush() { let cause = TransportError::Flush; - let (mut dispatch, _, mut cx) = set_up_always_err(cause); + let (mut dispatch, _, mut cx) = set_up_always_err::(cause); assert_eq!( dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Flush(Arc::new(cause)))) @@ -948,7 +991,7 @@ mod tests { #[tokio::test] async fn test_transport_error_close() { let cause = TransportError::Close; - let (mut dispatch, channel, mut cx) = set_up_always_err(cause); + let (mut dispatch, channel, mut cx) = set_up_always_err::(cause); drop(channel); assert_eq!( dispatch.as_mut().poll(&mut cx), @@ -957,34 +1000,36 @@ mod tests { } /// Sets up a RequestDispatch with a transport that always errors. - fn set_up_always_err( + fn set_up_always_err( cause: TransportError, ) -> ( - Pin>>>, - Channel, + Pin>>>, + Channel, Context<'static>, ) { let (to_dispatch, pending_requests) = mpsc::channel(1); let (cancellation, canceled_requests) = cancellations(); - let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); - let dispatch = Box::pin(RequestDispatch:: { + let transport: AlwaysErrorTransport = AlwaysErrorTransport(cause, PhantomData); + let dispatch = Box::pin(RequestDispatch:: { transport: transport.fuse(), pending_requests, canceled_requests, in_flight_requests: InFlightRequests::default(), config: Config::default(), terminal_error: None, + ghost: PhantomData }); let channel = Channel { to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), + ghost: PhantomData }; let cx = Context::from_waker(noop_waker_ref()); (dispatch, channel, cx) } - struct AlwaysErrorTransport(TransportError, PhantomData); + struct AlwaysErrorTransport(TransportError, PhantomData<( I, ClientCtx)>); #[derive(Debug, Error, PartialEq, Eq, Clone, Copy)] enum TransportError { @@ -1001,7 +1046,7 @@ mod tests { } } - impl Sink for AlwaysErrorTransport { + impl Sink for AlwaysErrorTransport { type Error = TransportError; fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { match self.0 { @@ -1033,8 +1078,8 @@ mod tests { } } - impl Stream for AlwaysErrorTransport { - type Item = Result, TransportError>; + impl Stream for AlwaysErrorTransport { + type Item = Result, TransportError>; fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { if matches!(self.0, TransportError::Read) { Poll::Ready(Some(Err(self.0))) @@ -1044,18 +1089,22 @@ mod tests { } } - fn set_up() -> ( + fn set_up() -> ( Pin< Box< RequestDispatch< String, String, - UnboundedChannel, ClientMessage>, + ClientCtx, + UnboundedChannel< + Response, + ClientMessage, + >, >, >, >, - Channel, - UnboundedChannel, Response>, + Channel, + UnboundedChannel, Response>, ) { let _ = tracing_subscriber::fmt().with_test_writer().try_init(); @@ -1063,60 +1112,36 @@ mod tests { let (cancellation, canceled_requests) = cancellations(); let (client_channel, server_channel) = transport::channel::unbounded(); - let dispatch = RequestDispatch:: { + let dispatch = RequestDispatch:: { transport: client_channel.fuse(), pending_requests, canceled_requests, in_flight_requests: InFlightRequests::default(), config: Config::default(), terminal_error: None, + ghost: PhantomData }; let channel = Channel { to_dispatch, cancellation, next_request_id: Arc::new(AtomicUsize::new(0)), + ghost: PhantomData }; (Box::pin(dispatch), channel, server_channel) } - async fn reserve_for_send<'a>( - channel: &'a mut Channel, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, - ) -> impl FnOnce(&str) -> ResponseGuard<'a, String> { - let permit = channel.to_dispatch.reserve().await.unwrap(); - |request| { - let request_id = - u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); - let request = DispatchRequest { - ctx: context::current(), - span: Span::current(), - request_id, - request: request.to_string(), - response_completion, - }; - permit.send(request); - ResponseGuard { - response, - cancellation: &channel.cancellation, - request_id, - cancel: true, - } - } - } - - async fn send_request<'a>( - channel: &'a mut Channel, + async fn send_request<'a, ClientCtx>( + channel: &'a mut Channel, request: &str, - response_completion: oneshot::Sender>, - response: &'a mut oneshot::Receiver>, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, ) -> ResponseGuard<'a, String> { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: context::current(), + ctx: SharedContext::current(), span: Span::current(), request_id, request: request.to_string(), @@ -1132,9 +1157,38 @@ mod tests { response_guard } - async fn send_response( - channel: &mut UnboundedChannel, Response>, - response: Response, + async fn reserve_for_send<'a, ClientCtx>( + channel: &'a mut Channel, + response_completion: oneshot::Sender>, + response: &'a mut oneshot::Receiver>, + ) -> impl FnOnce(&str) -> ResponseGuard<'a, String> { + let permit = channel.to_dispatch.reserve().await.unwrap(); + |request| { + let request_id = + u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); + let request = DispatchRequest { + ctx: SharedContext::current(), + span: Span::current(), + request_id, + request: request.to_string(), + response_completion, + }; + permit.send(request); + ResponseGuard { + response, + cancellation: &channel.cancellation, + request_id, + cancel: true, + } + } + } + + async fn send_response( + channel: &mut UnboundedChannel< + ClientMessage, + Response, + >, + response: Response, ) { channel.send(response).await.unwrap(); } diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 1776a74a0..5b648098b 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -1,8 +1,6 @@ -use crate::{ - context, - util::{Compact, TimeUntil}, -}; +use crate::{trace, util::{Compact, TimeUntil}}; use fnv::FnvHashMap; +use std::time::Instant; use std::{ collections::hash_map, task::{Context, Poll}, @@ -10,6 +8,8 @@ use std::{ use tokio::sync::oneshot; use tokio_util::time::delay_queue::{self, DelayQueue}; use tracing::Span; +use crate::client::RpcError; +use crate::context::{SharedContext}; /// Requests already written to the wire that haven't yet received responses. #[derive(Debug)] @@ -29,9 +29,9 @@ impl Default for InFlightRequests { #[derive(Debug)] struct RequestData { - ctx: context::Context, + ctx: trace::Context, span: Span, - response_completion: oneshot::Sender, + response_completion: oneshot::Sender>, /// The key to remove the timer for the request's deadline. deadline_key: delay_queue::Key, } @@ -56,13 +56,14 @@ impl InFlightRequests { pub fn insert_request( &mut self, request_id: u64, - ctx: context::Context, + ctx: trace::Context, + deadline: Instant, span: Span, - response_completion: oneshot::Sender, + response_completion: oneshot::Sender>, ) -> Result<(), AlreadyExistsError> { match self.request_data.entry(request_id) { hash_map::Entry::Vacant(vacant) => { - let timeout = ctx.deadline.time_until(); + let timeout = deadline.time_until(); let deadline_key = self.deadlines.insert(request_id, timeout); vacant.insert(RequestData { ctx, @@ -76,8 +77,8 @@ impl InFlightRequests { } } - /// Removes a request without aborting. Returns true iff the request was found. - pub fn complete_request(&mut self, request_id: u64, result: Res) -> Option { + /// Removes a request without aborting. Returns true if the request was found. + pub fn complete_request(&mut self, request_id: u64, result: Result<(SharedContext, Res), RpcError>) -> Option { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); @@ -95,7 +96,7 @@ impl InFlightRequests { /// Returns Spans for all completes requests. pub fn complete_all_requests<'a>( &'a mut self, - mut result: impl FnMut() -> Res + 'a, + mut result: impl FnMut() -> Result<(SharedContext, Res), RpcError> + 'a, ) -> impl Iterator + 'a { self.deadlines.clear(); self.request_data.drain().map(move |(_, request_data)| { @@ -106,7 +107,7 @@ impl InFlightRequests { /// Cancels a request without completing (typically used when a request handle was dropped /// before the request completed). - pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> { + pub fn cancel_request(&mut self, request_id: u64) -> Option<(trace::Context, Span)> { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); @@ -121,7 +122,7 @@ impl InFlightRequests { pub fn poll_expired( &mut self, cx: &mut Context, - expired_error: impl Fn() -> Res, + expired_error: impl Fn() -> Result<(SharedContext, Res), RpcError>, ) -> Poll> { self.deadlines.poll_expired(cx).map(|expired| { let request_id = expired?.into_inner(); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 85746b7f2..992f6d611 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -3,9 +3,9 @@ use crate::{ RequestName, client::{Channel, RpcError}, - context, server::Serve, }; +use crate::context::{ExtractContext, SharedContext}; pub mod load_balance; pub mod retry; @@ -23,30 +23,57 @@ pub trait Stub { /// The service response type. type Resp; + ///TODO: document + type ClientCtx; + /// Calls a remote service. - async fn call(&self, ctx: context::Context, request: Self::Req) - -> Result; + async fn call( + &self, + ctx: &mut Self::ClientCtx, + request: Self::Req, + ) -> Result; } -impl Stub for Channel +impl Stub for Channel where Req: RequestName, + ClientCtx: ExtractContext { type Req = Req; type Resp = Resp; + type ClientCtx = ClientCtx; - async fn call(&self, ctx: context::Context, request: Req) -> Result { + async fn call( + &self, + ctx: &mut Self::ClientCtx, + request: Req, + ) -> Result { Self::call(self, ctx, request).await } } impl Stub for S where - S: Serve + Clone, + S: Serve + Clone, { type Req = S::Req; type Resp = S::Resp; - async fn call(&self, ctx: context::Context, req: Self::Req) -> Result { - self.clone().serve(ctx, req).await.map_err(RpcError::Server) + type ClientCtx = SharedContext; + async fn call( + &self, + ctx: &mut Self::ClientCtx, + req: Self::Req, + ) -> Result { + let mut server_ctx = ctx.clone(); + + let res = self + .clone() + .serve(&mut server_ctx, req) + .await + .map_err(RpcError::Server); + + *ctx = server_ctx; + + res } } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index d28a3c137..60efafc91 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -7,7 +7,6 @@ pub use round_robin::RoundRobin; mod round_robin { use crate::{ client::{RpcError, stub}, - context, }; use cycle::AtomicCycle; @@ -17,10 +16,11 @@ mod round_robin { { type Req = Stub::Req; type Resp = Stub::Resp; + type ClientCtx = Stub::ClientCtx; async fn call( &self, - ctx: context::Context, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result { let next = self.stubs.next(); @@ -99,8 +99,7 @@ mod round_robin { /// the same stub. mod consistent_hash { use crate::{ - client::{RpcError, stub}, - context, + client::{RpcError, stub} }; use std::{ collections::hash_map::RandomState, @@ -116,10 +115,11 @@ mod consistent_hash { { type Req = Stub::Req; type Resp = Stub::Resp; + type ClientCtx = Stub::ClientCtx; async fn call( &self, - ctx: context::Context, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result { let index = usize::try_from(self.hasher.hash_one(&request) % self.stubs_len).expect( @@ -200,13 +200,19 @@ mod consistent_hash { )?; for _ in 0..2 { - let resp = stub.call(context::current(), 'a').await?; + let resp = stub + .call(&mut context::SharedContext::current(), 'a') + .await?; assert_eq!(resp, 1); - let resp = stub.call(context::current(), 'b').await?; + let resp = stub + .call(&mut context::SharedContext::current(), 'b') + .await?; assert_eq!(resp, 2); - let resp = stub.call(context::current(), 'c').await?; + let resp = stub + .call(&mut context::SharedContext::current(), 'c') + .await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index 145c14c1f..577ef5362 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -1,16 +1,17 @@ use crate::{ RequestName, ServerError, client::{RpcError, stub::Stub}, - context, }; use std::{collections::HashMap, hash::Hash, io}; +use std::marker::PhantomData; /// A mock stub that returns user-specified responses. -pub struct Mock { +pub struct Mock { responses: HashMap, + ghost: PhantomData } -impl Mock +impl Mock where Req: Eq + Hash, { @@ -18,19 +19,25 @@ where pub fn new(responses: [(Req, Resp); N]) -> Self { Self { responses: HashMap::from(responses), + ghost: PhantomData } } } -impl Stub for Mock +impl Stub for Mock where Req: Eq + Hash + RequestName, Resp: Clone, { type Req = Req; type Resp = Resp; + type ClientCtx = ServerCtx; - async fn call(&self, _: context::Context, request: Self::Req) -> Result { + async fn call( + &self, + _: &mut Self::ClientCtx, + request: Self::Req, + ) -> Result { self.responses .get(&request) .cloned() diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index a07b05fc5..5499f60e4 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -3,7 +3,6 @@ use crate::{ RequestName, client::{RpcError, stub}, - context, }; use std::sync::Arc; @@ -15,10 +14,11 @@ where { type Req = Req; type Resp = Stub::Resp; + type ClientCtx = Stub::ClientCtx; async fn call( &self, - ctx: context::Context, + ctx: &mut Self::ClientCtx, request: Self::Req, ) -> Result { let request = Arc::new(request); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 8e77cf223..a0f697f81 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -21,10 +21,9 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// /// The context should not be stored directly in a server implementation, because the context will /// be different for each request in scope. -#[derive(Clone, Copy, Debug)] -#[non_exhaustive] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Context { +pub struct SharedContext { /// When the client expects the request to be complete by. The server should cancel the request /// if it is not complete by this time. #[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))] @@ -35,7 +34,25 @@ pub struct Context { /// When a service handles a request by making requests itself, those requests should /// include the same `trace_id` as that included on the original request. This way, /// users can trace related actions across a distributed system. - pub trace_context: trace::Context, + pub trace_context: trace::Context +} + +///TODO: Document +pub trait ExtractContext { + ///TODO: Document + fn extract(&self) -> Ctx; + ///TODO: Document + fn update(&mut self, value: Ctx); +} + +impl ExtractContext for T where T: Clone { + fn extract(&self) -> T { + self.clone() + } + + fn update(&mut self, value: T) { + *self = value + } } #[cfg(feature = "serde1")] @@ -91,17 +108,12 @@ mod absolute_to_relative_time { } } -assert_impl_all!(Context: Send, Sync); +assert_impl_all!(SharedContext: Send, Sync); fn ten_seconds_from_now() -> Instant { Instant::now() + Duration::from_secs(10) } -/// Returns the context for the current request, or a default Context if no request is active. -pub fn current() -> Context { - Context::current() -} - #[derive(Clone)] struct Deadline(Instant); @@ -111,7 +123,7 @@ impl Default for Deadline { } } -impl Context { +impl SharedContext { /// Returns the context for the current request, or a default Context if no request is active. pub fn current() -> Self { let span = tracing::Span::current(); @@ -137,11 +149,11 @@ impl Context { pub(crate) trait SpanExt { /// Sets the given context on this span. Newly-created spans will be children of the given /// context's trace context. - fn set_context(&self, context: &Context); + fn set_context(&self, context: &SharedContext); } impl SpanExt for tracing::Span { - fn set_context(&self, context: &Context) { + fn set_context(&self, context: &SharedContext) { self.set_parent( opentelemetry::Context::new() .with_remote_span_context(opentelemetry::trace::SpanContext::new( diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 7e1944305..fc79e3056 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -124,8 +124,9 @@ //! struct HelloServer; //! //! impl World for HelloServer { +//! type Context = context::SharedContext; //! // Each defined rpc generates an async fn that serves the RPC -//! async fn hello(self, _: context::Context, name: String) -> String { +//! async fn hello(self, _: &mut Self::Context, name: String) -> String { //! format!("Hello, {name}!") //! } //! } @@ -142,7 +143,10 @@ //! # prelude::*, //! # }; //! # use tarpc::{ +//! # ClientMessage, //! # client, context, +//! # context::{SharedContext}, +//! # transport::channel, //! # server::{self, Channel}, //! # }; //! # // This is the service definition. It looks a lot like a trait definition. @@ -157,8 +161,9 @@ //! # #[derive(Clone)] //! # struct HelloServer; //! # impl World for HelloServer { -//! // Each defined rpc generates an async fn that serves the RPC -//! # async fn hello(self, _: context::Context, name: String) -> String { +//! # type Context = SharedContext; +//! # // Each defined rpc generates an async fn that serves the RPC +//! # async fn hello(self, _: &mut Self::Context, name: String) -> String { //! # format!("Hello, {name}!") //! # } //! # } @@ -167,8 +172,8 @@ //! # #[cfg(feature = "tokio1")] //! #[tokio::main] //! async fn main() -> anyhow::Result<()> { -//! let (client_transport, server_transport) = tarpc::transport::channel::unbounded(); -//! +//! use futures::future::Shared; +//! let (client_transport, server_transport) = channel::unbounded(); //! let server = server::BaseChannel::with_defaults(server_transport); //! tokio::spawn( //! server.execute(HelloServer.serve()) @@ -179,12 +184,13 @@ //! //! // WorldClient is generated by the #[tarpc::service] attribute. It has a constructor `new` //! // that takes a config and any Transport as input. -//! let mut client = WorldClient::new(client::Config::default(), client_transport).spawn(); +//! let mut client = WorldClient::::new(client::Config::default(), client_transport).spawn(); //! //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // specifies a deadline and trace information which can be helpful in debugging requests. -//! let hello = client.hello(context::current(), "Stim".to_string()).await?; +//! let mut context = context::SharedContext::current(); +//! let hello = client.hello(&mut context, "Stim".to_string()).await?; //! //! println!("{hello}"); //! @@ -197,7 +203,7 @@ //! Use `cargo doc` as you normally would to see the documentation created for all //! items expanded by a `service!` invocation. -#![deny(missing_docs)] +#![deny(missing_docs, warnings, unused, dead_code)] #![allow(clippy::type_complexity)] #![cfg_attr(docsrs, feature(doc_cfg))] @@ -250,17 +256,19 @@ pub(crate) mod util; pub use crate::transport::sealed::Transport; +use crate::context::SharedContext; +use std::ops::Deref; use std::{any::Any, error::Error, io, sync::Arc, time::Instant}; /// A message from a client to a server. #[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[non_exhaustive] -pub enum ClientMessage { +pub enum ClientMessage { /// A request initiated by a user. The server responds to a request by invoking a /// service-provided request handler. The handler completes with a [`response`](Response), which /// the server sends back to the client. - Request(Request), + Request(Request), /// A command to cancel an in-flight request, automatically sent by the client when a response /// future is dropped. /// @@ -278,16 +286,43 @@ pub enum ClientMessage { }, } +impl ClientMessage { + /// Creates a new ClientMessage by mapping the context using the provided function. + pub fn map_context(self, f: F) -> ClientMessage + where + F: FnOnce(Ctx) -> Ctx2, + { + match self { + ClientMessage::Request(Request { + context, + id, + message, + }) => ClientMessage::Request(Request { + context: f(context), + id, + message, + }), + ClientMessage::Cancel { + trace_context, + request_id, + } => ClientMessage::Cancel { + trace_context, + request_id, + }, + } + } +} + /// A request from a client to a server. -#[derive(Clone, Copy, Debug)] +#[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Request { +pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. - pub context: context::Context, + pub context: Ctx, /// Uniquely identifies the request across all requests sent over a single channel. pub id: u64, /// The request body. - pub message: T, + pub message: Req, } /// Implemented by the request types generated by tarpc::service. @@ -360,13 +395,29 @@ impl RequestName for u64 { /// A response from a server to a client. #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Response { +pub struct Response { /// The ID of the request being responded to. pub request_id: u64, + /// Trace context, deadline, and other cross-cutting concerns. + pub context: Ctx, /// The response body, or an error if the request failed. pub message: Result, } +impl Response { + /// Creates a modified Response by mapping the context using the provided function. + pub fn map_context(self, f: F) -> Response + where + F: FnOnce(Ctx) -> Ctx2, + { + Response { + request_id: self.request_id, + context: f(self.context), + message: self.message, + } + } +} + /// An error indicating the server aborted the request early, e.g., due to request throttling. #[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)] #[error("{kind:?}: {detail}")] @@ -490,7 +541,10 @@ impl ServerError { } } -impl Request { +impl Request +where + Ctx: Deref, +{ /// Returns the deadline for this request. pub fn deadline(&self) -> &Instant { &self.context.deadline diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index da3b3ae21..7d345a203 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -9,7 +9,7 @@ use crate::{ ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport, cancellations::{CanceledRequests, RequestCancellation, cancellations}, - context::{self, SpanExt}, + context::{SpanExt}, trace, util::TimeUntil, }; @@ -27,6 +27,7 @@ use std::{ convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin, sync::Arc, time::SystemTime, }; use tracing::{Span, info_span, instrument::Instrument}; +use crate::context::{ExtractContext, SharedContext}; mod in_flight_requests; pub mod request_hook; @@ -58,9 +59,10 @@ impl Default for Config { impl Config { /// Returns a channel backed by `transport` and configured with `self`. - pub fn channel(self, transport: T) -> BaseChannel + pub fn channel(self, transport: T) -> BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext, { BaseChannel::new(self, transport) } @@ -69,6 +71,9 @@ impl Config { /// Equivalent to a `FnOnce(Req) -> impl Future`. #[allow(async_fn_in_trait)] pub trait Serve { + ///TODO document + type ServerCtx; + /// Type of request. type Req: RequestName; @@ -76,17 +81,21 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: context::Context, req: Self::Req) -> Result; + async fn serve( + self, + ctx: &mut Self::ServerCtx, + req: Self::Req, + ) -> Result; } /// A Serve wrapper around a Fn. #[derive(Debug)] -pub struct ServeFn { +pub struct ServeFn { f: F, - data: PhantomData Resp>, + data: PhantomData<(Req, Resp, ServerCtx)>, } -impl Clone for ServeFn +impl Clone for ServeFn where F: Clone, { @@ -98,14 +107,13 @@ where } } -impl Copy for ServeFn where F: Copy {} +impl Copy for ServeFn where F: Copy {} /// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future>`. -pub fn serve(f: F) -> ServeFn +pub fn serve(f: F) -> ServeFn where - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + for<'a> F: FnOnce(&'a mut ServerCtx, Req) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -113,16 +121,19 @@ where } } -impl Serve for ServeFn +impl Serve for ServeFn where Req: RequestName, - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + for<'a> F: FnOnce( + &'a mut ServerCtx, + Req, + ) -> Pin> + 'a + Send>>, { + type ServerCtx = ServerCtx; type Req = Req; type Resp = Resp; - async fn serve(self, ctx: context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut ServerCtx, req: Req) -> Result { (self.f)(ctx, req).await } } @@ -138,7 +149,7 @@ where /// messages. Instead, it internally handles them by cancelling corresponding requests (removing /// the corresponding in-flight requests and aborting their handlers). #[pin_project] -pub struct BaseChannel { +pub struct BaseChannel { config: Config, /// Writes responses to the wire and reads requests off the wire. #[pin] @@ -151,12 +162,13 @@ pub struct BaseChannel { /// Holds data necessary to clean up in-flight requests. in_flight_requests: InFlightRequests, /// Types the request and response. - ghost: PhantomData<(fn() -> Req, fn(Resp))>, + ghost: PhantomData<(Req, Resp, ServeCtx)>, } -impl BaseChannel +impl BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext { /// Creates a new channel backed by `transport` and configured with `config`. pub fn new(config: Config, transport: T) -> Self { @@ -202,28 +214,29 @@ where fn start_request( mut self: Pin<&mut Self>, - mut request: Request, - ) -> Result, AlreadyExistsError> { + request: Request, + ) -> Result, AlreadyExistsError> { + let mut shared_context = request.context.extract(); let span = info_span!( "RPC", - rpc.trace_id = %request.context.trace_id(), - rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + request.context.deadline.time_until()), + rpc.trace_id = %shared_context.trace_id(), + rpc.deadline = %humantime::format_rfc3339(SystemTime::now() + shared_context.deadline.time_until()), otel.kind = "server", otel.name = tracing::field::Empty, ); - span.set_context(&request.context); - request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { + span.set_context(&shared_context); + shared_context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( "OpenTelemetry subscriber not installed; making unsampled \ child context." ); - request.context.trace_context.new_child() + shared_context.trace_context.new_child() }); let entered = span.enter(); tracing::debug!("ReceiveRequest"); let start = self.in_flight_requests_mut().start_request( request.id, - request.context.deadline, + shared_context.deadline, span.clone(), ); match start { @@ -248,7 +261,7 @@ where } } -impl fmt::Debug for BaseChannel { +impl fmt::Debug for BaseChannel { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "BaseChannel") } @@ -256,9 +269,9 @@ impl fmt::Debug for BaseChannel { /// A request tracked by a [`Channel`]. #[derive(Debug)] -pub struct TrackedRequest { +pub struct TrackedRequest { /// The request sent by the client. - pub request: Request, + pub request: Request, /// A registration to abort a future when the [`Channel`] that produced this request stops /// tracking it. pub abort_registration: AbortRegistration, @@ -295,7 +308,10 @@ pub struct TrackedRequest { /// created by [`BaseChannel`]. pub trait Channel where - Self: Transport::Resp>, TrackedRequest<::Req>>, + Self: Transport< + Response::Resp>, + TrackedRequest::Req>, + >, { /// Type of request item. type Req; @@ -305,6 +321,8 @@ where /// The wrapped transport. type Transport; + ///TODO document + type ServerCtx; /// Configuration of the channel. fn config(&self) -> &Config; @@ -343,7 +361,9 @@ where /// /// ```rust /// use tarpc::{ + /// ClientMessage, /// context, + /// context::{SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -360,10 +380,11 @@ where /// let mut requests = server.requests(); /// tokio::spawn(async move { /// while let Some(Ok(request)) = requests.next().await { - /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed()))); /// } /// }); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// let mut context = context::SharedContext::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` fn requests(self) -> Requests @@ -386,7 +407,7 @@ where /// # Example /// /// ```rust - /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use tarpc::{ClientMessage, context, client, server::{self, BaseChannel, Channel, serve}, transport, context::{SharedContext}}; /// use futures::prelude::*; /// use tracing_subscriber::prelude::*; /// @@ -399,12 +420,13 @@ where /// let client = client::new(client::Config::default(), tx).spawn(); /// let channel = BaseChannel::with_defaults(rx); /// tokio::spawn( - /// channel.execute(serve(|_, i: i32| async move { Ok(i + 1) })) + /// channel.execute(serve(|_, i: i32| async move { Ok(i + 1) }.boxed())) /// .for_each(|response| async move { /// tokio::spawn(response); - /// })); + /// }.boxed())); + /// let mut context = context::SharedContext::current(); /// assert_eq!( - /// client.call(context::current(), 1).await.unwrap(), + /// client.call(&mut context, 1).await.unwrap(), /// 2); /// } /// ``` @@ -412,17 +434,18 @@ where where Self: Sized, Self::Req: RequestName, - S: Serve + Clone, + S: Serve + Clone, { self.requests().execute(serve) } } -impl Stream for BaseChannel +impl Stream for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext { - type Item = Result, ChannelError>; + type Item = Result, ChannelError>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { #[derive(Clone, Copy, Debug)] @@ -525,10 +548,11 @@ where } } -impl Sink> for BaseChannel +impl Sink> for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, T::Error: Error, + ServerCtx: ExtractContext { type Error = ChannelError; @@ -539,7 +563,10 @@ where .map_err(|e| ChannelError::Ready(Arc::new(e))) } - fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { + fn start_send( + mut self: Pin<&mut Self>, + response: Response, + ) -> Result<(), Self::Error> { if let Some(span) = self .in_flight_requests_mut() .remove_request(response.request_id) @@ -572,19 +599,22 @@ where } } -impl AsRef for BaseChannel { +impl AsRef for BaseChannel { fn as_ref(&self) -> &T { self.transport.get_ref() } } -impl Channel for BaseChannel +impl Channel for BaseChannel where - T: Transport, ClientMessage>, + T: Transport, ClientMessage>, + ServerCtx: ExtractContext, { + type Req = Req; type Resp = Resp; type Transport = T; + type ServerCtx = ServerCtx; fn config(&self) -> &Config { &self.config @@ -609,9 +639,9 @@ where #[pin] channel: C, /// Responses waiting to be written to the wire. - pending_responses: mpsc::Receiver>, + pending_responses: mpsc::Receiver>, /// Handed out to request handlers to fan in responses. - responses_tx: mpsc::Sender>, + responses_tx: mpsc::Sender>, } impl Requests @@ -631,14 +661,14 @@ where /// Returns the inner channel over which messages are sent and received. pub fn pending_responses_mut<'a>( self: &'a mut Pin<&mut Self>, - ) -> &'a mut mpsc::Receiver> { + ) -> &'a mut mpsc::Receiver> { self.as_mut().project().pending_responses } fn pump_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, C::Error>>> { + ) -> Poll, C::Error>>> { self.channel_pin_mut().poll_next(cx).map_ok( |TrackedRequest { request, @@ -703,7 +733,7 @@ where fn poll_next_response( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, C::Error>>> { + ) -> Poll, C::Error>>> { ready!(self.ensure_writeable(cx)?); match ready!(self.pending_responses_mut().poll_recv(cx)) { @@ -736,7 +766,8 @@ where /// # Example /// /// ```rust - /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport}; + /// use tarpc::{context, client, server::{self, BaseChannel, Channel, serve}, transport, ClientMessage}; + /// use tarpc::context::{SharedContext}; /// use futures::prelude::*; /// /// # #[cfg(not(feature = "tokio1"))] @@ -748,17 +779,18 @@ where /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); /// let client = client::new(client::Config::default(), tx).spawn(); /// tokio::spawn( - /// requests.execute(serve(|_, i| async move { Ok(i + 1) })) + /// requests.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())) /// .for_each(|response| async move { /// tokio::spawn(response); - /// })); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// }.boxed())); + /// let mut context = context::SharedContext::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` pub fn execute(self, serve: S) -> impl Stream> where C::Req: RequestName, - S: Serve + Clone, + S: Serve + Clone, { self.take_while(|result| { if let Err(e) = result { @@ -805,17 +837,17 @@ impl Drop for ResponseGuard { /// If dropped without calling [`execute`](InFlightRequest::execute), a cancellation message will /// be sent to the Channel to clean up associated request state. #[derive(Debug)] -pub struct InFlightRequest { - request: Request, +pub struct InFlightRequest { + request: Request, abort_registration: AbortRegistration, response_guard: ResponseGuard, span: Span, - response_tx: mpsc::Sender>, + response_tx: mpsc::Sender>, } -impl InFlightRequest { +impl InFlightRequest { /// Returns a reference to the request. - pub fn get(&self) -> &Request { + pub fn get(&self) -> &Request { &self.request } @@ -838,7 +870,9 @@ impl InFlightRequest { /// /// ```rust /// use tarpc::{ + /// ClientMessage, /// context, + /// context::{SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, serve}, /// transport, @@ -855,18 +889,18 @@ impl InFlightRequest { /// tokio::spawn(async move { /// let mut requests = server.requests(); /// while let Some(Ok(in_flight_request)) = requests.next().await { - /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) })).await; + /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())).await; /// } - /// /// }); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// let mut context = context::SharedContext::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` /// pub async fn execute(self, serve: S) where Req: RequestName, - S: Serve, + S: Serve, { let Self { response_tx, @@ -875,7 +909,7 @@ impl InFlightRequest { span, request: Request { - context, + mut context, message, id: request_id, }, @@ -883,10 +917,11 @@ impl InFlightRequest { span.record("otel.name", message.name()); let _ = Abortable::new( async move { - let message = serve.serve(context, message).await; + let message = serve.serve(&mut context, message).await; tracing::debug!("CompleteRequest"); let response = Response { request_id, + context, message, }; let _ = response_tx.send(response).await; @@ -914,7 +949,7 @@ impl Stream for Requests where C: Channel, { - type Item = Result, C::Error>; + type Item = Result, C::Error>; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { loop { @@ -977,10 +1012,23 @@ mod tests { task::Poll, time::{Duration, Instant}, }; + use crate::context::{ExtractContext, SharedContext}; fn test_channel() -> ( - Pin, Response>>>>, - UnboundedChannel, ClientMessage>, + Pin< + Box< + BaseChannel< + Req, + Resp, + UnboundedChannel< + ClientMessage, + Response, + >, + SharedContext + >, + >, + >, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); (Box::pin(BaseChannel::new(Config::default(), rx)), tx) @@ -990,11 +1038,20 @@ mod tests { Pin< Box< Requests< - BaseChannel, Response>>, + BaseChannel< + Req, + Resp, + UnboundedChannel< + ClientMessage, + Response, + >, + SharedContext + >, + >, >, >, - UnboundedChannel, ClientMessage>, + UnboundedChannel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::unbounded(); ( @@ -1009,11 +1066,19 @@ mod tests { Pin< Box< Requests< - BaseChannel, Response>>, + BaseChannel< + Req, + Resp, + channel::Channel< + ClientMessage, + Response, + >, + SharedContext + >, >, >, >, - channel::Channel, ClientMessage>, + channel::Channel, ClientMessage>, ) { let (tx, rx) = crate::transport::channel::bounded(capacity); // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded). @@ -1023,9 +1088,9 @@ mod tests { (Box::pin(BaseChannel::new(config, rx).requests()), tx) } - fn fake_request(req: Req) -> ClientMessage { + fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { - context: context::current(), + context: context::SharedContext::current(), id: 0, message: req, }) @@ -1039,20 +1104,25 @@ mod tests { #[tokio::test] async fn test_serve() { - let serve = serve(|_, i| async move { Ok(i) }); - assert_matches!(serve.serve(context::current(), 7).await, Ok(7)); + let serve = serve(|_, i| async move { Ok(i) }.boxed()); + assert_matches!( + serve.serve(&mut context::SharedContext::current(), 7).await, + Ok(7) + ); } #[tokio::test] async fn serve_before_mutates_context() -> anyhow::Result<()> { struct SetDeadline(Instant); - impl BeforeRequest for SetDeadline { + impl BeforeRequest for SetDeadline where ServerCtx: ExtractContext { async fn before( &mut self, - ctx: &mut context::Context, + ctx: &mut ServerCtx, _: &Req, ) -> Result<(), ServerError> { - ctx.deadline = self.0; + let mut inner = ctx.extract(); + inner.deadline = self.0; + ctx.update(inner); Ok(()) } } @@ -1060,14 +1130,17 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: context::Context, i| async move { - assert_eq!(ctx.deadline, some_time); - Ok(i) + let serve = serve(move |ctx: &mut context::SharedContext, i| { + async move { + assert_eq!(ctx.deadline, some_time); + Ok(i) + } + .boxed() }); let deadline_hook = serve.before(SetDeadline(some_time)); - let mut ctx = context::current(); + let mut ctx = context::SharedContext::current(); ctx.deadline = some_other_time; - deadline_hook.serve(ctx, 7).await?; + deadline_hook.serve(&mut ctx, 7).await?; Ok(()) } @@ -1085,37 +1158,43 @@ mod tests { } } } - impl BeforeRequest for PrintLatency { + impl BeforeRequest for PrintLatency { async fn before( &mut self, - _: &mut context::Context, + _: &mut ServerCtx, _: &Req, ) -> Result<(), ServerError> { self.start = Instant::now(); Ok(()) } } - impl AfterRequest for PrintLatency { - async fn after(&mut self, _: &mut context::Context, _: &mut Result) { + impl AfterRequest for PrintLatency { + async fn after( + &mut self, + _: &mut ServerCtx, + _: &mut Result, + ) { tracing::debug!("Elapsed: {:?}", self.start.elapsed()); } } - let serve = serve(move |_: context::Context, i| async move { Ok(i) }); + let serve = serve(move |_: &mut context::SharedContext, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) - .serve(context::current(), 7) + .serve(&mut context::SharedContext::current(), 7) .await?; Ok(()) } #[tokio::test] async fn serve_before_error_aborts_request() -> anyhow::Result<()> { - let serve = serve(|_, _| async { panic!("Shouldn't get here") }); - let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { + let serve = serve(|_, _| async { panic!("Shouldn't get here") }.boxed()); + let deadline_hook = serve.before(|_: &mut context::SharedContext, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); - let resp: Result = deadline_hook.serve(context::current(), 7).await; + let resp: Result = deadline_hook + .serve(&mut context::SharedContext::current(), 7) + .await; assert_matches!(resp, Err(_)); Ok(()) } @@ -1128,14 +1207,14 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); assert_matches!( channel.as_mut().start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: () }), Err(AlreadyExistsError) @@ -1151,7 +1230,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1159,7 +1238,7 @@ mod tests { .as_mut() .start_request(Request { id: 1, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1182,7 +1261,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1211,7 +1290,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1253,7 +1332,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1276,7 +1355,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: SharedContext::current(), message: (), }) .unwrap(); @@ -1285,6 +1364,7 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, + context: SharedContext::current(), message: Ok(()), }) .unwrap(); @@ -1320,7 +1400,9 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request.execute(serve(|_, _| async { Ok(()) })).await; + request + .execute(serve(|_, _| async { Ok(()) }.boxed())) + .await; assert!( requests .as_mut() @@ -1341,7 +1423,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1350,6 +1432,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, + context: SharedContext::current(), message: Ok(()), }) .unwrap(); @@ -1361,6 +1444,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, + context: SharedContext::current(), message: Ok(()), }) .await @@ -1371,7 +1455,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::current(), + context: SharedContext::current(), message: (), }) .unwrap(); @@ -1392,7 +1476,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1401,6 +1485,7 @@ mod tests { .channel_pin_mut() .start_send(Response { request_id: 0, + context: SharedContext::current(), message: Ok(()), }) .unwrap(); @@ -1411,7 +1496,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::current(), + context: SharedContext::current(), message: (), }) .unwrap(); @@ -1421,6 +1506,7 @@ mod tests { .responses_tx .send(Response { request_id: 1, + context: SharedContext::current(), message: Ok(()), }) .await diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 428eb1a7d..6a71124b1 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -33,7 +33,7 @@ where ) -> impl Stream>> where C::Req: RequestName, - S: Serve + Clone, + S: Serve + Clone, { self.map(move |channel| channel.execute(serve.clone())) } @@ -48,7 +48,9 @@ where /// # Example /// ```rust /// use tarpc::{ +/// ClientMessage, /// context, +/// context::{SharedContext}, /// client::{self, NewClient}, /// server::{self, BaseChannel, Channel, incoming::{Incoming, spawn_incoming}, serve}, /// transport, @@ -63,9 +65,10 @@ where /// /// let incoming = stream::once(async move { /// BaseChannel::new(server::Config::default(), rx) -/// }).execute(serve(|_, i| async move { Ok(i + 1) })); +/// }).execute(serve(|_, i| async move { Ok(i + 1) }.boxed())); /// tokio::spawn(spawn_incoming(incoming)); -/// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); +/// let mut context = context::SharedContext::current(); +/// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` pub async fn spawn_incoming( diff --git a/tarpc/src/server/limits/channels_per_key.rs b/tarpc/src/server/limits/channels_per_key.rs index 64b644278..3ffdfac89 100644 --- a/tarpc/src/server/limits/channels_per_key.rs +++ b/tarpc/src/server/limits/channels_per_key.rs @@ -107,6 +107,7 @@ where type Req = C::Req; type Resp = C::Resp; type Transport = C::Transport; + type ServerCtx = C::ServerCtx; fn config(&self) -> &server::Config { self.inner.config() diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index bd9c103b0..deb723bda 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -67,6 +67,7 @@ where self.as_mut().start_send(Response { request_id: r.request.id, + context: r.request.context, message: Err(ServerError { kind: io::ErrorKind::WouldBlock, detail: "server throttled the request.".into(), @@ -80,7 +81,7 @@ where } } -impl Sink::Resp>> for MaxRequests +impl Sink::Resp>> for MaxRequests where C: Channel, { @@ -92,7 +93,7 @@ where fn start_send( self: Pin<&mut Self>, - item: Response<::Resp>, + item: Response::Resp>, ) -> Result<(), Self::Error> { self.project().inner.start_send(item) } @@ -119,6 +120,7 @@ where type Req = ::Req; type Resp = ::Resp; type Transport = ::Transport; + type ServerCtx = ::ServerCtx; fn in_flight_requests(&self) -> usize { self.inner.in_flight_requests() @@ -188,6 +190,7 @@ mod tests { time::{Duration, Instant}, }; use tracing::Span; + use crate::context::{SharedContext}; #[tokio::test] async fn throttler_in_flight_requests() { @@ -268,7 +271,8 @@ mod tests { } impl PendingSink<(), ()> { pub fn default() - -> PendingSink>, Response> { + -> PendingSink>, Response> + { PendingSink { ghost: PhantomData } } } @@ -293,10 +297,13 @@ mod tests { Poll::Pending } } - impl Channel for PendingSink>, Response> { + impl Channel + for PendingSink>, Response> + { type Req = Req; type Resp = Resp; type Transport = (); + type ServerCtx = SharedContext; fn config(&self) -> &Config { unimplemented!() } @@ -326,16 +333,16 @@ mod tests { .as_mut() .start_send(Response { request_id: 0, + context: SharedContext::current(), message: Ok(1), }) .unwrap(); assert_eq!(throttler.inner.in_flight_requests.len(), 0); - assert_eq!( - throttler.inner.sink.front(), - Some(&Response { - request_id: 0, - message: Ok(1), - }) - ); + + let result = throttler.inner.sink.front(); + + assert_eq!(result.map(|r| r.request_id), Some(0)); + + assert_eq!(result.map(|r| &r.message), Some(&Ok(1))); } } diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 66cf2878c..4f3d60377 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -43,12 +43,12 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; /// use std::io; /// - /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }) - /// .before(|_ctx: &mut context::Context, req: &i32| { + /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }.boxed()) + /// .before(|_ctx: &mut context::SharedContext, req: &i32| { /// future::ready( /// if *req == 1 { /// Err(ServerError::new( @@ -58,12 +58,13 @@ pub trait RequestHook: Serve { /// Ok(()) /// }) /// }); - /// let response = serve.serve(context::current(), 1); + /// let mut context = context::SharedContext::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` - fn before(self, hook: Hook) -> HookThenServe + fn before(self, hook: Hook) -> HookThenServe where - Hook: BeforeRequest, + Hook: BeforeRequest, Self: Sized, { HookThenServe::new(self, hook) @@ -80,7 +81,7 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; /// use std::io; /// @@ -93,20 +94,20 @@ pub trait RequestHook: Serve { /// } else { /// Ok(i + 1) /// } - /// }) - /// .after(|_ctx: &mut context::Context, resp: &mut Result| { + /// }.boxed()) + /// .after(|_ctx: &mut context::SharedContext, resp: &mut Result| { /// if let Err(e) = resp { /// eprintln!("server error: {e:?}"); /// } /// future::ready(()) /// }); - /// - /// let response = serve.serve(context::current(), 1); + /// let mut context = context::SharedContext::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` fn after(self, hook: Hook) -> ServeThenHook where - Hook: AfterRequest, + Hook: AfterRequest, Self: Sized, { ServeThenHook::new(self, hook) @@ -123,7 +124,7 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{ /// context, ServerError, /// server::{Serve, serve, request_hook::{BeforeRequest, AfterRequest, RequestHook}} @@ -132,17 +133,17 @@ pub trait RequestHook: Serve { /// /// struct PrintLatency(Instant); /// - /// impl BeforeRequest for PrintLatency { - /// async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { + /// impl BeforeRequest for PrintLatency { + /// async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { /// self.0 = Instant::now(); /// Ok(()) /// } /// } /// - /// impl AfterRequest for PrintLatency { + /// impl AfterRequest for PrintLatency { /// async fn after( /// &mut self, - /// _: &mut context::Context, + /// _: &mut ServerCtx, /// _: &mut Result, /// ) { /// tracing::info!("Elapsed: {:?}", self.0.elapsed()); @@ -151,16 +152,17 @@ pub trait RequestHook: Serve { /// /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) - /// }).before_and_after(PrintLatency(Instant::now())); - /// let response = serve.serve(context::current(), 1); + /// }.boxed()).before_and_after(PrintLatency(Instant::now())); + /// let mut context = context::SharedContext::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// ``` fn before_and_after( self, hook: Hook, - ) -> HookThenServeThenHook + ) -> HookThenServeThenHook where - Hook: BeforeRequest + AfterRequest, + Hook: BeforeRequest + AfterRequest, Self: Sized, { HookThenServeThenHook::new(self, hook) diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index b2ef9ccbd..ce6319e25 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -6,24 +6,32 @@ //! Provides a hook that runs after request execution. -use crate::{ServerError, context, server::Serve}; +use crate::{ServerError, server::Serve}; use futures::prelude::*; /// A hook that runs after request execution. #[allow(async_fn_in_trait)] -pub trait AfterRequest { +pub trait AfterRequest { /// The function that is called after request execution. /// /// The hook can modify the request context and the response. - async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result); + async fn after( + &mut self, + ctx: &mut ServerCtx, + resp: &mut Result, + ); } -impl AfterRequest for F +impl AfterRequest for F where - F: FnMut(&mut context::Context, &mut Result) -> Fut, + F: FnMut(&mut ServerCtx, &mut Result) -> Fut, Fut: Future, { - async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result) { + async fn after( + &mut self, + ctx: &mut ServerCtx, + resp: &mut Result, + ) { self(ctx, resp).await } } @@ -52,21 +60,22 @@ impl Clone for ServeThenHook { impl Serve for ServeThenHook where Serv: Serve, - Hook: AfterRequest, + Hook: AfterRequest, { type Req = Serv::Req; type Resp = Serv::Resp; + type ServerCtx = Serv::ServerCtx; async fn serve( self, - mut ctx: context::Context, + ctx: &mut Serv::ServerCtx, req: Serv::Req, ) -> Result { let ServeThenHook { serve, mut hook, .. } = self; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index e72e28a42..3e2e091c8 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -6,12 +6,13 @@ //! Provides a hook that runs before request execution. -use crate::{ServerError, context, server::Serve}; +use std::marker::PhantomData; +use crate::{ServerError, server::Serve}; use futures::prelude::*; /// A hook that runs before request execution. #[allow(async_fn_in_trait)] -pub trait BeforeRequest { +pub trait BeforeRequest { /// The function that is called before request execution. /// /// If this function returns an error, the request will not be executed and the error will be @@ -19,22 +20,26 @@ pub trait BeforeRequest { /// /// This function can also modify the request context. This could be used, for example, to /// enforce a maximum deadline on all requests. - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>; + async fn before( + &mut self, + ctx: &mut ServerCtx, + req: &Req, + ) -> Result<(), ServerError>; } /// A list of hooks that run in order before request execution. -pub trait BeforeRequestList: BeforeRequest { +pub trait BeforeRequestList: BeforeRequest { /// The hook returned by `BeforeRequestList::then`. - type Then: BeforeRequest + type Then: BeforeRequest where - Next: BeforeRequest; + Next: BeforeRequest; /// Returns a hook that, when run, runs two hooks, first `self` and then `next`. - fn then>(self, next: Next) -> Self::Then; + fn then>(self, next: Next) -> Self::Then; /// Same as `then`, but helps the compiler with type inference when Next is a closure. fn then_fn< - Next: FnMut(&mut context::Context, &Req) -> Fut, + Next: FnMut(&mut ServerCtx, &Req) -> Fut, Fut: Future>, >( self, @@ -47,53 +52,64 @@ pub trait BeforeRequestList: BeforeRequest { } /// The service fn returned by `BeforeRequestList::serving`. - type Serve>: Serve; + type Serve>: Serve; /// Runs the list of request hooks before execution of the given serve fn. /// This is equivalent to `serve.before(before_request_chain)` but may be syntactically nicer. - fn serving>(self, serve: S) -> Self::Serve; + fn serving>(self, serve: S) -> Self::Serve; } -impl BeforeRequest for F +impl BeforeRequest for F where - F: FnMut(&mut context::Context, &Req) -> Fut, + F: FnMut(&mut ServerCtx, &Req) -> Fut, Fut: Future>, { - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { + async fn before( + &mut self, + ctx: &mut ServerCtx, + req: &Req, + ) -> Result<(), ServerError> { self(ctx, req).await } } /// A Service function that runs a hook before request execution. -#[derive(Clone)] -pub struct HookThenServe { +pub struct HookThenServe { serve: Serv, hook: Hook, + ghost: PhantomData } -impl HookThenServe { +impl Clone for HookThenServe { + fn clone(&self) -> Self { + Self::new(self.serve.clone(), self.hook.clone()) + } +} + +impl HookThenServe { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { - Self { serve, hook } + Self { serve, hook, ghost: PhantomData } } } -impl Serve for HookThenServe +impl Serve for HookThenServe where - Serv: Serve, - Hook: BeforeRequest, + Serv: Serve, + Hook: BeforeRequest, { + type ServerCtx = ServerCtx; type Req = Serv::Req; type Resp = Serv::Resp; async fn serve( self, - mut ctx: context::Context, + ctx: &mut ServerCtx, req: Self::Req, ) -> Result { let HookThenServe { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; serve.serve(ctx, req).await } } @@ -103,7 +119,7 @@ where /// Example /// /// ```rust -/// use futures::{executor::block_on, future}; +/// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, serve, request_hook::{self, /// BeforeRequest, BeforeRequestList}}}; /// use std::{cell::Cell, io}; @@ -120,8 +136,9 @@ where /// i.set(2); /// Ok(()) /// }) -/// .serving(serve(|_ctx, i| async move { Ok(i + 1) })); -/// let response = serve.clone().serve(context::current(), 1); +/// .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); +/// let mut context = context::SharedContext::current(); +/// let response = serve.clone().serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// assert!(i.get() == 2); /// ``` @@ -137,10 +154,14 @@ pub struct BeforeRequestCons(First, Rest); #[derive(Clone, Copy)] pub struct BeforeRequestNil; -impl, Rest: BeforeRequest> BeforeRequest +impl, Rest: BeforeRequest, ServerCtx> BeforeRequest for BeforeRequestCons { - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { + async fn before( + &mut self, + ctx: &mut ServerCtx, + req: &Req, + ) -> Result<(), ServerError> { let BeforeRequestCons(first, rest) = self; first.before(ctx, req).await?; rest.before(ctx, req).await?; @@ -148,45 +169,45 @@ impl, Rest: BeforeRequest> BeforeRequest BeforeRequest for BeforeRequestNil { - async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { +impl BeforeRequest for BeforeRequestNil { + async fn before(&mut self, _: &mut ServerCtx, _: &Req) -> Result<(), ServerError> { Ok(()) } } -impl, Rest: BeforeRequestList> BeforeRequestList +impl, Rest: BeforeRequestList, ServerCtx> BeforeRequestList for BeforeRequestCons { type Then = BeforeRequestCons> where - Next: BeforeRequest; + Next: BeforeRequest; - fn then>(self, next: Next) -> Self::Then { + fn then>(self, next: Next) -> Self::Then { let BeforeRequestCons(first, rest) = self; BeforeRequestCons(first, rest.then(next)) } - type Serve> = HookThenServe; + type Serve> = HookThenServe; - fn serving>(self, serve: S) -> Self::Serve { + fn serving>(self, serve: S) -> Self::Serve { HookThenServe::new(serve, self) } } -impl BeforeRequestList for BeforeRequestNil { +impl BeforeRequestList for BeforeRequestNil { type Then = BeforeRequestCons where - Next: BeforeRequest; + Next: BeforeRequest; - fn then>(self, next: Next) -> Self::Then { + fn then>(self, next: Next) -> Self::Then { BeforeRequestCons(next, BeforeRequestNil) } - type Serve> = S; + type Serve> = S; - fn serving>(self, serve: S) -> S { + fn serving>(self, serve: S) -> S { serve } } @@ -209,8 +230,9 @@ fn before_request_list() { i.set(2); Ok(()) }) - .serving(serve(|_ctx, i| async move { Ok(i + 1) })); - let response = serve.clone().serve(context::current(), 1); + .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); + let mut context = crate::context::SharedContext::current(); + let response = serve.clone().serve(&mut context, 1); assert!(block_on(response).is_ok()); assert!(i.get() == 2); } diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index 0761a7df3..080c53b21 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -7,17 +7,17 @@ //! Provides a hook that runs both before and after request execution. use super::{after::AfterRequest, before::BeforeRequest}; -use crate::{RequestName, ServerError, context, server::Serve}; +use crate::{RequestName, ServerError, server::Serve}; use std::marker::PhantomData; /// A Service function that runs a hook both before and after request execution. -pub struct HookThenServeThenHook { +pub struct HookThenServeThenHook { serve: Serv, hook: Hook, - fns: PhantomData<(fn(Req), fn(Resp))>, + fns: PhantomData<(Req, Resp, ServerCtx)>, } -impl HookThenServeThenHook { +impl HookThenServeThenHook { pub(crate) fn new(serve: Serv, hook: Hook) -> Self { Self { serve, @@ -27,7 +27,7 @@ impl HookThenServeThenHook { } } -impl Clone for HookThenServeThenHook { +impl Clone for HookThenServeThenHook { fn clone(&self) -> Self { Self { serve: self.serve.clone(), @@ -37,22 +37,27 @@ impl Clone for HookThenServeThenHook Serve for HookThenServeThenHook +impl Serve for HookThenServeThenHook where Req: RequestName, - Serv: Serve, - Hook: BeforeRequest + AfterRequest, + Serv: Serve, + Hook: BeforeRequest + AfterRequest, { type Req = Req; type Resp = Resp; + type ServerCtx = ServerCtx; - async fn serve(self, mut ctx: context::Context, req: Req) -> Result { + async fn serve( + self, + ctx: &mut ServerCtx, + req: Req, + ) -> Result { let HookThenServeThenHook { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index db167c42e..9a941f711 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -4,6 +4,7 @@ // license that can be found in the LICENSE file or at // https://opensource.org/licenses/MIT. +use crate::context::{SharedContext}; use crate::{ Request, Response, cancellations::{CanceledRequests, RequestCancellation, cancellations}, @@ -38,14 +39,19 @@ where } } -impl Sink> for FakeChannel> { +impl Sink> + for FakeChannel> +{ type Error = io::Error; fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { self.project().sink.poll_ready(cx).map_err(|e| match e {}) } - fn start_send(mut self: Pin<&mut Self>, response: Response) -> Result<(), Self::Error> { + fn start_send( + mut self: Pin<&mut Self>, + response: Response, + ) -> Result<(), Self::Error> { self.as_mut() .project() .in_flight_requests @@ -65,13 +71,15 @@ impl Sink> for FakeChannel> { } } -impl Channel for FakeChannel>, Response> +impl Channel + for FakeChannel>, Response> where Req: Unpin, { type Req = Req; type Resp = Resp; type Transport = (); + type ServerCtx = SharedContext; fn config(&self) -> &Config { &self.config @@ -86,13 +94,13 @@ where } } -impl FakeChannel>, Response> { +impl FakeChannel>, Response> { pub fn push_req(&mut self, id: u64, message: Req) { let (_, abort_registration) = futures::future::AbortHandle::new_pair(); let (request_cancellation, _) = cancellations(); self.stream.push_back(Ok(TrackedRequest { request: Request { - context: context::Context { + context: context::SharedContext { deadline: Instant::now(), trace_context: Default::default(), }, @@ -111,8 +119,14 @@ impl FakeChannel>, Response> { } impl FakeChannel<(), ()> { - pub fn default() -> FakeChannel>, Response> { + pub fn default() + -> FakeChannel>, Response> { let (request_cancellation, canceled_requests) = cancellations(); + + let mut x = anymap3::AnyMap::new(); + + x.entry::<&str>(); + FakeChannel { stream: Default::default(), sink: Default::default(), diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 0268300dc..de9a8afdc 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -175,6 +175,7 @@ mod tests { use futures::{prelude::*, stream}; use std::io; use tracing::trace; + use crate::context::SharedContext; #[test] fn ensure_is_transport() { @@ -187,17 +188,22 @@ mod tests { async fn integration() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (client_channel, server_channel) = transport::channel::unbounded(); + let (client_channel, server_channel) = + transport::channel::unbounded(); + tokio::spawn( stream::once(future::ready(server_channel)) .map(BaseChannel::with_defaults) - .execute(serve(|_ctx, request: String| async move { - request.parse::().map_err(|_| { - ServerError::new( - io::ErrorKind::InvalidInput, - format!("{request:?} is not an int"), - ) - }) + .execute(serve(|_ctx: &mut SharedContext, request: String| { + async move { + request.parse::().map_err(|_| { + ServerError::new( + io::ErrorKind::InvalidInput, + format!("{request:?} is not an int"), + ) + }) + } + .boxed() })) .for_each(|channel| async move { tokio::spawn(channel.for_each(|response| response)); @@ -206,8 +212,12 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client.call(context::current(), "123".into()).await; - let response2 = client.call(context::current(), "abc".into()).await; + let response1 = client + .call(&mut context::SharedContext::current(), "123".into()) + .await; + let response2 = client + .call(&mut context::SharedContext::current(), "abc".into()) + .await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.rs b/tarpc/tests/compile_fail/must_use_request_dispatch.rs index 2915d3237..a5238fe8b 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.rs +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.rs @@ -1,5 +1,5 @@ use tarpc::client; - +use tarpc::context::SharedContext; #[tarpc::service] trait World { async fn hello(name: String) -> String; @@ -10,6 +10,6 @@ fn main() { #[deny(unused_must_use)] { - WorldClient::new(client::Config::default(), client_transport).dispatch; + WorldClient::::new(client::Config::default(), client_transport).dispatch; } } diff --git a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr index e652cc8e8..e0ec77ff3 100644 --- a/tarpc/tests/compile_fail/must_use_request_dispatch.stderr +++ b/tarpc/tests/compile_fail/must_use_request_dispatch.stderr @@ -1,8 +1,8 @@ error: unused `RequestDispatch` that must be used --> tests/compile_fail/must_use_request_dispatch.rs:13:9 | -13 | WorldClient::new(client::Config::default(), client_transport).dispatch; - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +13 | WorldClient::::new(client::Config::default(), client_transport).dispatch; + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | note: the lint level is defined here --> tests/compile_fail/must_use_request_dispatch.rs:11:12 @@ -11,5 +11,5 @@ note: the lint level is defined here | ^^^^^^^^^^^^^^^ help: use `let _ = ...` to ignore the resulting value | -13 | let _ = WorldClient::new(client::Config::default(), client_transport).dispatch; +13 | let _ = WorldClient::::new(client::Config::default(), client_transport).dispatch; | +++++++ diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 18bb3a997..6bcd255c4 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -1,5 +1,6 @@ use futures::prelude::*; -use tarpc::serde_transport; +use tarpc::context::{SharedContext}; +use tarpc::{serde_transport}; use tarpc::{ client, context, server::{BaseChannel, incoming::Incoming}, @@ -22,7 +23,8 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { - async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData { + type Context = SharedContext; + async fn get_opposite_color(self, _: &mut Self::Context, color: TestData) -> TestData { match color { TestData::White => TestData::Black, TestData::Black => TestData::White, @@ -53,7 +55,7 @@ async fn test_call() -> anyhow::Result<()> { let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client - .get_opposite_color(context::current(), TestData::White) + .get_opposite_color(&mut context::SharedContext::current(), TestData::White) .await?; assert_eq!(color, TestData::Black); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 06542b43b..7d1f96e18 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -5,12 +5,15 @@ use futures::{ }; use std::time::{Duration, Instant}; use tarpc::{ + ClientMessage, client::{self}, context, server::{BaseChannel, Channel, incoming::Incoming}, + transport, transport::channel, }; use tokio::join; +use tarpc::context::SharedContext; #[tarpc_plugins::service] trait Service { @@ -22,26 +25,36 @@ trait Service { struct Server; impl Service for Server { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + type Context = SharedContext; + async fn add(self, _: &mut Self::Context, x: i32, y: i32) -> i32 { x + y } - async fn hey(self, _: context::Context, name: String) -> String { + async fn hey(self, _: &mut Self::Context, name: String) -> String { format!("Hey, {name}.") } } #[tokio::test] async fn sequential() { - let (tx, rx) = tarpc::transport::channel::unbounded(); + let (tx, rx) = channel::unbounded(); + let client = client::new(client::Config::default(), tx).spawn(); let channel = BaseChannel::with_defaults(rx); tokio::spawn( channel - .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) })) + .execute(tarpc::server::serve(|_, i: u32| { + async move { Ok(i + 1) }.boxed() + })) .for_each(|response| response), ); - assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + assert_eq!( + client + .call(&mut context::SharedContext::current(), 1) + .await + .unwrap(), + 2 + ); } #[tokio::test] @@ -55,7 +68,8 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - async fn r#loop(self, _: context::Context) { + type Context = SharedContext; + async fn r#loop(self, _: &mut Self::Context) { loop { futures::pending!(); } @@ -64,16 +78,16 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = transport::channel::unbounded(); // Set up a client that initiates a long-lived request. // The request will complete in error when the server drops the connection. tokio::spawn(async move { let client = LoopClient::new(client::Config::default(), tx).spawn(); - let mut ctx = context::current(); + let mut ctx = context::SharedContext::current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); - let _ = client.r#loop(ctx).await; + let _ = client.r#loop(&mut ctx).await; }); let mut requests = BaseChannel::with_defaults(rx).requests(); @@ -112,9 +126,14 @@ async fn serde_tcp() -> anyhow::Result<()> { let transport = serde_transport::tcp::connect(addr, Json::default).await?; let client = ServiceClient::new(client::Config::default(), transport).spawn(); - assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); assert_matches!( - client.hey(context::current(), "Tim".to_string()).await, + client + .add(&mut context::SharedContext::current(), 1, 2) + .await, + Ok(3) + ); + assert_matches!( + client.hey(&mut context::SharedContext::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); @@ -142,11 +161,16 @@ async fn serde_uds() -> anyhow::Result<()> { ); let transport = serde_transport::unix::connect(&sock, Json::default).await?; + let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail - let res1 = client.add(context::current(), 1, 2).await; - let res2 = client.hey(context::current(), "Tim".to_string()).await; + let res1 = client + .add(&mut context::SharedContext::current(), 1, 2) + .await; + let res2 = client + .hey(&mut context::SharedContext::current(), "Tim".to_string()) + .await; assert_matches!(res1, Ok(3)); assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); @@ -158,7 +182,8 @@ async fn serde_uds() -> anyhow::Result<()> { async fn concurrent() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = transport::channel::unbounded(); + tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) @@ -169,12 +194,15 @@ async fn concurrent() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); - let req3 = client.hey(context::current(), "Tim".to_string()); + let mut context = context::SharedContext::current(); + let req1 = client.add(&mut context, 1, 2); assert_matches!(req1.await, Ok(3)); + + let req2 = client.add(&mut context, 3, 4); assert_matches!(req2.await, Ok(7)); + + let req3 = client.hey(&mut context, "Tim".to_string()); assert_matches!(req3.await, Ok(ref s) if s == "Hey, Tim."); Ok(()) @@ -184,7 +212,8 @@ async fn concurrent() -> anyhow::Result<()> { async fn concurrent_join() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = transport::channel::unbounded(); + tokio::spawn( stream::once(ready(rx)) .map(BaseChannel::with_defaults) @@ -195,9 +224,13 @@ async fn concurrent_join() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); - let req3 = client.hey(context::current(), "Tim".to_string()); + let mut context1 = context::SharedContext::current(); + let mut context2 = context::SharedContext::current(); + let mut context3 = context::SharedContext::current(); + + let req1 = client.add(&mut context1, 1, 2); + let req2 = client.add(&mut context2, 3, 4); + let req3 = client.hey(&mut context3, "Tim".to_string()); let (resp1, resp2, resp3) = join!(req1, req2, req3); assert_matches!(resp1, Ok(3)); @@ -216,7 +249,7 @@ async fn spawn(fut: impl Future + Send + 'static) { async fn concurrent_join_all() -> anyhow::Result<()> { let _ = tracing_subscriber::fmt::try_init(); - let (tx, rx) = channel::unbounded(); + let (tx, rx) = transport::channel::unbounded(); tokio::spawn( BaseChannel::with_defaults(rx) .execute(Server.serve()) @@ -225,8 +258,11 @@ async fn concurrent_join_all() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); + let mut context1 = context::SharedContext::current(); + let mut context2 = context::SharedContext::current(); + + let req1 = client.add(&mut context1, 1, 2); + let req2 = client.add(&mut context2, 3, 4); let responses = join_all(vec![req1, req2]).await; assert_matches!(responses[0], Ok(3)); @@ -245,14 +281,16 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - async fn count(self, _: context::Context) -> u32 { + type Context = SharedContext; + async fn count(self, _: &mut Self::Context) -> u32 { self.0 += 1; self.0 } } let (tx, rx) = channel::unbounded(); - tokio::spawn(async { + + tokio::task::spawn(async move { let mut requests = BaseChannel::with_defaults(rx).requests(); let mut counter = CountService(0); @@ -262,8 +300,14 @@ async fn counter() -> anyhow::Result<()> { }); let client = CounterClient::new(client::Config::default(), tx).spawn(); - assert_matches!(client.count(context::current()).await, Ok(1)); - assert_matches!(client.count(context::current()).await, Ok(2)); + assert_matches!( + client.count(&mut context::SharedContext::current()).await, + Ok(1) + ); + assert_matches!( + client.count(&mut context::SharedContext::current()).await, + Ok(2) + ); Ok(()) }