steve/pkg/sqlcache/db/sqlwraps.go

127 lines
2.9 KiB
Go

package db
import (
"context"
"database/sql"
"time"
"github.com/rancher/steve/pkg/sqlcache/db/logging"
)
// Row implements a subset of the methods provided by sql.Row
type Row interface {
Err() error
Scan(dest ...any) error
}
// Rows represents sql rows. It exposes method to navigate the rows, read their outputs, and close them.
type Rows interface {
Next() bool
Err() error
Close() error
Scan(dest ...any) error
}
// Stmt is an interface over a subset of sql.Stmt methods
// rationale: allow mocking
type Stmt interface {
Exec(args ...any) (sql.Result, error)
Query(args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, args ...any) (Rows, error)
QueryRowContext(ctx context.Context, args ...any) Row
Close() error
// SQLStmt unwraps the original sql.Stmt
SQLStmt() *sql.Stmt
// GetQueryString returns the original text used to prepare this statement
GetQueryString() string
}
// row wraps a sql.Row, keeping track of the original query used to produce it
type row struct {
*sql.Row
queryString string
}
// Err wraps the original *sql.Row's Err() with a QueryError
func (r row) Err() error {
if err := r.Row.Err(); err != nil {
return &QueryError{QueryString: r.queryString, Err: err}
}
return nil
}
// row wraps a sql.Rows, keeping track of the original query used to produce it
type rows struct {
*sql.Rows
queryString string
}
// Err wraps the original *sql.Rows's Err() with a QueryError
func (r rows) Err() error {
if err := r.Rows.Err(); err != nil {
return &QueryError{QueryString: r.queryString, Err: err}
}
return nil
}
// stmt implements the Stmt interface, wrapping a sql.Stmt and keeping track of the original query string
// Most of the methods will wrap original errors with a QueryError
type stmt struct {
*sql.Stmt
queryString string
queryLogger logging.QueryLogger
}
func (s *stmt) log(startTime time.Time, query string, args []any) {
if s.queryLogger == nil {
return
}
s.queryLogger.Log(startTime, query, args)
}
func (s *stmt) Exec(args ...any) (sql.Result, error) {
defer s.log(time.Now(), s.queryString, args)
res, err := s.Stmt.Exec(args...)
if err != nil {
err = &QueryError{
QueryString: s.queryString,
Err: err,
}
}
return res, err
}
func (s *stmt) QueryContext(ctx context.Context, args ...any) (Rows, error) {
defer s.log(time.Now(), s.queryString, args)
res, err := s.Stmt.QueryContext(ctx, args...)
if err != nil {
return res, &QueryError{
QueryString: s.queryString,
Err: err,
}
}
return rows{Rows: res, queryString: s.queryString}, nil
}
func (s *stmt) QueryRowContext(ctx context.Context, args ...any) Row {
return row{Row: s.Stmt.QueryRowContext(ctx, args...), queryString: s.queryString}
}
func (s *stmt) Close() error {
if err := s.Stmt.Close(); err != nil {
return &QueryError{QueryString: s.queryString, Err: err}
}
return nil
}
func (s *stmt) SQLStmt() *sql.Stmt {
return s.Stmt
}
func (s *stmt) GetQueryString() string {
return s.queryString
}