From fa2cab5fcdd1fc217587a01868e45af923623d67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ciro=20Garc=C3=ADa=20Belmonte?= Date: Tue, 31 Mar 2026 19:44:56 +0200 Subject: [PATCH] feat: add ImportModule functionality --- convgen.go | 8 +++ internal/convgen/parse/config.go | 53 ++++++++++++++-- internal/convgen/parse/injector.go | 11 +++- internal/convgen/parse/module.go | 61 +++++++++++++++---- testdata/program/ModuleImport/main/main.go | 43 +++++++++++++ .../ModuleImport/want/program_output.txt | 1 + 6 files changed, 157 insertions(+), 20 deletions(-) create mode 100644 testdata/program/ModuleImport/main/main.go create mode 100644 testdata/program/ModuleImport/want/program_output.txt diff --git a/convgen.go b/convgen.go index 2bf4495..5945739 100644 --- a/convgen.go +++ b/convgen.go @@ -201,6 +201,14 @@ func Module(opts ...moduleOption) module { panic("convgen: not generated") } +// ImportModule imports configurations and registered functions from another module. +// +// var core = convgen.Module(...) +// var ext = convgen.Module(convgen.ImportModule(core), ...) +func ImportModule(mod module) moduleOption { + panic("convgen: not generated") +} + // Struct directive generates a converter function between two struct types // without error: // diff --git a/internal/convgen/parse/config.go b/internal/convgen/parse/config.go index b853527..4a9c155 100644 --- a/internal/convgen/parse/config.go +++ b/internal/convgen/parse/config.go @@ -160,6 +160,29 @@ func (cfg Config) ForkForEnum() Config { return c } +func (cfg *Config) UpdateImport(other Config) { + cfg.Update(other) + + if other.ForStruct != nil { + if cfg.ForStruct == nil { + cfg.ForStruct = &Config{} + } + cfg.ForStruct.UpdateImport(*other.ForStruct) + } + if other.ForUnion != nil { + if cfg.ForUnion == nil { + cfg.ForUnion = &Config{} + } + cfg.ForUnion.UpdateImport(*other.ForUnion) + } + if other.ForEnum != nil { + if cfg.ForEnum == nil { + cfg.ForEnum = &Config{} + } + cfg.ForEnum.UpdateImport(*other.ForEnum) + } +} + type parsers interface { ParsePathX(p *Parser, expr ast.Expr) (*Path, error) ParsePathY(p *Parser, expr ast.Expr) (*Path, error) @@ -169,7 +192,7 @@ type parsers interface { ParsePkgY(p *Parser, expr ast.Expr) (*types.Package, error) } -func (p *Parser) ParseConfig(cfg *Config, args []ast.Expr, parsers parsers) error { +func (p *Parser) ParseConfig(cfg *Config, args []ast.Expr, parsers parsers, fetchMod func(token.Pos) (*Module, error)) error { var errs error for _, arg := range args { if _, ok := arg.(*ast.Ident); ok { @@ -195,7 +218,7 @@ func (p *Parser) ParseConfig(cfg *Config, args []ast.Expr, parsers parsers) erro if cfg.ForStruct == nil { cfg.ForStruct = &Config{} } - if err := p.ParseConfig(cfg.ForStruct, call.Args, parsers); err != nil { + if err := p.ParseConfig(cfg.ForStruct, call.Args, parsers, fetchMod); err != nil { errs = errors.Join(errs, err) } continue @@ -203,7 +226,7 @@ func (p *Parser) ParseConfig(cfg *Config, args []ast.Expr, parsers parsers) erro if cfg.ForUnion == nil { cfg.ForUnion = &Config{} } - if err := p.ParseConfig(cfg.ForUnion, call.Args, parsers); err != nil { + if err := p.ParseConfig(cfg.ForUnion, call.Args, parsers, fetchMod); err != nil { errs = errors.Join(errs, err) } continue @@ -211,20 +234,20 @@ func (p *Parser) ParseConfig(cfg *Config, args []ast.Expr, parsers parsers) erro if cfg.ForEnum == nil { cfg.ForEnum = &Config{} } - if err := p.ParseConfig(cfg.ForEnum, call.Args, parsers); err != nil { + if err := p.ParseConfig(cfg.ForEnum, call.Args, parsers, fetchMod); err != nil { errs = errors.Join(errs, err) } continue } - if err := p.ParseOption(cfg, call, parsers); err != nil { + if err := p.ParseOption(cfg, call, parsers, fetchMod); err != nil { errs = errors.Join(errs, err) } } return errs } -func (p *Parser) ParseOption(cfg *Config, call *ast.CallExpr, ps parsers) error { // nolint: gocyclo +func (p *Parser) ParseOption(cfg *Config, call *ast.CallExpr, ps parsers, fetchMod func(token.Pos) (*Module, error)) error { // nolint: gocyclo callee := typeutil.Callee(p.Pkg().TypesInfo, call) if callee == nil || !IsConvgenImport(callee.Pkg().Path()) { return codefmt.Errorf(p, call, "option must be convgen directive") @@ -232,6 +255,8 @@ func (p *Parser) ParseOption(cfg *Config, call *ast.CallExpr, ps parsers) error name := callee.Name() switch name { + case "ImportModule": + return p.ParseOptionImportModule(cfg, call, fetchMod) case "ImportFunc": return p.ParseOptionImportFunc(cfg, call, false) case "ImportFuncErr": @@ -290,6 +315,22 @@ func (p *Parser) ParseOption(cfg *Config, call *ast.CallExpr, ps parsers) error return codefmt.Errorf(p, call.Fun, "%s is not supported option", name) } +func (p *Parser) ParseOptionImportModule(c *Config, call *ast.CallExpr, fetchMod func(token.Pos) (*Module, error)) error { + expr, err := needArgs1(p, call) + if err != nil { + return err + } + + mod, err := p.ParseModuleArg(expr, fetchMod) + if err != nil { + return err + } + if mod != nil { + c.UpdateImport(mod.Config) + } + return nil +} + func (p *Parser) ParseOptionImportFunc(c *Config, call *ast.CallExpr, hasErr bool) error { expr, err := needArgs1(p, call) if err != nil { diff --git a/internal/convgen/parse/injector.go b/internal/convgen/parse/injector.go index 58b63df..f58b2b4 100644 --- a/internal/convgen/parse/injector.go +++ b/internal/convgen/parse/injector.go @@ -215,7 +215,14 @@ func (p *Parser) parseInjector(id *ast.Ident, call *ast.CallExpr, doc, comment * inj.Func = fn errs = errors.Join(errs, err) - mod, err := p.ParseModuleArg(call.Args[0], mods) + fetchMod := func(pos token.Pos) (*Module, error) { + if mod, ok := mods[pos]; ok { + return mod, nil + } + return nil, nil // not found + } + + mod, err := p.ParseModuleArg(call.Args[0], fetchMod) if err != nil { mod = NilModule() // Prevent nil panic to collect as many errors as possible } @@ -266,7 +273,7 @@ func (p *Parser) parseInjector(id *ast.Ident, call *ast.CallExpr, doc, comment * // Parse config cfg.DiscoverBySamplePkgX = inj.X().Pkg() cfg.DiscoverBySamplePkgY = inj.Y().Pkg() - errs = errors.Join(errs, p.ParseConfig(&cfg, opts, parsers)) + errs = errors.Join(errs, p.ParseConfig(&cfg, opts, parsers, fetchMod)) inj.Config = cfg // Register into the module diff --git a/internal/convgen/parse/module.go b/internal/convgen/parse/module.go index 39cdfc1..d775da9 100644 --- a/internal/convgen/parse/module.go +++ b/internal/convgen/parse/module.go @@ -37,7 +37,12 @@ type Module struct { // ParseModules finds and parses all convgen.Module calls in the parsed files. func (p *Parser) ParseModules() (map[token.Pos]*Module, error) { var errs error - mods := make(map[token.Pos]*Module) + + type modDecl struct { + Name string + Call *ast.CallExpr + } + decls := make(map[token.Pos]modDecl) for _, file := range p.ConvgenGoFiles() { for id, call := range p.FindModules(file) { @@ -46,10 +51,39 @@ func (p *Parser) ParseModules() (map[token.Pos]*Module, error) { name = "" } - mod, err := p.ParseModule(call, name) - mods[id.Pos()] = mod + decls[id.Pos()] = modDecl{Name: name, Call: call} + } + } + + mods := make(map[token.Pos]*Module) + visiting := make(map[token.Pos]bool) + + var fetchMod func(token.Pos) (*Module, error) + fetchMod = func(pos token.Pos) (*Module, error) { + if mod, ok := mods[pos]; ok { + return mod, nil + } + if visiting[pos] { + return nil, errors.New("import cycle detected") + } + decl, ok := decls[pos] + if !ok { + return nil, nil + } + + visiting[pos] = true + mod, err := p.ParseModule(decl.Call, decl.Name, fetchMod) + if err != nil { errs = errors.Join(errs, err) } + // Mark as not visiting even if there is an error to avoid blocking other modules that depend on this module. + visiting[pos] = false + mods[pos] = mod + return mod, err + } + + for pos := range decls { + _, _ = fetchMod(pos) } return mods, errs @@ -92,7 +126,7 @@ func (p *Parser) FindModules(file *ast.File) iter.Seq2[*ast.Ident, *ast.CallExpr // ParseModule parses a [convgen.Module] call expression and returns a new // module. -func (p *Parser) ParseModule(call *ast.CallExpr, name string) (*Module, error) { +func (p *Parser) ParseModule(call *ast.CallExpr, name string, fetchMod func(token.Pos) (*Module, error)) (*Module, error) { // Chain of For* after NewModule calls := []*ast.CallExpr{call} for { @@ -113,7 +147,7 @@ func (p *Parser) ParseModule(call *ast.CallExpr, name string) (*Module, error) { var cfg Config var errs error - if err := p.ParseConfig(&cfg, calls[0].Args, nil); err != nil { + if err := p.ParseConfig(&cfg, calls[0].Args, nil, fetchMod); err != nil { errs = errors.Join(errs, err) } @@ -121,15 +155,15 @@ func (p *Parser) ParseModule(call *ast.CallExpr, name string) (*Module, error) { switch call.Fun.(*ast.SelectorExpr).Sel.Name { case "ForStruct": cfg.ForStruct = &Config{} - err := p.ParseConfig(cfg.ForStruct, call.Args, nil) + err := p.ParseConfig(cfg.ForStruct, call.Args, nil, fetchMod) errs = errors.Join(errs, err) case "ForUnion": cfg.ForUnion = &Config{} - err := p.ParseConfig(cfg.ForUnion, call.Args, nil) + err := p.ParseConfig(cfg.ForUnion, call.Args, nil, fetchMod) errs = errors.Join(errs, err) case "ForEnum": cfg.ForEnum = &Config{} - err := p.ParseConfig(cfg.ForEnum, call.Args, nil) + err := p.ParseConfig(cfg.ForEnum, call.Args, nil, fetchMod) errs = errors.Join(errs, err) default: panic("unexpected module chain") @@ -181,7 +215,7 @@ func (p *Parser) newModuleLookup(cfg Config, old *typeinfo.Lookup[typeinfo.Func] // ParseModuleArg parses a Convgen module type argument from the given // expression. -func (p *Parser) ParseModuleArg(expr ast.Expr, mods map[token.Pos]*Module) (*Module, error) { +func (p *Parser) ParseModuleArg(expr ast.Expr, fetchMod func(token.Pos) (*Module, error)) (*Module, error) { expr = ast.Unparen(expr) // Inline Module Declaration @@ -194,7 +228,7 @@ func (p *Parser) ParseModuleArg(expr ast.Expr, mods map[token.Pos]*Module) (*Mod // implicit converters. The implicit converters will inherit the module's // configuration. if call, ok := expr.(*ast.CallExpr); ok && p.IsDirective(call, "Module") { - return p.ParseModule(call, "") + return p.ParseModule(call, "", fetchMod) } // Validate identifier @@ -230,8 +264,11 @@ func (p *Parser) ParseModuleArg(expr ast.Expr, mods map[token.Pos]*Module) (*Mod // This is the most common way to declare and use a module. Multiple // converters can belong to the same package-level module. modPos := p.Pkg().TypesInfo.ObjectOf(id).Pos() - mod, ok := mods[modPos] - if !ok { + mod, err := fetchMod(modPos) + if err != nil { + return nil, codefmt.Errorf(p, expr, "cannot import module %q: %v", id.Name, err) + } + if mod == nil { return nil, codefmt.Errorf(p, expr, "cannot find %q module declared by convgen.Module", id.Name) } return mod, nil diff --git a/testdata/program/ModuleImport/main/main.go b/testdata/program/ModuleImport/main/main.go new file mode 100644 index 0000000..f1eacfa --- /dev/null +++ b/testdata/program/ModuleImport/main/main.go @@ -0,0 +1,43 @@ +//go:build convgen + +package main + +import ( + "fmt" + "strconv" + + "github.com/sublee/convgen" +) + +type User struct { + Id int + Name string +} + +type UserDTO struct { + ID string + NAME string +} + +func IntToString(i int) string { + return strconv.Itoa(i) +} + +var mod1 = convgen.Module( + convgen.ImportFunc(IntToString), +) + +var mod2 = convgen.Module( + convgen.ImportModule(mod1), + convgen.RenameToLower(true, true), +) + +var UserToDTO = convgen.Struct[User, UserDTO](mod2) + +func main() { + dto := UserToDTO(User{ + Id: 42, + Name: "Alice", + }) + fmt.Println(dto.ID, dto.NAME) +} diff --git a/testdata/program/ModuleImport/want/program_output.txt b/testdata/program/ModuleImport/want/program_output.txt new file mode 100644 index 0000000..999f8c7 --- /dev/null +++ b/testdata/program/ModuleImport/want/program_output.txt @@ -0,0 +1 @@ +42 Alice