diff --git a/globals.go b/globals.go index 66eb2152f..c92239caf 100644 --- a/globals.go +++ b/globals.go @@ -1,7 +1,9 @@ package goose import ( + "errors" "fmt" + "path/filepath" ) var ( @@ -16,34 +18,120 @@ func ResetGlobalMigrations() { } // SetGlobalMigrations registers Go migrations globally. It returns an error if a migration with the -// same version has already been registered. -// -// Avoid constructing migrations manually, use [NewGoMigration] function. -// -// Source may be empty, but if it is set, it must be a path with a numeric component that matches -// the version. Do not register legacy non-context functions: UpFn, DownFn, UpFnNoTx, DownFnNoTx. +// same version has already been registered. Go migrations must be constructed using the +// [NewGoMigration] function. // // Not safe for concurrent use. func SetGlobalMigrations(migrations ...Migration) error { - for _, m := range migrations { - migration := &m - if err := validGoMigration(migration); err != nil { - return fmt.Errorf("invalid go migration: %w", err) + for _, migration := range migrations { + m := &migration + if err := setGoMigration(m); err != nil { + return err + } + registeredGoMigrations[m.Version] = m + } + return nil +} + +func setGoMigration(m *Migration) error { + if _, ok := registeredGoMigrations[m.Version]; ok { + return fmt.Errorf("go migration with version %d already registered", m.Version) + } + if err := checkMigration(m); err != nil { + return fmt.Errorf("invalid go migration: %w", err) + } + return nil +} + +func checkMigration(m *Migration) error { + if !m.construct { + return errors.New("must use NewGoMigration to construct migrations") + } + if !m.Registered { + return errors.New("must be registered") + } + if m.Type != TypeGo { + return fmt.Errorf("type must be %q", TypeGo) + } + if m.Version < 1 { + return errors.New("version must be greater than zero") + } + if m.Source != "" { + if filepath.Ext(m.Source) != ".go" { + return fmt.Errorf("source must have .go extension: %q", m.Source) } - if err := verifyAndUpdateGoFunc(migration.goUp); err != nil { - return fmt.Errorf("up function: %w", err) + // If the source is set, expect it to be a path with a numeric component that matches the + // version. This field is not intended to be used for descriptive purposes. + version, err := NumericComponent(m.Source) + if err != nil { + return fmt.Errorf("invalid source: %w", err) } - if err := verifyAndUpdateGoFunc(migration.goDown); err != nil { - return fmt.Errorf("down function: %w", err) + if version != m.Version { + return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source) } - if err := updateLegacyFuncs(migration); err != nil { - return fmt.Errorf("invalid go migration: %w", err) + } + if err := setGoFunc(m.goUp); err != nil { + return fmt.Errorf("up function: %w", err) + } + if err := setGoFunc(m.goDown); err != nil { + return fmt.Errorf("down function: %w", err) + } + if m.UpFnContext != nil && m.UpFnNoTxContext != nil { + return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext") + } + if m.UpFn != nil && m.UpFnNoTx != nil { + return errors.New("must specify exactly one of UpFn or UpFnNoTx") + } + if m.DownFnContext != nil && m.DownFnNoTxContext != nil { + return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext") + } + if m.DownFn != nil && m.DownFnNoTx != nil { + return errors.New("must specify exactly one of DownFn or DownFnNoTx") + } + return nil +} + +func setGoFunc(f *GoFunc) error { + if f == nil { + f = &GoFunc{Mode: TransactionEnabled} + return nil + } + if f.RunTx != nil && f.RunDB != nil { + return errors.New("must specify exactly one of RunTx or RunDB") + } + if f.RunTx == nil && f.RunDB == nil { + switch f.Mode { + case 0: + // Default to TransactionEnabled ONLY if mode is not set explicitly. + f.Mode = TransactionEnabled + case TransactionEnabled, TransactionDisabled: + // No functions but mode is set. This is not an error. It means the user wants to record + // a version with the given mode but not run any functions. + default: + return fmt.Errorf("invalid mode: %d", f.Mode) } - if _, ok := registeredGoMigrations[m.Version]; ok { - return fmt.Errorf("go migration with version %d already registered", m.Version) + return nil + } + if f.RunDB != nil { + switch f.Mode { + case 0, TransactionDisabled: + f.Mode = TransactionDisabled + default: + return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set") + } + } + if f.RunTx != nil { + switch f.Mode { + case 0, TransactionEnabled: + f.Mode = TransactionEnabled + default: + return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set") } - m.Next, m.Previous = -1, -1 // Do not allow these to be set by the user. - registeredGoMigrations[m.Version] = migration + } + // This is a defensive check. If the mode is still 0, it means we failed to infer the mode from + // the functions or return an error. This should never happen. + if f.Mode == 0 { + return errors.New("failed to infer transaction mode") } return nil } diff --git a/globals_test.go b/globals_test.go index 2a5294074..e91513b40 100644 --- a/globals_test.go +++ b/globals_test.go @@ -31,6 +31,23 @@ func TestNewGoMigration(t *testing.T) { check.Equal(t, m.goUp.Mode, TransactionEnabled) check.Equal(t, m.goDown.Mode, TransactionEnabled) }) + t.Run("all_set", func(t *testing.T) { + // This will eventually be an error when registering migrations. + m := NewGoMigration( + 1, + &GoFunc{RunTx: func(context.Context, *sql.Tx) error { return nil }, RunDB: func(context.Context, *sql.DB) error { return nil }}, + &GoFunc{RunTx: func(context.Context, *sql.Tx) error { return nil }, RunDB: func(context.Context, *sql.DB) error { return nil }}, + ) + // check only functions + check.Bool(t, m.UpFn != nil, true) + check.Bool(t, m.UpFnContext != nil, true) + check.Bool(t, m.UpFnNoTx != nil, true) + check.Bool(t, m.UpFnNoTxContext != nil, true) + check.Bool(t, m.DownFn != nil, true) + check.Bool(t, m.DownFnContext != nil, true) + check.Bool(t, m.DownFnNoTx != nil, true) + check.Bool(t, m.DownFnNoTxContext != nil, true) + }) } func TestTransactionMode(t *testing.T) { @@ -121,6 +138,7 @@ func TestLegacyFunctions(t *testing.T) { } t.Run("all_tx", func(t *testing.T) { + t.Cleanup(ResetGlobalMigrations) err := SetGlobalMigrations( NewGoMigration(1, &GoFunc{RunTx: runTx}, &GoFunc{RunTx: runTx}), ) @@ -137,17 +155,18 @@ func TestLegacyFunctions(t *testing.T) { check.Bool(t, m.goDown == nil, false) check.Bool(t, m.DownFnContext == nil, false) // Always nil - check.Bool(t, m.UpFn == nil, true) - check.Bool(t, m.DownFn == nil, true) + check.Bool(t, m.UpFn == nil, false) + check.Bool(t, m.DownFn == nil, false) check.Bool(t, m.UpFnNoTx == nil, true) check.Bool(t, m.DownFnNoTx == nil, true) }) t.Run("all_db", func(t *testing.T) { + t.Cleanup(ResetGlobalMigrations) err := SetGlobalMigrations( NewGoMigration(2, &GoFunc{RunDB: runDB}, &GoFunc{RunDB: runDB}), ) check.NoError(t, err) - check.Number(t, len(registeredGoMigrations), 2) + check.Number(t, len(registeredGoMigrations), 1) m := registeredGoMigrations[2] assertMigration(t, m, 2) // Legacy functions. @@ -161,15 +180,15 @@ func TestLegacyFunctions(t *testing.T) { // Always nil check.Bool(t, m.UpFn == nil, true) check.Bool(t, m.DownFn == nil, true) - check.Bool(t, m.UpFnNoTx == nil, true) - check.Bool(t, m.DownFnNoTx == nil, true) + check.Bool(t, m.UpFnNoTx == nil, false) + check.Bool(t, m.DownFnNoTx == nil, false) }) } func TestGlobalRegister(t *testing.T) { t.Cleanup(ResetGlobalMigrations) - runDB := func(context.Context, *sql.DB) error { return nil } + // runDB := func(context.Context, *sql.DB) error { return nil } runTx := func(context.Context, *sql.Tx) error { return nil } // Success. @@ -185,87 +204,63 @@ func TestGlobalRegister(t *testing.T) { ) check.HasError(t, err) check.Contains(t, err.Error(), "go migration with version 1 already registered") - err = SetGlobalMigrations( - Migration{ - Registered: true, - Version: 2, - Source: "00002_foo.sql", - Type: TypeGo, - UpFnContext: func(context.Context, *sql.Tx) error { return nil }, - DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil }, - }, - ) - check.NoError(t, err) - // Reset. - { - ResetGlobalMigrations() - } - // Failure. - err = SetGlobalMigrations( - Migration{}, - ) + err = SetGlobalMigrations(Migration{Registered: true, Version: 2, Type: TypeGo}) check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: must be registered") - err = SetGlobalMigrations( - Migration{Registered: true}, - ) + check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations") +} + +func TestCheckMigration(t *testing.T) { + // Failures. + err := checkMigration(&Migration{}) check.HasError(t, err) - check.Contains(t, err.Error(), `invalid go migration: type must be "go"`) - err = SetGlobalMigrations( - Migration{Registered: true, Version: 1, Type: TypeSQL}, - ) + check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations") + err = checkMigration(&Migration{construct: true}) check.HasError(t, err) - check.Contains(t, err.Error(), `invalid go migration: type must be "go"`) - err = SetGlobalMigrations( - Migration{Registered: true, Version: 0, Type: TypeGo}, - ) + check.Contains(t, err.Error(), "must be registered") + err = checkMigration(&Migration{construct: true, Registered: true}) check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: version must be greater than zero") - err = SetGlobalMigrations( - Migration{Registered: true, Version: 1, Source: "2_foo.sql", Type: TypeGo}, - ) + check.Contains(t, err.Error(), `type must be "go"`) + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo}) check.HasError(t, err) - check.Contains( - t, - err.Error(), - `invalid go migration: version:1 does not match numeric component in source "2_foo.sql"`, - ) - // Legacy functions. - err = SetGlobalMigrations( - Migration{Registered: true, Version: 1, UpFn: func(tx *sql.Tx) error { return nil }, Type: TypeGo}, - ) + check.Contains(t, err.Error(), "version must be greater than zero") + // Success. + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1}) + check.NoError(t, err) + // Failures. + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"}) check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: must not specify UpFn") - err = SetGlobalMigrations( - Migration{Registered: true, Version: 1, DownFn: func(tx *sql.Tx) error { return nil }, Type: TypeGo}, - ) + check.Contains(t, err.Error(), `source must have .go extension: "foo"`) + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"}) check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: must not specify DownFn") - err = SetGlobalMigrations( - Migration{Registered: true, Version: 1, UpFnNoTx: func(db *sql.DB) error { return nil }, Type: TypeGo}, - ) + check.Contains(t, err.Error(), `no filename separator '_' found`) + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"}) check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: must not specify UpFnNoTx") - err = SetGlobalMigrations( - Migration{Registered: true, Version: 1, DownFnNoTx: func(db *sql.DB) error { return nil }, Type: TypeGo}, - ) + check.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`) + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"}) check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: must not specify DownFnNoTx") - // Context-aware functions. - err = SetGlobalMigrations( - Migration{Registered: true, Version: 1, UpFnContext: runTx, UpFnNoTxContext: runDB, Type: TypeGo}, - ) + check.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`) + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, + UpFnContext: func(context.Context, *sql.Tx) error { return nil }, + UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil }, + }) check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of UpFnContext or UpFnNoTxContext") - err = SetGlobalMigrations( - Migration{Registered: true, Version: 1, DownFnContext: runTx, DownFnNoTxContext: runDB, Type: TypeGo}, - ) + check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext") + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, + DownFnContext: func(context.Context, *sql.Tx) error { return nil }, + DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil }, + }) check.HasError(t, err) - check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of DownFnContext or DownFnNoTxContext") - // Source and version mismatch. - err = SetGlobalMigrations( - Migration{Registered: true, Version: 1, Source: "invalid_numeric.sql", Type: TypeGo}, - ) + check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext") + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, + UpFn: func(*sql.Tx) error { return nil }, + UpFnNoTx: func(*sql.DB) error { return nil }, + }) + check.HasError(t, err) + check.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx") + err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, + DownFn: func(*sql.Tx) error { return nil }, + DownFnNoTx: func(*sql.DB) error { return nil }, + }) check.HasError(t, err) - check.Contains(t, err.Error(), `invalid go migration: failed to parse version from migration file: invalid_numeric.sql`) + check.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx") } diff --git a/migrate.go b/migrate.go index 599810ce8..22769ffc5 100644 --- a/migrate.go +++ b/migrate.go @@ -8,7 +8,6 @@ import ( "io/fs" "math" "path" - "runtime" "sort" "strings" "time" @@ -125,115 +124,6 @@ func (ms Migrations) String() string { return str } -// GoMigration is a Go migration func that is run within a transaction. -type GoMigration func(tx *sql.Tx) error - -// GoMigrationContext is a Go migration func that is run within a transaction and receives a context. -type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error - -// GoMigrationNoTx is a Go migration func that is run outside a transaction. -type GoMigrationNoTx func(db *sql.DB) error - -// GoMigrationNoTxContext is a Go migration func that is run outside a transaction and receives a context. -type GoMigrationNoTxContext func(ctx context.Context, db *sql.DB) error - -// AddMigration adds Go migrations. -// -// Deprecated: Use AddMigrationContext. -func AddMigration(up, down GoMigration) { - _, filename, _, _ := runtime.Caller(1) - AddNamedMigrationContext(filename, withContext(up), withContext(down)) -} - -// AddMigrationContext adds Go migrations. -func AddMigrationContext(up, down GoMigrationContext) { - _, filename, _, _ := runtime.Caller(1) - AddNamedMigrationContext(filename, up, down) -} - -// AddNamedMigration adds named Go migrations. -// -// Deprecated: Use AddNamedMigrationContext. -func AddNamedMigration(filename string, up, down GoMigration) { - AddNamedMigrationContext(filename, withContext(up), withContext(down)) -} - -// AddNamedMigrationContext adds named Go migrations. -func AddNamedMigrationContext(filename string, up, down GoMigrationContext) { - if err := register(filename, true, up, down, nil, nil); err != nil { - panic(err) - } -} - -// AddMigrationNoTx adds Go migrations that will be run outside transaction. -// -// Deprecated: Use AddNamedMigrationNoTxContext. -func AddMigrationNoTx(up, down GoMigrationNoTx) { - _, filename, _, _ := runtime.Caller(1) - AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down)) -} - -// AddMigrationNoTxContext adds Go migrations that will be run outside transaction. -func AddMigrationNoTxContext(up, down GoMigrationNoTxContext) { - _, filename, _, _ := runtime.Caller(1) - AddNamedMigrationNoTxContext(filename, up, down) -} - -// AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction. -// -// Deprecated: Use AddNamedMigrationNoTxContext. -func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) { - AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down)) -} - -// AddNamedMigrationNoTxContext adds named Go migrations that will be run outside transaction. -func AddNamedMigrationNoTxContext(filename string, up, down GoMigrationNoTxContext) { - if err := register(filename, false, nil, nil, up, down); err != nil { - panic(err) - } -} - -func register( - filename string, - useTx bool, - up, down GoMigrationContext, - upNoTx, downNoTx GoMigrationNoTxContext, -) error { - // Sanity check caller did not mix tx and non-tx based functions. - if (up != nil || down != nil) && (upNoTx != nil || downNoTx != nil) { - return fmt.Errorf("cannot mix tx and non-tx based go migrations functions") - } - v, _ := NumericComponent(filename) - if existing, ok := registeredGoMigrations[v]; ok { - return fmt.Errorf("failed to add migration %q: version %d conflicts with %q", - filename, - v, - existing.Source, - ) - } - // Add to global as a registered migration. - registeredGoMigrations[v] = &Migration{ - Version: v, - Next: -1, - Previous: -1, - Registered: true, - Source: filename, - UseTx: useTx, - UpFnContext: up, - DownFnContext: down, - UpFnNoTxContext: upNoTx, - DownFnNoTxContext: downNoTx, - // These are deprecated and will be removed in the future. - // For backwards compatibility we still save the non-context versions in the struct in case someone is using them. - // Goose does not use these internally anymore and instead uses the context versions. - UpFn: withoutContext(up), - DownFn: withoutContext(down), - UpFnNoTx: withoutContext(upNoTx), - DownFnNoTx: withoutContext(downNoTx), - } - return nil -} - func collectMigrationsFS( fsys fs.FS, dirpath string, @@ -388,29 +278,6 @@ func GetDBVersionContext(ctx context.Context, db *sql.DB) (int64, error) { return version, nil } -// withContext changes the signature of a function that receives one argument to receive a context and the argument. -func withContext[T any](fn func(T) error) func(context.Context, T) error { - if fn == nil { - return nil - } - - return func(ctx context.Context, t T) error { - return fn(t) - } -} - -// withoutContext changes the signature of a function that receives a context and one argument to receive only the argument. -// When called the passed context is always context.Background(). -func withoutContext[T any](fn func(context.Context, T) error) func(T) error { - if fn == nil { - return nil - } - - return func(t T) error { - return fn(context.Background(), t) - } -} - // collectGoMigrations collects Go migrations from the filesystem and merges them with registered // migrations. // diff --git a/migration.go b/migration.go index c8a74b23c..e59599015 100644 --- a/migration.go +++ b/migration.go @@ -23,16 +23,43 @@ func NewGoMigration(version int64, up, down *GoFunc) Migration { Type: TypeGo, Registered: true, Version: version, - Next: -1, - Previous: -1, - goUp: &GoFunc{Mode: TransactionEnabled}, - goDown: &GoFunc{Mode: TransactionEnabled}, + Next: -1, Previous: -1, + goUp: up, + goDown: down, + construct: true, } + // To maintain backwards compatibility, we set ALL legacy functions. In a future major version, + // we will remove these fields in favor of [GoFunc]. + // + // Note, this function does not do any validation. Validation is lazily done when the migration + // is registered. if up != nil { - m.goUp = up + if up.RunDB != nil { + m.UpFnNoTxContext = up.RunDB // func(context.Context, *sql.DB) error + m.UpFnNoTx = func(db *sql.DB) error { return up.RunDB(context.Background(), db) } + } + if up.RunTx != nil { + m.UseTx = true + m.UpFnContext = up.RunTx // func(context.Context, *sql.Tx) error + m.UpFn = func(tx *sql.Tx) error { return up.RunTx(context.Background(), tx) } + } } if down != nil { - m.goDown = down + if down.RunDB != nil { + m.DownFnNoTxContext = down.RunDB // func(context.Context, *sql.DB) error + m.DownFnNoTx = withoutContext(down.RunDB) // func(*sql.DB) error + } + if down.RunTx != nil { + m.UseTx = true + m.DownFnContext = down.RunTx // func(context.Context, *sql.Tx) error + m.DownFn = withoutContext(down.RunTx) // func(*sql.Tx) error + } + } + if m.goUp == nil { + m.goUp = &GoFunc{Mode: TransactionEnabled} + } + if m.goDown == nil { + m.goDown = &GoFunc{Mode: TransactionEnabled} } return m } @@ -51,6 +78,7 @@ type Migration struct { UpFnNoTxContext, DownFnNoTxContext GoMigrationNoTxContext // These fields are used internally by goose and users are not expected to set them. Instead, // use [NewGoMigration] to create a new go migration. + construct bool goUp, goDown *GoFunc // These fields will be removed in a future major version. They are here for backwards @@ -332,115 +360,3 @@ func truncateDuration(d time.Duration) time.Duration { } return d } - -func verifyAndUpdateGoFunc(f *GoFunc) error { - if f == nil { - return nil - } - if f.RunTx != nil && f.RunDB != nil { - return errors.New("must specify exactly one of RunTx or RunDB") - } - if f.RunTx == nil && f.RunDB == nil { - switch f.Mode { - case 0: - // Default to TransactionEnabled ONLY if mode is not set explicitly. - f.Mode = TransactionEnabled - case TransactionEnabled, TransactionDisabled: - // No functions but mode is set. This is not an error. It means the user wants to record - // a version with the given mode but not run any functions. - default: - return fmt.Errorf("invalid mode: %d", f.Mode) - } - return nil - } - if f.RunDB != nil { - switch f.Mode { - case 0, TransactionDisabled: - f.Mode = TransactionDisabled - default: - return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set") - } - } - if f.RunTx != nil { - switch f.Mode { - case 0, TransactionEnabled: - f.Mode = TransactionEnabled - default: - return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set") - } - } - // This is a defensive check. If the mode is still 0, it means we failed to infer the mode from - // the functions or return an error. This should never happen. - if f.Mode == 0 { - return errors.New("failed to infer transaction mode") - } - return nil -} - -func updateLegacyFuncs(m *Migration) error { - // Assign the context-aware functions to the legacy functions. This is an implementation detail - // and will be removed in a future major version. Users are encouraged to use [NewGoMigration] - // instead of constructing a Migration struct directly. - if up := m.goUp; up != nil { - if up.RunTx != nil { - m.UpFnContext = up.RunTx - m.UseTx = true - } - if up.RunDB != nil { - m.UpFnNoTxContext = up.RunDB - } - } - if down := m.goDown; down != nil { - if down.RunTx != nil { - m.DownFnContext = down.RunTx - m.UseTx = true - } - if down.RunDB != nil { - m.DownFnNoTxContext = down.RunDB - } - } - if m.UpFnContext != nil && m.UpFnNoTxContext != nil { - return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext") - } - if m.DownFnContext != nil && m.DownFnNoTxContext != nil { - return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext") - } - // Do not allow legacy functions to be set. - if m.UpFn != nil { - return errors.New("must not specify UpFn") - } - if m.DownFn != nil { - return errors.New("must not specify DownFn") - } - if m.UpFnNoTx != nil { - return errors.New("must not specify UpFnNoTx") - } - if m.DownFnNoTx != nil { - return errors.New("must not specify DownFnNoTx") - } - return nil -} - -func validGoMigration(m *Migration) error { - if !m.Registered { - return errors.New("must be registered") - } - if m.Type != TypeGo { - return fmt.Errorf("type must be %q", TypeGo) - } - if m.Version < 1 { - return errors.New("version must be greater than zero") - } - if m.Source != "" { - // If the source is set, expect it to be a path with a numeric component that matches the - // version. This field is not intended to be used for descriptive purposes. - version, err := NumericComponent(m.Source) - if err != nil { - return err - } - if version != m.Version { - return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source) - } - } - return nil -} diff --git a/register.go b/register.go new file mode 100644 index 000000000..d907e5973 --- /dev/null +++ b/register.go @@ -0,0 +1,132 @@ +package goose + +import ( + "context" + "database/sql" + "fmt" + "runtime" +) + +// GoMigrationContext is a Go migration func that is run within a transaction and receives a +// context. +type GoMigrationContext func(ctx context.Context, tx *sql.Tx) error + +// AddMigrationContext adds Go migrations. +func AddMigrationContext(up, down GoMigrationContext) { + _, filename, _, _ := runtime.Caller(1) + AddNamedMigrationContext(filename, up, down) +} + +// AddNamedMigrationContext adds named Go migrations. +func AddNamedMigrationContext(filename string, up, down GoMigrationContext) { + if err := register( + filename, + true, + &GoFunc{RunTx: up}, + &GoFunc{RunTx: down}, + ); err != nil { + panic(err) + } +} + +// GoMigrationNoTxContext is a Go migration func that is run outside a transaction and receives a +// context. +type GoMigrationNoTxContext func(ctx context.Context, db *sql.DB) error + +// AddMigrationNoTxContext adds Go migrations that will be run outside transaction. +func AddMigrationNoTxContext(up, down GoMigrationNoTxContext) { + _, filename, _, _ := runtime.Caller(1) + AddNamedMigrationNoTxContext(filename, up, down) +} + +// AddNamedMigrationNoTxContext adds named Go migrations that will be run outside transaction. +func AddNamedMigrationNoTxContext(filename string, up, down GoMigrationNoTxContext) { + if err := register( + filename, + false, + &GoFunc{RunDB: up}, + &GoFunc{RunDB: down}, + ); err != nil { + panic(err) + } +} + +func register(filename string, useTx bool, up, down *GoFunc) error { + v, _ := NumericComponent(filename) + if existing, ok := registeredGoMigrations[v]; ok { + return fmt.Errorf("failed to add migration %q: version %d conflicts with %q", + filename, + v, + existing.Source, + ) + } + // Add to global as a registered migration. + m := NewGoMigration(v, up, down) + m.Source = filename + m.UseTx = useTx + registeredGoMigrations[v] = &m + return nil +} + +// withContext changes the signature of a function that receives one argument to receive a context +// and the argument. +func withContext[T any](fn func(T) error) func(context.Context, T) error { + if fn == nil { + return nil + } + + return func(ctx context.Context, t T) error { + return fn(t) + } +} + +// withoutContext changes the signature of a function that receives a context and one argument to +// receive only the argument. When called the passed context is always context.Background(). +func withoutContext[T any](fn func(context.Context, T) error) func(T) error { + if fn == nil { + return nil + } + return func(t T) error { + return fn(context.Background(), t) + } +} + +// GoMigration is a Go migration func that is run within a transaction. +// +// Deprecated: Use GoMigrationContext. +type GoMigration func(tx *sql.Tx) error + +// GoMigrationNoTx is a Go migration func that is run outside a transaction. +// +// Deprecated: Use GoMigrationNoTxContext. +type GoMigrationNoTx func(db *sql.DB) error + +// AddMigration adds Go migrations. +// +// Deprecated: Use AddMigrationContext. +func AddMigration(up, down GoMigration) { + _, filename, _, _ := runtime.Caller(1) + AddNamedMigrationContext(filename, withContext(up), withContext(down)) +} + +// AddNamedMigration adds named Go migrations. +// +// Deprecated: Use AddNamedMigrationContext. +func AddNamedMigration(filename string, up, down GoMigration) { + AddNamedMigrationContext(filename, withContext(up), withContext(down)) +} + +// AddMigrationNoTx adds Go migrations that will be run outside transaction. +// +// Deprecated: Use AddMigrationNoTxContext. +func AddMigrationNoTx(up, down GoMigrationNoTx) { + _, filename, _, _ := runtime.Caller(1) + AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down)) +} + +// AddNamedMigrationNoTx adds named Go migrations that will be run outside transaction. +// +// Deprecated: Use AddNamedMigrationNoTxContext. +func AddNamedMigrationNoTx(filename string, up, down GoMigrationNoTx) { + AddNamedMigrationNoTxContext(filename, withContext(up), withContext(down)) +}