Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 140 additions & 17 deletions internal/dbtest/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ func TestMigrate(t *testing.T) {
{run: testRunMigration},
{run: testRunMigrationAssignsNewGroup},
{run: testRunMigrationUpErrorPreservesAppliedState},
{run: testHooks},
{run: testSQLMigrations},
}

testEachDB(t, func(t *testing.T, dbName string, db *bun.DB) {
Expand All @@ -79,22 +81,22 @@ func testMigrateUpAndDown(t *testing.T, db *bun.DB) {
migrations := migrate.NewMigrations()
migrations.Add(migrate.Migration{
Name: "20060102150405",
Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error {
Up: func(ctx context.Context, db *bun.DB) error {
history = append(history, "up1")
return nil
},
Down: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error {
Down: func(ctx context.Context, db *bun.DB) error {
history = append(history, "down1")
return nil
},
})
migrations.Add(migrate.Migration{
Name: "20060102160405",
Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error {
Up: func(ctx context.Context, db *bun.DB) error {
history = append(history, "up2")
return nil
},
Down: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error {
Down: func(ctx context.Context, db *bun.DB) error {
history = append(history, "down2")
return nil
},
Expand Down Expand Up @@ -129,33 +131,33 @@ func testMigrateUpError(t *testing.T, db *bun.DB) {
migrations := migrate.NewMigrations()
migrations.Add(migrate.Migration{
Name: "20060102150405",
Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error {
Up: func(ctx context.Context, db *bun.DB) error {
history = append(history, "up1")
return nil
},
Down: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error {
Down: func(ctx context.Context, db *bun.DB) error {
history = append(history, "down1")
return nil
},
})
migrations.Add(migrate.Migration{
Name: "20060102160405",
Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error {
Up: func(ctx context.Context, db *bun.DB) error {
history = append(history, "up2")
return errors.New("failed")
},
Down: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error {
Down: func(ctx context.Context, db *bun.DB) error {
history = append(history, "down2")
return nil
},
})
migrations.Add(migrate.Migration{
Name: "20060102170405",
Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error {
Up: func(ctx context.Context, db *bun.DB) error {
history = append(history, "up3")
return errors.New("failed")
},
Down: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error {
Down: func(ctx context.Context, db *bun.DB) error {
history = append(history, "down3")
return nil
},
Expand All @@ -169,8 +171,7 @@ func testMigrateUpError(t *testing.T, db *bun.DB) {
require.NoError(t, err)

group, err := m.Migrate(ctx)
require.Error(t, err)
require.Equal(t, "20060102160405: up: failed", err.Error())
require.ErrorContains(t, err, "20060102160405: up: failed")
require.Equal(t, int64(1), group.ID)
require.Len(t, group.Migrations, 2)
require.Equal(t, []string{"up1", "up2"}, history)
Expand All @@ -192,14 +193,14 @@ func testRunMigration(t *testing.T, db *bun.DB) {
migrations := migrate.NewMigrations()
migrations.Add(migrate.Migration{
Name: "20060102150405",
Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error {
Up: func(ctx context.Context, db *bun.DB) error {
history = append(history, "up1")
return nil
},
})
migrations.Add(migrate.Migration{
Name: "20060102160405",
Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error {
Up: func(ctx context.Context, db *bun.DB) error {
history = append(history, "up2")
return nil
},
Expand Down Expand Up @@ -239,13 +240,13 @@ func testRunMigrationAssignsNewGroup(t *testing.T, db *bun.DB) {
migrations := migrate.NewMigrations()
migrations.Add(migrate.Migration{
Name: "20060102150405",
Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error {
Up: func(ctx context.Context, db *bun.DB) error {
return nil
},
})
migrations.Add(migrate.Migration{
Name: "20060102160405",
Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error {
Up: func(ctx context.Context, db *bun.DB) error {
return nil
},
})
Expand Down Expand Up @@ -294,7 +295,7 @@ func testRunMigrationUpErrorPreservesAppliedState(t *testing.T, db *bun.DB) {
migrations := migrate.NewMigrations()
migrations.Add(migrate.Migration{
Name: "20060102150405",
Up: func(ctx context.Context, migrator *migrate.Migrator, migration *migrate.Migration) error {
Up: func(ctx context.Context, db *bun.DB) error {
if shouldFail {
return upErr
}
Expand Down Expand Up @@ -330,6 +331,128 @@ func testRunMigrationUpErrorPreservesAppliedState(t *testing.T, db *bun.DB) {
require.Len(t, appliedAfter, 1, "failed re-run must not delete the existing applied record")
}

// testHooks tests that BeforeHook and AfterHook are executed before each migration and rollback.
func testHooks(t *testing.T, db *bun.DB) {
nop := func(ctx context.Context, db *bun.DB) error {
return nil
}
for _, tt := range []struct {
name string // Test case name.
ms migrate.MigrationSlice // Migrations to register.
run func(ctx context.Context, m *migrate.Migrator) error // Run UP migrations.
wantHooks int // How many hooks tt.run is supposed to trigger.
}{
{
name: "migrate",
ms: migrate.MigrationSlice{
{Name: "20060102150405", Up: nop, Down: nop},
{Name: "20060102150406", Up: nop, Down: nop},
},
run: func(ctx context.Context, m *migrate.Migrator) error {
_, err := m.Migrate(t.Context())
return err
},
wantHooks: 2,
},
{
name: "run migration",
ms: migrate.MigrationSlice{
{Name: "20060102150405", Up: nop, Down: nop},
{Name: "20060102150406", Up: nop, Down: nop},
},
run: func(ctx context.Context, m *migrate.Migrator) error {
return m.RunMigration(ctx, "20060102150405")
},
wantHooks: 1,
},
} {
t.Run(tt.name, func(t *testing.T) {
migrations := migrate.NewMigrations()
for _, m := range tt.ms {
migrations.Add(m)
}

var hooks int
m := migrate.NewMigrator(db, migrations,
migrate.WithTableName(migrationsTable),
migrate.WithLocksTableName(migrationLocksTable),
migrate.WithUpsert(true),
migrate.BeforeMigration(func(context.Context, bun.IConn, *migrate.Migration) error {
hooks++
return nil
}))
require.NoError(t, m.Reset(t.Context()))

require.NoError(t, tt.run(t.Context(), m))
require.Equal(t, tt.wantHooks, hooks, "beforeMigrationHook must run on Migrate")

_, err := m.Rollback(t.Context())
require.NoError(t, err, "rollback")
require.Equal(t, tt.wantHooks*2, hooks, "beforeMigrationHook must run on Rollback")
})
}
}

func testSQLMigrations(t *testing.T, db *bun.DB) {
for _, tt := range []struct {
name string
up, down string
templateData any
table string
}{
{
name: "plain sql migration",
up: "CREATE TABLE books (isbn CHAR)",
down: "DROP TABLE books",
table: "books",
},
{
name: "with template data",
up: "CREATE TABLE {{ .Prefix }}books (isbn CHAR)",
down: "DROP TABLE {{ .Prefix }}books",
templateData: map[string]string{
"Prefix": "my_",
},
table: "my_books",
},
} {
t.Run(tt.name, func(t *testing.T) {
defer db.NewDropTable().Table(tt.table).Exec(t.Context())

tmp := t.TempDir()
name := "20060102150405_test"
var err error

err = os.WriteFile(filepath.Join(tmp, name+".up.sql"), []byte(tt.up), 0o644)
require.NoError(t, err, "create up migration")

err = os.WriteFile(filepath.Join(tmp, name+".down.sql"), []byte(tt.down), 0o644)
require.NoError(t, err, "create down migration")

migrations := migrate.NewMigrations()
err = migrations.Discover(os.DirFS(tmp))
require.NoError(t, err, "discover")

m := migrate.NewMigrator(db, migrations,
migrate.WithTableName(migrationsTable),
migrate.WithLocksTableName(migrationLocksTable))
require.NoError(t, m.Reset(t.Context()))

_, err = m.Migrate(t.Context(), migrate.WithSQLTemplateData(tt.templateData))
require.NoError(t, err, "migrate")

_, err = db.NewSelect().Table(tt.table).Exec(t.Context())
require.NoError(t, err, "books table must exist after migration")

_, err = m.Rollback(t.Context(), migrate.WithSQLTemplateData(tt.templateData))
require.NoError(t, err, "rollback")

_, err = db.NewSelect().Table(tt.table).Exec(t.Context())
require.Error(t, err, "books table must not exist after rollback")
})
}
}

// newAutoMigratorOrSkip creates an AutoMigrator configured to use test migratins/locks
// tables and dedicated migrations directory. If an AutoMigrator cannob be created because
// the dialect doesn't support either schema inspections or migrations, the test will be *skipped*
Expand Down
4 changes: 2 additions & 2 deletions migrate/auto.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ func (am *AutoMigrator) createSQLMigrations(ctx context.Context, transactional b
migrations := NewMigrations(am.migrationsOpts...)
migrations.Add(Migration{
Name: name,
Up: wrapGoMigrationFunc(changes.Up(am.dbMigrator)),
Down: wrapGoMigrationFunc(changes.Down(am.dbMigrator)),
Up: changes.Up(am.dbMigrator),
Down: changes.Down(am.dbMigrator),
Comment: "Changes detected by bun.AutoMigrator",
})

Expand Down
Loading
Loading