components-contrib/state/sqlserver/sqlserver.go

660 lines
15 KiB
Go

/*
Copyright 2021 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package sqlserver
import (
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strconv"
"unicode"
"github.com/agrea/ptr"
mssql "github.com/denisenkom/go-mssqldb"
"github.com/dapr/components-contrib/state"
"github.com/dapr/components-contrib/state/utils"
"github.com/dapr/kit/logger"
)
// KeyType defines type of the table identifier.
type KeyType string
// KeyTypeFromString tries to create a KeyType from a string value.
func KeyTypeFromString(k string) (KeyType, error) {
switch k {
case string(StringKeyType):
return StringKeyType, nil
case string(UUIDKeyType):
return UUIDKeyType, nil
case string(IntegerKeyType):
return IntegerKeyType, nil
}
return InvalidKeyType, errors.New("invalid key type")
}
const (
// StringKeyType defines a key of type string.
StringKeyType KeyType = "string"
// UUIDKeyType defines a key of type UUID/GUID.
UUIDKeyType KeyType = "uuid"
// IntegerKeyType defines a key of type integer.
IntegerKeyType KeyType = "integer"
// InvalidKeyType defines an invalid key type.
InvalidKeyType KeyType = "invalid"
)
const (
connectionStringKey = "connectionString"
tableNameKey = "tableName"
schemaKey = "schema"
keyTypeKey = "keyType"
keyLengthKey = "keyLength"
indexedPropertiesKey = "indexedProperties"
keyColumnName = "Key"
rowVersionColumnName = "RowVersion"
databaseNameKey = "databaseName"
defaultKeyLength = 200
defaultSchema = "dbo"
defaultDatabase = "dapr"
defaultTable = "state"
)
// NewSQLServerStateStore creates a new instance of a Sql Server transaction store.
func NewSQLServerStateStore(logger logger.Logger) *SQLServer {
store := SQLServer{
features: []state.Feature{state.FeatureETag, state.FeatureTransactional},
logger: logger,
}
store.migratorFactory = newMigration
return &store
}
// IndexedProperty defines a indexed property.
type IndexedProperty struct {
ColumnName string `json:"column"`
Property string `json:"property"`
Type string `json:"type"`
}
// SQLServer defines a Ms SQL Server based state store.
type SQLServer struct {
connectionString string
databaseName string
tableName string
schema string
keyType KeyType
keyLength int
indexedProperties []IndexedProperty
migratorFactory func(*SQLServer) migrator
bulkDeleteCommand string
itemRefTableTypeName string
upsertCommand string
getCommand string
deleteWithETagCommand string
deleteWithoutETagCommand string
features []state.Feature
logger logger.Logger
db *sql.DB
}
func isLetterOrNumber(c rune) bool {
return unicode.IsNumber(c) || unicode.IsLetter(c)
}
func isValidSQLName(s string) bool {
for _, c := range s {
if !(isLetterOrNumber(c) || (c == '_')) {
return false
}
}
return true
}
func isValidIndexedPropertyName(s string) bool {
for _, c := range s {
if !(isLetterOrNumber(c) || (c == '_') || (c == '.') || (c == '[') || (c == ']')) {
return false
}
}
return true
}
func isValidIndexedPropertyType(s string) bool {
for _, c := range s {
if !(isLetterOrNumber(c) || (c == '(') || (c == ')')) {
return false
}
}
return true
}
// Init initializes the SQL server state store.
func (s *SQLServer) Init(metadata state.Metadata) error {
if val, ok := metadata.Properties[connectionStringKey]; ok && val != "" {
s.connectionString = val
} else {
return fmt.Errorf("missing connection string")
}
if err := s.getTable(metadata); err != nil {
return err
}
if err := s.getDatabase(metadata); err != nil {
return err
}
if err := s.getKeyType(metadata); err != nil {
return err
}
if err := s.getSchema(metadata); err != nil {
return err
}
if err := s.getIndexedProperties(metadata); err != nil {
return err
}
migration := s.migratorFactory(s)
mr, err := migration.executeMigrations()
if err != nil {
return err
}
s.itemRefTableTypeName = mr.itemRefTableTypeName
s.bulkDeleteCommand = fmt.Sprintf("exec %s @itemsToDelete;", mr.bulkDeleteProcFullName)
s.upsertCommand = mr.upsertProcFullName
s.getCommand = mr.getCommand
s.deleteWithETagCommand = mr.deleteWithETagCommand
s.deleteWithoutETagCommand = mr.deleteWithoutETagCommand
s.db, err = sql.Open("sqlserver", s.connectionString)
if err != nil {
return err
}
return nil
}
// Returns validated index properties.
func (s *SQLServer) getIndexedProperties(metadata state.Metadata) error {
if val, ok := metadata.Properties[indexedPropertiesKey]; ok && val != "" {
var indexedProperties []IndexedProperty
err := json.Unmarshal([]byte(val), &indexedProperties)
if err != nil {
return err
}
err = s.validateIndexedProperties(indexedProperties)
if err != nil {
return err
}
s.indexedProperties = indexedProperties
}
return nil
}
// Validates that all the mandator index properties are supplied and that the
// values are valid.
func (s *SQLServer) validateIndexedProperties(indexedProperties []IndexedProperty) error {
for _, p := range indexedProperties {
if p.ColumnName == "" {
return errors.New("indexed property column cannot be empty")
}
if p.Property == "" {
return errors.New("indexed property name cannot be empty")
}
if p.Type == "" {
return errors.New("indexed property type cannot be empty")
}
if !isValidSQLName(p.ColumnName) {
return fmt.Errorf("invalid indexed property column name, accepted characters are (A-Z, a-z, 0-9, _)")
}
if !isValidIndexedPropertyName(p.Property) {
return fmt.Errorf("invalid indexed property name, accepted characters are (A-Z, a-z, 0-9, _, ., [, ])")
}
if !isValidIndexedPropertyType(p.Type) {
return fmt.Errorf("invalid indexed property type, accepted characters are (A-Z, a-z, 0-9, _, (, ))")
}
}
return nil
}
// Validates and returns the key type.
func (s *SQLServer) getKeyType(metadata state.Metadata) error {
if val, ok := metadata.Properties[keyTypeKey]; ok && val != "" {
kt, err := KeyTypeFromString(val)
if err != nil {
return err
}
s.keyType = kt
} else {
s.keyType = StringKeyType
}
if s.keyType != StringKeyType {
return nil
}
if val, ok := metadata.Properties[keyLengthKey]; ok && val != "" {
var err error
s.keyLength, err = strconv.Atoi(val)
if err != nil {
return err
}
if s.keyLength <= 0 {
return fmt.Errorf("invalid key length value of %d", s.keyLength)
}
} else {
s.keyLength = defaultKeyLength
}
return nil
}
// Returns the schema name if set or the default value otherwise.
func (s *SQLServer) getSchema(metadata state.Metadata) error {
if val, ok := metadata.Properties[schemaKey]; ok && val != "" {
if !isValidSQLName(val) {
return fmt.Errorf("invalid schema name, accepted characters are (A-Z, a-z, 0-9, _)")
}
s.schema = val
} else {
s.schema = defaultSchema
}
return nil
}
// Returns the database name if set or the default value otherwise.
func (s *SQLServer) getDatabase(metadata state.Metadata) error {
if val, ok := metadata.Properties[databaseNameKey]; ok && val != "" {
if !isValidSQLName(val) {
return fmt.Errorf("invalid database name, accepted characters are (A-Z, a-z, 0-9, _)")
}
s.databaseName = val
} else {
s.databaseName = defaultDatabase
}
return nil
}
// Returns the table name if set or the default value otherwise.
func (s *SQLServer) getTable(metadata state.Metadata) error {
if val, ok := metadata.Properties[tableNameKey]; ok && val != "" {
if !isValidSQLName(val) {
return fmt.Errorf("invalid table name, accepted characters are (A-Z, a-z, 0-9, _)")
}
s.tableName = val
} else {
s.tableName = defaultTable
}
return nil
}
func (s *SQLServer) Ping() error {
return nil
}
// Features returns the features available in this state store.
func (s *SQLServer) Features() []state.Feature {
return s.features
}
// Multi performs multiple updates on a Sql server store.
func (s *SQLServer) Multi(request *state.TransactionalStateRequest) error {
keyMap := make(map[string]struct{})
var sets []state.SetRequest
var deletes []state.DeleteRequest
// The order of unique key requests does not matter in an atomic transaction.
// Only the latest operation for any unique key is selected for execution.
// The other operations are redundant, and hence ignored.
for i := len(request.Operations) - 1; i >= 0; i-- {
req := request.Operations[i]
switch req.Operation {
case state.Upsert:
setReq, err := s.getSets(req)
if err != nil {
return err
}
_, ok := keyMap[setReq.Key]
if !ok {
sets = append(sets, setReq)
keyMap[setReq.Key] = struct{}{}
}
case state.Delete:
delReq, err := s.getDeletes(req)
if err != nil {
return err
}
_, ok := keyMap[delReq.Key]
if !ok {
deletes = append(deletes, delReq)
keyMap[delReq.Key] = struct{}{}
}
default:
return fmt.Errorf("unsupported operation: %s", req.Operation)
}
}
if len(sets) > 0 || len(deletes) > 0 {
return s.executeMulti(sets, deletes)
}
return nil
}
// Returns the set requests.
func (s *SQLServer) getSets(req state.TransactionalStateOperation) (state.SetRequest, error) {
setReq, ok := req.Request.(state.SetRequest)
if !ok {
return setReq, fmt.Errorf("expecting set request")
}
if setReq.Key == "" {
return setReq, fmt.Errorf("missing key in upsert operation")
}
return setReq, nil
}
// Returns the delete requests.
func (s *SQLServer) getDeletes(req state.TransactionalStateOperation) (state.DeleteRequest, error) {
delReq, ok := req.Request.(state.DeleteRequest)
if !ok {
return delReq, fmt.Errorf("expecting delete request")
}
if delReq.Key == "" {
return delReq, fmt.Errorf("missing key in upsert operation")
}
return delReq, nil
}
func (s *SQLServer) executeMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
if len(deletes) > 0 {
switch err = s.executeBulkDelete(tx, deletes); err.(type) {
case nil, *state.BulkDeleteRowMismatchError:
default:
tx.Rollback()
return err
}
}
if len(sets) > 0 {
for i := range sets {
err = s.executeSet(tx, &sets[i])
if err != nil {
tx.Rollback()
return err
}
}
}
return tx.Commit()
}
// Delete removes an entity from the store.
func (s *SQLServer) Delete(req *state.DeleteRequest) error {
var err error
var res sql.Result
if req.ETag != nil {
var b []byte
b, err = hex.DecodeString(*req.ETag)
if err != nil {
return state.NewETagError(state.ETagInvalid, err)
}
res, err = s.db.Exec(s.deleteWithETagCommand, sql.Named(keyColumnName, req.Key), sql.Named(rowVersionColumnName, b))
} else {
res, err = s.db.Exec(s.deleteWithoutETagCommand, sql.Named(keyColumnName, req.Key))
}
// err represents errors thrown by the stored procedure or the database itself
if err != nil {
return err
}
// if the row with matching key (and ETag if specified) is not found, then the stored procedure returns 0 rows affected
rows, err := res.RowsAffected()
if err != nil {
return err
}
// When an ETAG is specified, a row must have been deleted or else we return an ETag mismatch error
if rows != 1 && req.ETag != nil && *req.ETag != "" {
return state.NewETagError(state.ETagMismatch, nil)
}
// successful deletion, or noop if no ETAG specified
return nil
}
// TvpDeleteTableStringKey defines a table type with string key.
type TvpDeleteTableStringKey struct {
ID string
RowVersion []byte
}
// BulkDelete removes multiple entries from the store.
func (s *SQLServer) BulkDelete(req []state.DeleteRequest) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
err = s.executeBulkDelete(tx, req)
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
return nil
}
func (s *SQLServer) executeBulkDelete(db dbExecutor, req []state.DeleteRequest) error {
values := make([]TvpDeleteTableStringKey, len(req))
for i, d := range req {
var etag []byte
var err error
if d.ETag != nil {
etag, err = hex.DecodeString(*d.ETag)
if err != nil {
return state.NewETagError(state.ETagInvalid, err)
}
}
values[i] = TvpDeleteTableStringKey{ID: d.Key, RowVersion: etag}
}
itemsToDelete := mssql.TVP{
TypeName: s.itemRefTableTypeName,
Value: values,
}
res, err := db.Exec(s.bulkDeleteCommand, sql.Named("itemsToDelete", itemsToDelete))
if err != nil {
return err
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if int(rows) != len(req) {
err = state.NewBulkDeleteRowMismatchError(uint64(rows), uint64(len(req)))
return err
}
return nil
}
// Get returns an entity from store.
func (s *SQLServer) Get(req *state.GetRequest) (*state.GetResponse, error) {
rows, err := s.db.Query(s.getCommand, sql.Named(keyColumnName, req.Key))
if err != nil {
return nil, err
}
if rows.Err() != nil {
return nil, rows.Err()
}
defer rows.Close()
if !rows.Next() {
return &state.GetResponse{}, nil
}
var data string
var rowVersion []byte
err = rows.Scan(&data, &rowVersion)
if err != nil {
return nil, err
}
etag := hex.EncodeToString(rowVersion)
return &state.GetResponse{
Data: []byte(data),
ETag: ptr.String(etag),
}, nil
}
// BulkGet performs a bulks get operations.
func (s *SQLServer) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
return false, nil, nil
}
// Set adds/updates an entity on store.
func (s *SQLServer) Set(req *state.SetRequest) error {
return s.executeSet(s.db, req)
}
// dbExecutor implements a common functionality implemented by db or tx.
type dbExecutor interface {
Exec(query string, args ...interface{}) (sql.Result, error)
}
func (s *SQLServer) executeSet(db dbExecutor, req *state.SetRequest) error {
var err error
var bytes []byte
bytes, err = utils.Marshal(req.Value, json.Marshal)
if err != nil {
return err
}
etag := sql.Named(rowVersionColumnName, nil)
if req.ETag != nil && *req.ETag != "" {
var b []byte
b, err = hex.DecodeString(*req.ETag)
if err != nil {
return state.NewETagError(state.ETagInvalid, err)
}
etag = sql.Named(rowVersionColumnName, b)
}
var res sql.Result
if req.Options.Concurrency == state.FirstWrite {
res, err = db.Exec(s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 1))
} else {
res, err = db.Exec(s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 0))
}
if err != nil {
if req.ETag != nil && *req.ETag != "" {
return state.NewETagError(state.ETagMismatch, err)
}
return err
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return fmt.Errorf("no item was updated")
}
return nil
}
// BulkSet adds/updates multiple entities on store.
func (s *SQLServer) BulkSet(req []state.SetRequest) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
for i := range req {
err = s.executeSet(tx, &req[i])
if err != nil {
tx.Rollback()
return err
}
}
err = tx.Commit()
return err
}