diff --git a/internal/dbtest/migrate_test.go b/internal/dbtest/migrate_test.go index 9bf6e79ae..662a28d7c 100644 --- a/internal/dbtest/migrate_test.go +++ b/internal/dbtest/migrate_test.go @@ -58,6 +58,8 @@ func TestMigrate(t *testing.T) { {run: testRunMigration}, {run: testRunMigrationAssignsNewGroup}, {run: testRunMigrationUpErrorPreservesAppliedState}, + {run: testHooks}, + {run: testSQLMigrations}, } testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) { @@ -79,22 +81,22 @@ func testMigrateUpAndDown(t *testing.T, db *bun.DB) { migrations := migrate.NewMigrations() migrations.Add(migrate.Migration{ Name: "20060102150405", - Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error { + Up: func(ctx context.Context, db *bun.DB) error { history = append(history, "up1") return nil }, - Down: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error { + Down: func(ctx context.Context, db *bun.DB) error { history = append(history, "down1") return nil }, }) migrations.Add(migrate.Migration{ Name: "20060102160405", - Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error { + Up: func(ctx context.Context, db *bun.DB) error { history = append(history, "up2") return nil }, - Down: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error { + Down: func(ctx context.Context, db *bun.DB) error { history = append(history, "down2") return nil }, @@ -129,33 +131,33 @@ func testMigrateUpError(t *testing.T, db *bun.DB) { migrations := migrate.NewMigrations() migrations.Add(migrate.Migration{ Name: "20060102150405", - Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error { + Up: func(ctx context.Context, db *bun.DB) error { history = append(history, "up1") return nil }, - Down: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error { + Down: func(ctx context.Context, db *bun.DB) error { history = append(history, "down1") return nil }, }) migrations.Add(migrate.Migration{ Name: "20060102160405", - Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error { + Up: func(ctx context.Context, db *bun.DB) error { history = append(history, "up2") return errors.New("failed") }, - Down: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error { + Down: func(ctx context.Context, db *bun.DB) error { history = append(history, "down2") return nil }, }) migrations.Add(migrate.Migration{ Name: "20060102170405", - Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error { + Up: func(ctx context.Context, db *bun.DB) error { history = append(history, "up3") return errors.New("failed") }, - Down: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error { + Down: func(ctx context.Context, db *bun.DB) error { history = append(history, "down3") return nil }, @@ -169,8 +171,7 @@ func testMigrateUpError(t *testing.T, db *bun.DB) { require.NoError(t, err) group, err := m.Migrate(ctx) - require.Error(t, err) - require.Equal(t, "20060102160405: up: failed", err.Error()) + require.ErrorContains(t, err, "20060102160405: up: failed") require.Equal(t, int64(1), group.ID) require.Len(t, group.Migrations, 2) require.Equal(t, []string{"up1", "up2"}, history) @@ -192,14 +193,14 @@ func testRunMigration(t *testing.T, db *bun.DB) { migrations := migrate.NewMigrations() migrations.Add(migrate.Migration{ Name: "20060102150405", - Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error { + Up: func(ctx context.Context, db *bun.DB) error { history = append(history, "up1") return nil }, }) migrations.Add(migrate.Migration{ Name: "20060102160405", - Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error { + Up: func(ctx context.Context, db *bun.DB) error { history = append(history, "up2") return nil }, @@ -239,13 +240,13 @@ func testRunMigrationAssignsNewGroup(t *testing.T, db *bun.DB) { migrations := migrate.NewMigrations() migrations.Add(migrate.Migration{ Name: "20060102150405", - Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error { + Up: func(ctx context.Context, db *bun.DB) error { return nil }, }) migrations.Add(migrate.Migration{ Name: "20060102160405", - Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error { + Up: func(ctx context.Context, db *bun.DB) error { return nil }, }) @@ -294,7 +295,7 @@ func testRunMigrationUpErrorPreservesAppliedState(t *testing.T, db *bun.DB) { migrations := migrate.NewMigrations() migrations.Add(migrate.Migration{ Name: "20060102150405", - Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error { + Up: func(ctx context.Context, db *bun.DB) error { if shouldFail { return upErr } @@ -330,6 +331,128 @@ func testRunMigrationUpErrorPreservesAppliedState(t *testing.T, db *bun.DB) { require.Len(t, appliedAfter, 1, "failed re-run must not delete the existing applied record") } +// testHooks tests that BeforeHook and AfterHook are executed before each migration and rollback. +func testHooks(t *testing.T, db *bun.DB) { + nop := func(ctx context.Context, db *bun.DB) error { + return nil + } + for _, tt := range []struct { + name string // Test case name. + ms migrate.MigrationSlice // Migrations to register. + run func(ctx context.Context, m *migrate.Migrator) error // Run UP migrations. + wantHooks int // How many hooks tt.run is supposed to trigger. + }{ + { + name: "migrate", + ms: migrate.MigrationSlice{ + {Name: "20060102150405", Up: nop, Down: nop}, + {Name: "20060102150406", Up: nop, Down: nop}, + }, + run: func(ctx context.Context, m *migrate.Migrator) error { + _, err := m.Migrate(t.Context()) + return err + }, + wantHooks: 2, + }, + { + name: "run migration", + ms: migrate.MigrationSlice{ + {Name: "20060102150405", Up: nop, Down: nop}, + {Name: "20060102150406", Up: nop, Down: nop}, + }, + run: func(ctx context.Context, m *migrate.Migrator) error { + return m.RunMigration(ctx, "20060102150405") + }, + wantHooks: 1, + }, + } { + t.Run(tt.name, func(t *testing.T) { + migrations := migrate.NewMigrations() + for _, m := range tt.ms { + migrations.Add(m) + } + + var hooks int + m := migrate.NewMigrator(db, migrations, + migrate.WithTableName(migrationsTable), + migrate.WithLocksTableName(migrationLocksTable), + migrate.WithUpsert(true), + migrate.BeforeMigration(func(context.Context, bun.IConn, *migrate.Migration) error { + hooks++ + return nil + })) + require.NoError(t, m.Reset(t.Context())) + + require.NoError(t, tt.run(t.Context(), m)) + require.Equal(t, tt.wantHooks, hooks, "beforeMigrationHook must run on Migrate") + + _, err := m.Rollback(t.Context()) + require.NoError(t, err, "rollback") + require.Equal(t, tt.wantHooks*2, hooks, "beforeMigrationHook must run on Rollback") + }) + } +} + +func testSQLMigrations(t *testing.T, db *bun.DB) { + for _, tt := range []struct { + name string + up, down string + templateData any + table string + }{ + { + name: "plain sql migration", + up: "CREATE TABLE books (isbn CHAR)", + down: "DROP TABLE books", + table: "books", + }, + { + name: "with template data", + up: "CREATE TABLE {{ .Prefix }}books (isbn CHAR)", + down: "DROP TABLE {{ .Prefix }}books", + templateData: map[string]string{ + "Prefix": "my_", + }, + table: "my_books", + }, + } { + t.Run(tt.name, func(t *testing.T) { + defer db.NewDropTable().Table(tt.table).Exec(t.Context()) + + tmp := t.TempDir() + name := "20060102150405_test" + var err error + + err = os.WriteFile(filepath.Join(tmp, name+".up.sql"), []byte(tt.up), 0o644) + require.NoError(t, err, "create up migration") + + err = os.WriteFile(filepath.Join(tmp, name+".down.sql"), []byte(tt.down), 0o644) + require.NoError(t, err, "create down migration") + + migrations := migrate.NewMigrations() + err = migrations.Discover(os.DirFS(tmp)) + require.NoError(t, err, "discover") + + m := migrate.NewMigrator(db, migrations, + migrate.WithTableName(migrationsTable), + migrate.WithLocksTableName(migrationLocksTable)) + require.NoError(t, m.Reset(t.Context())) + + _, err = m.Migrate(t.Context(), migrate.WithSQLTemplateData(tt.templateData)) + require.NoError(t, err, "migrate") + + _, err = db.NewSelect().Table(tt.table).Exec(t.Context()) + require.NoError(t, err, "books table must exist after migration") + + _, err = m.Rollback(t.Context(), migrate.WithSQLTemplateData(tt.templateData)) + require.NoError(t, err, "rollback") + + _, err = db.NewSelect().Table(tt.table).Exec(t.Context()) + require.Error(t, err, "books table must not exist after rollback") + }) + } +} + // newAutoMigratorOrSkip creates an AutoMigrator configured to use test migratins/locks // tables and dedicated migrations directory. If an AutoMigrator cannob be created because // the dialect doesn't support either schema inspections or migrations, the test will be *skipped* diff --git a/migrate/auto.go b/migrate/auto.go index 174a25649..394deeae4 100644 --- a/migrate/auto.go +++ b/migrate/auto.go @@ -272,8 +272,8 @@ func (am *AutoMigrator) createSQLMigrations(ctx context.Context, transactional b migrations := NewMigrations(am.migrationsOpts...) migrations.Add(Migration{ Name: name, - Up: wrapGoMigrationFunc(changes.Up(am.dbMigrator)), - Down: wrapGoMigrationFunc(changes.Down(am.dbMigrator)), + Up: changes.Up(am.dbMigrator), + Down: changes.Down(am.dbMigrator), Comment: "Changes detected by bun.AutoMigrator", }) diff --git a/migrate/migration.go b/migrate/migration.go index 945b59559..a0ec88923 100644 --- a/migrate/migration.go +++ b/migrate/migration.go @@ -9,6 +9,7 @@ import ( "io/fs" "slices" "strings" + "sync" "text/template" "time" @@ -25,8 +26,8 @@ type Migration struct { GroupID int64 MigratedAt time.Time `bun:",notnull,nullzero,default:current_timestamp"` - Up internalMigrationFunc `bun:"-"` - Down internalMigrationFunc `bun:"-"` + Up MigrationFunc `bun:"-"` + Down MigrationFunc `bun:"-"` } // String returns the migration name and comment. @@ -42,45 +43,26 @@ func (m Migration) IsApplied() bool { // MigrationFunc is a function that executes a migration against a database. type MigrationFunc func(ctx context.Context, db *bun.DB) error -type internalMigrationFunc func(ctx context.Context, migrator *Migrator, migration *Migration) error - -func wrapGoMigrationFunc(fn MigrationFunc) internalMigrationFunc { - return func(ctx context.Context, migrator *Migrator, migration *Migration) error { - if migrator.beforeMigrationHook != nil { - if err := migrator.beforeMigrationHook(ctx, migrator.db, migration); err != nil { - return err - } - } - - if err := fn(ctx, migrator.db); err != nil { - return err - } - - if migrator.afterMigrationHook != nil { - if err := migrator.afterMigrationHook(ctx, migrator.db, migration); err != nil { - return err - } - } - - return nil +func newSQLMigrationFunc(fsys fs.FS, name string) (MigrationFunc, error) { + sqlFile, err := fsys.Open(name) + if err != nil { + return nil, err } -} -func newSQLMigrationFunc(fsys fs.FS, name string) internalMigrationFunc { - return func(ctx context.Context, migrator *Migrator, migration *Migration) error { - sqlFile, err := fsys.Open(name) - if err != nil { - return err - } + contents, err := io.ReadAll(sqlFile) + if err != nil { + return nil, err + } - contents, err := io.ReadAll(sqlFile) - if err != nil { - return err - } + tmpl := sync.OnceValues(func() (*template.Template, error) { + return template.New(name).Parse(string(contents)) + }) + return func(ctx context.Context, db *bun.DB) error { var reader io.Reader = bytes.NewReader(contents) - if migrator.templateData != nil { - buf, err := renderTemplate(contents, migrator.templateData) + + if data := ctx.Value(templateDataKey); data != nil { + buf, err := renderTemplate(tmpl, data) if err != nil { return err } @@ -118,15 +100,16 @@ func newSQLMigrationFunc(fsys fs.FS, name string) internalMigrationFunc { var idb bun.IConn - isTx := strings.HasSuffix(name, ".tx.up.sql") || strings.HasSuffix(name, ".tx.down.sql") + isTx := strings.HasSuffix(name, ".tx.up.sql") || + strings.HasSuffix(name, ".tx.down.sql") if isTx { - tx, err := migrator.db.BeginTx(ctx, nil) + tx, err := db.BeginTx(ctx, nil) if err != nil { return err } idb = tx } else { - conn, err := migrator.db.Conn(ctx) + conn, err := db.Conn(ctx) if err != nil { return err } @@ -154,25 +137,34 @@ func newSQLMigrationFunc(fsys fs.FS, name string) internalMigrationFunc { panic("not reached") }() - execErr = migrator.exec(ctx, idb, migration, queries) - if execErr != nil { - return execErr + for _, query := range queries { + if strings.TrimSpace(query) == "" { + continue + } + if _, execErr = db.ExecContext(ctx, query); execErr != nil { + return execErr + } } + return retErr - } + }, nil } -func renderTemplate(contents []byte, templateData any) (*bytes.Buffer, error) { - tmpl, err := template.New("migration").Parse(string(contents)) +//------------------------------------------------------------------------------ + +type contextKey struct{} + +var templateDataKey = contextKey{} + +func renderTemplate(parseFunc func() (*template.Template, error), templateData any) (*bytes.Buffer, error) { + tmpl, err := parseFunc() if err != nil { - return nil, fmt.Errorf("failed to parse template: %w", err) + return nil, fmt.Errorf("parse template: %w", err) } - var rendered bytes.Buffer if err := tmpl.Execute(&rendered, templateData); err != nil { - return nil, fmt.Errorf("failed to execute template: %w", err) + return nil, fmt.Errorf("execute template: %w", err) } - return &rendered, nil } @@ -293,6 +285,15 @@ func (ms MigrationSlice) LastGroup() *MigrationGroup { return group } +func (ms MigrationSlice) Index(migrationName string) int { + for i := range ms { + if ms[i].Name == migrationName { + return i + } + } + return -1 +} + // MigrationGroup is a group of migrations that were applied together in a single Migrate call. type MigrationGroup struct { ID int64 @@ -321,11 +322,14 @@ type MigrationFile struct { //------------------------------------------------------------------------------ type migrationConfig struct { - nop bool + nop bool + templateData any } -func newMigrationConfig(opts []MigrationOption) *migrationConfig { - cfg := new(migrationConfig) +func (m *Migrator) newMigrationConfig(opts []MigrationOption) *migrationConfig { + cfg := &migrationConfig{ + templateData: m.templateData, + } for _, opt := range opts { opt(cfg) } @@ -342,6 +346,13 @@ func WithNopMigration() MigrationOption { } } +// WithSQLTemplateData provides data for templated SQL migrations. +func WithSQLTemplateData(templateData any) MigrationOption { + return func(cfg *migrationConfig) { + cfg.templateData = templateData + } +} + //------------------------------------------------------------------------------ func sortAsc(ms MigrationSlice) { diff --git a/migrate/migrations.go b/migrate/migrations.go index 6af5447c6..04cdeb2f1 100644 --- a/migrate/migrations.go +++ b/migrate/migrations.go @@ -1,7 +1,6 @@ package migrate import ( - "errors" "fmt" "io/fs" "os" @@ -65,8 +64,8 @@ func (m *Migrations) Register(up, down MigrationFunc) error { m.Add(Migration{ Name: name, Comment: comment, - Up: wrapGoMigrationFunc(up), - Down: wrapGoMigrationFunc(down), + Up: up, + Down: down, }) return nil @@ -115,7 +114,10 @@ func (m *Migrations) Discover(fsys fs.FS) error { } migration.Comment = comment - migrationFunc := newSQLMigrationFunc(fsys, path) + migrationFunc, err := newSQLMigrationFunc(fsys, path) + if err != nil { + return err + } if strings.HasSuffix(path, ".up.sql") { migration.Up = migrationFunc @@ -126,7 +128,7 @@ func (m *Migrations) Discover(fsys fs.FS) error { return nil } - return errors.New("migrate: not reached") + panic("unreachable") }) } diff --git a/migrate/migrator.go b/migrate/migrator.go index 9e439f2fe..508c758eb 100644 --- a/migrate/migrator.go +++ b/migrate/migrator.go @@ -54,6 +54,7 @@ func WithUpsert(enabled bool) MigratorOption { } // WithTemplateData sets data passed to SQL migration templates during rendering. +// Use [WithSQLTemplateData] to re-use the [Migrator] instance with different data. func WithTemplateData(data any) MigratorOption { return func(m *Migrator) { m.templateData = data @@ -190,49 +191,23 @@ func (m *Migrator) Reset(ctx context.Context) error { // Migrate runs unapplied migrations. If a migration fails, migrate immediately exits. func (m *Migrator) Migrate(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) { - cfg := newMigrationConfig(opts) - - group := new(MigrationGroup) - if err := m.validate(); err != nil { - return group, err + return nil, err } migrations, lastGroupID, err := m.migrationsWithStatus(ctx) if err != nil { - return group, err - } - migrations = migrations.Unapplied() - if len(migrations) == 0 { - return group, nil + return nil, err } - group.ID = lastGroupID + 1 - for i := range migrations { - migration := &migrations[i] - migration.GroupID = group.ID - - if !m.markAppliedOnSuccess { - if err := m.MarkApplied(ctx, migration); err != nil { - return group, err - } - } - - group.Migrations = migrations[:i+1] - - if !cfg.nop && migration.Up != nil { - if err := migration.Up(ctx, m, migration); err != nil { - return group, fmt.Errorf("%s: up: %w", migration.Name, err) - } - } - - if m.markAppliedOnSuccess { - if err := m.MarkApplied(ctx, migration); err != nil { - return group, err - } - } + group := &MigrationGroup{ + ID: lastGroupID + 1, + Migrations: migrations.Unapplied(), } + if err := m.migrateGroup(ctx, group, opts...); err != nil { + return group, fmt.Errorf("migrate: %w", err) + } return group, nil } @@ -242,8 +217,6 @@ func (m *Migrator) Migrate(ctx context.Context, opts ...MigrationOption) (*Migra func (m *Migrator) RunMigration( ctx context.Context, migrationName string, opts ...MigrationOption, ) error { - cfg := newMigrationConfig(opts) - if err := m.validate(); err != nil { return err } @@ -259,38 +232,72 @@ func (m *Migrator) RunMigration( return err } - var migration *Migration - for i := range migrations { - if migrations[i].Name == migrationName { - migration = &migrations[i] - break - } - } - if migration == nil { + idx := migrations.Index(migrationName) + if idx == -1 { return fmt.Errorf("migrate: migration with name %q not found", migrationName) } - if migration.Up == nil { - return fmt.Errorf("migrate: migration %s does not have up migration", migration.Name) + + group := &MigrationGroup{ + ID: lastGroupID + 1, + Migrations: migrations[idx : idx+1], } - if cfg.nop { - return nil + + if err := m.migrateGroup(ctx, group, opts...); err != nil { + return fmt.Errorf("migrate: %w", err) } + return nil +} - migration.GroupID = lastGroupID + 1 +func (m *Migrator) migrateGroup(ctx context.Context, group *MigrationGroup, opts ...MigrationOption) error { + cfg := m.newMigrationConfig(opts) - if !m.markAppliedOnSuccess { - if err := m.MarkApplied(ctx, migration); err != nil { - return err + migrations := group.Migrations[:] + group.Migrations = group.Migrations[:0] + + for i := range migrations { + migration := &migrations[i] + migration.GroupID = group.ID + + if migration.Up == nil { + return fmt.Errorf("migrate: migration %s does not have up migration", migration.Name) } - } - if err := migration.Up(ctx, m, migration); err != nil { - return fmt.Errorf("%s: up: %w", migration.Name, err) - } + if !m.markAppliedOnSuccess { + if err := m.MarkApplied(ctx, migration); err != nil { + return err + } + } - if m.markAppliedOnSuccess { - if err := m.MarkApplied(ctx, migration); err != nil { - return err + group.Migrations = group.Migrations[:i+1] + + // TODO(dyma): Migrate marks a migration applied even in a nop run; RunMigration + // doesn't MarkApplied on a nop run. Is this intentional? + if !cfg.nop { + if cfg.templateData != nil { + ctx = context.WithValue(ctx, templateDataKey, cfg.templateData) + } + + if m.beforeMigrationHook != nil { + if err := m.beforeMigrationHook(ctx, m.db, migration); err != nil { + return err + } + } + + if err := migration.Up(ctx, m.db); err != nil { + return fmt.Errorf("%s: up: %w", migration.Name, err) + } + + if m.afterMigrationHook != nil { + if err := m.afterMigrationHook(ctx, m.db, migration); err != nil { + return err + } + } + } + + if m.markAppliedOnSuccess { + if err := m.MarkApplied(ctx, migration); err != nil { + return err + } } } @@ -299,7 +306,7 @@ func (m *Migrator) RunMigration( // Rollback rolls back the last migration group. func (m *Migrator) Rollback(ctx context.Context, opts ...MigrationOption) (*MigrationGroup, error) { - cfg := newMigrationConfig(opts) + cfg := m.newMigrationConfig(opts) lastGroup := new(MigrationGroup) @@ -323,12 +330,27 @@ func (m *Migrator) Rollback(ctx context.Context, opts ...MigrationOption) (*Migr } } + if m.beforeMigrationHook != nil { + if err := m.beforeMigrationHook(ctx, m.db, migration); err != nil { + return lastGroup, fmt.Errorf("before migration: %w", err) + } + } + if !cfg.nop && migration.Down != nil { - if err := migration.Down(ctx, m, migration); err != nil { + if cfg.templateData != nil { + ctx = context.WithValue(ctx, templateDataKey, cfg.templateData) + } + if err := migration.Down(ctx, m.db); err != nil { return lastGroup, fmt.Errorf("%s: down: %w", migration.Name, err) } } + if m.afterMigrationHook != nil { + if err := m.afterMigrationHook(ctx, m.db, migration); err != nil { + return lastGroup, fmt.Errorf("after migration: %w", err) + } + } + if m.markAppliedOnSuccess { if err := m.MarkUnapplied(ctx, migration); err != nil { return lastGroup, err @@ -576,33 +598,6 @@ func (m *Migrator) validate() error { return nil } -func (m *Migrator) exec( - ctx context.Context, db bun.IConn, migration *Migration, queries []string, -) error { - if m.beforeMigrationHook != nil { - if err := m.beforeMigrationHook(ctx, db, migration); err != nil { - return err - } - } - - for _, query := range queries { - if strings.TrimSpace(query) == "" { - continue - } - if _, err := db.ExecContext(ctx, query); err != nil { - return err - } - } - - if m.afterMigrationHook != nil { - if err := m.afterMigrationHook(ctx, db, migration); err != nil { - return err - } - } - - return nil -} - //------------------------------------------------------------------------------ type migrationLock struct {