Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions example-service/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
// https://opensource.org/licenses/MIT.

use clap::Parser;
use futures::{SinkExt, future};
use service::{WorldClient, init_tracing};
use std::{net::SocketAddr, time::Duration};
use tarpc::{client, context, tokio_serde::formats::Json};
use tarpc::context::{SharedContext};
use tarpc::{client, tokio_serde::formats::Json};
use tokio::time::sleep;
use tracing::Instrument;

Expand All @@ -29,15 +31,20 @@ async fn main() -> anyhow::Result<()> {
let mut transport = tarpc::serde_transport::tcp::connect(flags.server_addr, Json::default);
transport.config_mut().max_frame_length(usize::MAX);

let transport = transport.await?;

// WorldClient is generated by the service attribute. It has a constructor `new` that takes a
// config and any Transport as input.
let client = WorldClient::new(client::Config::default(), transport.await?).spawn();
let client = WorldClient::new(client::Config::default(), transport).spawn();

let hello = async move {
let mut context = SharedContext::current();
let mut context2 = SharedContext::current();

// Send the request twice, just to be safe! ;)
tokio::select! {
hello1 = client.hello(context::current(), format!("{}1", flags.name)) => { hello1 }
hello2 = client.hello(context::current(), format!("{}2", flags.name)) => { hello2 }
hello1 = client.hello(&mut context, format!("{}1", flags.name)) => { hello1 }
hello2 = client.hello(&mut context2, format!("{}2", flags.name)) => { hello2 }
}
}
.instrument(tracing::info_span!("Two Hellos"))
Expand Down
11 changes: 7 additions & 4 deletions example-service/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ use rand::{
thread_rng,
};
use service::{World, init_tracing};
use std::ops::Deref;
use std::{
net::{IpAddr, Ipv6Addr, SocketAddr},
time::Duration,
};
use tarpc::context::{SharedContext};
use tarpc::{
context,
ClientMessage, context,
server::{self, Channel, incoming::Incoming},
tokio_serde::formats::Json,
};
Expand All @@ -35,7 +37,8 @@ struct Flags {
struct HelloServer(SocketAddr);

impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
type Context = SharedContext;
async fn hello(self, _: &mut Self::Context, name: String) -> String {
let sleep_time =
Duration::from_millis(Uniform::new_inclusive(1, 10).sample(&mut thread_rng()));
time::sleep(sleep_time).await;
Expand Down Expand Up @@ -64,11 +67,11 @@ async fn main() -> anyhow::Result<()> {
.filter_map(|r| future::ready(r.ok()))
.map(server::BaseChannel::with_defaults)
// Limit channels to 1 per IP.
.max_channels_per_key(1, |t| t.transport().peer_addr().unwrap().ip())
.max_channels_per_key(1, |t| t.transport().get_ref().peer_addr().unwrap().ip())
// serve is generated by the service attribute. It takes as input any type implementing
// the generated World trait.
.map(|channel| {
let server = HelloServer(channel.transport().peer_addr().unwrap());
let server = HelloServer(channel.transport().get_ref().peer_addr().unwrap());
channel.execute(server.serve()).for_each(spawn)
})
// Max 10 channels.
Expand Down
1 change: 1 addition & 0 deletions plugins/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,6 @@ proc-macro = true
[dev-dependencies]
assert-type-eq = "0.1.0"
futures = "0.3"
futures-util = "0.3.31"
serde = { version = "1.0", features = ["derive"] }
tarpc = { path = "../tarpc", features = ["serde1"] }
57 changes: 36 additions & 21 deletions plugins/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,10 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec<Vec<&Attribute>> {
/// # Example
///
/// ```no_run
/// use tarpc::{client, transport, service, server::{self, Channel}, context::Context};
/// use tarpc::{client, transport, service, server::{self, Channel}};
/// use futures_util::{TryStreamExt, sink::SinkExt};///
///
/// use tarpc::context::SharedContext;
///
/// #[service]
/// pub trait Calculator {
Expand All @@ -401,7 +404,8 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec<Vec<&Attribute>> {
/// #[derive(Clone)]
/// struct CalculatorServer;
/// impl Calculator for CalculatorServer {
/// async fn add(self, context: Context, a: i32, b: i32) -> i32 {
/// type Context = SharedContext;
/// async fn add(self, context: &mut Self::Context, a: i32, b: i32) -> i32 {
/// a + b
/// }
/// }
Expand Down Expand Up @@ -558,7 +562,7 @@ impl ServiceGenerator<'_> {
)| {
quote! {
#( #attrs )*
async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output;
async fn #ident(self, context: &mut Self::Context, #( #args ),*) -> #output;
}
},
);
Expand All @@ -567,6 +571,8 @@ impl ServiceGenerator<'_> {
quote! {
#( #attrs )*
#vis trait #service_ident: ::core::marker::Sized {
type Context: ::tarpc::context::ExtractContext<::tarpc::context::SharedContext>;

#( #rpc_fns )*

/// Returns a serving function to use with
Expand All @@ -577,11 +583,11 @@ impl ServiceGenerator<'_> {
}

#[doc = #stub_doc]
#vis trait #client_stub_ident: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
#vis trait #client_stub_ident<ClientCtx>: ::tarpc::client::stub::Stub<ClientCtx = ClientCtx, Req = #request_ident, Resp = #response_ident> {
}

impl<S> #client_stub_ident for S
where S: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
impl<S, ClientCtx> #client_stub_ident<ClientCtx> for S
where S: ::tarpc::client::stub::Stub<ClientCtx = ClientCtx, Req = #request_ident, Resp = #response_ident>
{
}
}
Expand Down Expand Up @@ -620,9 +626,9 @@ impl ServiceGenerator<'_> {
{
type Req = #request_ident;
type Resp = #response_ident;
type ServerCtx = S::Context;


async fn serve(self, ctx: ::tarpc::context::Context, req: #request_ident)
async fn serve(self, ctx: &mut Self::ServerCtx, req: #request_ident)
-> ::core::result::Result<#response_ident, ::tarpc::ServerError> {
match req {
#(
Expand Down Expand Up @@ -711,12 +717,19 @@ impl ServiceGenerator<'_> {

quote! {
#[allow(unused)]
#[derive(Clone, Debug)]
#[derive(Debug)]
/// The client stub that makes RPC calls to the server. All request methods return
/// [Futures](::core::future::Future).
#vis struct #client_ident<
Stub = ::tarpc::client::Channel<#request_ident, #response_ident>
>(Stub);
ClientCtx,
Stub = ::tarpc::client::Channel<#request_ident, #response_ident, ClientCtx>
>(Stub, ::std::marker::PhantomData<ClientCtx>);

impl<ClientCtx, Stub: ::std::clone::Clone> ::std::clone::Clone for #client_ident<ClientCtx,Stub> {
fn clone(&self) -> Self {
Self(self.0.clone(), ::std::marker::PhantomData)
}
}
}
}

Expand All @@ -730,32 +743,33 @@ impl ServiceGenerator<'_> {
} = self;

quote! {
impl #client_ident {
impl<ClientCtx> #client_ident<ClientCtx> {
/// Returns a new client stub that sends requests over the given transport.
#vis fn new<T>(config: ::tarpc::client::Config, transport: T)
-> ::tarpc::client::NewClient<
Self,
::tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
::tarpc::client::RequestDispatch<#request_ident, #response_ident, ClientCtx, T>
>
where
T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>>
T: ::tarpc::Transport<::tarpc::ClientMessage<ClientCtx, #request_ident>, ::tarpc::Response<ClientCtx, #response_ident>>
{
let new_client = ::tarpc::client::new(config, transport);
::tarpc::client::NewClient {
client: #client_ident(new_client.client),
client: #client_ident(new_client.client, ::std::marker::PhantomData),
dispatch: new_client.dispatch,
}
}
}

impl<Stub> ::core::convert::From<Stub> for #client_ident<Stub>
impl<ClientCtx, Stub> ::core::convert::From<Stub> for #client_ident<ClientCtx, Stub>
where Stub: ::tarpc::client::stub::Stub<
Req = #request_ident,
Resp = #response_ident>
Resp = #response_ident,
ClientCtx = ClientCtx>
{
/// Returns a new client stub that sends requests over the given transport.
fn from(stub: Stub) -> Self {
#client_ident(stub)
#client_ident::<ClientCtx, Stub>(stub, ::std::marker::PhantomData)
}

}
Expand All @@ -778,15 +792,16 @@ impl ServiceGenerator<'_> {
} = self;

quote! {
impl<Stub> #client_ident<Stub>
impl<ClientCtx, Stub> #client_ident<ClientCtx, Stub>
where Stub: ::tarpc::client::stub::Stub<
Req = #request_ident,
Resp = #response_ident>
Resp = #response_ident,
ClientCtx = ClientCtx>
{
#(
#[allow(unused)]
#( #method_attrs )*
#vis fn #method_idents(&self, ctx: ::tarpc::context::Context, #( #args ),*)
#vis fn #method_idents<'a>(&'a self, ctx: &'a mut Stub::ClientCtx, #( #args ),*)
-> impl ::core::future::Future<Output = ::core::result::Result<#return_types, ::tarpc::client::RpcError>> + '_ {
let request = #request_ident::#camel_case_idents { #( #arg_pats ),* };
let resp = self.0.call(ctx, request);
Expand Down
23 changes: 16 additions & 7 deletions plugins/tests/service.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use serde::{Deserialize, Serialize};
use std::hash::Hash;
use tarpc::context;
use tarpc::context::SharedContext;

#[test]
fn att_service_trait() {
Expand All @@ -12,15 +13,21 @@ fn att_service_trait() {
}

impl Foo for () {
async fn two_part(self, _: context::Context, s: String, i: i32) -> (String, i32) {
type Context = SharedContext;
async fn two_part(
self,
_: &mut context::SharedContext,
s: String,
i: i32,
) -> (String, i32) {
(s, i)
}

async fn bar(self, _: context::Context, s: String) -> String {
async fn bar(self, _: &mut Self::Context, s: String) -> String {
s
}

async fn baz(self, _: context::Context) {}
async fn baz(self, _: &mut Self::Context) {}
}
}

Expand All @@ -37,20 +44,21 @@ fn raw_idents() {
}

impl r#trait for () {
type Context = SharedContext;
async fn r#await(
self,
_: context::Context,
_: &mut Self::Context,
r#struct: r#yield,
r#enum: i32,
) -> (r#yield, i32) {
(r#struct, r#enum)
}

async fn r#fn(self, _: context::Context, r#impl: r#yield) -> r#yield {
async fn r#fn(self, _: &mut Self::Context, r#impl: r#yield) -> r#yield {
r#impl
}

async fn r#async(self, _: context::Context) {}
async fn r#async(self, _: &mut Self::Context) {}
}
}

Expand All @@ -64,7 +72,8 @@ fn service_with_cfg_rpc() {
}

impl Foo for () {
async fn foo(self, _: context::Context) {}
type Context = SharedContext;
async fn foo(self, _: &mut Self::Context) {}
}
}

Expand Down
3 changes: 3 additions & 0 deletions tarpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ tracing = { version = "0.1", default-features = false, features = [
tracing-opentelemetry = { version = "0.31.0", default-features = false }
opentelemetry = { version = "0.30.0", default-features = false }
opentelemetry-semantic-conventions = "0.30.0"
anymap3 = "1.0.1"
serde-value = "0.7"

[dev-dependencies]
assert_matches = "1.4"
Expand All @@ -81,6 +83,7 @@ trybuild = "1.0"
tokio-rustls = "0.26"
rustls-pemfile = "2.0"


[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]
Expand Down
15 changes: 11 additions & 4 deletions tarpc/examples/compression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use tarpc::{
server::{BaseChannel, Channel},
tokio_serde::formats::Bincode,
};
use tarpc::context::SharedContext;

/// Type of compression that should be enabled on the request. The transport is free to ignore this.
#[derive(Debug, PartialEq, Eq, Clone, Copy, Deserialize, Serialize)]
Expand Down Expand Up @@ -108,7 +109,8 @@ pub trait World {
struct HelloServer;

impl World for HelloServer {
async fn hello(self, _: context::Context, name: String) -> String {
type Context = SharedContext;
async fn hello(self, _: &mut Self::Context, name: String) -> String {
format!("Hey, {name}!")
}
}
Expand All @@ -120,21 +122,26 @@ async fn spawn(fut: impl Future<Output = ()> + Send + 'static) {
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let mut incoming = tcp::listen("localhost:0", Bincode::default).await?;

let addr = incoming.local_addr();
tokio::spawn(async move {
let transport = incoming.next().await.unwrap().unwrap();
BaseChannel::with_defaults(add_compression(transport))
let transport = add_compression(transport);
BaseChannel::with_defaults(transport)
.execute(HelloServer.serve())
.for_each(spawn)
.await;
});

let transport = tcp::connect(addr, Bincode::default).await?;
let client = WorldClient::new(client::Config::default(), add_compression(transport)).spawn();
let transport = add_compression(transport);
let client = WorldClient::new(client::Config::default(), transport).spawn();

println!(
"{}",
client.hello(context::current(), "friend".into()).await?
client
.hello(&mut context::SharedContext::current(), "friend".into())
.await?
);
Ok(())
}
9 changes: 5 additions & 4 deletions tarpc/examples/custom_transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

use console_subscriber::Server;
use futures::prelude::*;
use tarpc::context::Context;
use tarpc::context::{SharedContext};
use tarpc::serde_transport as transport;
use tarpc::server::{BaseChannel, Channel};
use tarpc::tokio_serde::formats::Bincode;
Expand All @@ -21,9 +22,9 @@ pub trait PingService {
struct Service;

impl PingService for Service {
async fn ping(self, _: Context) {}
type Context = SharedContext;
async fn ping(self, _: &mut Self::Context) {}
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
let bind_addr = "/tmp/tarpc_on_unix_example.sock";
Expand Down Expand Up @@ -52,7 +53,7 @@ async fn main() -> anyhow::Result<()> {
let transport = transport::new(codec_builder.new_framed(conn), Bincode::default());
PingServiceClient::new(Default::default(), transport)
.spawn()
.ping(tarpc::context::current())
.ping(&mut SharedContext::current())
.await?;

Ok(())
Expand Down
Loading