matrix-go-test/vendor/maunium.net/go/mautrix/util/dbutil/connlog.go
2023-02-08 18:23:21 +01:00

172 lines
4.4 KiB
Go

// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"context"
"database/sql"
"time"
)
// LoggingExecable is a wrapper for anything with database Exec methods (i.e. sql.Conn, sql.DB and sql.Tx)
// that can preprocess queries (e.g. replacing $ with ? on SQLite) and log query durations.
type LoggingExecable struct {
UnderlyingExecable UnderlyingExecable
db *Database
}
func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
start := time.Now()
query = le.db.mutateQuery(query)
res, err := le.UnderlyingExecable.ExecContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start))
return res, err
}
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) {
start := time.Now()
query = le.db.mutateQuery(query)
rows, err := le.UnderlyingExecable.QueryContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start))
return &LoggingRows{
ctx: ctx,
db: le.db,
query: query,
args: args,
rs: rows,
start: start,
}, err
}
func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
start := time.Now()
query = le.db.mutateQuery(query)
row := le.UnderlyingExecable.QueryRowContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "QueryRow", query, args, -1, time.Since(start))
return row
}
func (le *LoggingExecable) Exec(query string, args ...interface{}) (sql.Result, error) {
return le.ExecContext(context.Background(), query, args...)
}
func (le *LoggingExecable) Query(query string, args ...interface{}) (Rows, error) {
return le.QueryContext(context.Background(), query, args...)
}
func (le *LoggingExecable) QueryRow(query string, args ...interface{}) *sql.Row {
return le.QueryRowContext(context.Background(), query, args...)
}
// loggingDB is a wrapper for LoggingExecable that allows access to BeginTx.
//
// While LoggingExecable has a pointer to the database and could use BeginTx, it's not technically safe since
// the LoggingExecable could be for a transaction (where BeginTx wouldn't make sense).
type loggingDB struct {
LoggingExecable
}
func (ld *loggingDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*LoggingTxn, error) {
start := time.Now()
tx, err := ld.db.RawDB.BeginTx(ctx, opts)
ld.db.Log.QueryTiming(ctx, "Begin", "", nil, -1, time.Since(start))
if err != nil {
return nil, err
}
return &LoggingTxn{
LoggingExecable: LoggingExecable{UnderlyingExecable: tx, db: ld.db},
UnderlyingTx: tx,
ctx: ctx,
}, nil
}
func (ld *loggingDB) Begin() (*LoggingTxn, error) {
return ld.BeginTx(context.Background(), nil)
}
type LoggingTxn struct {
LoggingExecable
UnderlyingTx *sql.Tx
ctx context.Context
}
func (lt *LoggingTxn) Commit() error {
start := time.Now()
err := lt.UnderlyingTx.Commit()
lt.db.Log.QueryTiming(lt.ctx, "Commit", "", nil, -1, time.Since(start))
return err
}
func (lt *LoggingTxn) Rollback() error {
start := time.Now()
err := lt.UnderlyingTx.Rollback()
lt.db.Log.QueryTiming(lt.ctx, "Rollback", "", nil, -1, time.Since(start))
return err
}
type LoggingRows struct {
ctx context.Context
db *Database
query string
args []interface{}
rs Rows
start time.Time
nrows int
}
func (lrs *LoggingRows) stopTiming() {
if !lrs.start.IsZero() {
lrs.db.Log.QueryTiming(lrs.ctx, "EndRows", lrs.query, lrs.args, lrs.nrows, time.Since(lrs.start))
lrs.start = time.Time{}
}
}
func (lrs *LoggingRows) Close() error {
err := lrs.rs.Close()
lrs.stopTiming()
return err
}
func (lrs *LoggingRows) ColumnTypes() ([]*sql.ColumnType, error) {
return lrs.rs.ColumnTypes()
}
func (lrs *LoggingRows) Columns() ([]string, error) {
return lrs.rs.Columns()
}
func (lrs *LoggingRows) Err() error {
return lrs.rs.Err()
}
func (lrs *LoggingRows) Next() bool {
hasNext := lrs.rs.Next()
if !hasNext {
lrs.stopTiming()
} else {
lrs.nrows++
}
return hasNext
}
func (lrs *LoggingRows) NextResultSet() bool {
hasNext := lrs.rs.NextResultSet()
if !hasNext {
lrs.stopTiming()
} else {
lrs.nrows++
}
return hasNext
}
func (lrs *LoggingRows) Scan(dest ...any) error {
return lrs.rs.Scan(dest...)
}