Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
authors = ["ImJeremyHe<yiliang.he@qq.com>"]
edition = "2018"
name = "gents"
version = "1.3.1"
version = "1.3.2"
license = "MIT"
description = "generate TypeScript interfaces from Rust code"
repository = "https://github.com/ImJeremyHe/gents"
Expand Down
2 changes: 1 addition & 1 deletion derives/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "gents_derives"
version = "1.3.1"
version = "1.3.2"
description = "provides some macros for gents"
authors = ["ImJeremyHe<yiliang.he@qq.com>"]
license = "MIT"
Expand Down
121 changes: 42 additions & 79 deletions derives/src/rpc_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ impl Parse for InterfaceAttrs {
struct RpcParam {
name: String,
ty: Type,
optional: bool,
}

/// Represents a parsed RPC method field
struct RpcMethod {
name: String,
params: Vec<RpcParam>,
ok_type: Type,
err_type: Type,
ret_type: Type,
}

pub fn derive_interface(input: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -129,17 +129,21 @@ fn expand_interface(input: &DeriveInput) -> Result<proc_macro2::TokenStream> {
.iter()
.map(|m| {
let name = &m.name;
let ok_ty = &m.ok_type;
let err_ty = &m.err_type;
let ret_ty = &m.ret_type;

let param_tokens: Vec<_> = m
.params
.iter()
.map(|p| {
let param_name = &p.name;
let param_ty = &p.ty;
let optional = p.optional;
quote! {
(#param_name.to_string(), std::any::TypeId::of::<#param_ty>())
::gents::RpcParamDescriptor {
name: #param_name.to_string(),
type_id: std::any::TypeId::of::<#param_ty>(),
optional: #optional,
}
}
})
.collect();
Expand All @@ -148,8 +152,7 @@ fn expand_interface(input: &DeriveInput) -> Result<proc_macro2::TokenStream> {
::gents::RpcMethodDescriptor {
name: #name.to_string(),
params: vec![ #(#param_tokens),* ],
ok_type: std::any::TypeId::of::<#ok_ty>(),
err_type: std::any::TypeId::of::<#err_ty>(),
ret_type: std::any::TypeId::of::<#ret_ty>(),
}
}
})
Expand All @@ -159,12 +162,8 @@ fn expand_interface(input: &DeriveInput) -> Result<proc_macro2::TokenStream> {
let register_tokens: Vec<_> = methods
.iter()
.flat_map(|m| {
let ok_ty = &m.ok_type;
let err_ty = &m.err_type;
let mut tokens = vec![
quote! { <#ok_ty as ::gents::TS>::_register(manager, true); },
quote! { <#err_ty as ::gents::TS>::_register(manager, true); },
];
let ret_ty = &m.ret_type;
let mut tokens = vec![quote! { <#ret_ty as ::gents::TS>::_register(manager, true); }];
for p in &m.params {
let param_ty = &p.ty;
tokens.push(quote! { <#param_ty as ::gents::TS>::_register(manager, true); });
Expand Down Expand Up @@ -226,14 +225,21 @@ fn parse_fn_field(
// Parse parameters
let params = parse_params(bare_fn, rename_all)?;

// Parse return type (must be Result<T, E>)
let (ok_type, err_type) = parse_result_return_type(bare_fn)?;
// Parse return type
let ret_type = match &bare_fn.output {
syn::ReturnType::Type(_, ty) => ty.as_ref().clone(),
syn::ReturnType::Default => {
return Err(Error::new_spanned(
bare_fn,
"RPC method must have a return type",
))
}
};

Ok(RpcMethod {
name,
params,
ok_type,
err_type,
ret_type,
})
}

Expand All @@ -250,76 +256,33 @@ fn parse_params(bare_fn: &TypeBareFn, rename_all: Option<RenameAll>) -> Result<V
None => format!("arg{}", idx), // fallback if no name
};

// Check if the type is Option<T>, extract inner T and mark as optional
let (ty, optional) = extract_option_inner(&arg.ty);

params.push(RpcParam {
name: param_name,
ty: arg.ty.clone(),
ty,
optional,
});
}

Ok(params)
}

fn parse_result_return_type(bare_fn: &TypeBareFn) -> Result<(Type, Type)> {
let ret_type = match &bare_fn.output {
syn::ReturnType::Type(_, ty) => ty.as_ref(),
syn::ReturnType::Default => {
return Err(Error::new_spanned(
bare_fn,
"RPC method must have a return type: Result<T, E>",
))
}
};

// Extract Result<T, E>
let type_path = match ret_type {
Type::Path(p) => p,
_ => {
return Err(Error::new_spanned(
ret_type,
"return type must be Result<T, E>",
))
}
};

let last_segment = type_path
.path
.segments
.last()
.ok_or_else(|| Error::new_spanned(ret_type, "invalid return type"))?;

if last_segment.ident != "Result" {
return Err(Error::new_spanned(
ret_type,
"return type must be Result<T, E>",
));
}

let args = match &last_segment.arguments {
syn::PathArguments::AngleBracketed(args) => args,
_ => {
return Err(Error::new_spanned(
ret_type,
"Result must have type parameters: Result<T, E>",
))
/// If the type is `Option<T>`, returns `(T, true)`. Otherwise returns `(ty, false)`.
fn extract_option_inner(ty: &Type) -> (Type, bool) {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
if segment.ident == "Option" {
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
if args.args.len() == 1 {
if let syn::GenericArgument::Type(inner) = &args.args[0] {
return (inner.clone(), true);
}
}
}
}
}
};

if args.args.len() != 2 {
return Err(Error::new_spanned(
ret_type,
"Result must have exactly two type parameters: Result<T, E>",
));
}

let ok_type = match &args.args[0] {
syn::GenericArgument::Type(t) => t.clone(),
_ => return Err(Error::new_spanned(&args.args[0], "expected type argument")),
};

let err_type = match &args.args[1] {
syn::GenericArgument::Type(t) => t.clone(),
_ => return Err(Error::new_spanned(&args.args[1], "expected type argument")),
};

Ok((ok_type, err_type))
(ty.clone(), false)
}
54 changes: 27 additions & 27 deletions src/descriptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{
collections::{HashMap, HashSet},
};

use crate::ts_formatter::TsFormatter;
use crate::ts_formatter::{RpcTsParam, TsFormatter};
use crate::utils::remove_ext;

// `TS` trait defines the behavior of your types when generating files.
Expand Down Expand Up @@ -63,12 +63,18 @@ pub struct RpcInterfaceDescriptor {
pub methods: Vec<RpcMethodDescriptor>,
}

/// Descriptor for a single RPC method: `fn(param1: T1, param2: T2) -> Result<T, E>`
/// Descriptor for a single RPC method parameter
pub struct RpcParamDescriptor {
pub name: String,
pub type_id: TypeId,
pub optional: bool,
}

/// Descriptor for a single RPC method
pub struct RpcMethodDescriptor {
pub name: String,
pub params: Vec<(String, TypeId)>,
pub ok_type: TypeId,
pub err_type: TypeId,
pub params: Vec<RpcParamDescriptor>,
pub ret_type: TypeId,
}

#[derive(Default)]
Expand Down Expand Up @@ -269,8 +275,8 @@ impl DescriptorManager {

// Collect all type dependencies for imports
for m in &rpc.methods {
for (_, type_id) in &m.params {
if let Some(&idx) = id_map.get(type_id) {
for p in &m.params {
if let Some(&idx) = id_map.get(&p.type_id) {
let desc = descriptors.get(idx).unwrap();
if let Descriptor::BuiltinType(_) = desc {
continue;
Expand All @@ -282,12 +288,9 @@ impl DescriptorManager {
}
}
}
for type_id in [m.ok_type, m.err_type] {
if let Some(&idx) = id_map.get(&type_id) {
let desc = descriptors.get(idx).unwrap();
if let Descriptor::BuiltinType(_) = desc {
continue;
}
if let Some(&idx) = id_map.get(&m.ret_type) {
let desc = descriptors.get(idx).unwrap();
if !matches!(desc, Descriptor::BuiltinType(_)) {
let import_deps = get_import_deps_idx(&descriptors, idx);
for dep in import_deps {
let (ts_name, file_name) = get_import_deps(&descriptors, dep);
Expand All @@ -300,33 +303,30 @@ impl DescriptorManager {
fmt.start_interface(&rpc.name, "");

for m in rpc.methods {
let params: Vec<(String, String)> = m
let params: Vec<RpcTsParam> = m
.params
.iter()
.map(|(name, type_id)| {
.map(|p| {
let ts_type = id_map
.get(type_id)
.get(&p.type_id)
.and_then(|idx| descriptors.get(*idx))
.map(|d| d.ts_name().to_string())
.unwrap_or_else(|| "unknown".to_string());
(name.clone(), ts_type)
RpcTsParam {
name: p.name.clone(),
ts_type,
optional: p.optional,
}
})
.collect();

let ok_ts = id_map
.get(&m.ok_type)
.and_then(|idx| descriptors.get(*idx))
.map(|d| d.ts_name().to_string())
.unwrap_or_else(|| "unknown".to_string());

let err_ts = id_map
.get(&m.err_type)
let ret_ts = id_map
.get(&m.ret_type)
.and_then(|idx| descriptors.get(*idx))
.map(|d| d.ts_name().to_string())
.unwrap_or_else(|| "unknown".to_string());

let ret_type = format!("{} | {}", ok_ts, err_ts);
fmt.add_rpc_method(&m.name, params, &ret_type);
fmt.add_rpc_method(&m.name, params, &ret_ts);
}

fmt.end_interface();
Expand Down
17 changes: 15 additions & 2 deletions src/ts_formatter.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
use std::collections::BTreeSet;

/// Resolved RPC parameter for TS generation
pub struct RpcTsParam {
pub name: String,
pub ts_type: String,
pub optional: bool,
}

#[derive(Default)]
pub struct TsFormatter {
imports: BTreeSet<String>,
Expand Down Expand Up @@ -91,10 +98,16 @@ impl TsFormatter {
));
}

pub fn add_rpc_method(&mut self, name: &str, params: Vec<(String, String)>, ret_type: &str) {
pub fn add_rpc_method(&mut self, name: &str, params: Vec<RpcTsParam>, ret_type: &str) {
let param_str = params
.iter()
.map(|(n, t)| format!("{}: {}", n, t))
.map(|p| {
if p.optional {
format!("{}?: {}", p.name, p.ts_type)
} else {
format!("{}: {}", p.name, p.ts_type)
}
})
.collect::<Vec<_>>()
.join(", ");
self.write_line(&format!("{}({}): Promise<{}>;", name, param_str, ret_type));
Expand Down
Loading