components-contrib/state/mysql/mysql.go

619 lines
15 KiB
Go

/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mysql
import (
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/agrea/ptr"
"github.com/google/uuid"
"github.com/dapr/components-contrib/state"
"github.com/dapr/components-contrib/state/utils"
"github.com/dapr/kit/logger"
)
// Optimistic Concurrency is implemented using a string column that stores
// a UUID.
const (
// Used if the user does not configure a table name in the metadata.
defaultTableName = "state"
// Used if the user does not configure a database name in the metadata.
defaultSchemaName = "dapr_state_store"
// The key name in the metadata if the user wants a different table name
// than the defaultTableName.
tableNameKey = "tableName"
// The key name in the metadata if the user wants a different database name
// than the defaultSchemaName.
schemaNameKey = "schemaName"
// The key for the mandatory connection string of the metadata.
connectionStringKey = "connectionString"
// Standard error message if not connection string is provided.
errMissingConnectionString = "missing connection string"
// To connect to MySQL running in Azure over SSL you have to download a
// SSL certificate. If this is provided the driver will connect using
// SSL. If you have disable SSL you can leave this empty.
// When the user provides a pem path their connection string must end with
// &tls=custom
// The connection string should be in the following format
// "%s:%s@tcp(%s:3306)/%s?allowNativePasswords=true&tls=custom",'myadmin@mydemoserver', 'yourpassword', 'mydemoserver.mysql.database.azure.com', 'targetdb'.
pemPathKey = "pemPath"
)
// MySQL state store.
type MySQL struct {
// Name of the table to store state. If the table does not exist it will
// be created.
tableName string
// Name of the table to create to store state. If the table does not exist
// it will be created.
schemaName string
connectionString string
// Instance of the database to issue commands to
db *sql.DB
features []state.Feature
// Logger used in a functions
logger logger.Logger
factory iMySQLFactory
}
// NewMySQLStateStore creates a new instance of MySQL state store.
func NewMySQLStateStore(logger logger.Logger) *MySQL {
factory := newMySQLFactory(logger)
// Store the provided logger and return the object. The rest of the
// properties will be populated in the Init function
return newMySQLStateStore(logger, factory)
}
// Hidden implementation for testing.
func newMySQLStateStore(logger logger.Logger, factory iMySQLFactory) *MySQL {
// Store the provided logger and return the object. The rest of the
// properties will be populated in the Init function
return &MySQL{
features: []state.Feature{state.FeatureETag, state.FeatureTransactional},
logger: logger,
factory: factory,
}
}
// Init initializes the SQL server state store
// Implements the following interfaces:
// Store
// TransactionalStore
// Populate the rest of the MySQL object by reading the metadata and opening
// a connection to the server.
func (m *MySQL) Init(metadata state.Metadata) error {
m.logger.Debug("Initializing MySql state store")
val, ok := metadata.Properties[tableNameKey]
if ok && val != "" {
m.tableName = val
} else {
// Default to the constant
m.tableName = defaultTableName
}
val, ok = metadata.Properties[schemaNameKey]
if ok && val != "" {
m.schemaName = val
} else {
// Default to the constant
m.schemaName = defaultSchemaName
}
m.connectionString, ok = metadata.Properties[connectionStringKey]
if !ok || m.connectionString == "" {
m.logger.Error("Missing MySql connection string")
return fmt.Errorf(errMissingConnectionString)
}
val, ok = metadata.Properties[pemPathKey]
if ok && val != "" {
err := m.factory.RegisterTLSConfig(val)
if err != nil {
m.logger.Error(err)
return err
}
}
db, err := m.factory.Open(m.connectionString)
// will be nil if everything is good or an err that needs to be returned
return m.finishInit(db, err)
}
func (m *MySQL) Ping() error {
return nil
}
// Features returns the features available in this state store.
func (m *MySQL) Features() []state.Feature {
return m.features
}
// Separated out to make this portion of code testable.
func (m *MySQL) finishInit(db *sql.DB, err error) error {
if err != nil {
m.logger.Error(err)
return err
}
m.db = db
schemaErr := m.ensureStateSchema()
if schemaErr != nil {
m.logger.Error(schemaErr)
return schemaErr
}
pingErr := m.db.Ping()
if pingErr != nil {
m.logger.Error(pingErr)
return pingErr
}
// will be nil if everything is good or an err that needs to be returned
return m.ensureStateTable(m.tableName)
}
func (m *MySQL) ensureStateSchema() error {
exists, err := schemaExists(m.db, m.schemaName)
if err != nil {
return err
}
if !exists {
m.logger.Infof("Creating MySql schema '%s'", m.schemaName)
createTable := fmt.Sprintf("CREATE DATABASE %s;", m.schemaName)
_, err = m.db.Exec(createTable)
if err != nil {
return err
}
}
// Build a connection string that contains the new schema name
// All MySQL connection strings must contain a / so split on it.
parts := strings.Split(m.connectionString, "/")
// Even if the connection string ends with a / parts will have two values
// with the second being an empty string.
m.connectionString = fmt.Sprintf("%s/%s%s", parts[0], m.schemaName, parts[1])
// Close the connection we used to confirm and or create the schema
err = m.db.Close()
if err != nil {
return err
}
// Open a connection to the new schema
m.db, err = m.factory.Open(m.connectionString)
return err
}
func (m *MySQL) ensureStateTable(stateTableName string) error {
exists, err := tableExists(m.db, stateTableName)
if err != nil {
return err
}
if !exists {
m.logger.Infof("Creating MySql state table '%s'", stateTableName)
// updateDate is updated automactically on every UPDATE commands so you
// never need to pass it in.
// eTag is a UUID stored as a 36 characters string. It needs to be passed
// in on inserts and updates and is used for Optimistic Concurrency
createTable := fmt.Sprintf(`CREATE TABLE %s (
id VARCHAR(255) NOT NULL PRIMARY KEY,
value JSON NOT NULL,
isbinary BOOLEAN NOT NULL,
insertDate TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
updateDate TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
eTag VARCHAR(36) NOT NULL
);`, stateTableName)
_, err = m.db.Exec(createTable)
if err != nil {
return err
}
}
return nil
}
func schemaExists(db *sql.DB, schemaName string) (bool, error) {
exists := ""
query := `SELECT EXISTS (
SELECT SCHEMA_NAME FROM information_schema.schemata WHERE SCHEMA_NAME = ?
) AS 'exists'`
// Returns 1 or 0 as a string if the table exists or not
err := db.QueryRow(query, schemaName).Scan(&exists)
return exists == "1", err
}
func tableExists(db *sql.DB, tableName string) (bool, error) {
exists := ""
query := `SELECT EXISTS (
SELECT TABLE_NAME FROM information_schema.tables WHERE TABLE_NAME = ?
) AS 'exists'`
// Returns 1 or 0 as a string if the table exists or not
err := db.QueryRow(query, tableName).Scan(&exists)
return exists == "1", err
}
// Delete removes an entity from the store
// Store Interface.
func (m *MySQL) Delete(req *state.DeleteRequest) error {
return state.DeleteWithOptions(m.deleteValue, req)
}
// deleteValue is an internal implementation of delete to enable passing the
// logic to state.DeleteWithRetries as a func.
func (m *MySQL) deleteValue(req *state.DeleteRequest) error {
m.logger.Debug("Deleting state value from MySql")
if req.Key == "" {
return fmt.Errorf("missing key in delete operation")
}
var err error
var result sql.Result
if req.ETag == nil || *req.ETag == "" {
result, err = m.db.Exec(fmt.Sprintf(
`DELETE FROM %s WHERE id = ?`,
m.tableName), req.Key)
} else {
result, err = m.db.Exec(fmt.Sprintf(
`DELETE FROM %s WHERE id = ? and eTag = ?`,
m.tableName), req.Key, *req.ETag)
}
if err != nil {
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows != 1 && req.ETag != nil && *req.ETag != "" {
return state.NewETagError(state.ETagMismatch, nil)
}
return nil
}
// BulkDelete removes multiple entries from the store
// Store Interface.
func (m *MySQL) BulkDelete(req []state.DeleteRequest) error {
m.logger.Debug("Executing BulkDelete request")
tx, err := m.db.Begin()
if err != nil {
return err
}
if len(req) > 0 {
for _, d := range req {
da := d // Fix for goSec G601: Implicit memory aliasing in for loop.
err = m.Delete(&da)
if err != nil {
tx.Rollback()
return err
}
}
}
err = tx.Commit()
return err
}
// Get returns an entity from store
// Store Interface.
func (m *MySQL) Get(req *state.GetRequest) (*state.GetResponse, error) {
m.logger.Debug("Getting state value from MySql")
if req.Key == "" {
return nil, fmt.Errorf("missing key in get operation")
}
var eTag, value string
var isBinary bool
err := m.db.QueryRow(fmt.Sprintf(
`SELECT value, eTag, isbinary FROM %s WHERE id = ?`,
m.tableName), req.Key).Scan(&value, &eTag, &isBinary)
if err != nil {
// If no rows exist, return an empty response, otherwise return an error.
if errors.Is(err, sql.ErrNoRows) {
return &state.GetResponse{}, nil
}
return nil, err
}
if isBinary {
var s string
var data []byte
if err = json.Unmarshal([]byte(value), &s); err != nil {
return nil, err
}
if data, err = base64.StdEncoding.DecodeString(s); err != nil {
return nil, err
}
return &state.GetResponse{
Data: data,
ETag: ptr.String(eTag),
Metadata: req.Metadata,
}, nil
}
return &state.GetResponse{
Data: []byte(value),
ETag: ptr.String(eTag),
Metadata: req.Metadata,
}, nil
}
// Set adds/updates an entity on store
// Store Interface.
func (m *MySQL) Set(req *state.SetRequest) error {
return state.SetWithOptions(m.setValue, req)
}
// setValue is an internal implementation of set to enable passing the logic
// to state.SetWithRetries as a func.
func (m *MySQL) setValue(req *state.SetRequest) error {
m.logger.Debug("Setting state value in MySql")
err := state.CheckRequestOptions(req.Options)
if err != nil {
return err
}
if req.Key == "" {
return fmt.Errorf("missing key in set operation")
}
if v, ok := req.Value.(string); ok && v == "" {
return fmt.Errorf("empty string is not allowed in set operation")
}
v := req.Value
byteArray, isBinary := req.Value.([]uint8)
if isBinary {
v = base64.StdEncoding.EncodeToString(byteArray)
}
// Convert to json string
bt, _ := utils.Marshal(v, json.Marshal)
value := string(bt)
var result sql.Result
eTag := uuid.New().String()
// Sprintf is required for table name because sql.DB does not substitute
// parameters for table names.
// Other parameters use sql.DB parameter substitution.
if req.ETag == nil || *req.ETag == "" {
// If this is a duplicate MySQL returns that two rows affected
result, err = m.db.Exec(fmt.Sprintf(
`INSERT INTO %s (value, id, eTag, isbinary)
VALUES (?, ?, ?, ?) on duplicate key update value=?, eTag=?, isbinary=?;`,
m.tableName), value, req.Key, eTag, isBinary, value, eTag, isBinary)
} else {
// When an eTag is provided do an update - not insert
result, err = m.db.Exec(fmt.Sprintf(
`UPDATE %s SET value = ?, eTag = ?, isbinary = ?
WHERE id = ? AND eTag = ?;`,
m.tableName), value, eTag, isBinary, req.Key, *req.ETag)
}
if err != nil {
if req.ETag != nil && *req.ETag != "" {
return state.NewETagError(state.ETagMismatch, err)
}
return err
}
rows, err := result.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
err = fmt.Errorf(`rows affected error: no rows match given key '%s' and eTag '%s'`, req.Key, *req.ETag)
err = state.NewETagError(state.ETagMismatch, err)
m.logger.Error(err)
return err
}
if rows > 2 {
err = fmt.Errorf(`rows affected error: more than 2 row affected, expected 2, actual %d`, rows)
m.logger.Error(err)
return err
}
return nil
}
// BulkSet adds/updates multiple entities on store
// Store Interface.
func (m *MySQL) BulkSet(req []state.SetRequest) error {
m.logger.Debug("Executing BulkSet request")
tx, err := m.db.Begin()
if err != nil {
return err
}
if len(req) > 0 {
for _, s := range req {
sa := s // Fix for goSec G601: Implicit memory aliasing in for loop.
err = m.Set(&sa)
if err != nil {
tx.Rollback()
return err
}
}
}
err = tx.Commit()
return err
}
// Multi handles multiple transactions.
// TransactionalStore Interface.
func (m *MySQL) Multi(request *state.TransactionalStateRequest) error {
m.logger.Debug("Executing Multi request")
tx, err := m.db.Begin()
if err != nil {
return err
}
for _, req := range request.Operations {
switch req.Operation {
case state.Upsert:
setReq, err := m.getSets(req)
if err != nil {
tx.Rollback()
return err
}
err = m.Set(&setReq)
if err != nil {
tx.Rollback()
return err
}
case state.Delete:
delReq, err := m.getDeletes(req)
if err != nil {
tx.Rollback()
return err
}
err = m.Delete(&delReq)
if err != nil {
tx.Rollback()
return err
}
default:
return fmt.Errorf("unsupported operation: %s", req.Operation)
}
}
return tx.Commit()
}
// Returns the set requests.
func (m *MySQL) getSets(req state.TransactionalStateOperation) (state.SetRequest, error) {
setReq, ok := req.Request.(state.SetRequest)
if !ok {
return setReq, fmt.Errorf("expecting set request")
}
if setReq.Key == "" {
return setReq, fmt.Errorf("missing key in upsert operation")
}
return setReq, nil
}
// Returns the delete requests.
func (m *MySQL) getDeletes(req state.TransactionalStateOperation) (state.DeleteRequest, error) {
delReq, ok := req.Request.(state.DeleteRequest)
if !ok {
return delReq, fmt.Errorf("expecting delete request")
}
if delReq.Key == "" {
return delReq, fmt.Errorf("missing key in upsert operation")
}
return delReq, nil
}
// BulkGet performs a bulks get operations.
func (m *MySQL) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
// by default, the store doesn't support bulk get
// return false so daprd will fallback to call get() method one by one
return false, nil, nil
}
// Close implements io.Closer.
func (m *MySQL) Close() error {
if m.db != nil {
return m.db.Close()
}
return nil
}