Skip to content

Commit

Permalink
limit global override with NewGoMigration constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman committed Nov 3, 2023
1 parent c6f7847 commit f944a3d
Show file tree
Hide file tree
Showing 5 changed files with 346 additions and 348 deletions.
128 changes: 108 additions & 20 deletions globals.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package goose

import (
"errors"
"fmt"
"path/filepath"
)

var (
Expand All @@ -16,34 +18,120 @@ func ResetGlobalMigrations() {
}

// SetGlobalMigrations registers Go migrations globally. It returns an error if a migration with the
// same version has already been registered.
//
// Avoid constructing migrations manually, use [NewGoMigration] function.
//
// Source may be empty, but if it is set, it must be a path with a numeric component that matches
// the version. Do not register legacy non-context functions: UpFn, DownFn, UpFnNoTx, DownFnNoTx.
// same version has already been registered. Go migrations must be constructed using the
// [NewGoMigration] function.
//
// Not safe for concurrent use.
func SetGlobalMigrations(migrations ...Migration) error {
for _, m := range migrations {
migration := &m
if err := validGoMigration(migration); err != nil {
return fmt.Errorf("invalid go migration: %w", err)
for _, migration := range migrations {
m := &migration
if err := setGoMigration(m); err != nil {
return err
}
registeredGoMigrations[m.Version] = m
}
return nil
}

func setGoMigration(m *Migration) error {
if _, ok := registeredGoMigrations[m.Version]; ok {
return fmt.Errorf("go migration with version %d already registered", m.Version)
}
if err := checkMigration(m); err != nil {
return fmt.Errorf("invalid go migration: %w", err)
}
return nil
}

func checkMigration(m *Migration) error {
if !m.construct {
return errors.New("must use NewGoMigration to construct migrations")
}
if !m.Registered {
return errors.New("must be registered")
}
if m.Type != TypeGo {
return fmt.Errorf("type must be %q", TypeGo)
}
if m.Version < 1 {
return errors.New("version must be greater than zero")
}
if m.Source != "" {
if filepath.Ext(m.Source) != ".go" {
return fmt.Errorf("source must have .go extension: %q", m.Source)
}
if err := verifyAndUpdateGoFunc(migration.goUp); err != nil {
return fmt.Errorf("up function: %w", err)
// If the source is set, expect it to be a path with a numeric component that matches the
// version. This field is not intended to be used for descriptive purposes.
version, err := NumericComponent(m.Source)
if err != nil {
return fmt.Errorf("invalid source: %w", err)
}
if err := verifyAndUpdateGoFunc(migration.goDown); err != nil {
return fmt.Errorf("down function: %w", err)
if version != m.Version {
return fmt.Errorf("version:%d does not match numeric component in source %q", m.Version, m.Source)
}
if err := updateLegacyFuncs(migration); err != nil {
return fmt.Errorf("invalid go migration: %w", err)
}
if err := setGoFunc(m.goUp); err != nil {
return fmt.Errorf("up function: %w", err)
}
if err := setGoFunc(m.goDown); err != nil {
return fmt.Errorf("down function: %w", err)
}
if m.UpFnContext != nil && m.UpFnNoTxContext != nil {
return errors.New("must specify exactly one of UpFnContext or UpFnNoTxContext")
}
if m.UpFn != nil && m.UpFnNoTx != nil {
return errors.New("must specify exactly one of UpFn or UpFnNoTx")
}
if m.DownFnContext != nil && m.DownFnNoTxContext != nil {
return errors.New("must specify exactly one of DownFnContext or DownFnNoTxContext")
}
if m.DownFn != nil && m.DownFnNoTx != nil {
return errors.New("must specify exactly one of DownFn or DownFnNoTx")
}
return nil
}

func setGoFunc(f *GoFunc) error {
if f == nil {
f = &GoFunc{Mode: TransactionEnabled}
return nil
}
if f.RunTx != nil && f.RunDB != nil {
return errors.New("must specify exactly one of RunTx or RunDB")
}
if f.RunTx == nil && f.RunDB == nil {
switch f.Mode {
case 0:
// Default to TransactionEnabled ONLY if mode is not set explicitly.
f.Mode = TransactionEnabled
case TransactionEnabled, TransactionDisabled:
// No functions but mode is set. This is not an error. It means the user wants to record
// a version with the given mode but not run any functions.
default:
return fmt.Errorf("invalid mode: %d", f.Mode)
}
if _, ok := registeredGoMigrations[m.Version]; ok {
return fmt.Errorf("go migration with version %d already registered", m.Version)
return nil
}
if f.RunDB != nil {
switch f.Mode {
case 0, TransactionDisabled:
f.Mode = TransactionDisabled
default:
return fmt.Errorf("transaction mode must be disabled or unspecified when RunDB is set")
}
}
if f.RunTx != nil {
switch f.Mode {
case 0, TransactionEnabled:
f.Mode = TransactionEnabled
default:
return fmt.Errorf("transaction mode must be enabled or unspecified when RunTx is set")
}
m.Next, m.Previous = -1, -1 // Do not allow these to be set by the user.
registeredGoMigrations[m.Version] = migration
}
// This is a defensive check. If the mode is still 0, it means we failed to infer the mode from
// the functions or return an error. This should never happen.
if f.Mode == 0 {
return errors.New("failed to infer transaction mode")
}
return nil
}
149 changes: 72 additions & 77 deletions globals_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@ func TestNewGoMigration(t *testing.T) {
check.Equal(t, m.goUp.Mode, TransactionEnabled)
check.Equal(t, m.goDown.Mode, TransactionEnabled)
})
t.Run("all_set", func(t *testing.T) {
// This will eventually be an error when registering migrations.
m := NewGoMigration(
1,
&GoFunc{RunTx: func(context.Context, *sql.Tx) error { return nil }, RunDB: func(context.Context, *sql.DB) error { return nil }},
&GoFunc{RunTx: func(context.Context, *sql.Tx) error { return nil }, RunDB: func(context.Context, *sql.DB) error { return nil }},
)
// check only functions
check.Bool(t, m.UpFn != nil, true)
check.Bool(t, m.UpFnContext != nil, true)
check.Bool(t, m.UpFnNoTx != nil, true)
check.Bool(t, m.UpFnNoTxContext != nil, true)
check.Bool(t, m.DownFn != nil, true)
check.Bool(t, m.DownFnContext != nil, true)
check.Bool(t, m.DownFnNoTx != nil, true)
check.Bool(t, m.DownFnNoTxContext != nil, true)
})
}

func TestTransactionMode(t *testing.T) {
Expand Down Expand Up @@ -121,6 +138,7 @@ func TestLegacyFunctions(t *testing.T) {
}

t.Run("all_tx", func(t *testing.T) {
t.Cleanup(ResetGlobalMigrations)
err := SetGlobalMigrations(
NewGoMigration(1, &GoFunc{RunTx: runTx}, &GoFunc{RunTx: runTx}),
)
Expand All @@ -137,17 +155,18 @@ func TestLegacyFunctions(t *testing.T) {
check.Bool(t, m.goDown == nil, false)
check.Bool(t, m.DownFnContext == nil, false)
// Always nil
check.Bool(t, m.UpFn == nil, true)
check.Bool(t, m.DownFn == nil, true)
check.Bool(t, m.UpFn == nil, false)
check.Bool(t, m.DownFn == nil, false)
check.Bool(t, m.UpFnNoTx == nil, true)
check.Bool(t, m.DownFnNoTx == nil, true)
})
t.Run("all_db", func(t *testing.T) {
t.Cleanup(ResetGlobalMigrations)
err := SetGlobalMigrations(
NewGoMigration(2, &GoFunc{RunDB: runDB}, &GoFunc{RunDB: runDB}),
)
check.NoError(t, err)
check.Number(t, len(registeredGoMigrations), 2)
check.Number(t, len(registeredGoMigrations), 1)
m := registeredGoMigrations[2]
assertMigration(t, m, 2)
// Legacy functions.
Expand All @@ -161,15 +180,15 @@ func TestLegacyFunctions(t *testing.T) {
// Always nil
check.Bool(t, m.UpFn == nil, true)
check.Bool(t, m.DownFn == nil, true)
check.Bool(t, m.UpFnNoTx == nil, true)
check.Bool(t, m.DownFnNoTx == nil, true)
check.Bool(t, m.UpFnNoTx == nil, false)
check.Bool(t, m.DownFnNoTx == nil, false)
})
}

func TestGlobalRegister(t *testing.T) {
t.Cleanup(ResetGlobalMigrations)

runDB := func(context.Context, *sql.DB) error { return nil }
// runDB := func(context.Context, *sql.DB) error { return nil }
runTx := func(context.Context, *sql.Tx) error { return nil }

// Success.
Expand All @@ -185,87 +204,63 @@ func TestGlobalRegister(t *testing.T) {
)
check.HasError(t, err)
check.Contains(t, err.Error(), "go migration with version 1 already registered")
err = SetGlobalMigrations(
Migration{
Registered: true,
Version: 2,
Source: "00002_foo.sql",
Type: TypeGo,
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
},
)
check.NoError(t, err)
// Reset.
{
ResetGlobalMigrations()
}
// Failure.
err = SetGlobalMigrations(
Migration{},
)
err = SetGlobalMigrations(Migration{Registered: true, Version: 2, Type: TypeGo})
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must be registered")
err = SetGlobalMigrations(
Migration{Registered: true},
)
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
}

func TestCheckMigration(t *testing.T) {
// Failures.
err := checkMigration(&Migration{})
check.HasError(t, err)
check.Contains(t, err.Error(), `invalid go migration: type must be "go"`)
err = SetGlobalMigrations(
Migration{Registered: true, Version: 1, Type: TypeSQL},
)
check.Contains(t, err.Error(), "must use NewGoMigration to construct migrations")
err = checkMigration(&Migration{construct: true})
check.HasError(t, err)
check.Contains(t, err.Error(), `invalid go migration: type must be "go"`)
err = SetGlobalMigrations(
Migration{Registered: true, Version: 0, Type: TypeGo},
)
check.Contains(t, err.Error(), "must be registered")
err = checkMigration(&Migration{construct: true, Registered: true})
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: version must be greater than zero")
err = SetGlobalMigrations(
Migration{Registered: true, Version: 1, Source: "2_foo.sql", Type: TypeGo},
)
check.Contains(t, err.Error(), `type must be "go"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo})
check.HasError(t, err)
check.Contains(
t,
err.Error(),
`invalid go migration: version:1 does not match numeric component in source "2_foo.sql"`,
)
// Legacy functions.
err = SetGlobalMigrations(
Migration{Registered: true, Version: 1, UpFn: func(tx *sql.Tx) error { return nil }, Type: TypeGo},
)
check.Contains(t, err.Error(), "version must be greater than zero")
// Success.
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1})
check.NoError(t, err)
// Failures.
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo"})
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify UpFn")
err = SetGlobalMigrations(
Migration{Registered: true, Version: 1, DownFn: func(tx *sql.Tx) error { return nil }, Type: TypeGo},
)
check.Contains(t, err.Error(), `source must have .go extension: "foo"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1, Source: "foo.go"})
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify DownFn")
err = SetGlobalMigrations(
Migration{Registered: true, Version: 1, UpFnNoTx: func(db *sql.DB) error { return nil }, Type: TypeGo},
)
check.Contains(t, err.Error(), `no filename separator '_' found`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.sql"})
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify UpFnNoTx")
err = SetGlobalMigrations(
Migration{Registered: true, Version: 1, DownFnNoTx: func(db *sql.DB) error { return nil }, Type: TypeGo},
)
check.Contains(t, err.Error(), `source must have .go extension: "00001_foo.sql"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 2, Source: "00001_foo.go"})
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must not specify DownFnNoTx")
// Context-aware functions.
err = SetGlobalMigrations(
Migration{Registered: true, Version: 1, UpFnContext: runTx, UpFnNoTxContext: runDB, Type: TypeGo},
)
check.Contains(t, err.Error(), `version:2 does not match numeric component in source "00001_foo.go"`)
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
UpFnContext: func(context.Context, *sql.Tx) error { return nil },
UpFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
})
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of UpFnContext or UpFnNoTxContext")
err = SetGlobalMigrations(
Migration{Registered: true, Version: 1, DownFnContext: runTx, DownFnNoTxContext: runDB, Type: TypeGo},
)
check.Contains(t, err.Error(), "must specify exactly one of UpFnContext or UpFnNoTxContext")
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
DownFnContext: func(context.Context, *sql.Tx) error { return nil },
DownFnNoTxContext: func(context.Context, *sql.DB) error { return nil },
})
check.HasError(t, err)
check.Contains(t, err.Error(), "invalid go migration: must specify exactly one of DownFnContext or DownFnNoTxContext")
// Source and version mismatch.
err = SetGlobalMigrations(
Migration{Registered: true, Version: 1, Source: "invalid_numeric.sql", Type: TypeGo},
)
check.Contains(t, err.Error(), "must specify exactly one of DownFnContext or DownFnNoTxContext")
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
UpFn: func(*sql.Tx) error { return nil },
UpFnNoTx: func(*sql.DB) error { return nil },
})
check.HasError(t, err)
check.Contains(t, err.Error(), "must specify exactly one of UpFn or UpFnNoTx")
err = checkMigration(&Migration{construct: true, Registered: true, Type: TypeGo, Version: 1,
DownFn: func(*sql.Tx) error { return nil },
DownFnNoTx: func(*sql.DB) error { return nil },
})
check.HasError(t, err)
check.Contains(t, err.Error(), `invalid go migration: failed to parse version from migration file: invalid_numeric.sql`)
check.Contains(t, err.Error(), "must specify exactly one of DownFn or DownFnNoTx")
}
Loading

0 comments on commit f944a3d

Please sign in to comment.