From ff10bea43f6b7472ec1c3a018f3928b0f18106ce Mon Sep 17 00:00:00 2001 From: Javier Aliaga Date: Fri, 1 Aug 2025 18:14:50 +0200 Subject: [PATCH] chore: Use base64 to store sqlserver state data (#3919) Signed-off-by: Javier Aliaga Co-authored-by: Yaron Schneider Co-authored-by: Cassie Coyle --- common/proto/state/sqlserver/test.pb.go | 162 ++++++++++++++++++ common/proto/state/sqlserver/test.proto | 10 ++ state/sqlserver/sqlserver.go | 34 +++- state/sqlserver/sqlserver_integration_test.go | 30 +++- 4 files changed, 225 insertions(+), 11 deletions(-) create mode 100644 common/proto/state/sqlserver/test.pb.go create mode 100644 common/proto/state/sqlserver/test.proto diff --git a/common/proto/state/sqlserver/test.pb.go b/common/proto/state/sqlserver/test.pb.go new file mode 100644 index 000000000..1c936adb5 --- /dev/null +++ b/common/proto/state/sqlserver/test.pb.go @@ -0,0 +1,162 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.32.0 +// protoc v4.25.4 +// source: test.proto + +package sqlserver + +import ( + "reflect" + "sync" + + "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/runtime/protoimpl" + "google.golang.org/protobuf/types/known/timestamppb" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type TestEvent struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + EventId int32 `protobuf:"varint,1,opt,name=eventId,proto3" json:"eventId,omitempty"` + Timestamp *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"` +} + +func (x *TestEvent) Reset() { + *x = TestEvent{} + if protoimpl.UnsafeEnabled { + mi := &file_test_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TestEvent) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TestEvent) ProtoMessage() {} + +func (x *TestEvent) ProtoReflect() protoreflect.Message { + mi := &file_test_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TestEvent.ProtoReflect.Descriptor instead. +func (*TestEvent) Descriptor() ([]byte, []int) { + return file_test_proto_rawDescGZIP(), []int{0} +} + +func (x *TestEvent) GetEventId() int32 { + if x != nil { + return x.EventId + } + return 0 +} + +func (x *TestEvent) GetTimestamp() *timestamppb.Timestamp { + if x != nil { + return x.Timestamp + } + return nil +} + +var File_test_proto protoreflect.FileDescriptor + +var file_test_proto_rawDesc = []byte{ + 0x0a, 0x0a, 0x74, 0x65, 0x73, 0x74, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x1f, 0x67, 0x6f, + 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, + 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x5f, 0x0a, + 0x09, 0x54, 0x65, 0x73, 0x74, 0x45, 0x76, 0x65, 0x6e, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x65, 0x76, + 0x65, 0x6e, 0x74, 0x49, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x07, 0x65, 0x76, 0x65, + 0x6e, 0x74, 0x49, 0x64, 0x12, 0x38, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, + 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, + 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, + 0x61, 0x6d, 0x70, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x42, 0x41, + 0x5a, 0x3f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x64, 0x61, 0x70, + 0x72, 0x2f, 0x63, 0x6f, 0x6d, 0x70, 0x6f, 0x6e, 0x65, 0x6e, 0x74, 0x73, 0x2d, 0x63, 0x6f, 0x6e, + 0x74, 0x72, 0x69, 0x62, 0x2f, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x2f, 0x73, 0x74, 0x61, 0x74, 0x65, 0x2f, 0x73, 0x71, 0x6c, 0x73, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_test_proto_rawDescOnce sync.Once + file_test_proto_rawDescData = file_test_proto_rawDesc +) + +func file_test_proto_rawDescGZIP() []byte { + file_test_proto_rawDescOnce.Do(func() { + file_test_proto_rawDescData = protoimpl.X.CompressGZIP(file_test_proto_rawDescData) + }) + return file_test_proto_rawDescData +} + +var file_test_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_test_proto_goTypes = []interface{}{ + (*TestEvent)(nil), // 0: TestEvent + (*timestamppb.Timestamp)(nil), // 1: google.protobuf.Timestamp +} +var file_test_proto_depIdxs = []int32{ + 1, // 0: TestEvent.timestamp:type_name -> google.protobuf.Timestamp + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_test_proto_init() } +func file_test_proto_init() { + if File_test_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_test_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TestEvent); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_test_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_test_proto_goTypes, + DependencyIndexes: file_test_proto_depIdxs, + MessageInfos: file_test_proto_msgTypes, + }.Build() + File_test_proto = out.File + file_test_proto_rawDesc = nil + file_test_proto_goTypes = nil + file_test_proto_depIdxs = nil +} diff --git a/common/proto/state/sqlserver/test.proto b/common/proto/state/sqlserver/test.proto new file mode 100644 index 000000000..7c8480259 --- /dev/null +++ b/common/proto/state/sqlserver/test.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +option go_package = "github.com/dapr/components-contrib/common/proto/state/sqlserver"; + +import "google/protobuf/timestamp.proto"; + +message TestEvent { + int32 eventId = 1; + google.protobuf.Timestamp timestamp = 2; +} \ No newline at end of file diff --git a/state/sqlserver/sqlserver.go b/state/sqlserver/sqlserver.go index 35e2f7ed2..f607d47af 100644 --- a/state/sqlserver/sqlserver.go +++ b/state/sqlserver/sqlserver.go @@ -16,6 +16,7 @@ package sqlserver import ( "context" "database/sql" + "encoding/base64" "encoding/hex" "encoding/json" "errors" @@ -287,8 +288,15 @@ func (s *SQLServer) Get(ctx context.Context, req *state.GetRequest) (*state.GetR } } + bytes, err := base64.StdEncoding.DecodeString(data) + if err != nil { + s.logger. + WithFields(map[string]any{"error": err}). + Debug("error decoding base64 data. Fallback to []byte") + bytes = []byte(data) + } return &state.GetResponse{ - Data: []byte(data), + Data: bytes, ETag: ptr.Of(etag), Metadata: metadata, }, nil @@ -305,16 +313,23 @@ type dbExecutor interface { } func (s *SQLServer) executeSet(ctx context.Context, db dbExecutor, req *state.SetRequest) error { - var err error - var bytes []byte - bytes, err = utils.Marshal(req.Value, json.Marshal) - if err != nil { - return err + var reqValue string + + bytes, ok := req.Value.([]byte) + if !ok { + bt, err := json.Marshal(req.Value) + if err != nil { + return err + } + reqValue = string(bt) + } else { + reqValue = base64.StdEncoding.EncodeToString(bytes) } + etag := sql.Named(rowVersionColumnName, nil) if req.HasETag() { var b []byte - b, err = hex.DecodeString(*req.ETag) + b, err := hex.DecodeString(*req.ETag) if err != nil { return state.NewETagError(state.ETagInvalid, err) } @@ -327,13 +342,14 @@ func (s *SQLServer) executeSet(ctx context.Context, db dbExecutor, req *state.Se } var res sql.Result + var err error if req.Options.Concurrency == state.FirstWrite { res, err = db.ExecContext(ctx, s.upsertCommand, sql.Named(keyColumnName, req.Key), - sql.Named("Data", string(bytes)), etag, + sql.Named("Data", reqValue), etag, sql.Named("FirstWrite", 1), sql.Named("TTL", ttl)) } else { res, err = db.ExecContext(ctx, s.upsertCommand, sql.Named(keyColumnName, req.Key), - sql.Named("Data", string(bytes)), etag, + sql.Named("Data", reqValue), etag, sql.Named("FirstWrite", 0), sql.Named("TTL", ttl)) } diff --git a/state/sqlserver/sqlserver_integration_test.go b/state/sqlserver/sqlserver_integration_test.go index 7af622047..e317128a4 100644 --- a/state/sqlserver/sqlserver_integration_test.go +++ b/state/sqlserver/sqlserver_integration_test.go @@ -30,10 +30,12 @@ import ( "testing" "time" - uuid "github.com/google/uuid" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "github.com/dapr/components-contrib/common/proto/state/sqlserver" "github.com/dapr/components-contrib/metadata" "github.com/dapr/components-contrib/state" "github.com/dapr/kit/logger" @@ -42,7 +44,7 @@ import ( const ( // connectionStringEnvKey defines the key containing the integration test connection string // To use docker, server=localhost;user id=sa;password=Pass@Word1;port=1433; - // To use Azure SQL, server=.database.windows.net;user id=;port=1433;password=;database=dapr_test;. + // To use Azure SQL, server=.database.windows.net;User id=;port=1433;password=;database=dapr_test;. connectionStringEnvKey = "DAPR_TEST_SQL_CONNSTRING" usersTableName = "Users" beverageTea = "tea" @@ -77,6 +79,7 @@ func TestIntegrationCases(t *testing.T) { t.Run("Multi operations", testMultiOperations) t.Run("Insert and Update Set Record Dates", testInsertAndUpdateSetRecordDates) t.Run("Multiple initializations", testMultipleInitializations) + t.Run("Should preserve byte data when not base64 encoded", testNonBase64ByteData) // Run concurrent set tests 10 times const executions = 10 @@ -112,6 +115,9 @@ func createMetadata(schema string, kt KeyType, indexedProperties string) state.M // Ensure the database is running // For docker, use: docker run --name sqlserver -e "ACCEPT_EULA=Y" -e "SA_PASSWORD=Pass@Word1" -p 1433:1433 -d mcr.microsoft.com/mssql/server:2019-GA-ubuntu-16.04. +// For azure-sql-edge use: +// docker volume create sqlvolume +// docker run --name sqlserver -e "ACCEPT_EULA=Y" -e "MSSQL_SA_PASSWORD=Pass@Word1" -e "MSSQL_PID=Developer" -e "MSSQL_AGENT_ENABLED=TRUE" -e "MSSQL_COLLATION=SQL_Latin1_General_CP1_CI_AS" -e "MSSQL_LCID=1033" -p 1433:1433 -v sqlvolume:/var/opt/mssql -d mcr.microsoft.com/azure-sql-edge:latest func getTestStore(t *testing.T, indexedProperties string) *SQLServer { return getTestStoreWithKeyType(t, StringKeyType, indexedProperties) } @@ -597,3 +603,23 @@ func testMultipleInitializations(t *testing.T) { }) } } + +func testNonBase64ByteData(t *testing.T) { + t.Run("Set And Get", func(t *testing.T) { + store := getTestStore(t, "") + request := &sqlserver.TestEvent{ + EventId: -1, + } + requestBytes, err := proto.Marshal(request) + require.NoError(t, err) + require.NoError(t, store.Set(t.Context(), &state.SetRequest{Key: "1", Value: requestBytes})) + resp, err := store.Get(t.Context(), &state.GetRequest{Key: "1"}) + require.NoError(t, err) + + response := &sqlserver.TestEvent{} + err = proto.Unmarshal(resp.Data, response) + require.NoError(t, err) + + assert.EqualValues(t, request.GetEventId(), response.GetEventId()) + }) +}