From 335cc146c06a8ab56ce213cd59d42f941b00ce8d Mon Sep 17 00:00:00 2001 From: Nolan Blankenau Date: Fri, 7 Nov 2025 11:03:49 -0800 Subject: [PATCH 1/2] Make migrate safe --- database/rqlite/rqlite_test.go | 96 ------------------- database/testing/migrate_testing.go | 12 --- internal/cli/commands.go | 29 ++---- internal/cli/commands_test.go | 38 -------- internal/cli/main.go | 60 +----------- migrate.go | 32 ------- migrate_test.go | 141 ---------------------------- 7 files changed, 11 insertions(+), 397 deletions(-) diff --git a/database/rqlite/rqlite_test.go b/database/rqlite/rqlite_test.go index c19f7476b..5cc6e80e5 100644 --- a/database/rqlite/rqlite_test.go +++ b/database/rqlite/rqlite_test.go @@ -156,102 +156,6 @@ func TestNoConfig(t *testing.T) { }) } -func TestWithInstanceEmptyConfig(t *testing.T) { - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { - ip, port, err := c.Port(defaultPort) - assert.NoError(t, err) - - // gorqlite expects http(s) schemes - connectString := fmt.Sprintf("http://%s:%s?level=strong&disableClusterDiscovery=true", ip, port) - t.Logf("DB connect string : %s\n", connectString) - db, err := gorqlite.Open(connectString) - assert.NoError(t, err) - - driver, err := WithInstance(db, &Config{}) - assert.NoError(t, err) - - defer func() { - if err := driver.Close(); err != nil { - t.Fatal(err) - } - }() - - m, err := migrate.NewWithDatabaseInstance( - "file://./examples/migrations", - "ql", driver) - assert.NoError(t, err) - - t.Log("UP") - err = m.Up() - assert.NoError(t, err) - - _, err = db.QueryOne(fmt.Sprintf("SELECT * FROM %s", DefaultMigrationsTable)) - assert.NoError(t, err) - - t.Log("DOWN") - err = m.Down() - assert.NoError(t, err) - }) -} - -func TestMigrationTable(t *testing.T) { - dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { - ip, port, err := c.Port(defaultPort) - assert.NoError(t, err) - - // gorqlite expects http(s) schemes - connectString := fmt.Sprintf("http://%s:%s?level=strong&disableClusterDiscovery=true", ip, port) - t.Logf("DB connect string : %s\n", connectString) - db, err := gorqlite.Open(connectString) - assert.NoError(t, err) - - config := Config{MigrationsTable: "my_migration_table"} - driver, err := WithInstance(db, &config) - assert.NoError(t, err) - - defer func() { - if err := driver.Close(); err != nil { - t.Fatal(err) - } - }() - - m, err := migrate.NewWithDatabaseInstance( - "file://./examples/migrations", - "ql", driver) - assert.NoError(t, err) - - t.Log("UP") - err = m.Up() - assert.NoError(t, err) - - _, err = db.QueryOne(fmt.Sprintf("SELECT * FROM %s", config.MigrationsTable)) - assert.NoError(t, err) - - _, err = db.WriteOne(`INSERT INTO pets (name, predator) VALUES ("franklin", true)`) - assert.NoError(t, err) - - res, err := db.QueryOne(`SELECT name, predator FROM pets LIMIT 1`) - assert.NoError(t, err) - - _ = res.Next() - - // make sure we can use the migrated table - var petName string - var petPredator int - err = res.Scan(&petName, &petPredator) - assert.NoError(t, err) - assert.Equal(t, petName, "franklin") - assert.Equal(t, petPredator, 1) - - t.Log("DOWN") - err = m.Down() - assert.NoError(t, err) - - _, err = db.QueryOne(fmt.Sprintf("SELECT * FROM %s", config.MigrationsTable)) - assert.NoError(t, err) - }) -} - func TestParseUrl(t *testing.T) { tests := []struct { name string diff --git a/database/testing/migrate_testing.go b/database/testing/migrate_testing.go index be8ed195f..870b41921 100644 --- a/database/testing/migrate_testing.go +++ b/database/testing/migrate_testing.go @@ -5,25 +5,13 @@ package testing import ( "testing" -) -import ( "github.com/golang-migrate/migrate/v4" ) // TestMigrate runs integration-tests between the Migrate layer and database implementations. func TestMigrate(t *testing.T, m *migrate.Migrate) { TestMigrateUp(t, m) - TestMigrateDrop(t, m) -} - -// Regression test for preventing a regression for #164 https://github.com/golang-migrate/migrate/pull/173 -// Similar to TestDrop(), but tests the dropping mechanism through the Migrate logic instead, to check for -// double-locking during the Drop logic. -func TestMigrateDrop(t *testing.T, m *migrate.Migrate) { - if err := m.Drop(); err != nil { - t.Fatal(err) - } } func TestMigrateUp(t *testing.T, m *migrate.Migrate) { diff --git a/internal/cli/commands.go b/internal/cli/commands.go index 7adec2f84..a6d499d04 100644 --- a/internal/cli/commands.go +++ b/internal/cli/commands.go @@ -185,19 +185,7 @@ func downCmd(m *migrate.Migrate, limit int) error { log.Println(err) } } else { - if err := m.Down(); err != nil { - if err != migrate.ErrNoChange { - return err - } - log.Println(err) - } - } - return nil -} - -func dropCmd(m *migrate.Migrate) error { - if err := m.Drop(); err != nil { - return err + log.Println("With migrate-safe you can not apply all down migrations. Please specify a number of migrations to apply.") } return nil } @@ -223,26 +211,25 @@ func versionCmd(m *migrate.Migrate) error { } // numDownMigrationsFromArgs returns an int for number of migrations to apply -// and a bool indicating if we need a confirm before applying -func numDownMigrationsFromArgs(applyAll bool, args []string) (int, bool, error) { +func numDownMigrationsFromArgs(applyAll bool, args []string) (int, error) { if applyAll { if len(args) > 0 { - return 0, false, errors.New("-all cannot be used with other arguments") + return 0, errors.New("-all cannot be used with other arguments") } - return -1, false, nil + return -1, nil } switch len(args) { case 0: - return -1, true, nil + return -1, nil case 1: downValue := args[0] n, err := strconv.ParseUint(downValue, 10, 64) if err != nil { - return 0, false, errors.New("can't read limit argument N") + return 0, errors.New("can't read limit argument N") } - return int(n), false, nil + return int(n), nil default: - return 0, false, errors.New("too many arguments") + return 0, errors.New("too many arguments") } } diff --git a/internal/cli/commands_test.go b/internal/cli/commands_test.go index e89a690f5..df97a7634 100644 --- a/internal/cli/commands_test.go +++ b/internal/cli/commands_test.go @@ -252,41 +252,3 @@ func (s *CreateCmdSuite) TestCreateCmd() { }) } } - -func TestNumDownFromArgs(t *testing.T) { - cases := []struct { - name string - args []string - applyAll bool - expectedNeedConfirm bool - expectedNum int - expectedErrStr string - }{ - {"no args", []string{}, false, true, -1, ""}, - {"down all", []string{}, true, false, -1, ""}, - {"down 5", []string{"5"}, false, false, 5, ""}, - {"down N", []string{"N"}, false, false, 0, "can't read limit argument N"}, - {"extra arg after -all", []string{"5"}, true, false, 0, "-all cannot be used with other arguments"}, - {"extra arg before -all", []string{"5", "-all"}, false, false, 0, "too many arguments"}, - } - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - num, needsConfirm, err := numDownMigrationsFromArgs(c.applyAll, c.args) - if needsConfirm != c.expectedNeedConfirm { - t.Errorf("Incorrect needsConfirm was: %v wanted %v", needsConfirm, c.expectedNeedConfirm) - } - - if num != c.expectedNum { - t.Errorf("Incorrect num was: %v wanted %v", num, c.expectedNum) - } - - if err != nil { - if err.Error() != c.expectedErrStr { - t.Error("Incorrect error: " + err.Error() + " != " + c.expectedErrStr) - } - } else if c.expectedErrStr != "" { - t.Error("Expected error: " + c.expectedErrStr + " but got nil instead") - } - }) - } -} diff --git a/internal/cli/main.go b/internal/cli/main.go index c7a3bd74a..47c3cd0ca 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -254,38 +254,7 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU } case "down": - downFlagSet, helpPtr := newFlagSetWithHelp("down") - applyAll := downFlagSet.Bool("all", false, "Apply all down migrations") - - if err := downFlagSet.Parse(args); err != nil { - log.fatalErr(err) - } - - handleSubCmdHelp(*helpPtr, downUsage, downFlagSet) - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - downArgs := downFlagSet.Args() - num, needsConfirm, err := numDownMigrationsFromArgs(*applyAll, downArgs) - if err != nil { - log.fatalErr(err) - } - if needsConfirm { - log.Println("Are you sure you want to apply all down migrations? [y/N]") - var response string - _, _ = fmt.Scanln(&response) - response = strings.ToLower(strings.TrimSpace(response)) - - if response == "y" { - log.Println("Applying all down migrations") - } else { - log.fatal("Not applying all down migrations") - } - } - - if err := downCmd(migrater, num); err != nil { + if err := downCmd(migrater, 1); err != nil { log.fatalErr(err) } @@ -295,7 +264,7 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU case "drop": dropFlagSet, help := newFlagSetWithHelp("drop") - forceDrop := dropFlagSet.Bool("f", false, "Force the drop command by bypassing the confirmation prompt") + _ = dropFlagSet.Bool("f", false, "Force the drop command by bypassing the confirmation prompt") if err := dropFlagSet.Parse(args); err != nil { log.fatalErr(err) @@ -303,30 +272,7 @@ Database drivers: `+strings.Join(database.List(), ", ")+"\n", createUsage, gotoU handleSubCmdHelp(*help, dropUsage, dropFlagSet) - if !*forceDrop { - log.Println("Are you sure you want to drop the entire database schema? [y/N]") - var response string - _, _ = fmt.Scanln(&response) - response = strings.ToLower(strings.TrimSpace(response)) - - if response == "y" { - log.Println("Dropping the entire database schema") - } else { - log.fatal("Aborted dropping the entire database schema") - } - } - - if migraterErr != nil { - log.fatalErr(migraterErr) - } - - if err := dropCmd(migrater); err != nil { - log.fatalErr(err) - } - - if log.verbose { - log.Println("Finished after", time.Since(startTime)) - } + log.fatal("With migrate-safe you can not drop the entire database schema. Please specify a number of migrations to apply.") case "force": forceSet, helpPtr := newFlagSetWithHelp("force") diff --git a/migrate.go b/migrate.go index 266cc04eb..636c27bac 100644 --- a/migrate.go +++ b/migrate.go @@ -284,38 +284,6 @@ func (m *Migrate) Up() error { return m.unlockErr(m.runMigrations(ret)) } -// Down looks at the currently active migration version -// and will migrate all the way down (applying all down migrations). -func (m *Migrate) Down() error { - if err := m.lock(); err != nil { - return err - } - - curVersion, dirty, err := m.databaseDrv.Version() - if err != nil { - return m.unlockErr(err) - } - - if dirty { - return m.unlockErr(ErrDirty{curVersion}) - } - - ret := make(chan interface{}, m.PrefetchMigrations) - go m.readDown(curVersion, -1, ret) - return m.unlockErr(m.runMigrations(ret)) -} - -// Drop deletes everything in the database. -func (m *Migrate) Drop() error { - if err := m.lock(); err != nil { - return err - } - if err := m.databaseDrv.Drop(); err != nil { - return m.unlockErr(err) - } - return m.unlock() -} - // Run runs any migration provided by you against the database. // It does not check any currently active version in database. // Usually you don't need this function at all. Use Migrate, diff --git a/migrate_test.go b/migrate_test.go index f2728179e..c17c4a4fa 100644 --- a/migrate_test.go +++ b/migrate_test.go @@ -764,120 +764,6 @@ func TestStepsDirty(t *testing.T) { } } -func TestUpAndDown(t *testing.T) { - m, _ := New("stub://", "stub://") - m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations - dbDrv := m.databaseDrv.(*dStub.Stub) - - // go Up first - if err := m.Up(); err != nil { - t.Fatal(err) - } - expectedSequence := migrationSequence{ - mr("CREATE 1"), - mr("CREATE 3"), - mr("CREATE 4"), - mr("CREATE 7"), - } - equalDbSeq(t, 0, expectedSequence, dbDrv) - - // go Down - if err := m.Down(); err != nil { - t.Fatal(err) - } - expectedSequence = migrationSequence{ - mr("CREATE 1"), - mr("CREATE 3"), - mr("CREATE 4"), - mr("CREATE 7"), - mr("DROP 7"), - mr("DROP 5"), - mr("DROP 4"), - mr("DROP 1"), - } - equalDbSeq(t, 1, expectedSequence, dbDrv) - - // go 1 Up and then all the way Up - if err := m.Steps(1); err != nil { - t.Fatal(err) - } - expectedSequence = migrationSequence{ - mr("CREATE 1"), - mr("CREATE 3"), - mr("CREATE 4"), - mr("CREATE 7"), - mr("DROP 7"), - mr("DROP 5"), - mr("DROP 4"), - mr("DROP 1"), - mr("CREATE 1"), - } - equalDbSeq(t, 2, expectedSequence, dbDrv) - - if err := m.Up(); err != nil { - t.Fatal(err) - } - expectedSequence = migrationSequence{ - mr("CREATE 1"), - mr("CREATE 3"), - mr("CREATE 4"), - mr("CREATE 7"), - mr("DROP 7"), - mr("DROP 5"), - mr("DROP 4"), - mr("DROP 1"), - mr("CREATE 1"), - mr("CREATE 3"), - mr("CREATE 4"), - mr("CREATE 7"), - } - equalDbSeq(t, 3, expectedSequence, dbDrv) - - // go 1 Down and then all the way Down - if err := m.Steps(-1); err != nil { - t.Fatal(err) - } - expectedSequence = migrationSequence{ - mr("CREATE 1"), - mr("CREATE 3"), - mr("CREATE 4"), - mr("CREATE 7"), - mr("DROP 7"), - mr("DROP 5"), - mr("DROP 4"), - mr("DROP 1"), - mr("CREATE 1"), - mr("CREATE 3"), - mr("CREATE 4"), - mr("CREATE 7"), - mr("DROP 7"), - } - equalDbSeq(t, 1, expectedSequence, dbDrv) - - if err := m.Down(); err != nil { - t.Fatal(err) - } - expectedSequence = migrationSequence{ - mr("CREATE 1"), - mr("CREATE 3"), - mr("CREATE 4"), - mr("CREATE 7"), - mr("DROP 7"), - mr("DROP 5"), - mr("DROP 4"), - mr("DROP 1"), - mr("CREATE 1"), - mr("CREATE 3"), - mr("CREATE 4"), - mr("CREATE 7"), - mr("DROP 7"), - mr("DROP 5"), - mr("DROP 4"), - mr("DROP 1"), - } - equalDbSeq(t, 1, expectedSequence, dbDrv) -} - func TestUpDirty(t *testing.T) { m, _ := New("stub://", "stub://") dbDrv := m.databaseDrv.(*dStub.Stub) @@ -891,33 +777,6 @@ func TestUpDirty(t *testing.T) { } } -func TestDownDirty(t *testing.T) { - m, _ := New("stub://", "stub://") - dbDrv := m.databaseDrv.(*dStub.Stub) - if err := dbDrv.SetVersion(0, true); err != nil { - t.Fatal(err) - } - - err := m.Down() - if _, ok := err.(ErrDirty); !ok { - t.Fatalf("expected ErrDirty, got %v", err) - } -} - -func TestDrop(t *testing.T) { - m, _ := New("stub://", "stub://") - m.sourceDrv.(*sStub.Stub).Migrations = sourceStubMigrations - dbDrv := m.databaseDrv.(*dStub.Stub) - - if err := m.Drop(); err != nil { - t.Fatal(err) - } - - if dbDrv.MigrationSequence[len(dbDrv.MigrationSequence)-1] != dStub.DROP { - t.Fatalf("expected database to DROP, got sequence %v", dbDrv.MigrationSequence) - } -} - func TestVersion(t *testing.T) { m, _ := New("stub://", "stub://") dbDrv := m.databaseDrv.(*dStub.Stub) From 740c87b44934271082abaacff2bdbcc1c198442c Mon Sep 17 00:00:00 2001 From: Nolan Blankenau Date: Fri, 7 Nov 2025 11:04:31 -0800 Subject: [PATCH 2/2] removed unused fun --- internal/cli/commands.go | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/internal/cli/commands.go b/internal/cli/commands.go index a6d499d04..edf73e4a5 100644 --- a/internal/cli/commands.go +++ b/internal/cli/commands.go @@ -209,27 +209,3 @@ func versionCmd(m *migrate.Migrate) error { } return nil } - -// numDownMigrationsFromArgs returns an int for number of migrations to apply -func numDownMigrationsFromArgs(applyAll bool, args []string) (int, error) { - if applyAll { - if len(args) > 0 { - return 0, errors.New("-all cannot be used with other arguments") - } - return -1, nil - } - - switch len(args) { - case 0: - return -1, nil - case 1: - downValue := args[0] - n, err := strconv.ParseUint(downValue, 10, 64) - if err != nil { - return 0, errors.New("can't read limit argument N") - } - return int(n), nil - default: - return 0, errors.New("too many arguments") - } -}