diff --git a/derives/src/lib.rs b/derives/src/lib.rs index 951572f..654e2d7 100644 --- a/derives/src/lib.rs +++ b/derives/src/lib.rs @@ -3,6 +3,7 @@ use proc_macro2::Span; use syn::{parse_macro_input, DeriveInput}; mod case; mod container; +mod rpc_interface; mod serde_json; mod symbol; mod ts_interface; @@ -260,3 +261,8 @@ fn get_generic_placeholder( pub fn ts_interface(attr: TokenStream, item: TokenStream) -> TokenStream { ts_interface::ts_interface(attr, item) } + +#[proc_macro_derive(Interface, attributes(ts))] +pub fn derive_interface(input: TokenStream) -> TokenStream { + rpc_interface::derive_interface(input) +} diff --git a/derives/src/rpc_interface.rs b/derives/src/rpc_interface.rs new file mode 100644 index 0000000..01e7981 --- /dev/null +++ b/derives/src/rpc_interface.rs @@ -0,0 +1,325 @@ +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::quote; +use syn::{ + parse::{Parse, ParseStream}, + parse_macro_input, DeriveInput, Error, Fields, Ident, LitStr, Result, Token, Type, TypeBareFn, +}; + +use crate::convert_camel_from_snake; + +/// Parses the `#[ts(...)]` attributes for the Interface macro +struct InterfaceAttrs { + file_name: String, + rename_all: Option, +} + +#[derive(Clone, Copy)] +enum RenameAll { + CamelCase, +} + +impl Parse for InterfaceAttrs { + fn parse(input: ParseStream) -> Result { + let mut file_name: Option = None; + let mut rename_all: Option = None; + + while !input.is_empty() { + let ident: Ident = input.parse()?; + input.parse::()?; + let lit: LitStr = input.parse()?; + + if ident == "file_name" { + file_name = Some(lit.value()); + } else if ident == "rename_all" { + match lit.value().as_str() { + "camelCase" => rename_all = Some(RenameAll::CamelCase), + _ => return Err(Error::new_spanned(lit, "expected `camelCase`")), + } + } else { + return Err(Error::new_spanned( + ident, + "expected `file_name = \"...\"` or `rename_all = \"camelCase\"`", + )); + } + + if input.is_empty() { + break; + } + input.parse::()?; + } + + let file_name = + file_name.ok_or_else(|| input.error("`file_name = \"...\"` is required"))?; + + Ok(Self { + file_name, + rename_all, + }) + } +} + +/// Represents a parsed parameter +struct RpcParam { + name: String, + ty: Type, +} + +/// Represents a parsed RPC method field +struct RpcMethod { + name: String, + params: Vec, + ok_type: Type, + err_type: Type, +} + +pub fn derive_interface(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + + match expand_interface(&input) { + Ok(ts) => ts.into(), + Err(e) => e.to_compile_error().into(), + } +} + +fn expand_interface(input: &DeriveInput) -> Result { + // Parse `#[ts(...)]` attribute + let attrs = parse_ts_attrs(input)?; + let file_name = attrs.file_name; + let rename_all = attrs.rename_all; + let ident = &input.ident; + let ts_name = ident.to_string(); + + // Must be a struct + let fields = match &input.data { + syn::Data::Struct(ds) => &ds.fields, + _ => { + return Err(Error::new_spanned( + ident, + "Interface can only be derived for structs", + )) + } + }; + + // Must be named fields + let named_fields = match fields { + Fields::Named(f) => &f.named, + _ => { + return Err(Error::new_spanned( + ident, + "Interface struct must have named fields", + )) + } + }; + + // Parse each field as an RPC method + let mut methods = Vec::new(); + for field in named_fields { + let field_ident = field + .ident + .as_ref() + .ok_or_else(|| Error::new_spanned(&field.ty, "expected named field"))?; + + let method = parse_fn_field(field_ident, &field.ty, rename_all)?; + methods.push(method); + } + + // Generate the method descriptors + let method_tokens: Vec<_> = methods + .iter() + .map(|m| { + let name = &m.name; + let ok_ty = &m.ok_type; + let err_ty = &m.err_type; + + let param_tokens: Vec<_> = m + .params + .iter() + .map(|p| { + let param_name = &p.name; + let param_ty = &p.ty; + quote! { + (#param_name.to_string(), std::any::TypeId::of::<#param_ty>()) + } + }) + .collect(); + + quote! { + ::gents::RpcMethodDescriptor { + name: #name.to_string(), + params: vec![ #(#param_tokens),* ], + ok_type: std::any::TypeId::of::<#ok_ty>(), + err_type: std::any::TypeId::of::<#err_ty>(), + } + } + }) + .collect(); + + // Generate type registrations + let register_tokens: Vec<_> = methods + .iter() + .flat_map(|m| { + let ok_ty = &m.ok_type; + let err_ty = &m.err_type; + let mut tokens = vec![ + quote! { <#ok_ty as ::gents::TS>::_register(manager, true); }, + quote! { <#err_ty as ::gents::TS>::_register(manager, true); }, + ]; + for p in &m.params { + let param_ty = &p.ty; + tokens.push(quote! { <#param_ty as ::gents::TS>::_register(manager, true); }); + } + tokens + }) + .collect(); + + let expanded = quote! { + impl ::gents::_TsRpcInterface for #ident { + fn __get_rpc_descriptor(manager: &mut ::gents::DescriptorManager) -> ::gents::RpcInterfaceDescriptor { + #(#register_tokens)* + ::gents::RpcInterfaceDescriptor { + name: #ts_name.to_string(), + file_name: #file_name.to_string(), + methods: vec![ #(#method_tokens),* ], + } + } + } + }; + + Ok(expanded) +} + +fn parse_ts_attrs(input: &DeriveInput) -> Result { + for attr in &input.attrs { + if attr.path().is_ident("ts") { + return attr.parse_args::(); + } + } + Err(Error::new( + Span::call_site(), + "missing #[ts(file_name = \"...\")] attribute", + )) +} + +fn parse_fn_field( + field_ident: &Ident, + ty: &Type, + rename_all: Option, +) -> Result { + // Convert method name based on rename_all + let name = match rename_all { + Some(RenameAll::CamelCase) => convert_camel_from_snake(field_ident.to_string()), + None => field_ident.to_string(), + }; + + // Extract the bare function type + let bare_fn = match ty { + Type::BareFn(f) => f, + _ => { + return Err(Error::new_spanned( + ty, + "Interface fields must be function types: fn(name: Type, ...) -> Result", + )) + } + }; + + // Parse parameters + let params = parse_params(bare_fn, rename_all)?; + + // Parse return type (must be Result) + let (ok_type, err_type) = parse_result_return_type(bare_fn)?; + + Ok(RpcMethod { + name, + params, + ok_type, + err_type, + }) +} + +fn parse_params(bare_fn: &TypeBareFn, rename_all: Option) -> Result> { + let mut params = Vec::new(); + + for (idx, arg) in bare_fn.inputs.iter().enumerate() { + // Get parameter name, convert based on rename_all + let param_name = match &arg.name { + Some((ident, _)) => match rename_all { + Some(RenameAll::CamelCase) => convert_camel_from_snake(ident.to_string()), + None => ident.to_string(), + }, + None => format!("arg{}", idx), // fallback if no name + }; + + params.push(RpcParam { + name: param_name, + ty: arg.ty.clone(), + }); + } + + Ok(params) +} + +fn parse_result_return_type(bare_fn: &TypeBareFn) -> Result<(Type, Type)> { + let ret_type = match &bare_fn.output { + syn::ReturnType::Type(_, ty) => ty.as_ref(), + syn::ReturnType::Default => { + return Err(Error::new_spanned( + bare_fn, + "RPC method must have a return type: Result", + )) + } + }; + + // Extract Result + let type_path = match ret_type { + Type::Path(p) => p, + _ => { + return Err(Error::new_spanned( + ret_type, + "return type must be Result", + )) + } + }; + + let last_segment = type_path + .path + .segments + .last() + .ok_or_else(|| Error::new_spanned(ret_type, "invalid return type"))?; + + if last_segment.ident != "Result" { + return Err(Error::new_spanned( + ret_type, + "return type must be Result", + )); + } + + let args = match &last_segment.arguments { + syn::PathArguments::AngleBracketed(args) => args, + _ => { + return Err(Error::new_spanned( + ret_type, + "Result must have type parameters: Result", + )) + } + }; + + if args.args.len() != 2 { + return Err(Error::new_spanned( + ret_type, + "Result must have exactly two type parameters: Result", + )); + } + + let ok_type = match &args.args[0] { + syn::GenericArgument::Type(t) => t.clone(), + _ => return Err(Error::new_spanned(&args.args[0], "expected type argument")), + }; + + let err_type = match &args.args[1] { + syn::GenericArgument::Type(t) => t.clone(), + _ => return Err(Error::new_spanned(&args.args[1], "expected type argument")), + }; + + Ok((ok_type, err_type)) +} diff --git a/src/descriptor.rs b/src/descriptor.rs index 7be17c5..423f3ff 100644 --- a/src/descriptor.rs +++ b/src/descriptor.rs @@ -35,6 +35,12 @@ pub trait _TsAPI { fn __get_api_descriptor() -> ApiDescriptor; } +/// Trait for defining TypeScript RPC interfaces from struct with function fields. +/// Use `#[derive(Interface)]` to implement this trait. +pub trait _TsRpcInterface { + fn __get_rpc_descriptor(manager: &mut DescriptorManager) -> RpcInterfaceDescriptor; +} + pub struct ApiDescriptor { pub name: String, pub file_name: String, @@ -50,10 +56,26 @@ pub struct MethodDescriptor { pub return_type: Option, } +/// Descriptor for RPC interface generated from `#[derive(Interface)]` +pub struct RpcInterfaceDescriptor { + pub name: String, + pub file_name: String, + pub methods: Vec, +} + +/// Descriptor for a single RPC method: `fn(param1: T1, param2: T2) -> Result` +pub struct RpcMethodDescriptor { + pub name: String, + pub params: Vec<(String, TypeId)>, + pub ok_type: TypeId, + pub err_type: TypeId, +} + #[derive(Default)] pub struct DescriptorManager { pub descriptors: Vec, pub api_descriptors: Vec, + pub rpc_descriptors: Vec, pub id_map: HashMap, generics_map: HashMap, } @@ -76,6 +98,10 @@ impl DescriptorManager { self.api_descriptors.push(descriptor); } + pub fn add_rpc_descriptor(&mut self, descriptor: RpcInterfaceDescriptor) { + self.rpc_descriptors.push(descriptor); + } + pub fn add_generics_map(&mut self, idx: usize, generics: String) { self.generics_map.insert(idx, generics); } @@ -85,6 +111,7 @@ impl DescriptorManager { let DescriptorManager { descriptors, api_descriptors, + rpc_descriptors, id_map, generics_map, } = self; @@ -235,6 +262,77 @@ impl DescriptorManager { fmt.end_interface(); result.push((api.file_name.to_string(), fmt.end_file())); }); + + // Generate RPC interface files + rpc_descriptors.into_iter().for_each(|rpc| { + let mut fmt = TsFormatter::new(); + + // Collect all type dependencies for imports + for m in &rpc.methods { + for (_, type_id) in &m.params { + if let Some(&idx) = id_map.get(type_id) { + let desc = descriptors.get(idx).unwrap(); + if let Descriptor::BuiltinType(_) = desc { + continue; + } + let import_deps = get_import_deps_idx(&descriptors, idx); + for dep in import_deps { + let (ts_name, file_name) = get_import_deps(&descriptors, dep); + fmt.add_import(&ts_name, &file_name); + } + } + } + for type_id in [m.ok_type, m.err_type] { + if let Some(&idx) = id_map.get(&type_id) { + let desc = descriptors.get(idx).unwrap(); + if let Descriptor::BuiltinType(_) = desc { + continue; + } + let import_deps = get_import_deps_idx(&descriptors, idx); + for dep in import_deps { + let (ts_name, file_name) = get_import_deps(&descriptors, dep); + fmt.add_import(&ts_name, &file_name); + } + } + } + } + + fmt.start_interface(&rpc.name, ""); + + for m in rpc.methods { + let params: Vec<(String, String)> = m + .params + .iter() + .map(|(name, type_id)| { + let ts_type = id_map + .get(type_id) + .and_then(|idx| descriptors.get(*idx)) + .map(|d| d.ts_name().to_string()) + .unwrap_or_else(|| "unknown".to_string()); + (name.clone(), ts_type) + }) + .collect(); + + let ok_ts = id_map + .get(&m.ok_type) + .and_then(|idx| descriptors.get(*idx)) + .map(|d| d.ts_name().to_string()) + .unwrap_or_else(|| "unknown".to_string()); + + let err_ts = id_map + .get(&m.err_type) + .and_then(|idx| descriptors.get(*idx)) + .map(|d| d.ts_name().to_string()) + .unwrap_or_else(|| "unknown".to_string()); + + let ret_type = format!("{} | {}", ok_ts, err_ts); + fmt.add_rpc_method(&m.name, params, &ret_type); + } + + fmt.end_interface(); + result.push((rpc.file_name.to_string(), fmt.end_file())); + }); + result } } diff --git a/src/file_generator.rs b/src/file_generator.rs index d34214b..d83aef6 100644 --- a/src/file_generator.rs +++ b/src/file_generator.rs @@ -2,6 +2,7 @@ use std::path::Path; use std::{fs, io::Write}; use crate::_TsAPI; +use crate::_TsRpcInterface; use crate::descriptor::{DescriptorManager, TS}; use crate::utils::remove_ext; @@ -32,6 +33,11 @@ impl FileGroup { self.manager.add_api_descriptor(d); } + pub fn add_rpc(&mut self) { + let d = T::__get_rpc_descriptor(&mut self.manager); + self.manager.add_rpc_descriptor(d); + } + pub fn gen_files(self, dir: &str, index_file: bool) { let mut data = self.manager.gen_data(); if index_file { diff --git a/src/ts_formatter.rs b/src/ts_formatter.rs index 7451d66..7c727a3 100644 --- a/src/ts_formatter.rs +++ b/src/ts_formatter.rs @@ -91,6 +91,15 @@ impl TsFormatter { )); } + pub fn add_rpc_method(&mut self, name: &str, params: Vec<(String, String)>, ret_type: &str) { + let param_str = params + .iter() + .map(|(n, t)| format!("{}: {}", n, t)) + .collect::>() + .join(", "); + self.write_line(&format!("{}({}): Promise<{}>;", name, param_str, ret_type)); + } + pub fn end_interface(&mut self) { if self.indent > 0 { self.indent -= 1; diff --git a/tests/src/tests.rs b/tests/src/tests.rs index e010c67..9c56188 100644 --- a/tests/src/tests.rs +++ b/tests/src/tests.rs @@ -461,6 +461,103 @@ export interface V3 { } } +#[cfg(test)] +mod test_rpc_interface { + use gents::*; + use gents_derives::{Interface, TS}; + + #[derive(TS, Clone)] + #[ts(file_name = "get_cell_params.ts", rename_all = "camelCase")] + pub struct GetCellParams { + pub row: u32, + pub col: u32, + } + + #[derive(TS, Clone)] + #[ts(file_name = "cell_info.ts", rename_all = "camelCase")] + pub struct CellInfo { + pub value: String, + } + + #[derive(TS, Clone)] + #[ts(file_name = "save_params.ts", rename_all = "camelCase")] + pub struct SaveParams { + pub path: String, + } + + #[derive(TS, Clone)] + #[ts(file_name = "save_result.ts", rename_all = "camelCase")] + pub struct SaveResult { + pub success: bool, + } + + #[derive(TS, Clone)] + #[ts(file_name = "error_message.ts", rename_all = "camelCase")] + pub struct ErrorMessage { + pub message: String, + } + + // Without rename_all: keeps original snake_case names + #[derive(Interface)] + #[ts(file_name = "workbook_methods.ts")] + pub struct WorkbookMethods { + pub get_cell: fn(row_idx: u32, col_idx: u32) -> Result, + pub save_file: fn(file_path: String) -> Result, + } + + // With rename_all = "camelCase": converts to camelCase + #[derive(Interface)] + #[ts(file_name = "workbook_methods_camel.ts", rename_all = "camelCase")] + pub struct WorkbookMethodsCamel { + pub get_cell: fn(row_idx: u32, col_idx: u32) -> Result, + pub save_file: fn(file_path: String) -> Result, + } + + #[test] + fn test_rpc_interface_no_rename() { + let mut manager = DescriptorManager::default(); + let desc = WorkbookMethods::__get_rpc_descriptor(&mut manager); + manager.add_rpc_descriptor(desc); + + let data = manager.gen_data(); + let files: std::collections::HashMap<&str, &str> = data + .iter() + .map(|(name, content)| (name.as_str(), content.as_str())) + .collect(); + + let content = files.get("workbook_methods.ts").unwrap(); + // Without rename_all: keeps snake_case + assert!(content.contains( + "get_cell(row_idx: number, col_idx: number): Promise;" + )); + assert!( + content.contains("save_file(file_path: string): Promise;") + ); + } + + #[test] + fn test_rpc_interface_with_rename_all() { + let mut manager = DescriptorManager::default(); + let desc = WorkbookMethodsCamel::__get_rpc_descriptor(&mut manager); + manager.add_rpc_descriptor(desc); + + let data = manager.gen_data(); + let files: std::collections::HashMap<&str, &str> = data + .iter() + .map(|(name, content)| (name.as_str(), content.as_str())) + .collect(); + + let content = files.get("workbook_methods_camel.ts").unwrap(); + // With rename_all = "camelCase": converts to camelCase + assert!(content.contains( + "getCell(rowIdx: number, colIdx: number): Promise;" + )); + assert!( + content.contains("saveFile(filePath: string): Promise;") + ); + } +} + #[cfg(test)] mod test_api { use gents::*;