diff --git a/migrate/migration.go b/migrate/migration.go index 4d60a5858..295b72555 100644 --- a/migrate/migration.go +++ b/migrate/migration.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "io/fs" - "sort" "strings" "text/template" "time" @@ -232,7 +231,7 @@ func (ms MigrationSlice) Applied() MigrationSlice { applied = append(applied, ms[i]) } } - sortDesc(applied) + SafeDescSort(applied) return applied } @@ -245,7 +244,7 @@ func (ms MigrationSlice) Unapplied() MigrationSlice { unapplied = append(unapplied, ms[i]) } } - sortAsc(unapplied) + SafeAscSort(unapplied) return unapplied } @@ -275,6 +274,7 @@ func (ms MigrationSlice) LastGroup() *MigrationGroup { group.Migrations = append(group.Migrations, ms[i]) } } + return group } @@ -321,17 +321,3 @@ func WithNopMigration() MigrationOption { cfg.nop = true } } - -//------------------------------------------------------------------------------ - -func sortAsc(ms MigrationSlice) { - sort.Slice(ms, func(i, j int) bool { - return ms[i].Name < ms[j].Name - }) -} - -func sortDesc(ms MigrationSlice) { - sort.Slice(ms, func(i, j int) bool { - return ms[i].Name > ms[j].Name - }) -} diff --git a/migrate/migrations.go b/migrate/migrations.go index a22e615cb..5908b86b0 100644 --- a/migrate/migrations.go +++ b/migrate/migrations.go @@ -38,7 +38,7 @@ func NewMigrations(opts ...MigrationsOption) *Migrations { func (m *Migrations) Sorted() MigrationSlice { migrations := make(MigrationSlice, len(m.ms)) copy(migrations, m.ms) - sortAsc(migrations) + SafeAscSort(migrations) return migrations } diff --git a/migrate/migrator.go b/migrate/migrator.go index a325c3993..b64479c0d 100644 --- a/migrate/migrator.go +++ b/migrate/migrator.go @@ -47,6 +47,15 @@ func WithTemplateData(data any) MigratorOption { } } +// SetSort overrides the default ascending sort function for all migrations. +// This affects all sorting operations in the entire migrate package. +func SetSort(ascSortFn, descSortFn MigrationSortFunc) { + sortMutex.Lock() + defer sortMutex.Unlock() + AscSort = ascSortFn + DescSort = descSortFn +} + type Migrator struct { db *bun.DB migrations *Migrations diff --git a/migrate/sort.go b/migrate/sort.go new file mode 100644 index 000000000..e8a7eb46a --- /dev/null +++ b/migrate/sort.go @@ -0,0 +1,49 @@ +package migrate + +import ( + "sort" + "sync" +) + +// MigrationSortFunc defines the signature for functions that sort migrations. +type MigrationSortFunc func(ms MigrationSlice) + +// sortMutex protects access to the global sort functions. +var sortMutex sync.RWMutex + +// Default sort implementations +var defaultAscSort MigrationSortFunc = func(ms MigrationSlice) { + sort.Slice(ms, func(i, j int) bool { + return ms[i].Name < ms[j].Name + }) +} + +var defaultDescSort MigrationSortFunc = func(ms MigrationSlice) { + sort.Slice(ms, func(i, j int) bool { + return ms[i].Name > ms[j].Name + }) +} + +// AscSort is the global ascending sort function. +// Default is to sort by migration name in ascending order. +// Can be overridden to use custom sorting logic. +var AscSort MigrationSortFunc = defaultAscSort + +// DescSort is the global descending sort function. +// Default is to sort by migration name in descending order. +// Can be overridden to use custom sorting logic. +var DescSort MigrationSortFunc = defaultDescSort + +// SafeAscSort applies the current ascending sort function in a thread-safe manner. +func SafeAscSort(ms MigrationSlice) { + sortMutex.RLock() + defer sortMutex.RUnlock() + AscSort(ms) +} + +// SafeDescSort applies the current descending sort function in a thread-safe manner. +func SafeDescSort(ms MigrationSlice) { + sortMutex.RLock() + defer sortMutex.RUnlock() + DescSort(ms) +}