components-contrib/state/sqlserver/sqlserver.go

639 lines
15 KiB
Go

/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlserver
import (
"context"
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"unicode"
mssql "github.com/denisenkom/go-mssqldb"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/components-contrib/state"
"github.com/dapr/components-contrib/state/utils"
"github.com/dapr/kit/logger"
"github.com/dapr/kit/ptr"
)
// KeyType defines type of the table identifier.
type KeyType string
// KeyTypeFromString tries to create a KeyType from a string value.
func KeyTypeFromString(k string) (KeyType, error) {
switch k {
case string(StringKeyType):
return StringKeyType, nil
case string(UUIDKeyType):
return UUIDKeyType, nil
case string(IntegerKeyType):
return IntegerKeyType, nil
}
return InvalidKeyType, errors.New("invalid key type")
}
const (
// StringKeyType defines a key of type string.
StringKeyType KeyType = "string"
// UUIDKeyType defines a key of type UUID/GUID.
UUIDKeyType KeyType = "uuid"
// IntegerKeyType defines a key of type integer.
IntegerKeyType KeyType = "integer"
// InvalidKeyType defines an invalid key type.
InvalidKeyType KeyType = "invalid"
)
const (
connectionStringKey = "connectionString"
tableNameKey = "tableName"
schemaKey = "schema"
keyTypeKey = "keyType"
keyLengthKey = "keyLength"
indexedPropertiesKey = "indexedProperties"
keyColumnName = "Key"
rowVersionColumnName = "RowVersion"
databaseNameKey = "databaseName"
defaultKeyLength = 200
defaultSchema = "dbo"
defaultDatabase = "dapr"
defaultTable = "state"
)
// NewSQLServerStateStore creates a new instance of a Sql Server transaction store.
func NewSQLServerStateStore(logger logger.Logger) state.Store {
store := SQLServer{
features: []state.Feature{state.FeatureETag, state.FeatureTransactional},
logger: logger,
}
store.migratorFactory = newMigration
return &store
}
// IndexedProperty defines a indexed property.
type IndexedProperty struct {
ColumnName string `json:"column"`
Property string `json:"property"`
Type string `json:"type"`
}
// SQLServer defines a Ms SQL Server based state store.
type SQLServer struct {
connectionString string
databaseName string
tableName string
schema string
keyType KeyType
keyLength int
indexedProperties []IndexedProperty
migratorFactory func(*SQLServer) migrator
bulkDeleteCommand string
itemRefTableTypeName string
upsertCommand string
getCommand string
deleteWithETagCommand string
deleteWithoutETagCommand string
features []state.Feature
logger logger.Logger
db *sql.DB
}
type sqlServerMetadata struct {
ConnectionString string
DatabaseName string
TableName string
Schema string
KeyType string
KeyLength int
IndexedProperties string
}
func isLetterOrNumber(c rune) bool {
return unicode.IsNumber(c) || unicode.IsLetter(c)
}
func isValidSQLName(s string) bool {
for _, c := range s {
if !(isLetterOrNumber(c) || (c == '_')) {
return false
}
}
return true
}
func isValidIndexedPropertyName(s string) bool {
for _, c := range s {
if !(isLetterOrNumber(c) || (c == '_') || (c == '.') || (c == '[') || (c == ']')) {
return false
}
}
return true
}
func isValidIndexedPropertyType(s string) bool {
for _, c := range s {
if !(isLetterOrNumber(c) || (c == '(') || (c == ')')) {
return false
}
}
return true
}
// Init initializes the SQL server state store.
func (s *SQLServer) Init(_ context.Context, metadata state.Metadata) error {
err := s.parseMetadata(metadata.Properties)
if err != nil {
return err
}
migration := s.migratorFactory(s)
mr, err := migration.executeMigrations()
if err != nil {
return err
}
s.itemRefTableTypeName = mr.itemRefTableTypeName
s.bulkDeleteCommand = fmt.Sprintf("exec %s @itemsToDelete;", mr.bulkDeleteProcFullName)
s.upsertCommand = mr.upsertProcFullName
s.getCommand = mr.getCommand
s.deleteWithETagCommand = mr.deleteWithETagCommand
s.deleteWithoutETagCommand = mr.deleteWithoutETagCommand
s.db, err = sql.Open("sqlserver", s.connectionString)
if err != nil {
return err
}
return nil
}
func (s *SQLServer) parseMetadata(meta map[string]string) error {
m := sqlServerMetadata{
TableName: defaultTable,
Schema: defaultSchema,
DatabaseName: defaultDatabase,
KeyLength: defaultKeyLength,
}
err := metadata.DecodeMetadata(meta, &m)
if err != nil {
return err
}
if m.ConnectionString == "" {
return fmt.Errorf("missing connection string")
}
s.connectionString = m.ConnectionString
if err := s.setTable(m.TableName); err != nil {
return err
}
if err := s.setDatabase(m.DatabaseName); err != nil {
return err
}
if err := s.setKeyType(m.KeyType, m.KeyLength); err != nil {
return err
}
if err := s.setSchema(m.Schema); err != nil {
return err
}
if err := s.setIndexedProperties(m.IndexedProperties); err != nil {
return err
}
return nil
}
// Returns validated index properties.
func (s *SQLServer) setIndexedProperties(indexedPropertiesString string) error {
if indexedPropertiesString != "" {
var indexedProperties []IndexedProperty
err := json.Unmarshal([]byte(indexedPropertiesString), &indexedProperties)
if err != nil {
return err
}
err = s.validateIndexedProperties(indexedProperties)
if err != nil {
return err
}
s.indexedProperties = indexedProperties
}
return nil
}
// Validates that all the mandator index properties are supplied and that the
// values are valid.
func (s *SQLServer) validateIndexedProperties(indexedProperties []IndexedProperty) error {
for _, p := range indexedProperties {
if p.ColumnName == "" {
return errors.New("indexed property column cannot be empty")
}
if p.Property == "" {
return errors.New("indexed property name cannot be empty")
}
if p.Type == "" {
return errors.New("indexed property type cannot be empty")
}
if !isValidSQLName(p.ColumnName) {
return fmt.Errorf("invalid indexed property column name, accepted characters are (A-Z, a-z, 0-9, _)")
}
if !isValidIndexedPropertyName(p.Property) {
return fmt.Errorf("invalid indexed property name, accepted characters are (A-Z, a-z, 0-9, _, ., [, ])")
}
if !isValidIndexedPropertyType(p.Type) {
return fmt.Errorf("invalid indexed property type, accepted characters are (A-Z, a-z, 0-9, _, (, ))")
}
}
return nil
}
// Validates and returns the key type.
func (s *SQLServer) setKeyType(keyType string, keyLength int) error {
if keyType != "" {
kt, err := KeyTypeFromString(keyType)
if err != nil {
return err
}
s.keyType = kt
} else {
s.keyType = StringKeyType
}
if s.keyType != StringKeyType {
return nil
}
if keyLength <= 0 {
return fmt.Errorf("invalid key length value of %d", keyLength)
} else {
s.keyLength = keyLength
}
return nil
}
// Returns the schema name if set or the default value otherwise.
func (s *SQLServer) setSchema(schemaName string) error {
if !isValidSQLName(schemaName) {
return fmt.Errorf("invalid schema name, accepted characters are (A-Z, a-z, 0-9, _)")
}
s.schema = schemaName
return nil
}
// Returns the database name if set or the default value otherwise.
func (s *SQLServer) setDatabase(databaseName string) error {
if !isValidSQLName(databaseName) {
return fmt.Errorf("invalid database name, accepted characters are (A-Z, a-z, 0-9, _)")
}
s.databaseName = databaseName
return nil
}
// Returns the table name if set or the default value otherwise.
func (s *SQLServer) setTable(tableName string) error {
if !isValidSQLName(tableName) {
return fmt.Errorf("invalid table name, accepted characters are (A-Z, a-z, 0-9, _)")
}
s.tableName = tableName
return nil
}
// Features returns the features available in this state store.
func (s *SQLServer) Features() []state.Feature {
return s.features
}
// Multi performs multiple updates on a Sql server store.
func (s *SQLServer) Multi(ctx context.Context, request *state.TransactionalStateRequest) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
for _, req := range request.Operations {
switch req.Operation {
case state.Upsert:
setReq, err := s.getSets(req)
if err != nil {
tx.Rollback()
return err
}
err = s.executeSet(ctx, tx, &setReq)
if err != nil {
tx.Rollback()
return err
}
case state.Delete:
delReq, err := s.getDeletes(req)
if err != nil {
tx.Rollback()
return err
}
err = s.executeDelete(ctx, tx, &delReq)
if err != nil {
tx.Rollback()
return err
}
default:
tx.Rollback()
return fmt.Errorf("unsupported operation: %s", req.Operation)
}
}
return tx.Commit()
}
// Returns the set requests.
func (s *SQLServer) getSets(req state.TransactionalStateOperation) (state.SetRequest, error) {
setReq, ok := req.Request.(state.SetRequest)
if !ok {
return setReq, fmt.Errorf("expecting set request")
}
if setReq.Key == "" {
return setReq, fmt.Errorf("missing key in upsert operation")
}
return setReq, nil
}
// Returns the delete requests.
func (s *SQLServer) getDeletes(req state.TransactionalStateOperation) (state.DeleteRequest, error) {
delReq, ok := req.Request.(state.DeleteRequest)
if !ok {
return delReq, fmt.Errorf("expecting delete request")
}
if delReq.Key == "" {
return delReq, fmt.Errorf("missing key in upsert operation")
}
return delReq, nil
}
// Delete removes an entity from the store.
func (s *SQLServer) Delete(ctx context.Context, req *state.DeleteRequest) error {
return s.executeDelete(ctx, s.db, req)
}
func (s *SQLServer) executeDelete(ctx context.Context, db dbExecutor, req *state.DeleteRequest) error {
var err error
var res sql.Result
if req.ETag != nil {
var b []byte
b, err = hex.DecodeString(*req.ETag)
if err != nil {
return state.NewETagError(state.ETagInvalid, err)
}
res, err = db.ExecContext(ctx, s.deleteWithETagCommand, sql.Named(keyColumnName, req.Key), sql.Named(rowVersionColumnName, b))
} else {
res, err = db.ExecContext(ctx, s.deleteWithoutETagCommand, sql.Named(keyColumnName, req.Key))
}
// err represents errors thrown by the stored procedure or the database itself
if err != nil {
return err
}
// if the row with matching key (and ETag if specified) is not found, then the stored procedure returns 0 rows affected
rows, err := res.RowsAffected()
if err != nil {
return err
}
// When an ETAG is specified, a row must have been deleted or else we return an ETag mismatch error
if rows != 1 && req.ETag != nil && *req.ETag != "" {
return state.NewETagError(state.ETagMismatch, nil)
}
// successful deletion, or noop if no ETAG specified
return nil
}
// TvpDeleteTableStringKey defines a table type with string key.
type TvpDeleteTableStringKey struct {
ID string
RowVersion []byte
}
// BulkDelete removes multiple entries from the store.
func (s *SQLServer) BulkDelete(ctx context.Context, req []state.DeleteRequest) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
err = s.executeBulkDelete(ctx, tx, req)
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
return nil
}
func (s *SQLServer) executeBulkDelete(ctx context.Context, db dbExecutor, req []state.DeleteRequest) error {
values := make([]TvpDeleteTableStringKey, len(req))
for i, d := range req {
var etag []byte
var err error
if d.ETag != nil {
etag, err = hex.DecodeString(*d.ETag)
if err != nil {
return state.NewETagError(state.ETagInvalid, err)
}
}
values[i] = TvpDeleteTableStringKey{ID: d.Key, RowVersion: etag}
}
itemsToDelete := mssql.TVP{
TypeName: s.itemRefTableTypeName,
Value: values,
}
res, err := db.ExecContext(ctx, s.bulkDeleteCommand, sql.Named("itemsToDelete", itemsToDelete))
if err != nil {
return err
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if int(rows) != len(req) {
err = state.NewBulkDeleteRowMismatchError(uint64(rows), uint64(len(req)))
return err
}
return nil
}
// Get returns an entity from store.
func (s *SQLServer) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
rows, err := s.db.Query(s.getCommand, sql.Named(keyColumnName, req.Key))
if err != nil {
return nil, err
}
if rows.Err() != nil {
return nil, rows.Err()
}
defer rows.Close()
if !rows.Next() {
return &state.GetResponse{}, nil
}
var data string
var rowVersion []byte
err = rows.Scan(&data, &rowVersion)
if err != nil {
return nil, err
}
etag := hex.EncodeToString(rowVersion)
return &state.GetResponse{
Data: []byte(data),
ETag: ptr.Of(etag),
}, nil
}
// BulkGet performs a bulks get operations.
func (s *SQLServer) BulkGet(ctx context.Context, req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
return false, nil, nil
}
// Set adds/updates an entity on store.
func (s *SQLServer) Set(ctx context.Context, req *state.SetRequest) error {
return s.executeSet(ctx, s.db, req)
}
// dbExecutor implements a common functionality implemented by db or tx.
type dbExecutor interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
}
func (s *SQLServer) executeSet(ctx context.Context, db dbExecutor, req *state.SetRequest) error {
var err error
var bytes []byte
bytes, err = utils.Marshal(req.Value, json.Marshal)
if err != nil {
return err
}
etag := sql.Named(rowVersionColumnName, nil)
if req.ETag != nil && *req.ETag != "" {
var b []byte
b, err = hex.DecodeString(*req.ETag)
if err != nil {
return state.NewETagError(state.ETagInvalid, err)
}
etag = sql.Named(rowVersionColumnName, b)
}
var res sql.Result
if req.Options.Concurrency == state.FirstWrite {
res, err = db.ExecContext(ctx, s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 1))
} else {
res, err = db.ExecContext(ctx, s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 0))
}
if err != nil {
if req.ETag != nil && *req.ETag != "" {
return state.NewETagError(state.ETagMismatch, err)
}
return err
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return fmt.Errorf("no item was updated")
}
return nil
}
// BulkSet adds/updates multiple entities on store.
func (s *SQLServer) BulkSet(ctx context.Context, req []state.SetRequest) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
for i := range req {
err = s.executeSet(ctx, tx, &req[i])
if err != nil {
tx.Rollback()
return err
}
}
err = tx.Commit()
return err
}
func (s *SQLServer) GetComponentMetadata() map[string]string {
return map[string]string{}
}