164 lines
4.4 KiB
Go
164 lines
4.4 KiB
Go
|
// Copyright (c) 2022 Tulir Asokan
|
||
|
//
|
||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||
|
|
||
|
package dbutil
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
)
|
||
|
|
||
|
type upgradeFunc func(Execable, *Database) error
|
||
|
|
||
|
type upgrade struct {
|
||
|
message string
|
||
|
fn upgradeFunc
|
||
|
|
||
|
upgradesTo int
|
||
|
transaction bool
|
||
|
}
|
||
|
|
||
|
var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
|
||
|
var ErrForeignTables = fmt.Errorf("the database contains foreign tables")
|
||
|
var ErrNotOwned = fmt.Errorf("the database is owned by")
|
||
|
|
||
|
func (db *Database) getVersion() (int, error) {
|
||
|
_, err := db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version INTEGER)", db.VersionTable))
|
||
|
if err != nil {
|
||
|
return -1, err
|
||
|
}
|
||
|
|
||
|
version := 0
|
||
|
err = db.QueryRow(fmt.Sprintf("SELECT version FROM %s LIMIT 1", db.VersionTable)).Scan(&version)
|
||
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||
|
return -1, err
|
||
|
}
|
||
|
return version, nil
|
||
|
}
|
||
|
|
||
|
const tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
|
||
|
const tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND tbl_name=$1)"
|
||
|
|
||
|
func (db *Database) tableExists(table string) (exists bool, err error) {
|
||
|
if db.Dialect == SQLite {
|
||
|
err = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
|
||
|
} else if db.Dialect == Postgres {
|
||
|
err = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
|
||
|
}
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (db *Database) tableExistsNoError(table string) bool {
|
||
|
exists, err := db.tableExists(table)
|
||
|
if err != nil {
|
||
|
panic(fmt.Errorf("failed to check if table exists: %w", err))
|
||
|
}
|
||
|
return exists
|
||
|
}
|
||
|
|
||
|
const createOwnerTable = `
|
||
|
CREATE TABLE IF NOT EXISTS database_owner (
|
||
|
key INTEGER PRIMARY KEY DEFAULT 0,
|
||
|
owner TEXT NOT NULL
|
||
|
)
|
||
|
`
|
||
|
|
||
|
func (db *Database) checkDatabaseOwner() error {
|
||
|
var owner string
|
||
|
if !db.IgnoreForeignTables {
|
||
|
if db.tableExistsNoError("state_groups_state") {
|
||
|
return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
|
||
|
} else if db.tableExistsNoError("roomserver_rooms") {
|
||
|
return fmt.Errorf("%w (found roomserver_rooms, likely belonging to Dendrite)", ErrForeignTables)
|
||
|
}
|
||
|
}
|
||
|
if db.Owner == "" {
|
||
|
return nil
|
||
|
}
|
||
|
if _, err := db.Exec(createOwnerTable); err != nil {
|
||
|
return fmt.Errorf("failed to ensure database owner table exists: %w", err)
|
||
|
} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
|
||
|
_, err = db.Exec("INSERT INTO database_owner (key, owner) VALUES (0, $1)", db.Owner)
|
||
|
if err != nil {
|
||
|
return fmt.Errorf("failed to insert database owner: %w", err)
|
||
|
}
|
||
|
} else if err != nil {
|
||
|
return fmt.Errorf("failed to check database owner: %w", err)
|
||
|
} else if owner != db.Owner {
|
||
|
return fmt.Errorf("%w %s", ErrNotOwned, owner)
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (db *Database) setVersion(tx Execable, version int) error {
|
||
|
_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", db.VersionTable))
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version) VALUES ($1)", db.VersionTable), version)
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (db *Database) Upgrade() error {
|
||
|
err := db.checkDatabaseOwner()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
version, err := db.getVersion()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if version > len(db.UpgradeTable) {
|
||
|
if db.IgnoreUnsupportedDatabase {
|
||
|
db.Log.WarnUnsupportedVersion(version, len(db.UpgradeTable))
|
||
|
return nil
|
||
|
}
|
||
|
return fmt.Errorf("%w: currently on v%d, latest known: v%d", ErrUnsupportedDatabaseVersion, version, len(db.UpgradeTable))
|
||
|
}
|
||
|
|
||
|
db.Log.PrepareUpgrade(version, len(db.UpgradeTable))
|
||
|
logVersion := version
|
||
|
for version < len(db.UpgradeTable) {
|
||
|
upgradeItem := db.UpgradeTable[version]
|
||
|
if upgradeItem.fn == nil {
|
||
|
version++
|
||
|
continue
|
||
|
}
|
||
|
db.Log.DoUpgrade(logVersion, upgradeItem.upgradesTo, upgradeItem.message, upgradeItem.transaction)
|
||
|
var tx Transaction
|
||
|
var upgradeConn Execable
|
||
|
if upgradeItem.transaction {
|
||
|
tx, err = db.Begin()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
upgradeConn = tx
|
||
|
} else {
|
||
|
upgradeConn = db
|
||
|
}
|
||
|
err = upgradeItem.fn(upgradeConn, db)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
version = upgradeItem.upgradesTo
|
||
|
logVersion = version
|
||
|
err = db.setVersion(upgradeConn, version)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if tx != nil {
|
||
|
err = tx.Commit()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|