From b54c5d8c5ec358cf6b4f6013edbae5b4ff1a5641 Mon Sep 17 00:00:00 2001 From: Kailash Nadh Date: Mon, 3 Aug 2020 19:07:14 +0530 Subject: [PATCH] Add upgrade file --- upgrade.go | 148 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 upgrade.go diff --git a/upgrade.go b/upgrade.go new file mode 100644 index 0000000..48f525f --- /dev/null +++ b/upgrade.go @@ -0,0 +1,148 @@ +package main + +import ( + "fmt" + "strings" + + "github.com/jmoiron/sqlx" + "github.com/knadh/koanf" + "github.com/knadh/listmonk/internal/migrations" + "github.com/knadh/stuffbin" + "github.com/lib/pq" + "golang.org/x/mod/semver" +) + +// migFunc represents a migration function for a particular version. +// fn (generally) executes database migrations and additionally +// takes the filesystem and config objects in case there are additional bits +// of logic to be performed before executing upgrades. fn is idempotent. +type migFunc struct { + version string + fn func(*sqlx.DB, stuffbin.FileSystem, *koanf.Koanf) error +} + +// migList is the list of available migList ordered by the semver. +// Each migration is a Go file in internal/migrations named after the semver. +// The functions are named as: v0.7.0 => migrations.V0_7_0() and are idempotent. +var migList = []migFunc{ + {"v0.4.0", migrations.V0_4_0}, + {"v0.7.0", migrations.V0_7_0}, +} + +// upgrade upgrades the database to the current version by running SQL migration files +// for all version from the last known version to the current one. +func upgrade(db *sqlx.DB, fs stuffbin.FileSystem, prompt bool) { + if prompt { + var ok string + fmt.Printf("** IMPORTANT: Take a backup of the database before upgrading.\n") + fmt.Print("continue (y/n)? ") + if _, err := fmt.Scanf("%s", &ok); err != nil { + lo.Fatalf("error reading value from terminal: %v", err) + } + if strings.ToLower(ok) != "y" { + fmt.Println("upgrade cancelled") + return + } + } + + _, toRun, err := getPendingMigrations(db) + if err != nil { + lo.Fatalf("error checking migrations: %v", err) + } + + // No migrations to run. + if len(toRun) == 0 { + lo.Printf("no upgrades to run. Database is up-to-date.") + return + } + + // Execute migrations in succession. + for _, m := range toRun { + lo.Printf("running migration %s", m.version) + if err := m.fn(db, fs, ko); err != nil { + lo.Fatalf("error running migration %s: %v", m.version, err) + } + + // Record the migration version in the settings table. There was no + // settings table until v0.7.0, so ignore the no-table errors. + if err := recordMigrationVersion(m.version, db); err != nil { + if isTableNotExistErr(err) { + continue + } + lo.Fatalf("error recording migration version %s: %v", m.version, err) + } + } + + lo.Printf("upgrade complete") +} + +// checkUpgrade checks if the current database schema matches the expected +// binary version. +func checkUpgrade(db *sqlx.DB) { + lastVer, toRun, err := getPendingMigrations(db) + if err != nil { + lo.Fatalf("error checking migrations: %v", err) + } + + // No migrations to run. + if len(toRun) == 0 { + return + } + + var vers []string + for _, m := range toRun { + vers = append(vers, m.version) + } + + lo.Fatalf(`there are %d pending database upgrade(s): %v. The last upgrade was %s. Backup the database and run listmonk --upgrade`, + len(toRun), vers, lastVer) +} + +// getPendingMigrations gets the pending migrations by comparing the last +// recorded migration in the DB against all migrations listed in `migrations`. +func getPendingMigrations(db *sqlx.DB) (string, []migFunc, error) { + lastVer, err := getLastMigrationVersion() + if err != nil { + return "", nil, err + } + + // Iterate through the migration versions and get everything above the last + // last upgraded semver. + var toRun []migFunc + for i, m := range migList { + if semver.Compare(m.version, lastVer) > 0 { + toRun = migList[i:] + break + } + } + + return lastVer, toRun, nil +} + +// getLastMigrationVersion returns the last migration semver recorded in the DB. +// If there isn't any, `v0.0.0` is returned. +func getLastMigrationVersion() (string, error) { + var v string + if err := db.Get(&v, ` + SELECT COALESCE( + (SELECT value->>-1 FROM settings WHERE key='migrations'), + 'v0.0.0')`); err != nil { + if isTableNotExistErr(err) { + return "v0.0.0", nil + } + return v, err + } + return v, nil +} + +// isPqNoTableErr checks if the given error represents a Postgres/pq +// "table does not exist" error. +func isTableNotExistErr(err error) bool { + if p, ok := err.(*pq.Error); ok { + // `settings` table does not exist. It was introduced in v0.7.0. + if p.Code == "42P01" { + return true + } + } + return false +}