diff --git a/example-service/src/client.rs b/example-service/src/client.rs index 8a4ff72eb..dc7104bfd 100644 --- a/example-service/src/client.rs +++ b/example-service/src/client.rs @@ -10,6 +10,7 @@ use std::{net::SocketAddr, time::Duration}; use tarpc::{client, context, tokio_serde::formats::Json}; use tokio::time::sleep; use tracing::Instrument; +use tarpc::context::ClientContext; #[derive(Parser)] struct Flags { @@ -34,10 +35,13 @@ async fn main() -> anyhow::Result<()> { let client = WorldClient::new(client::Config::default(), transport.await?).spawn(); let hello = async move { + let mut context = ClientContext::current(); + let mut context2 = ClientContext::current(); + // Send the request twice, just to be safe! ;) tokio::select! { - hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 } - hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 } + hello1 = client.hello(&mut context, format!("{}1", flags.name)) => { hello1 } + hello2 = client.hello(&mut context2, format!("{}2", flags.name)) => { hello2 } } } .instrument(tracing::info_span!("Two Hellos")) diff --git a/example-service/src/server.rs b/example-service/src/server.rs index 896280c3d..0845783c7 100644 --- a/example-service/src/server.rs +++ b/example-service/src/server.rs @@ -35,7 +35,7 @@ struct Flags { struct HelloServer(SocketAddr); impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::ServerContext, name: String) -> String { let sleep_time = Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng())); time::sleep(sleep_time).await; diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index da6443edf..886b85b48 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -375,7 +375,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// # Example /// /// ```no_run -/// use tarpc::{client, transport, service, server::{self, Channel}, context::Context}; +/// use tarpc::{client, transport, service, server::{self, Channel}, context::ServerContext}; /// /// #[service] /// pub trait Calculator { @@ -401,7 +401,7 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { /// #[derive(Clone)] /// struct CalculatorServer; /// impl Calculator for CalculatorServer { -/// async fn add(self, context: Context, a: i32, b: i32) -> i32 { +/// async fn add(self, context: &mut ServerContext, a: i32, b: i32) -> i32 { /// a + b /// } /// } @@ -558,7 +558,7 @@ impl ServiceGenerator<'_> { )| { quote! { #( #attrs )* - async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output; + async fn #ident(self, context: &mut ::tarpc::context::ServerContext, #( #args ),*) -> #output; } }, ); @@ -622,7 +622,7 @@ impl ServiceGenerator<'_> { type Resp = #response_ident; - async fn serve(self, ctx: ::tarpc::context::Context, req: #request_ident) + async fn serve(self, ctx: &mut ::tarpc::context::ServerContext, req: #request_ident) -> ::core::result::Result<#response_ident, ::tarpc::ServerError> { match req { #( @@ -786,7 +786,7 @@ impl ServiceGenerator<'_> { #( #[allow(unused)] #( #method_attrs )* - #vis fn #method_idents(&self, ctx: ::tarpc::context::Context, #( #args ),*) + #vis fn #method_idents<'a>(&'a self, ctx: &'a mut ::tarpc::context::ClientContext, #( #args ),*) -> impl ::core::future::Future> + '_ { let request = #request_ident::#camel_case_idents { #( #arg_pats ),* }; let resp = self.0.call(ctx, request); diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 26ee1ec39..b03f3470f 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -12,15 +12,15 @@ fn att_service_trait() { } impl Foo for () { - async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) { + async fn two_part(self, _: &mut context::ServerContext, s: String, i: i32) -> (String, i32) { (s, i) } - async fn bar(self, _: context::Context, s: String) -> String { + async fn bar(self, _: &mut context::ServerContext, s: String) -> String { s } - async fn baz(self, _: context::Context) {} + async fn baz(self, _: &mut context::ServerContext) {} } } @@ -39,18 +39,18 @@ fn raw_idents() { impl r#trait for () { async fn r#await( self, - _: context::Context, + _: &mut context::ServerContext, r#struct: r#yield, r#enum: i32, ) -> (r#yield, i32) { (r#struct, r#enum) } - async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield { + async fn r#fn(self, _: &mut context::ServerContext, r#impl: r#yield) -> r#yield { r#impl } - async fn r#async(self, _: context::Context) {} + async fn r#async(self, _: &mut context::ServerContext) {} } } @@ -64,7 +64,7 @@ fn service_with_cfg_rpc() { } impl Foo for () { - async fn foo(self, _: context::Context) {} + async fn foo(self, _: &mut context::ServerContext) {} } } diff --git a/tarpc/Cargo.toml b/tarpc/Cargo.toml index 778eb0938..cb837da39 100644 --- a/tarpc/Cargo.toml +++ b/tarpc/Cargo.toml @@ -61,6 +61,7 @@ tracing = { version = "0.1", default-features = false, features = [ tracing-opentelemetry = { version = "0.31.0", default-features = false } opentelemetry = { version = "0.30.0", default-features = false } opentelemetry-semantic-conventions = "0.30.0" +anymap3 = "1.0.1" [dev-dependencies] assert_matches = "1.4" diff --git a/tarpc/examples/compression.rs b/tarpc/examples/compression.rs index d66261d19..663236731 100644 --- a/tarpc/examples/compression.rs +++ b/tarpc/examples/compression.rs @@ -108,7 +108,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::ServerContext, name: String) -> String { format!("Hey, {name}!") } } @@ -134,7 +134,7 @@ async fn main() -> anyhow::Result<()> { println!( "{}", - client.hello(context::current(), "friend".into()).await? + client.hello(&mut context::ClientContext::current(), "friend".into()).await? ); Ok(()) } diff --git a/tarpc/examples/custom_transport.rs b/tarpc/examples/custom_transport.rs index 5f5386785..1c682173d 100644 --- a/tarpc/examples/custom_transport.rs +++ b/tarpc/examples/custom_transport.rs @@ -5,7 +5,7 @@ // https://opensource.org/licenses/MIT. use futures::prelude::*; -use tarpc::context::Context; +use tarpc::context::{ClientContext, ServerContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; @@ -21,7 +21,7 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: Context) {} + async fn ping(self, _: &mut ServerContext) {} } #[tokio::main] @@ -52,7 +52,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(conn), Bincode::default()); PingServiceClient::new(Default::default(), transport) .spawn() - .ping(tarpc::context::current()) + .ping(&mut ClientContext::current()) .await?; Ok(()) diff --git a/tarpc/examples/pubsub.rs b/tarpc/examples/pubsub.rs index d61f68c48..2644ee617 100644 --- a/tarpc/examples/pubsub.rs +++ b/tarpc/examples/pubsub.rs @@ -80,11 +80,11 @@ struct Subscriber { } impl subscriber::Subscriber for Subscriber { - async fn topics(self, _: context::Context) -> Vec { + async fn topics(self, _: &mut context::ServerContext) -> Vec { self.topics.clone() } - async fn receive(self, _: context::Context, topic: String, message: String) { + async fn receive(self, _: &mut context::ServerContext, topic: String, message: String) { info!(local_addr = %self.local_addr, %topic, %message, "ReceivedMessage") } } @@ -210,7 +210,7 @@ impl Publisher { subscriber: subscriber::SubscriberClient, ) { // Populate the topics - if let Ok(topics) = subscriber.topics(context::current()).await { + if let Ok(topics) = subscriber.topics(&mut context::ClientContext::current()).await { self.clients.lock().unwrap().insert( subscriber_addr, Subscription { @@ -263,15 +263,19 @@ impl Publisher { } impl publisher::Publisher for Publisher { - async fn publish(self, _: context::Context, topic: String, message: String) { + async fn publish(self, _: &mut context::ServerContext, topic: String, message: String) { info!("received message to publish."); let mut subscribers = match self.subscriptions.read().unwrap().get(&topic) { None => return, Some(subscriptions) => subscriptions.clone(), }; let mut publications = Vec::new(); + + for client in subscribers.values_mut() { - publications.push(client.receive(context::current(), topic.clone(), message.clone())); + publications.push(async { + client.receive(&mut context::ClientContext::current(), topic.clone(), message.clone()).await + }); } // Ignore failing subscribers. In a real pubsub, you'd want to continually retry until // subscribers ack. Of course, a lot would be different in a real pubsub :) @@ -342,26 +346,26 @@ async fn main() -> anyhow::Result<()> { .spawn(); publisher - .publish(context::current(), "calculus".into(), "sqrt(2)".into()) + .publish(&mut context::ClientContext::current(), "calculus".into(), "sqrt(2)".into()) .await?; publisher .publish( - context::current(), + &mut context::ClientContext::current(), "cool shorts".into(), "hello to all".into(), ) .await?; publisher - .publish(context::current(), "history".into(), "napoleon".to_string()) + .publish(&mut context::ClientContext::current(), "history".into(), "napoleon".to_string()) .await?; drop(_subscriber0); publisher .publish( - context::current(), + &mut context::ClientContext::current(), "cool shorts".into(), "hello to who?".into(), ) diff --git a/tarpc/examples/readme.rs b/tarpc/examples/readme.rs index c328bd884..60daf4e45 100644 --- a/tarpc/examples/readme.rs +++ b/tarpc/examples/readme.rs @@ -23,7 +23,7 @@ pub trait World { struct HelloServer; impl World for HelloServer { - async fn hello(self, _: context::Context, name: String) -> String { + async fn hello(self, _: &mut context::ServerContext, name: String) -> String { format!("Hello, {name}!") } } @@ -46,7 +46,7 @@ async fn main() -> anyhow::Result<()> { // The client has an RPC method for each RPC defined in the annotated trait. It takes the same // args as defined, with the addition of a Context, which is always the first arg. The Context // specifies a deadline and trace information which can be helpful in debugging requests. - let hello = client.hello(context::current(), "Stim".to_string()).await?; + let hello = client.hello(&mut context::ClientContext::current(), "Stim".to_string()).await?; println!("{hello}"); diff --git a/tarpc/examples/tls_over_tcp.rs b/tarpc/examples/tls_over_tcp.rs index 968f76c17..cc3c1690b 100644 --- a/tarpc/examples/tls_over_tcp.rs +++ b/tarpc/examples/tls_over_tcp.rs @@ -18,7 +18,7 @@ use tokio_rustls::rustls::{ }; use tokio_rustls::{TlsAcceptor, TlsConnector}; -use tarpc::context::Context; +use tarpc::context::{ClientContext, ServerContext}; use tarpc::serde_transport as transport; use tarpc::server::{BaseChannel, Channel}; use tarpc::tokio_serde::formats::Bincode; @@ -33,7 +33,7 @@ pub trait PingService { struct Service; impl PingService for Service { - async fn ping(self, _: Context) -> String { + async fn ping(self, _: &mut ServerContext) -> String { "🔒".to_owned() } } @@ -146,7 +146,7 @@ async fn main() -> anyhow::Result<()> { let transport = transport::new(codec_builder.new_framed(stream), Bincode::default()); let answer = PingServiceClient::new(Default::default(), transport) .spawn() - .ping(tarpc::context::current()) + .ping(&mut ClientContext::current()) .await?; println!("ping answer: {answer}"); diff --git a/tarpc/examples/tracing.rs b/tarpc/examples/tracing.rs index 79a7026c0..be1b539c1 100644 --- a/tarpc/examples/tracing.rs +++ b/tarpc/examples/tracing.rs @@ -56,7 +56,7 @@ pub mod double { struct AddServer; impl AddService for AddServer { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + async fn add(self, _: &mut context::ServerContext, x: i32, y: i32) -> i32 { x + y } } @@ -70,9 +70,9 @@ impl DoubleService for DoubleServer where Stub: AddStub + Clone + Send + Sync + 'static, { - async fn double(self, _: context::Context, x: i32) -> Result { + async fn double(self, _: &mut context::ServerContext, x: i32) -> Result { self.add_client - .add(context::current(), x, x) + .add(&mut context::ClientContext::current(), x, x) .await .map_err(|e| e.to_string()) } @@ -193,9 +193,8 @@ async fn main() -> anyhow::Result<()> { let double_client = double::DoubleClient::new(client::Config::default(), to_double_server).spawn(); - let ctx = context::current(); for _ in 1..=5 { - tracing::info!("{:?}", double_client.double(ctx, 1).await?); + tracing::info!("{:?}", double_client.double(&mut context::ClientContext::current(), 1).await?); } tracer_provider.shutdown()?; diff --git a/tarpc/src/client.rs b/tarpc/src/client.rs index d14edf8ca..164dd5533 100644 --- a/tarpc/src/client.rs +++ b/tarpc/src/client.rs @@ -128,7 +128,7 @@ where otel.kind = "client", otel.name = %request.name()) )] - pub async fn call(&self, mut ctx: context::Context, request: Req) -> Result { + pub async fn call(&self, ctx: &mut context::SharedContext, request: Req) -> Result { let span = Span::current(); ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| { tracing::trace!( @@ -153,7 +153,7 @@ where }; self.to_dispatch .send(DispatchRequest { - ctx, + ctx: ctx.clone(), span, request_id, request, @@ -419,7 +419,7 @@ where cx: &mut Context<'_>, ) -> Poll, ChannelError>>> { if self.in_flight_requests().len() >= self.config.max_in_flight_requests { - tracing::info!( + tracing::debug!( "At in-flight request capacity ({}/{}).", self.in_flight_requests().len(), self.config.max_in_flight_requests @@ -437,7 +437,7 @@ where Some(request) => { if request.response_completion.is_closed() { let _entered = request.span.enter(); - tracing::info!("AbortRequest"); + tracing::debug!("AbortRequest"); continue; } @@ -457,7 +457,7 @@ where fn poll_next_cancellation( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll>>> { + ) -> Poll>>> { ready!(self.ensure_writeable(cx)?); loop { @@ -513,16 +513,18 @@ where let request = ClientMessage::Request(Request { id: request_id, message: request, - context: context::Context { - deadline: ctx.deadline, - trace_context: ctx.trace_context, - }, + context: ctx.clone(), }); + + //TODO: Feels like we could avoid either saving the request context in insert_request + // or submitting the context in start_request. + let full_context = context::ClientContext::new(ctx); + self.in_flight_requests() - .insert_request(request_id, ctx, span.clone(), response_completion) + .insert_request(request_id, full_context, span.clone(), response_completion) .expect("Request IDs should be unique"); match self.start_send(request) { - Ok(()) => tracing::info!("SendRequest"), + Ok(()) => tracing::debug!("SendRequest"), Err(e) => { self.in_flight_requests() .complete_request(request_id, Err(RpcError::Send(Box::new(e)))); @@ -553,7 +555,7 @@ where }; self.start_send(cancel) .map_err(|e| ChannelError::Write(Arc::new(e)))?; - tracing::info!("CancelRequest"); + tracing::debug!("CancelRequest"); Poll::Ready(Some(Ok(()))) } @@ -564,7 +566,7 @@ where response.message.map_err(RpcError::Server), ) { let _entered = span.enter(); - tracing::info!("ReceiveResponse"); + tracing::debug!("ReceiveResponse"); return true; } false @@ -594,7 +596,7 @@ where }) => { let _entered = span.enter(); if response_completion.is_closed() { - tracing::info!("AbortRequest"); + tracing::debug!("AbortRequest"); } else { tracing::warn!("RpcError::Channel"); let _ = response_completion.send(Err(RpcError::Channel(e.clone()))); @@ -612,15 +614,15 @@ where loop { match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) { (Poll::Ready(None), _) => { - tracing::info!("Shutdown: read half closed, so shutting down."); + tracing::debug!("Shutdown: read half closed, so shutting down."); return Poll::Ready(Ok(())); } (read, Poll::Ready(None)) => { if self.in_flight_requests.is_empty() { - tracing::info!("Shutdown: write half closed, and no requests in flight."); + tracing::debug!("Shutdown: write half closed, and no requests in flight."); return Poll::Ready(Ok(())); } - tracing::info!( + tracing::debug!( "Shutdown: write half closed, and {} requests in flight.", self.in_flight_requests().len() ); @@ -648,7 +650,7 @@ where ) -> Poll>> { loop { if let Some(e) = self.terminal_error_mut() { - tracing::info!("RpcError::Channel"); + tracing::debug!("RpcError::Channel"); let e: ChannelError = e .clone() .downcast() @@ -669,7 +671,7 @@ where /// the lifecycle of the request. #[derive(Debug)] struct DispatchRequest { - pub ctx: context::Context, + pub ctx: context::SharedContext, pub span: Span, pub request_id: u64, pub request: Req, @@ -684,7 +686,6 @@ mod tests { use crate::{ ChannelError, ClientMessage, Response, client::{Config, in_flight_requests::InFlightRequests}, - context::{self, current}, transport::{self, channel::UnboundedChannel}, }; use assert_matches::assert_matches; @@ -705,6 +706,7 @@ mod tests { oneshot, }; use tracing::Span; + use crate::context::{ClientContext, SharedContext}; #[tokio::test] async fn response_completes_request_future() { @@ -714,7 +716,7 @@ mod tests { dispatch .in_flight_requests - .insert_request(0, context::current(), Span::current(), tx) + .insert_request(0, ClientContext::current(), Span::current(), tx) .unwrap(); server_channel .send(Response { @@ -881,7 +883,7 @@ mod tests { let (dispatch, channel, _server_channel) = set_up(); drop(dispatch); // error on send - let resp = channel.call(current(), "hi".to_string()).await; + let resp = channel.call(&mut ClientContext::current(), "hi".to_string()).await; assert_matches!(resp, Err(RpcError::Shutdown)); } @@ -1091,7 +1093,7 @@ mod tests { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: context::current(), + ctx: SharedContext::current(), span: Span::current(), request_id, request: request.to_string(), @@ -1116,7 +1118,7 @@ mod tests { let request_id = u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap(); let request = DispatchRequest { - ctx: context::current(), + ctx: SharedContext::current(), span: Span::current(), request_id, request: request.to_string(), diff --git a/tarpc/src/client/in_flight_requests.rs b/tarpc/src/client/in_flight_requests.rs index 1776a74a0..a368a5a48 100644 --- a/tarpc/src/client/in_flight_requests.rs +++ b/tarpc/src/client/in_flight_requests.rs @@ -29,7 +29,7 @@ impl Default for InFlightRequests { #[derive(Debug)] struct RequestData { - ctx: context::Context, + ctx: context::ClientContext, span: Span, response_completion: oneshot::Sender, /// The key to remove the timer for the request's deadline. @@ -56,7 +56,7 @@ impl InFlightRequests { pub fn insert_request( &mut self, request_id: u64, - ctx: context::Context, + ctx: context::ClientContext, span: Span, response_completion: oneshot::Sender, ) -> Result<(), AlreadyExistsError> { @@ -106,7 +106,7 @@ impl InFlightRequests { /// Cancels a request without completing (typically used when a request handle was dropped /// before the request completed). - pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::Context, Span)> { + pub fn cancel_request(&mut self, request_id: u64) -> Option<(context::ClientContext, Span)> { if let Some(request_data) = self.request_data.remove(&request_id) { self.request_data.compact(0.1); self.deadlines.remove(&request_data.deadline_key); diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index 85746b7f2..c7dc12008 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -24,7 +24,7 @@ pub trait Stub { type Resp; /// Calls a remote service. - async fn call(&self, ctx: context::Context, request: Self::Req) + async fn call(&self, ctx: &mut context::ClientContext, request: Self::Req) -> Result; } @@ -35,7 +35,7 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, ctx: context::Context, request: Req) -> Result { + async fn call(&self, ctx: &mut context::ClientContext, request: Req) -> Result { Self::call(self, ctx, request).await } } @@ -46,7 +46,13 @@ where { type Req = S::Req; type Resp = S::Resp; - async fn call(&self, ctx: context::Context, req: Self::Req) -> Result { - self.clone().serve(ctx, req).await.map_err(RpcError::Server) + async fn call(&self, ctx: &mut context::ClientContext, req: Self::Req) -> Result { + let mut server_ctx = context::ServerContext::new(ctx.shared_context.clone()); + + let res = self.clone().serve(&mut server_ctx, req).await.map_err(RpcError::Server); + + ctx.shared_context = server_ctx.shared_context; + + res } } diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index d28a3c137..bf70ebe2a 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -20,7 +20,7 @@ mod round_robin { async fn call( &self, - ctx: context::Context, + ctx: &mut context::ClientContext, request: Self::Req, ) -> Result { let next = self.stubs.next(); @@ -119,7 +119,7 @@ mod consistent_hash { async fn call( &self, - ctx: context::Context, + ctx: &mut context::ClientContext, request: Self::Req, ) -> Result { let index = usize::try_from(self.hasher.hash_one(&request) % self.stubs_len).expect( @@ -200,13 +200,13 @@ mod consistent_hash { )?; for _ in 0..2 { - let resp = stub.call(context::current(), 'a').await?; + let resp = stub.call(&mut context::ClientContext::current(), 'a').await?; assert_eq!(resp, 1); - let resp = stub.call(context::current(), 'b').await?; + let resp = stub.call(&mut context::ClientContext::current(), 'b').await?; assert_eq!(resp, 2); - let resp = stub.call(context::current(), 'c').await?; + let resp = stub.call(&mut context::ClientContext::current(), 'c').await?; assert_eq!(resp, 3); } diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index 145c14c1f..451544433 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -30,7 +30,7 @@ where type Req = Req; type Resp = Resp; - async fn call(&self, _: context::Context, request: Self::Req) -> Result { + async fn call(&self, _: &mut context::ClientContext, request: Self::Req) -> Result { self.responses .get(&request) .cloned() diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index a07b05fc5..d93daa156 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -18,7 +18,7 @@ where async fn call( &self, - ctx: context::Context, + ctx: &mut context::ClientContext, request: Self::Req, ) -> Result { let request = Arc::new(request); diff --git a/tarpc/src/context.rs b/tarpc/src/context.rs index 8e77cf223..91b0f4ee6 100644 --- a/tarpc/src/context.rs +++ b/tarpc/src/context.rs @@ -14,6 +14,7 @@ use std::{ convert::TryFrom, time::{Duration, Instant}, }; +use std::ops::{Deref, DerefMut}; use tracing_opentelemetry::OpenTelemetrySpanExt; /// A request context that carries request-scoped information like deadlines and trace information. @@ -21,10 +22,10 @@ use tracing_opentelemetry::OpenTelemetrySpanExt; /// /// The context should not be stored directly in a server implementation, because the context will /// be different for each request in scope. -#[derive(Clone, Copy, Debug)] +#[derive(Debug, Clone)] #[non_exhaustive] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct Context { +pub struct SharedContext { /// When the client expects the request to be complete by. The server should cancel the request /// if it is not complete by this time. #[cfg_attr(feature = "serde1", serde(default = "ten_seconds_from_now"))] @@ -38,6 +39,99 @@ pub struct Context { pub trace_context: trace::Context, } +/// Request context that carries request-scoped server side information like deadlines and trace information +/// as well as any server side extensions defined by the transport, hooks or service implementations. +/// It is build from the shared context sent from client to server. +/// +/// The context should not be stored directly in a server implementation, because the context will +/// be different for each request in scope. +#[derive(Debug)] +pub struct ServerContext { + /// Shared context sent from client to server which contains information used by both sides. + pub shared_context: SharedContext, + + /// Server side extensions that are not seen by the client + /// Transport implementations, hooks and service implementations + /// can use this to store per-request data, and communicate with eachother. + /// Note that this is NOT sent to the client, and they will always see an empty map here. + pub server_context: anymap3::Map, +} + +impl ServerContext { + /// Creates a new ServerContext from the given SharedContext with no extensions. + pub fn new(shared_context: SharedContext) -> Self { + Self { + shared_context, + server_context: anymap3::Map::new(), + } + } + + /// Creates a new ServerContext for the current shared context with no extensions. + pub fn current() -> Self { + Self::new(SharedContext::current()) + } +} + +impl Deref for ServerContext { + type Target = SharedContext; + + fn deref(&self) -> &Self::Target { + &self.shared_context + } +} +impl DerefMut for ServerContext { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.shared_context + } +} + + +/// Request context that carries request-scoped client side information like deadlines and trace information +/// as well as any server side extensions defined by the transport, hooks and stubs. +/// The shared part of the context is sent from client to server, while the client side extensions are only seen on the client side. +/// +/// The context should not be stored directly in a stub implementation, because the context will +/// be different for each request in scope. +#[derive(Debug)] +pub struct ClientContext { + /// Shared context sent from client to server which contains information used by both sides. + pub shared_context: SharedContext, + + /// Client side extensions that are not seen by the server + /// XXX, YYY, and ZZZ can use this to store per-request data, and communicate with eachother. + /// Note that this is NOT sent to the server, and they will always see an empty map here. + pub client_context: anymap3::Map, +} + +impl ClientContext { + /// Creates a new ServerContext from the given SharedContext with no extensions. + pub fn new(shared_context: SharedContext) -> Self { + Self { + shared_context, + client_context: anymap3::Map::new(), + } + } + + /// Creates a new ServerContext for the current shared context with no extensions. + pub fn current() -> Self { + Self::new(SharedContext::current()) + } +} + +impl Deref for ClientContext { + type Target = SharedContext; + + fn deref(&self) -> &Self::Target { + &self.shared_context + } +} + +impl DerefMut for ClientContext { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.shared_context + } +} + #[cfg(feature = "serde1")] mod absolute_to_relative_time { pub use serde::{Deserialize, Deserializer, Serialize, Serializer}; @@ -91,17 +185,12 @@ mod absolute_to_relative_time { } } -assert_impl_all!(Context: Send, Sync); +assert_impl_all!(SharedContext: Send, Sync); fn ten_seconds_from_now() -> Instant { Instant::now() + Duration::from_secs(10) } -/// Returns the context for the current request, or a default Context if no request is active. -pub fn current() -> Context { - Context::current() -} - #[derive(Clone)] struct Deadline(Instant); @@ -111,7 +200,7 @@ impl Default for Deadline { } } -impl Context { +impl SharedContext { /// Returns the context for the current request, or a default Context if no request is active. pub fn current() -> Self { let span = tracing::Span::current(); @@ -137,11 +226,11 @@ impl Context { pub(crate) trait SpanExt { /// Sets the given context on this span. Newly-created spans will be children of the given /// context's trace context. - fn set_context(&self, context: &Context); + fn set_context(&self, context: &SharedContext); } impl SpanExt for tracing::Span { - fn set_context(&self, context: &Context) { + fn set_context(&self, context: &SharedContext) { self.set_parent( opentelemetry::Context::new() .with_remote_span_context(opentelemetry::trace::SpanContext::new( diff --git a/tarpc/src/lib.rs b/tarpc/src/lib.rs index 7e1944305..a83efae02 100644 --- a/tarpc/src/lib.rs +++ b/tarpc/src/lib.rs @@ -125,7 +125,7 @@ //! //! impl World for HelloServer { //! // Each defined rpc generates an async fn that serves the RPC -//! async fn hello(self, _: context::Context, name: String) -> String { +//! async fn hello(self, _: &mut context::ServerContext, name: String) -> String { //! format!("Hello, {name}!") //! } //! } @@ -158,7 +158,7 @@ //! # struct HelloServer; //! # impl World for HelloServer { //! // Each defined rpc generates an async fn that serves the RPC -//! # async fn hello(self, _: context::Context, name: String) -> String { +//! # async fn hello(self, _: &mut context::ServerContext, name: String) -> String { //! # format!("Hello, {name}!") //! # } //! # } @@ -184,7 +184,8 @@ //! // The client has an RPC method for each RPC defined in the annotated trait. It takes the same //! // args as defined, with the addition of a Context, which is always the first arg. The Context //! // specifies a deadline and trace information which can be helpful in debugging requests. -//! let hello = client.hello(context::current(), "Stim".to_string()).await?; +//! let mut context = context::ClientContext::current(); +//! let hello = client.hello(&mut context, "Stim".to_string()).await?; //! //! println!("{hello}"); //! @@ -279,11 +280,11 @@ pub enum ClientMessage { } /// A request from a client to a server. -#[derive(Clone, Copy, Debug)] +#[derive(Debug)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] pub struct Request { /// Trace context, deadline, and other cross-cutting concerns. - pub context: context::Context, + pub context: context::SharedContext, /// Uniquely identifies the request across all requests sent over a single channel. pub id: u64, /// The request body. diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index d4551fd4b..8d0b08cbd 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -76,7 +76,7 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: context::Context, req: Self::Req) -> Result; + async fn serve(self, ctx: &mut context::ServerContext, req: Self::Req) -> Result; } /// A Serve wrapper around a Fn. @@ -102,10 +102,9 @@ impl Copy for ServeFn where F: Copy {} /// Creates a [`Serve`] wrapper around a `FnOnce(context::Context, Req) -> impl Future>`. -pub fn serve(f: F) -> ServeFn +pub fn serve(f: F) -> ServeFn where - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + for<'a> F: FnOnce(&'a mut context::ServerContext, Req) -> Pin> + 'a + Send>>, { ServeFn { f, @@ -113,16 +112,15 @@ where } } -impl Serve for ServeFn +impl Serve for ServeFn where Req: RequestName, - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + for<'a> F: FnOnce(&'a mut context::ServerContext, Req) -> Pin> + 'a + Send>>, { type Req = Req; type Resp = Resp; - async fn serve(self, ctx: context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut context::ServerContext, req: Req) -> Result { (self.f)(ctx, req).await } } @@ -220,7 +218,7 @@ where request.context.trace_context.new_child() }); let entered = span.enter(); - tracing::info!("ReceiveRequest"); + tracing::debug!("ReceiveRequest"); let start = self.in_flight_requests_mut().start_request( request.id, request.context.deadline, @@ -360,10 +358,11 @@ where /// let mut requests = server.requests(); /// tokio::spawn(async move { /// while let Some(Ok(request)) = requests.next().await { - /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }))); + /// tokio::spawn(request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed()))); /// } /// }); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// let mut context = context::ClientContext::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` fn requests(self) -> Requests @@ -399,12 +398,13 @@ where /// let client = client::new(client::Config::default(), tx).spawn(); /// let channel = BaseChannel::with_defaults(rx); /// tokio::spawn( - /// channel.execute(serve(|_, i: i32| async move { Ok(i + 1) })) + /// channel.execute(serve(|_, i: i32| async move { Ok(i + 1) }.boxed())) /// .for_each(|response| async move { /// tokio::spawn(response); - /// })); + /// }.boxed())); + /// let mut context = context::ClientContext::current(); /// assert_eq!( - /// client.call(context::current(), 1).await.unwrap(), + /// client.call(&mut context, 1).await.unwrap(), /// 2); /// } /// ``` @@ -450,7 +450,7 @@ where Poll::Ready(Some(request_id)) => { if let Some(span) = self.in_flight_requests_mut().remove_request(request_id) { let _entered = span.enter(); - tracing::info!("ResponseCancelled"); + tracing::debug!("ResponseCancelled"); } Ready } @@ -545,7 +545,7 @@ where .remove_request(response.request_id) { let _entered = span.enter(); - tracing::info!("SendResponse"); + tracing::debug!("SendResponse"); self.project() .transport .start_send(response) @@ -650,7 +650,7 @@ where response_guard.cancel = true; { let _entered = span.enter(); - tracing::info!("BeginRequest"); + tracing::debug!("BeginRequest"); } InFlightRequest { request, @@ -748,11 +748,12 @@ where /// let requests = BaseChannel::new(server::Config::default(), rx).requests(); /// let client = client::new(client::Config::default(), tx).spawn(); /// tokio::spawn( - /// requests.execute(serve(|_, i| async move { Ok(i + 1) })) + /// requests.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())) /// .for_each(|response| async move { /// tokio::spawn(response); - /// })); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// }.boxed())); + /// let mut context = context::ClientContext::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` pub fn execute(self, serve: S) -> impl Stream> @@ -855,11 +856,11 @@ impl InFlightRequest { /// tokio::spawn(async move { /// let mut requests = server.requests(); /// while let Some(Ok(in_flight_request)) = requests.next().await { - /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) })).await; + /// in_flight_request.execute(serve(|_, i| async move { Ok(i + 1) }.boxed())).await; /// } - /// /// }); - /// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + /// let mut context = context::ClientContext::current(); + /// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` /// @@ -881,16 +882,17 @@ impl InFlightRequest { }, } = self; span.record("otel.name", message.name()); + let mut full_context = context::ServerContext::new(context); let _ = Abortable::new( async move { - let message = serve.serve(context, message).await; - tracing::info!("CompleteRequest"); + let message = serve.serve(&mut full_context, message).await; + tracing::debug!("CompleteRequest"); let response = Response { request_id, message, }; let _ = response_tx.send(response).await; - tracing::info!("BufferResponse"); + tracing::debug!("BufferResponse"); }, abort_registration, ) @@ -1025,7 +1027,7 @@ mod tests { fn fake_request(req: Req) -> ClientMessage { ClientMessage::Request(Request { - context: context::current(), + context: context::SharedContext::current(), id: 0, message: req, }) @@ -1039,8 +1041,8 @@ mod tests { #[tokio::test] async fn test_serve() { - let serve = serve(|_, i| async move { Ok(i) }); - assert_matches!(serve.serve(context::current(), 7).await, Ok(7)); + let serve = serve(|_, i| async move { Ok(i) }.boxed()); + assert_matches!(serve.serve(&mut context::ServerContext::current(), 7).await, Ok(7)); } #[tokio::test] @@ -1049,7 +1051,7 @@ mod tests { impl BeforeRequest for SetDeadline { async fn before( &mut self, - ctx: &mut context::Context, + ctx: &mut context::ServerContext, _: &Req, ) -> Result<(), ServerError> { ctx.deadline = self.0; @@ -1060,14 +1062,14 @@ mod tests { let some_time = Instant::now() + Duration::from_secs(37); let some_other_time = Instant::now() + Duration::from_secs(83); - let serve = serve(move |ctx: context::Context, i| async move { + let serve = serve(move |ctx: &mut context::ServerContext, i| async move { assert_eq!(ctx.deadline, some_time); Ok(i) - }); + }.boxed()); let deadline_hook = serve.before(SetDeadline(some_time)); - let mut ctx = context::current(); + let mut ctx = context::ServerContext::current(); ctx.deadline = some_other_time; - deadline_hook.serve(ctx, 7).await?; + deadline_hook.serve(&mut ctx, 7).await?; Ok(()) } @@ -1088,7 +1090,7 @@ mod tests { impl BeforeRequest for PrintLatency { async fn before( &mut self, - _: &mut context::Context, + _: &mut context::ServerContext, _: &Req, ) -> Result<(), ServerError> { self.start = Instant::now(); @@ -1096,26 +1098,26 @@ mod tests { } } impl AfterRequest for PrintLatency { - async fn after(&mut self, _: &mut context::Context, _: &mut Result) { - tracing::info!("Elapsed: {:?}", self.start.elapsed()); + async fn after(&mut self, _: &mut context::ServerContext, _: &mut Result) { + tracing::debug!("Elapsed: {:?}", self.start.elapsed()); } } - let serve = serve(move |_: context::Context, i| async move { Ok(i) }); + let serve = serve(move |_: &mut context::ServerContext, i| async move { Ok(i) }.boxed()); serve .before_and_after(PrintLatency::new()) - .serve(context::current(), 7) + .serve(&mut context::ServerContext::current(), 7) .await?; Ok(()) } #[tokio::test] async fn serve_before_error_aborts_request() -> anyhow::Result<()> { - let serve = serve(|_, _| async { panic!("Shouldn't get here") }); - let deadline_hook = serve.before(|_: &mut context::Context, _: &i32| async { + let serve = serve(|_, _| async { panic!("Shouldn't get here") }.boxed()); + let deadline_hook = serve.before(|_: &mut context::ServerContext, _: &i32| async { Err(ServerError::new(io::ErrorKind::Other, "oops".into())) }); - let resp: Result = deadline_hook.serve(context::current(), 7).await; + let resp: Result = deadline_hook.serve(&mut context::ServerContext::current(), 7).await; assert_matches!(resp, Err(_)); Ok(()) } @@ -1128,14 +1130,14 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); assert_matches!( channel.as_mut().start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: () }), Err(AlreadyExistsError) @@ -1151,7 +1153,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1159,7 +1161,7 @@ mod tests { .as_mut() .start_request(Request { id: 1, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1182,7 +1184,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1211,7 +1213,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1253,7 +1255,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1276,7 +1278,7 @@ mod tests { .as_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1320,7 +1322,7 @@ mod tests { Poll::Ready(Some(Ok(request))) => request, result => panic!("Unexpected result: {result:?}"), }; - request.execute(serve(|_, _| async { Ok(()) })).await; + request.execute(serve(|_, _| async { Ok(()) }.boxed())).await; assert!( requests .as_mut() @@ -1341,7 +1343,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1371,7 +1373,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1392,7 +1394,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 0, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); @@ -1411,7 +1413,7 @@ mod tests { .channel_pin_mut() .start_request(Request { id: 1, - context: context::current(), + context: context::SharedContext::current(), message: (), }) .unwrap(); diff --git a/tarpc/src/server/in_flight_requests.rs b/tarpc/src/server/in_flight_requests.rs index 4abf8b1e2..252262a83 100644 --- a/tarpc/src/server/in_flight_requests.rs +++ b/tarpc/src/server/in_flight_requests.rs @@ -74,7 +74,7 @@ impl InFlightRequests { self.request_data.compact(0.1); abort_handle.abort(); self.deadlines.remove(&deadline_key); - tracing::info!("ReceiveCancel"); + tracing::debug!("ReceiveCancel"); true } else { false diff --git a/tarpc/src/server/incoming.rs b/tarpc/src/server/incoming.rs index 428eb1a7d..cb01021f5 100644 --- a/tarpc/src/server/incoming.rs +++ b/tarpc/src/server/incoming.rs @@ -63,9 +63,10 @@ where /// /// let incoming = stream::once(async move { /// BaseChannel::new(server::Config::default(), rx) -/// }).execute(serve(|_, i| async move { Ok(i + 1) })); +/// }).execute(serve(|_, i| async move { Ok(i + 1) }.boxed())); /// tokio::spawn(spawn_incoming(incoming)); -/// assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); +/// let mut context = context::ClientContext::current(); +/// assert_eq!(client.call(&mut context, 1).await.unwrap(), 2); /// } /// ``` pub async fn spawn_incoming( diff --git a/tarpc/src/server/limits/channels_per_key.rs b/tarpc/src/server/limits/channels_per_key.rs index 46e9d9fa1..64b644278 100644 --- a/tarpc/src/server/limits/channels_per_key.rs +++ b/tarpc/src/server/limits/channels_per_key.rs @@ -16,7 +16,7 @@ use std::{ collections::hash_map::Entry, convert::TryFrom, fmt, hash::Hash, marker::Unpin, pin::Pin, }; use tokio::sync::mpsc; -use tracing::{debug, info, trace}; +use tracing::{debug, trace}; /// An [`Incoming`](crate::server::incoming::Incoming) stream that drops new channels based on /// per-key limits. @@ -198,7 +198,7 @@ where Entry::Occupied(mut o) => { let count = o.get().strong_count(); if count >= usize::try_from(*self_.channels_per_key).unwrap() { - info!( + debug!( channel_filter_key = %key, open_channels = count, max_open_channels = *self_.channels_per_key, diff --git a/tarpc/src/server/limits/requests_per_channel.rs b/tarpc/src/server/limits/requests_per_channel.rs index b559f6a7d..bd9c103b0 100644 --- a/tarpc/src/server/limits/requests_per_channel.rs +++ b/tarpc/src/server/limits/requests_per_channel.rs @@ -60,7 +60,7 @@ where match ready!(self.as_mut().project().inner.poll_next(cx)?) { Some(r) => { let _entered = r.span.enter(); - tracing::info!( + tracing::debug!( in_flight_requests = self.as_mut().in_flight_requests(), "ThrottleRequest", ); diff --git a/tarpc/src/server/request_hook.rs b/tarpc/src/server/request_hook.rs index 66cf2878c..38b0998bf 100644 --- a/tarpc/src/server/request_hook.rs +++ b/tarpc/src/server/request_hook.rs @@ -43,12 +43,12 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; /// use std::io; /// - /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }) - /// .before(|_ctx: &mut context::Context, req: &i32| { + /// let serve = serve(|_ctx, i| async move { Ok(i + 1) }.boxed()) + /// .before(|_ctx: &mut context::ServerContext, req: &i32| { /// future::ready( /// if *req == 1 { /// Err(ServerError::new( @@ -58,7 +58,8 @@ pub trait RequestHook: Serve { /// Ok(()) /// }) /// }); - /// let response = serve.serve(context::current(), 1); + /// let mut context = context::ServerContext::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` fn before(self, hook: Hook) -> HookThenServe @@ -80,7 +81,7 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, request_hook::RequestHook, serve}}; /// use std::io; /// @@ -93,15 +94,15 @@ pub trait RequestHook: Serve { /// } else { /// Ok(i + 1) /// } - /// }) - /// .after(|_ctx: &mut context::Context, resp: &mut Result| { + /// }.boxed()) + /// .after(|_ctx: &mut context::ServerContext, resp: &mut Result| { /// if let Err(e) = resp { /// eprintln!("server error: {e:?}"); /// } /// future::ready(()) /// }); - /// - /// let response = serve.serve(context::current(), 1); + /// let mut context = context::ServerContext::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_err()); /// ``` fn after(self, hook: Hook) -> ServeThenHook @@ -123,7 +124,7 @@ pub trait RequestHook: Serve { /// # Example /// /// ```rust - /// use futures::{executor::block_on, future}; + /// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{ /// context, ServerError, /// server::{Serve, serve, request_hook::{BeforeRequest, AfterRequest, RequestHook}} @@ -133,7 +134,7 @@ pub trait RequestHook: Serve { /// struct PrintLatency(Instant); /// /// impl BeforeRequest for PrintLatency { - /// async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { + /// async fn before(&mut self, _: &mut context::ServerContext, _: &Req) -> Result<(), ServerError> { /// self.0 = Instant::now(); /// Ok(()) /// } @@ -142,7 +143,7 @@ pub trait RequestHook: Serve { /// impl AfterRequest for PrintLatency { /// async fn after( /// &mut self, - /// _: &mut context::Context, + /// _: &mut context::ServerContext, /// _: &mut Result, /// ) { /// tracing::info!("Elapsed: {:?}", self.0.elapsed()); @@ -151,8 +152,9 @@ pub trait RequestHook: Serve { /// /// let serve = serve(|_ctx, i| async move { /// Ok(i + 1) - /// }).before_and_after(PrintLatency(Instant::now())); - /// let response = serve.serve(context::current(), 1); + /// }.boxed()).before_and_after(PrintLatency(Instant::now())); + /// let mut context = context::ServerContext::current(); + /// let response = serve.serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// ``` fn before_and_after( diff --git a/tarpc/src/server/request_hook/after.rs b/tarpc/src/server/request_hook/after.rs index b2ef9ccbd..d9e676ca4 100644 --- a/tarpc/src/server/request_hook/after.rs +++ b/tarpc/src/server/request_hook/after.rs @@ -15,15 +15,15 @@ pub trait AfterRequest { /// The function that is called after request execution. /// /// The hook can modify the request context and the response. - async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result); + async fn after(&mut self, ctx: &mut context::ServerContext, resp: &mut Result); } impl AfterRequest for F where - F: FnMut(&mut context::Context, &mut Result) -> Fut, + F: FnMut(&mut context::ServerContext, &mut Result) -> Fut, Fut: Future, { - async fn after(&mut self, ctx: &mut context::Context, resp: &mut Result) { + async fn after(&mut self, ctx: &mut context::ServerContext, resp: &mut Result) { self(ctx, resp).await } } @@ -59,14 +59,14 @@ where async fn serve( self, - mut ctx: context::Context, + ctx: &mut context::ServerContext, req: Serv::Req, ) -> Result { let ServeThenHook { serve, mut hook, .. } = self; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/server/request_hook/before.rs b/tarpc/src/server/request_hook/before.rs index e72e28a42..4a1b2ad8a 100644 --- a/tarpc/src/server/request_hook/before.rs +++ b/tarpc/src/server/request_hook/before.rs @@ -19,7 +19,7 @@ pub trait BeforeRequest { /// /// This function can also modify the request context. This could be used, for example, to /// enforce a maximum deadline on all requests. - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError>; + async fn before(&mut self, ctx: &mut context::ServerContext, req: &Req) -> Result<(), ServerError>; } /// A list of hooks that run in order before request execution. @@ -34,7 +34,7 @@ pub trait BeforeRequestList: BeforeRequest { /// Same as `then`, but helps the compiler with type inference when Next is a closure. fn then_fn< - Next: FnMut(&mut context::Context, &Req) -> Fut, + Next: FnMut(&mut context::ServerContext, &Req) -> Fut, Fut: Future>, >( self, @@ -56,10 +56,10 @@ pub trait BeforeRequestList: BeforeRequest { impl BeforeRequest for F where - F: FnMut(&mut context::Context, &Req) -> Fut, + F: FnMut(&mut context::ServerContext, &Req) -> Fut, Fut: Future>, { - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { + async fn before(&mut self, ctx: &mut context::ServerContext, req: &Req) -> Result<(), ServerError> { self(ctx, req).await } } @@ -87,13 +87,13 @@ where async fn serve( self, - mut ctx: context::Context, + ctx: &mut context::ServerContext, req: Self::Req, ) -> Result { let HookThenServe { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; serve.serve(ctx, req).await } } @@ -103,7 +103,7 @@ where /// Example /// /// ```rust -/// use futures::{executor::block_on, future}; +/// use futures::{executor::block_on, future, FutureExt}; /// use tarpc::{context, ServerError, server::{Serve, serve, request_hook::{self, /// BeforeRequest, BeforeRequestList}}}; /// use std::{cell::Cell, io}; @@ -120,8 +120,9 @@ where /// i.set(2); /// Ok(()) /// }) -/// .serving(serve(|_ctx, i| async move { Ok(i + 1) })); -/// let response = serve.clone().serve(context::current(), 1); +/// .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); +/// let mut context = context::ServerContext::current(); +/// let response = serve.clone().serve(&mut context, 1); /// assert!(block_on(response).is_ok()); /// assert!(i.get() == 2); /// ``` @@ -140,7 +141,7 @@ pub struct BeforeRequestNil; impl, Rest: BeforeRequest> BeforeRequest for BeforeRequestCons { - async fn before(&mut self, ctx: &mut context::Context, req: &Req) -> Result<(), ServerError> { + async fn before(&mut self, ctx: &mut context::ServerContext, req: &Req) -> Result<(), ServerError> { let BeforeRequestCons(first, rest) = self; first.before(ctx, req).await?; rest.before(ctx, req).await?; @@ -149,7 +150,7 @@ impl, Rest: BeforeRequest> BeforeRequest BeforeRequest for BeforeRequestNil { - async fn before(&mut self, _: &mut context::Context, _: &Req) -> Result<(), ServerError> { + async fn before(&mut self, _: &mut context::ServerContext, _: &Req) -> Result<(), ServerError> { Ok(()) } } @@ -209,8 +210,9 @@ fn before_request_list() { i.set(2); Ok(()) }) - .serving(serve(|_ctx, i| async move { Ok(i + 1) })); - let response = serve.clone().serve(context::current(), 1); + .serving(serve(|_ctx, i| async move { Ok(i + 1) }.boxed())); + let mut context = context::ServerContext::current(); + let response = serve.clone().serve(&mut context, 1); assert!(block_on(response).is_ok()); assert!(i.get() == 2); } diff --git a/tarpc/src/server/request_hook/before_and_after.rs b/tarpc/src/server/request_hook/before_and_after.rs index 0761a7df3..af37427af 100644 --- a/tarpc/src/server/request_hook/before_and_after.rs +++ b/tarpc/src/server/request_hook/before_and_after.rs @@ -46,13 +46,13 @@ where type Req = Req; type Resp = Resp; - async fn serve(self, mut ctx: context::Context, req: Req) -> Result { + async fn serve(self, ctx: &mut context::ServerContext, req: Req) -> Result { let HookThenServeThenHook { serve, mut hook, .. } = self; - hook.before(&mut ctx, &req).await?; + hook.before(ctx, &req).await?; let mut resp = serve.serve(ctx, req).await; - hook.after(&mut ctx, &mut resp).await; + hook.after(ctx, &mut resp).await; resp } } diff --git a/tarpc/src/server/testing.rs b/tarpc/src/server/testing.rs index db167c42e..70c4e7f69 100644 --- a/tarpc/src/server/testing.rs +++ b/tarpc/src/server/testing.rs @@ -92,7 +92,7 @@ impl FakeChannel>, Response> { let (request_cancellation, _) = cancellations(); self.stream.push_back(Ok(TrackedRequest { request: Request { - context: context::Context { + context: context::SharedContext { deadline: Instant::now(), trace_context: Default::default(), }, diff --git a/tarpc/src/transport/channel.rs b/tarpc/src/transport/channel.rs index 0268300dc..e785ff49a 100644 --- a/tarpc/src/transport/channel.rs +++ b/tarpc/src/transport/channel.rs @@ -198,7 +198,7 @@ mod tests { format!("{request:?} is not an int"), ) }) - })) + }.boxed())) .for_each(|channel| async move { tokio::spawn(channel.for_each(|response| response)); }), @@ -206,8 +206,8 @@ mod tests { let client = client::new(client::Config::default(), client_channel).spawn(); - let response1 = client.call(context::current(), "123".into()).await; - let response2 = client.call(context::current(), "abc".into()).await; + let response1 = client.call(&mut context::ClientContext::current(), "123".into()).await; + let response2 = client.call(&mut context::ClientContext::current(), "abc".into()).await; trace!("response1: {:?}, response2: {:?}", response1, response2); diff --git a/tarpc/tests/dataservice.rs b/tarpc/tests/dataservice.rs index 18bb3a997..e4cbf338d 100644 --- a/tarpc/tests/dataservice.rs +++ b/tarpc/tests/dataservice.rs @@ -22,7 +22,7 @@ pub trait ColorProtocol { struct ColorServer; impl ColorProtocol for ColorServer { - async fn get_opposite_color(self, _: context::Context, color: TestData) -> TestData { + async fn get_opposite_color(self, _: &mut context::ServerContext, color: TestData) -> TestData { match color { TestData::White => TestData::Black, TestData::Black => TestData::White, @@ -53,7 +53,7 @@ async fn test_call() -> anyhow::Result<()> { let client = ColorProtocolClient::new(client::Config::default(), transport).spawn(); let color = client - .get_opposite_color(context::current(), TestData::White) + .get_opposite_color(&mut context::ClientContext::current(), TestData::White) .await?; assert_eq!(color, TestData::Black); diff --git a/tarpc/tests/service_functional.rs b/tarpc/tests/service_functional.rs index 06542b43b..77b4606e8 100644 --- a/tarpc/tests/service_functional.rs +++ b/tarpc/tests/service_functional.rs @@ -22,11 +22,11 @@ trait Service { struct Server; impl Service for Server { - async fn add(self, _: context::Context, x: i32, y: i32) -> i32 { + async fn add(self, _: &mut context::ServerContext, x: i32, y: i32) -> i32 { x + y } - async fn hey(self, _: context::Context, name: String) -> String { + async fn hey(self, _: &mut context::ServerContext, name: String) -> String { format!("Hey, {name}.") } } @@ -38,10 +38,10 @@ async fn sequential() { let channel = BaseChannel::with_defaults(rx); tokio::spawn( channel - .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) })) + .execute(tarpc::server::serve(|_, i: u32| async move { Ok(i + 1) }.boxed())) .for_each(|response| response), ); - assert_eq!(client.call(context::current(), 1).await.unwrap(), 2); + assert_eq!(client.call(&mut context::ClientContext::current(), 1).await.unwrap(), 2); } #[tokio::test] @@ -55,7 +55,7 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { struct LoopServer; impl Loop for LoopServer { - async fn r#loop(self, _: context::Context) { + async fn r#loop(self, _: &mut context::ServerContext) { loop { futures::pending!(); } @@ -71,9 +71,9 @@ async fn dropped_channel_aborts_in_flight_requests() -> anyhow::Result<()> { tokio::spawn(async move { let client = LoopClient::new(client::Config::default(), tx).spawn(); - let mut ctx = context::current(); + let mut ctx = context::ClientContext::current(); ctx.deadline = Instant::now() + Duration::from_secs(60 * 60); - let _ = client.r#loop(ctx).await; + let _ = client.r#loop(&mut ctx).await; }); let mut requests = BaseChannel::with_defaults(rx).requests(); @@ -112,9 +112,9 @@ async fn serde_tcp() -> anyhow::Result<()> { let transport = serde_transport::tcp::connect(addr, Json::default).await?; let client = ServiceClient::new(client::Config::default(), transport).spawn(); - assert_matches!(client.add(context::current(), 1, 2).await, Ok(3)); + assert_matches!(client.add(&mut context::ClientContext::current(), 1, 2).await, Ok(3)); assert_matches!( - client.hey(context::current(), "Tim".to_string()).await, + client.hey(&mut context::ClientContext::current(), "Tim".to_string()).await, Ok(ref s) if s == "Hey, Tim." ); @@ -145,8 +145,8 @@ async fn serde_uds() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), transport).spawn(); // Save results using socket so we can clean the socket even if our test assertions fail - let res1 = client.add(context::current(), 1, 2).await; - let res2 = client.hey(context::current(), "Tim".to_string()).await; + let res1 = client.add(&mut context::ClientContext::current(), 1, 2).await; + let res2 = client.hey(&mut context::ClientContext::current(), "Tim".to_string()).await; assert_matches!(res1, Ok(3)); assert_matches!(res2, Ok(ref s) if s == "Hey, Tim."); @@ -169,12 +169,15 @@ async fn concurrent() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); - let req3 = client.hey(context::current(), "Tim".to_string()); + let mut context = context::ClientContext::current(); + let req1 = client.add(&mut context, 1, 2); assert_matches!(req1.await, Ok(3)); + + let req2 = client.add(&mut context, 3, 4); assert_matches!(req2.await, Ok(7)); + + let req3 = client.hey(&mut context, "Tim".to_string()); assert_matches!(req3.await, Ok(ref s) if s == "Hey, Tim."); Ok(()) @@ -195,9 +198,13 @@ async fn concurrent_join() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); - let req3 = client.hey(context::current(), "Tim".to_string()); + let mut context1 = context::ClientContext::current(); + let mut context2 = context::ClientContext::current(); + let mut context3 = context::ClientContext::current(); + + let req1 = client.add(&mut context1, 1, 2); + let req2 = client.add(&mut context2, 3, 4); + let req3 = client.hey(&mut context3, "Tim".to_string()); let (resp1, resp2, resp3) = join!(req1, req2, req3); assert_matches!(resp1, Ok(3)); @@ -225,8 +232,11 @@ async fn concurrent_join_all() -> anyhow::Result<()> { let client = ServiceClient::new(client::Config::default(), tx).spawn(); - let req1 = client.add(context::current(), 1, 2); - let req2 = client.add(context::current(), 3, 4); + let mut context1 = context::ClientContext::current(); + let mut context2 = context::ClientContext::current(); + + let req1 = client.add(&mut context1, 1, 2); + let req2 = client.add(&mut context2, 3, 4); let responses = join_all(vec![req1, req2]).await; assert_matches!(responses[0], Ok(3)); @@ -245,7 +255,7 @@ async fn counter() -> anyhow::Result<()> { struct CountService(u32); impl Counter for &mut CountService { - async fn count(self, _: context::Context) -> u32 { + async fn count(self, _: &mut context::ServerContext) -> u32 { self.0 += 1; self.0 } @@ -262,8 +272,8 @@ async fn counter() -> anyhow::Result<()> { }); let client = CounterClient::new(client::Config::default(), tx).spawn(); - assert_matches!(client.count(context::current()).await, Ok(1)); - assert_matches!(client.count(context::current()).await, Ok(2)); + assert_matches!(client.count(&mut context::ClientContext::current()).await, Ok(1)); + assert_matches!(client.count(&mut context::ClientContext::current()).await, Ok(2)); Ok(()) }