components-contrib/state/sqlserver/sqlserver.go

540 lines
12 KiB
Go

// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation and Dapr Contributors.
// Licensed under the MIT License.
// ------------------------------------------------------------
package sqlserver
import (
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strconv"
"unicode"
"github.com/agrea/ptr"
mssql "github.com/denisenkom/go-mssqldb"
"github.com/dapr/components-contrib/state"
"github.com/dapr/components-contrib/state/utils"
"github.com/dapr/kit/logger"
)
// KeyType defines type of the table identifier
type KeyType string
// KeyTypeFromString tries to create a KeyType from a string value
func KeyTypeFromString(k string) (KeyType, error) {
switch k {
case string(StringKeyType):
return StringKeyType, nil
case string(UUIDKeyType):
return UUIDKeyType, nil
case string(IntegerKeyType):
return IntegerKeyType, nil
}
return InvalidKeyType, errors.New("invalid key type")
}
const (
// StringKeyType defines a key of type string
StringKeyType KeyType = "string"
// UUIDKeyType defines a key of type UUID/GUID
UUIDKeyType KeyType = "uuid"
// IntegerKeyType defines a key of type integer
IntegerKeyType KeyType = "integer"
// InvalidKeyType defines an invalid key type
InvalidKeyType KeyType = "invalid"
)
const (
connectionStringKey = "connectionString"
tableNameKey = "tableName"
schemaKey = "schema"
keyTypeKey = "keyType"
keyLengthKey = "keyLength"
indexedPropertiesKey = "indexedProperties"
keyColumnName = "Key"
rowVersionColumnName = "RowVersion"
defaultKeyLength = 200
defaultSchema = "dbo"
)
// NewSQLServerStateStore creates a new instance of a Sql Server transaction store
func NewSQLServerStateStore(logger logger.Logger) *SQLServer {
store := SQLServer{
features: []state.Feature{state.FeatureETag, state.FeatureTransactional},
logger: logger,
}
store.migratorFactory = newMigration
return &store
}
// IndexedProperty defines a indexed property
type IndexedProperty struct {
ColumnName string `json:"column"`
Property string `json:"property"`
Type string `json:"type"`
}
// SQLServer defines a Ms SQL Server based state store
type SQLServer struct {
connectionString string
tableName string
schema string
keyType KeyType
keyLength int
indexedProperties []IndexedProperty
migratorFactory func(*SQLServer) migrator
bulkDeleteCommand string
itemRefTableTypeName string
upsertCommand string
getCommand string
deleteWithETagCommand string
deleteWithoutETagCommand string
features []state.Feature
logger logger.Logger
db *sql.DB
}
func isLetterOrNumber(c rune) bool {
return unicode.IsNumber(c) || unicode.IsLetter(c)
}
func isValidSQLName(s string) bool {
for _, c := range s {
if !(isLetterOrNumber(c) || (c == '_')) {
return false
}
}
return true
}
func isValidIndexedPropertyName(s string) bool {
for _, c := range s {
if !(isLetterOrNumber(c) || (c == '_') || (c == '.') || (c == '[') || (c == ']')) {
return false
}
}
return true
}
func isValidIndexedPropertyType(s string) bool {
for _, c := range s {
if !(isLetterOrNumber(c) || (c == '(') || (c == ')')) {
return false
}
}
return true
}
// Init initializes the SQL server state store
func (s *SQLServer) Init(metadata state.Metadata) error {
if val, ok := metadata.Properties[connectionStringKey]; ok && val != "" {
s.connectionString = val
} else {
return fmt.Errorf("missing connection string")
}
if val, ok := metadata.Properties[tableNameKey]; ok && val != "" {
if !isValidSQLName(val) {
return fmt.Errorf("invalid table name, accepted characters are (A-Z, a-z, 0-9, _)")
}
s.tableName = val
} else {
return fmt.Errorf("missing table name")
}
if val, ok := metadata.Properties[keyTypeKey]; ok && val != "" {
kt, err := KeyTypeFromString(val)
if err != nil {
return err
}
s.keyType = kt
} else {
s.keyType = StringKeyType
}
//nolint:nestif
if s.keyType == StringKeyType {
if val, ok := metadata.Properties[keyLengthKey]; ok && val != "" {
var err error
s.keyLength, err = strconv.Atoi(val)
if err != nil {
return err
}
if s.keyLength <= 0 {
return fmt.Errorf("invalid key length value of %d", s.keyLength)
}
} else {
s.keyLength = defaultKeyLength
}
}
if val, ok := metadata.Properties[schemaKey]; ok && val != "" {
if !isValidSQLName(val) {
return fmt.Errorf("invalid schema name, accepted characters are (A-Z, a-z, 0-9, _)")
}
s.schema = val
} else {
s.schema = defaultSchema
}
//nolint:nestif
if val, ok := metadata.Properties[indexedPropertiesKey]; ok && val != "" {
var indexedProperties []IndexedProperty
err := json.Unmarshal([]byte(val), &indexedProperties)
if err != nil {
return err
}
for _, p := range indexedProperties {
if p.ColumnName == "" {
return errors.New("indexed property column cannot be empty")
}
if p.Property == "" {
return errors.New("indexed property name cannot be empty")
}
if p.Type == "" {
return errors.New("indexed property type cannot be empty")
}
if !isValidSQLName(p.ColumnName) {
return fmt.Errorf("invalid indexed property column name, accepted characters are (A-Z, a-z, 0-9, _)")
}
if !isValidIndexedPropertyName(p.Property) {
return fmt.Errorf("invalid indexed property name, accepted characters are (A-Z, a-z, 0-9, _, ., [, ])")
}
if !isValidIndexedPropertyType(p.Type) {
return fmt.Errorf("invalid indexed property type, accepted characters are (A-Z, a-z, 0-9, _, (, ))")
}
}
s.indexedProperties = indexedProperties
}
migration := s.migratorFactory(s)
mr, err := migration.executeMigrations()
if err != nil {
return err
}
s.itemRefTableTypeName = mr.itemRefTableTypeName
s.bulkDeleteCommand = fmt.Sprintf("exec %s @itemsToDelete;", mr.bulkDeleteProcFullName)
s.upsertCommand = mr.upsertProcFullName
s.getCommand = mr.getCommand
s.deleteWithETagCommand = mr.deleteWithETagCommand
s.deleteWithoutETagCommand = mr.deleteWithoutETagCommand
s.db, err = sql.Open("sqlserver", s.connectionString)
if err != nil {
return err
}
return nil
}
func (s *SQLServer) Ping() error {
return nil
}
// Features returns the features available in this state store
func (s *SQLServer) Features() []state.Feature {
return s.features
}
// Multi performs multiple updates on a Sql server store
func (s *SQLServer) Multi(request *state.TransactionalStateRequest) error {
var deletes []state.DeleteRequest
var sets []state.SetRequest
for _, req := range request.Operations {
switch req.Operation {
case state.Upsert:
setReq, ok := req.Request.(state.SetRequest)
if !ok {
return fmt.Errorf("expecting set request")
}
if setReq.Key == "" {
return fmt.Errorf("missing key in upsert operation")
}
sets = append(sets, setReq)
case state.Delete:
delReq, ok := req.Request.(state.DeleteRequest)
if !ok {
return fmt.Errorf("expecting delete request")
}
if delReq.Key == "" {
return fmt.Errorf("missing key in upsert operation")
}
deletes = append(deletes, delReq)
default:
return fmt.Errorf("unsupported operation: %s", req.Operation)
}
}
if len(sets) > 0 || len(deletes) > 0 {
return s.executeMulti(sets, deletes)
}
return nil
}
func (s *SQLServer) executeMulti(sets []state.SetRequest, deletes []state.DeleteRequest) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
if len(deletes) > 0 {
err = s.executeBulkDelete(tx, deletes)
if err != nil {
tx.Rollback()
return err
}
}
if len(sets) > 0 {
for i := range sets {
err = s.executeSet(tx, &sets[i])
if err != nil {
tx.Rollback()
return err
}
}
}
return tx.Commit()
}
// Delete removes an entity from the store
func (s *SQLServer) Delete(req *state.DeleteRequest) error {
var err error
var res sql.Result
if req.ETag != nil {
var b []byte
b, err = hex.DecodeString(*req.ETag)
if err != nil {
return state.NewETagError(state.ETagInvalid, err)
}
res, err = s.db.Exec(s.deleteWithETagCommand, sql.Named(keyColumnName, req.Key), sql.Named(rowVersionColumnName, b))
} else {
res, err = s.db.Exec(s.deleteWithoutETagCommand, sql.Named(keyColumnName, req.Key))
}
if err != nil {
if req.ETag != nil {
return state.NewETagError(state.ETagMismatch, err)
}
return err
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return fmt.Errorf("items was not updated")
}
return nil
}
// TvpDeleteTableStringKey defines a table type with string key
type TvpDeleteTableStringKey struct {
ID string
RowVersion []byte
}
// BulkDelete removes multiple entries from the store
func (s *SQLServer) BulkDelete(req []state.DeleteRequest) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
err = s.executeBulkDelete(tx, req)
if err != nil {
tx.Rollback()
return err
}
tx.Commit()
return nil
}
func (s *SQLServer) executeBulkDelete(db dbExecutor, req []state.DeleteRequest) error {
values := make([]TvpDeleteTableStringKey, len(req))
for i, d := range req {
var etag []byte
var err error
if d.ETag != nil {
etag, err = hex.DecodeString(*d.ETag)
if err != nil {
return state.NewETagError(state.ETagInvalid, err)
}
}
values[i] = TvpDeleteTableStringKey{ID: d.Key, RowVersion: etag}
}
itemsToDelete := mssql.TVP{
TypeName: s.itemRefTableTypeName,
Value: values,
}
res, err := db.Exec(s.bulkDeleteCommand, sql.Named("itemsToDelete", itemsToDelete))
if err != nil {
return err
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if int(rows) != len(req) {
err = fmt.Errorf("delete affected only %d rows, expected %d", rows, len(req))
return err
}
return nil
}
// Get returns an entity from store
func (s *SQLServer) Get(req *state.GetRequest) (*state.GetResponse, error) {
rows, err := s.db.Query(s.getCommand, sql.Named(keyColumnName, req.Key))
if err != nil {
return nil, err
}
if rows.Err() != nil {
return nil, rows.Err()
}
defer rows.Close()
if !rows.Next() {
return &state.GetResponse{}, nil
}
var data string
var rowVersion []byte
err = rows.Scan(&data, &rowVersion)
if err != nil {
return nil, err
}
etag := hex.EncodeToString(rowVersion)
return &state.GetResponse{
Data: []byte(data),
ETag: ptr.String(etag),
}, nil
}
// BulkGet performs a bulks get operations
func (s *SQLServer) BulkGet(req []state.GetRequest) (bool, []state.BulkGetResponse, error) {
return false, nil, nil
}
// Set adds/updates an entity on store
func (s *SQLServer) Set(req *state.SetRequest) error {
return s.executeSet(s.db, req)
}
// dbExecutor implements a common functionality implemented by db or tx
type dbExecutor interface {
Exec(query string, args ...interface{}) (sql.Result, error)
}
func (s *SQLServer) executeSet(db dbExecutor, req *state.SetRequest) error {
var err error
var bytes []byte
bytes, err = utils.Marshal(req.Value, json.Marshal)
if err != nil {
return err
}
etag := sql.Named(rowVersionColumnName, nil)
if req.ETag != nil {
var b []byte
b, err = hex.DecodeString(*req.ETag)
if err != nil {
return state.NewETagError(state.ETagInvalid, err)
}
etag.Value = b
}
res, err := db.Exec(s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag)
if err != nil {
if req.ETag != nil && *req.ETag != "" {
return state.NewETagError(state.ETagMismatch, err)
}
return err
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows != 1 {
return fmt.Errorf("no item was updated")
}
return nil
}
// BulkSet adds/updates multiple entities on store
func (s *SQLServer) BulkSet(req []state.SetRequest) error {
tx, err := s.db.Begin()
if err != nil {
return err
}
for i := range req {
err = s.executeSet(tx, &req[i])
if err != nil {
tx.Rollback()
return err
}
}
err = tx.Commit()
return err
}