diff --git a/Cargo.toml b/Cargo.toml index 9d1ee10..29a04e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ authors = ["ImJeremyHe"] edition = "2018" name = "gents" -version = "1.3.1" +version = "1.3.2" license = "MIT" description = "generate TypeScript interfaces from Rust code" repository = "https://github.com/ImJeremyHe/gents" diff --git a/derives/Cargo.toml b/derives/Cargo.toml index cc6686a..7bfd229 100644 --- a/derives/Cargo.toml +++ b/derives/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "gents_derives" -version = "1.3.1" +version = "1.3.2" description = "provides some macros for gents" authors = ["ImJeremyHe"] license = "MIT" diff --git a/derives/src/rpc_interface.rs b/derives/src/rpc_interface.rs index 01e7981..2619c28 100644 --- a/derives/src/rpc_interface.rs +++ b/derives/src/rpc_interface.rs @@ -63,14 +63,14 @@ impl Parse for InterfaceAttrs { struct RpcParam { name: String, ty: Type, + optional: bool, } /// Represents a parsed RPC method field struct RpcMethod { name: String, params: Vec, - ok_type: Type, - err_type: Type, + ret_type: Type, } pub fn derive_interface(input: TokenStream) -> TokenStream { @@ -129,8 +129,7 @@ fn expand_interface(input: &DeriveInput) -> Result { .iter() .map(|m| { let name = &m.name; - let ok_ty = &m.ok_type; - let err_ty = &m.err_type; + let ret_ty = &m.ret_type; let param_tokens: Vec<_> = m .params @@ -138,8 +137,13 @@ fn expand_interface(input: &DeriveInput) -> Result { .map(|p| { let param_name = &p.name; let param_ty = &p.ty; + let optional = p.optional; quote! { - (#param_name.to_string(), std::any::TypeId::of::<#param_ty>()) + ::gents::RpcParamDescriptor { + name: #param_name.to_string(), + type_id: std::any::TypeId::of::<#param_ty>(), + optional: #optional, + } } }) .collect(); @@ -148,8 +152,7 @@ fn expand_interface(input: &DeriveInput) -> Result { ::gents::RpcMethodDescriptor { name: #name.to_string(), params: vec![ #(#param_tokens),* ], - ok_type: std::any::TypeId::of::<#ok_ty>(), - err_type: std::any::TypeId::of::<#err_ty>(), + ret_type: std::any::TypeId::of::<#ret_ty>(), } } }) @@ -159,12 +162,8 @@ fn expand_interface(input: &DeriveInput) -> Result { let register_tokens: Vec<_> = methods .iter() .flat_map(|m| { - let ok_ty = &m.ok_type; - let err_ty = &m.err_type; - let mut tokens = vec![ - quote! { <#ok_ty as ::gents::TS>::_register(manager, true); }, - quote! { <#err_ty as ::gents::TS>::_register(manager, true); }, - ]; + let ret_ty = &m.ret_type; + let mut tokens = vec![quote! { <#ret_ty as ::gents::TS>::_register(manager, true); }]; for p in &m.params { let param_ty = &p.ty; tokens.push(quote! { <#param_ty as ::gents::TS>::_register(manager, true); }); @@ -226,14 +225,21 @@ fn parse_fn_field( // Parse parameters let params = parse_params(bare_fn, rename_all)?; - // Parse return type (must be Result) - let (ok_type, err_type) = parse_result_return_type(bare_fn)?; + // Parse return type + let ret_type = match &bare_fn.output { + syn::ReturnType::Type(_, ty) => ty.as_ref().clone(), + syn::ReturnType::Default => { + return Err(Error::new_spanned( + bare_fn, + "RPC method must have a return type", + )) + } + }; Ok(RpcMethod { name, params, - ok_type, - err_type, + ret_type, }) } @@ -250,76 +256,33 @@ fn parse_params(bare_fn: &TypeBareFn, rename_all: Option) -> Result format!("arg{}", idx), // fallback if no name }; + // Check if the type is Option, extract inner T and mark as optional + let (ty, optional) = extract_option_inner(&arg.ty); + params.push(RpcParam { name: param_name, - ty: arg.ty.clone(), + ty, + optional, }); } Ok(params) } -fn parse_result_return_type(bare_fn: &TypeBareFn) -> Result<(Type, Type)> { - let ret_type = match &bare_fn.output { - syn::ReturnType::Type(_, ty) => ty.as_ref(), - syn::ReturnType::Default => { - return Err(Error::new_spanned( - bare_fn, - "RPC method must have a return type: Result", - )) - } - }; - - // Extract Result - let type_path = match ret_type { - Type::Path(p) => p, - _ => { - return Err(Error::new_spanned( - ret_type, - "return type must be Result", - )) - } - }; - - let last_segment = type_path - .path - .segments - .last() - .ok_or_else(|| Error::new_spanned(ret_type, "invalid return type"))?; - - if last_segment.ident != "Result" { - return Err(Error::new_spanned( - ret_type, - "return type must be Result", - )); - } - - let args = match &last_segment.arguments { - syn::PathArguments::AngleBracketed(args) => args, - _ => { - return Err(Error::new_spanned( - ret_type, - "Result must have type parameters: Result", - )) +/// If the type is `Option`, returns `(T, true)`. Otherwise returns `(ty, false)`. +fn extract_option_inner(ty: &Type) -> (Type, bool) { + if let Type::Path(type_path) = ty { + if let Some(segment) = type_path.path.segments.last() { + if segment.ident == "Option" { + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { + if args.args.len() == 1 { + if let syn::GenericArgument::Type(inner) = &args.args[0] { + return (inner.clone(), true); + } + } + } + } } - }; - - if args.args.len() != 2 { - return Err(Error::new_spanned( - ret_type, - "Result must have exactly two type parameters: Result", - )); } - - let ok_type = match &args.args[0] { - syn::GenericArgument::Type(t) => t.clone(), - _ => return Err(Error::new_spanned(&args.args[0], "expected type argument")), - }; - - let err_type = match &args.args[1] { - syn::GenericArgument::Type(t) => t.clone(), - _ => return Err(Error::new_spanned(&args.args[1], "expected type argument")), - }; - - Ok((ok_type, err_type)) + (ty.clone(), false) } diff --git a/src/descriptor.rs b/src/descriptor.rs index 4146ca2..9ec9046 100644 --- a/src/descriptor.rs +++ b/src/descriptor.rs @@ -3,7 +3,7 @@ use std::{ collections::{HashMap, HashSet}, }; -use crate::ts_formatter::TsFormatter; +use crate::ts_formatter::{RpcTsParam, TsFormatter}; use crate::utils::remove_ext; // `TS` trait defines the behavior of your types when generating files. @@ -63,12 +63,18 @@ pub struct RpcInterfaceDescriptor { pub methods: Vec, } -/// Descriptor for a single RPC method: `fn(param1: T1, param2: T2) -> Result` +/// Descriptor for a single RPC method parameter +pub struct RpcParamDescriptor { + pub name: String, + pub type_id: TypeId, + pub optional: bool, +} + +/// Descriptor for a single RPC method pub struct RpcMethodDescriptor { pub name: String, - pub params: Vec<(String, TypeId)>, - pub ok_type: TypeId, - pub err_type: TypeId, + pub params: Vec, + pub ret_type: TypeId, } #[derive(Default)] @@ -269,8 +275,8 @@ impl DescriptorManager { // Collect all type dependencies for imports for m in &rpc.methods { - for (_, type_id) in &m.params { - if let Some(&idx) = id_map.get(type_id) { + for p in &m.params { + if let Some(&idx) = id_map.get(&p.type_id) { let desc = descriptors.get(idx).unwrap(); if let Descriptor::BuiltinType(_) = desc { continue; @@ -282,12 +288,9 @@ impl DescriptorManager { } } } - for type_id in [m.ok_type, m.err_type] { - if let Some(&idx) = id_map.get(&type_id) { - let desc = descriptors.get(idx).unwrap(); - if let Descriptor::BuiltinType(_) = desc { - continue; - } + if let Some(&idx) = id_map.get(&m.ret_type) { + let desc = descriptors.get(idx).unwrap(); + if !matches!(desc, Descriptor::BuiltinType(_)) { let import_deps = get_import_deps_idx(&descriptors, idx); for dep in import_deps { let (ts_name, file_name) = get_import_deps(&descriptors, dep); @@ -300,33 +303,30 @@ impl DescriptorManager { fmt.start_interface(&rpc.name, ""); for m in rpc.methods { - let params: Vec<(String, String)> = m + let params: Vec = m .params .iter() - .map(|(name, type_id)| { + .map(|p| { let ts_type = id_map - .get(type_id) + .get(&p.type_id) .and_then(|idx| descriptors.get(*idx)) .map(|d| d.ts_name().to_string()) .unwrap_or_else(|| "unknown".to_string()); - (name.clone(), ts_type) + RpcTsParam { + name: p.name.clone(), + ts_type, + optional: p.optional, + } }) .collect(); - let ok_ts = id_map - .get(&m.ok_type) - .and_then(|idx| descriptors.get(*idx)) - .map(|d| d.ts_name().to_string()) - .unwrap_or_else(|| "unknown".to_string()); - - let err_ts = id_map - .get(&m.err_type) + let ret_ts = id_map + .get(&m.ret_type) .and_then(|idx| descriptors.get(*idx)) .map(|d| d.ts_name().to_string()) .unwrap_or_else(|| "unknown".to_string()); - let ret_type = format!("{} | {}", ok_ts, err_ts); - fmt.add_rpc_method(&m.name, params, &ret_type); + fmt.add_rpc_method(&m.name, params, &ret_ts); } fmt.end_interface(); diff --git a/src/ts_formatter.rs b/src/ts_formatter.rs index 7c727a3..9efce13 100644 --- a/src/ts_formatter.rs +++ b/src/ts_formatter.rs @@ -1,5 +1,12 @@ use std::collections::BTreeSet; +/// Resolved RPC parameter for TS generation +pub struct RpcTsParam { + pub name: String, + pub ts_type: String, + pub optional: bool, +} + #[derive(Default)] pub struct TsFormatter { imports: BTreeSet, @@ -91,10 +98,16 @@ impl TsFormatter { )); } - pub fn add_rpc_method(&mut self, name: &str, params: Vec<(String, String)>, ret_type: &str) { + pub fn add_rpc_method(&mut self, name: &str, params: Vec, ret_type: &str) { let param_str = params .iter() - .map(|(n, t)| format!("{}: {}", n, t)) + .map(|p| { + if p.optional { + format!("{}?: {}", p.name, p.ts_type) + } else { + format!("{}: {}", p.name, p.ts_type) + } + }) .collect::>() .join(", "); self.write_line(&format!("{}({}): Promise<{}>;", name, param_str, ret_type)); diff --git a/tests/src/tests.rs b/tests/src/tests.rs index 9c56188..3fcc39e 100644 --- a/tests/src/tests.rs +++ b/tests/src/tests.rs @@ -503,6 +503,8 @@ mod test_rpc_interface { pub struct WorkbookMethods { pub get_cell: fn(row_idx: u32, col_idx: u32) -> Result, pub save_file: fn(file_path: String) -> Result, + pub delete: fn(id: u32) -> CellInfo, + pub update: fn(id: u32, value: Option) -> Result<(), ErrorMessage>, } // With rename_all = "camelCase": converts to camelCase @@ -510,7 +512,8 @@ mod test_rpc_interface { #[ts(file_name = "workbook_methods_camel.ts", rename_all = "camelCase")] pub struct WorkbookMethodsCamel { pub get_cell: fn(row_idx: u32, col_idx: u32) -> Result, - pub save_file: fn(file_path: String) -> Result, + pub save_file: + fn(file_path: String, overwrite: Option) -> Result, } #[test] @@ -526,13 +529,20 @@ mod test_rpc_interface { .collect(); let content = files.get("workbook_methods.ts").unwrap(); - // Without rename_all: keeps snake_case + // Multiple params, Result return assert!(content.contains( "get_cell(row_idx: number, col_idx: number): Promise;" )); + // Single param, Result return assert!( content.contains("save_file(file_path: string): Promise;") ); + // Non-Result return type + assert!(content.contains("delete(id: number): Promise;")); + // Optional param + void ok type + assert!( + content.contains("update(id: number, value?: string): Promise;") + ); } #[test] @@ -548,13 +558,14 @@ mod test_rpc_interface { .collect(); let content = files.get("workbook_methods_camel.ts").unwrap(); - // With rename_all = "camelCase": converts to camelCase + // camelCase + Result return assert!(content.contains( "getCell(rowIdx: number, colIdx: number): Promise;" )); - assert!( - content.contains("saveFile(filePath: string): Promise;") - ); + // camelCase + optional param + assert!(content.contains( + "saveFile(filePath: string, overwrite?: boolean): Promise;" + )); } }