/* Copyright 2023 The Dapr Authors Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package sqlserver import ( "context" "database/sql" "encoding/hex" "encoding/json" "errors" "fmt" "reflect" "time" commonsql "github.com/dapr/components-contrib/common/component/sql" sqltransactions "github.com/dapr/components-contrib/common/component/sql/transactions" "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/state" "github.com/dapr/components-contrib/state/utils" "github.com/dapr/kit/logger" "github.com/dapr/kit/ptr" ) // KeyType defines type of the table identifier. type KeyType string // KeyTypeFromString tries to create a KeyType from a string value. func KeyTypeFromString(k string) (KeyType, error) { switch k { case string(StringKeyType): return StringKeyType, nil case string(UUIDKeyType): return UUIDKeyType, nil case string(IntegerKeyType): return IntegerKeyType, nil } return InvalidKeyType, errors.New("invalid key type") } const ( // StringKeyType defines a key of type string. StringKeyType KeyType = "string" // UUIDKeyType defines a key of type UUID/GUID. UUIDKeyType KeyType = "uuid" // IntegerKeyType defines a key of type integer. IntegerKeyType KeyType = "integer" // InvalidKeyType defines an invalid key type. InvalidKeyType KeyType = "invalid" ) // New creates a new instance of a SQL Server transaction store. func New(logger logger.Logger) state.Store { s := &SQLServer{ features: []state.Feature{ state.FeatureETag, state.FeatureTransactional, state.FeatureTTL, }, logger: logger, migratorFactory: newMigration, } s.BulkStore = state.NewDefaultBulkStore(s) return s } // IndexedProperty defines a indexed property. type IndexedProperty struct { ColumnName string `json:"column"` Property string `json:"property"` Type string `json:"type"` } // SQLServer defines a MS SQL Server based state store. type SQLServer struct { state.BulkStore metadata sqlServerMetadata migratorFactory func(*sqlServerMetadata) migrator itemRefTableTypeName string upsertCommand string getCommand string deleteWithETagCommand string deleteWithoutETagCommand string features []state.Feature logger logger.Logger db *sql.DB gc commonsql.GarbageCollector } // Init initializes the SQL server state store. func (s *SQLServer) Init(ctx context.Context, metadata state.Metadata) error { s.metadata = newMetadata() err := s.metadata.Parse(metadata.Properties) if err != nil { return err } migration := s.migratorFactory(&s.metadata) mr, err := migration.executeMigrations(ctx) if err != nil { return err } s.itemRefTableTypeName = mr.itemRefTableTypeName s.upsertCommand = mr.upsertProcFullName s.getCommand = mr.getCommand s.deleteWithETagCommand = mr.deleteWithETagCommand s.deleteWithoutETagCommand = mr.deleteWithoutETagCommand conn, _, err := s.metadata.GetConnector(true) if err != nil { return err } s.db = sql.OpenDB(conn) if s.metadata.CleanupInterval != nil { err = s.startGC() if err != nil { return err } } return nil } func (s *SQLServer) startGC() error { gc, err := commonsql.ScheduleGarbageCollector(commonsql.GCOptions{ Logger: s.logger, UpdateLastCleanupQuery: func(arg any) (string, any) { return fmt.Sprintf(`BEGIN TRANSACTION; BEGIN TRY INSERT INTO [%[1]s].[%[2]s] ([Key], [Value]) VALUES ('last-cleanup', CONVERT(nvarchar(MAX), GETDATE(), 21)); END TRY BEGIN CATCH UPDATE [%[1]s].[%[2]s] SET [Value] = CONVERT(nvarchar(MAX), GETDATE(), 21) WHERE [Key] = 'last-cleanup' AND Datediff_big(MS, [Value], GETUTCDATE()) > @Interval END CATCH COMMIT TRANSACTION;`, s.metadata.SchemaName, s.metadata.MetadataTableName), sql.Named("Interval", arg) }, DeleteExpiredValuesQuery: fmt.Sprintf( `DELETE FROM [%s].[%s] WHERE [ExpireDate] IS NOT NULL AND [ExpireDate] < GETDATE()`, s.metadata.SchemaName, s.metadata.TableName, ), CleanupInterval: *s.metadata.CleanupInterval, DB: commonsql.AdaptDatabaseSQLConn(s.db), }) if err != nil { return err } s.gc = gc return nil } // Features returns the features available in this state store. func (s *SQLServer) Features() []state.Feature { return s.features } // Multi performs batched updates on a SQL Server store. func (s *SQLServer) Multi(ctx context.Context, request *state.TransactionalStateRequest) error { if request == nil { return nil } // If there's only 1 operation, skip starting a transaction switch len(request.Operations) { case 0: return nil case 1: return s.execMultiOperation(ctx, request.Operations[0], s.db) default: _, err := sqltransactions.ExecuteInTransaction(ctx, s.logger, s.db, func(ctx context.Context, tx *sql.Tx) (r struct{}, err error) { for _, op := range request.Operations { err = s.execMultiOperation(ctx, op, tx) if err != nil { return r, err } } return r, nil }) return err } } func (s *SQLServer) execMultiOperation(ctx context.Context, op state.TransactionalStateOperation, db dbExecutor) error { switch req := op.(type) { case state.SetRequest: return s.executeSet(ctx, db, &req) case state.DeleteRequest: return s.executeDelete(ctx, db, &req) default: return fmt.Errorf("unsupported operation: %s", op.Operation()) } } // Delete removes an entity from the store. func (s *SQLServer) Delete(ctx context.Context, req *state.DeleteRequest) error { return s.executeDelete(ctx, s.db, req) } func (s *SQLServer) executeDelete(ctx context.Context, db dbExecutor, req *state.DeleteRequest) error { var err error var res sql.Result if req.HasETag() { var b []byte b, err = hex.DecodeString(*req.ETag) if err != nil { return state.NewETagError(state.ETagInvalid, err) } res, err = db.ExecContext(ctx, s.deleteWithETagCommand, sql.Named(keyColumnName, req.Key), sql.Named(rowVersionColumnName, b)) } else { res, err = db.ExecContext(ctx, s.deleteWithoutETagCommand, sql.Named(keyColumnName, req.Key)) } // err represents errors thrown by the stored procedure or the database itself if err != nil { return err } // if the row with matching key (and ETag if specified) is not found, then the stored procedure returns 0 rows affected rows, err := res.RowsAffected() if err != nil { return err } // When an ETAG is specified, a row must have been deleted or else we return an ETag mismatch error if rows != 1 && req.ETag != nil && *req.ETag != "" { return state.NewETagError(state.ETagMismatch, nil) } // successful deletion, or noop if no ETAG specified return nil } // Get returns an entity from store. func (s *SQLServer) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) { rows, err := s.db.QueryContext(ctx, s.getCommand, sql.Named(keyColumnName, req.Key)) if err != nil { return nil, err } if rows.Err() != nil { return nil, rows.Err() } defer rows.Close() if !rows.Next() { return &state.GetResponse{}, nil } var ( data string rowVersion []byte expireDate sql.NullTime ) err = rows.Scan(&data, &rowVersion, &expireDate) if err != nil { return nil, err } etag := hex.EncodeToString(rowVersion) var metadata map[string]string if expireDate.Valid { metadata = map[string]string{ state.GetRespMetaKeyTTLExpireTime: expireDate.Time.UTC().Format(time.RFC3339), } } return &state.GetResponse{ Data: []byte(data), ETag: ptr.Of(etag), Metadata: metadata, }, nil } // Set adds/updates an entity on store. func (s *SQLServer) Set(ctx context.Context, req *state.SetRequest) error { return s.executeSet(ctx, s.db, req) } // dbExecutor implements a common functionality implemented by db or tx. type dbExecutor interface { ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) } func (s *SQLServer) executeSet(ctx context.Context, db dbExecutor, req *state.SetRequest) error { var err error var bytes []byte bytes, err = utils.Marshal(req.Value, json.Marshal) if err != nil { return err } etag := sql.Named(rowVersionColumnName, nil) if req.HasETag() { var b []byte b, err = hex.DecodeString(*req.ETag) if err != nil { return state.NewETagError(state.ETagInvalid, err) } etag = sql.Named(rowVersionColumnName, b) } ttl, ttlerr := utils.ParseTTL(req.Metadata) if ttlerr != nil { return fmt.Errorf("error parsing TTL: %w", ttlerr) } var res sql.Result if req.Options.Concurrency == state.FirstWrite { res, err = db.ExecContext(ctx, s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 1), sql.Named("TTL", ttl)) } else { res, err = db.ExecContext(ctx, s.upsertCommand, sql.Named(keyColumnName, req.Key), sql.Named("Data", string(bytes)), etag, sql.Named("FirstWrite", 0), sql.Named("TTL", ttl)) } if err != nil { return err } rows, err := res.RowsAffected() if err != nil { return err } if rows != 1 { if req.HasETag() { return state.NewETagError(state.ETagMismatch, err) } return errors.New("no item was updated") } return nil } func (s *SQLServer) GetComponentMetadata() (metadataInfo metadata.MetadataMap) { settingsStruct := sqlServerMetadata{} metadata.GetMetadataInfoFromStructType(reflect.TypeOf(settingsStruct), &metadataInfo, metadata.StateStoreType) return } // Close implements io.Closer. func (s *SQLServer) Close() error { if s.db != nil { s.db.Close() s.db = nil } if s.gc != nil { return s.gc.Close() } return nil } // GetCleanupInterval returns the cleanupInterval property. // This is primarily used for tests. func (s *SQLServer) GetCleanupInterval() *time.Duration { return s.metadata.CleanupInterval } func (s *SQLServer) CleanupExpired() error { if s.gc != nil { return s.gc.CleanupExpired() } return nil }