356 lines
10 KiB
Go
356 lines
10 KiB
Go
/*
|
|
Copyright 2021 The Dapr Authors
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package oracledatabase
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/url"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"github.com/dapr/components-contrib/state"
|
|
stateutils "github.com/dapr/components-contrib/state/utils"
|
|
"github.com/dapr/kit/logger"
|
|
"github.com/dapr/kit/metadata"
|
|
|
|
// Blank import for the underlying Oracle Database driver.
|
|
_ "github.com/sijms/go-ora/v2"
|
|
)
|
|
|
|
const (
|
|
connectionStringKey = "connectionString"
|
|
oracleWalletLocationKey = "oracleWalletLocation"
|
|
errMissingConnectionString = "missing connection string"
|
|
defaultTableName = "state"
|
|
)
|
|
|
|
// oracleDatabaseAccess implements dbaccess.
|
|
type oracleDatabaseAccess struct {
|
|
logger logger.Logger
|
|
metadata oracleDatabaseMetadata
|
|
db *sql.DB
|
|
connectionString string
|
|
}
|
|
|
|
type oracleDatabaseMetadata struct {
|
|
ConnectionString string
|
|
OracleWalletLocation string
|
|
TableName string
|
|
}
|
|
|
|
// newOracleDatabaseAccess creates a new instance of oracleDatabaseAccess.
|
|
func newOracleDatabaseAccess(logger logger.Logger) *oracleDatabaseAccess {
|
|
return &oracleDatabaseAccess{
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
func (o *oracleDatabaseAccess) Ping(ctx context.Context) error {
|
|
return o.db.PingContext(ctx)
|
|
}
|
|
|
|
func parseMetadata(meta map[string]string) (oracleDatabaseMetadata, error) {
|
|
m := oracleDatabaseMetadata{
|
|
TableName: defaultTableName,
|
|
}
|
|
err := metadata.DecodeMetadata(meta, &m)
|
|
return m, err
|
|
}
|
|
|
|
// Init sets up OracleDatabase connection and ensures that the state table exists.
|
|
func (o *oracleDatabaseAccess) Init(ctx context.Context, metadata state.Metadata) error {
|
|
meta, err := parseMetadata(metadata.Properties)
|
|
o.metadata = meta
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if o.metadata.ConnectionString != "" {
|
|
o.connectionString = meta.ConnectionString
|
|
} else {
|
|
o.logger.Error("Missing Oracle Database connection string")
|
|
return errors.New(errMissingConnectionString)
|
|
}
|
|
if o.metadata.OracleWalletLocation != "" {
|
|
o.connectionString += "?TRACE FILE=trace.log&SSL=enable&SSL Verify=false&WALLET=" + url.QueryEscape(o.metadata.OracleWalletLocation)
|
|
}
|
|
db, err := sql.Open("oracle", o.connectionString)
|
|
if err != nil {
|
|
o.logger.Error(err)
|
|
|
|
return err
|
|
}
|
|
|
|
o.db = db
|
|
|
|
err = db.PingContext(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
err = o.ensureStateTable(o.metadata.TableName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Set makes an insert or update to the database.
|
|
func (o *oracleDatabaseAccess) Set(ctx context.Context, req *state.SetRequest) error {
|
|
return o.doSet(ctx, o.db, req)
|
|
}
|
|
|
|
func (o *oracleDatabaseAccess) doSet(ctx context.Context, db querier, req *state.SetRequest) error {
|
|
err := state.CheckRequestOptions(req.Options)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if req.Key == "" {
|
|
return errors.New("missing key in set operation")
|
|
}
|
|
|
|
requestValue := req.Value
|
|
byteArray, isBinary := req.Value.([]uint8)
|
|
binaryYN := "N"
|
|
if isBinary {
|
|
requestValue = base64.StdEncoding.EncodeToString(byteArray)
|
|
binaryYN = "Y"
|
|
}
|
|
|
|
// Convert to json string
|
|
bt, _ := stateutils.Marshal(requestValue, json.Marshal)
|
|
value := string(bt)
|
|
|
|
// TTL
|
|
var ttlSeconds int
|
|
ttl, err := stateutils.ParseTTL(req.Metadata)
|
|
if err != nil {
|
|
return fmt.Errorf("error parsing TTL: %w", err)
|
|
}
|
|
if ttl != nil {
|
|
ttlSeconds = *ttl
|
|
}
|
|
ttlStatement := "NULL"
|
|
if ttlSeconds > 0 {
|
|
// We're passing ttlStatements via string concatenation - no risk of SQL injection because we control the value and make sure it's a number
|
|
ttlStatement = "systimestamp + numtodsinterval(" + strconv.Itoa(ttlSeconds) + ", 'SECOND')"
|
|
}
|
|
|
|
// Generate new etag
|
|
etagObj, err := uuid.NewRandom()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate UUID: %w", err)
|
|
}
|
|
etag := etagObj.String()
|
|
|
|
var result sql.Result
|
|
if !req.HasETag() {
|
|
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
|
|
// Other parameters use sql.DB parameter substitution.
|
|
var stmt string
|
|
if req.Options.Concurrency == state.FirstWrite {
|
|
stmt = `INSERT INTO ` + o.metadata.TableName + `
|
|
(key, value, binary_yn, etag, expiration_time)
|
|
VALUES
|
|
(:key, :value, :binary_yn, :etag, ` + ttlStatement + `) `
|
|
} else {
|
|
// As per Discord Thread https://discord.com/channels/778680217417809931/901141713089863710/938520959562952735 expiration time is reset in case of an update.
|
|
stmt = `MERGE INTO ` + o.metadata.TableName + ` t
|
|
USING (SELECT :key key, :value value, :binary_yn binary_yn, :etag etag FROM dual) new_state_to_store
|
|
ON (t.key = new_state_to_store.key)
|
|
WHEN MATCHED THEN UPDATE SET value = new_state_to_store.value, binary_yn = new_state_to_store.binary_yn, update_time = systimestamp, etag = new_state_to_store.etag, t.expiration_time = ` + ttlStatement + `
|
|
WHEN NOT MATCHED THEN INSERT (t.key, t.value, t.binary_yn, t.etag, t.expiration_time) VALUES (new_state_to_store.key, new_state_to_store.value, new_state_to_store.binary_yn, new_state_to_store.etag, ` + ttlStatement + ` ) `
|
|
}
|
|
result, err = db.ExecContext(ctx, stmt, req.Key, value, binaryYN, etag)
|
|
} else {
|
|
// When first write policy is indicated, an existing record has to be updated - one that has the etag provided.
|
|
updateStatement := `UPDATE ` + o.metadata.TableName + ` SET
|
|
value = :value,
|
|
binary_yn = :binary_yn,
|
|
etag = :new_etag,
|
|
update_time = systimestamp,
|
|
expiration_time = ` + ttlStatement + `
|
|
WHERE key = :key
|
|
AND etag = :etag
|
|
AND (expiration_time IS NULL OR expiration_time >= systimestamp)`
|
|
result, err = db.ExecContext(ctx, updateStatement, value, binaryYN, etag, req.Key, *req.ETag)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rows != 1 {
|
|
if req.HasETag() {
|
|
return state.NewETagError(state.ETagMismatch, err)
|
|
}
|
|
return errors.New("no item was updated")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned.
|
|
func (o *oracleDatabaseAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
|
|
if req.Key == "" {
|
|
return nil, errors.New("missing key in get operation")
|
|
}
|
|
var (
|
|
value string
|
|
binaryYN string
|
|
etag string
|
|
expireTime sql.NullTime
|
|
)
|
|
err := o.db.QueryRowContext(ctx, "SELECT value, binary_yn, etag, expiration_time FROM "+o.metadata.TableName+" WHERE key = :key AND (expiration_time IS NULL OR expiration_time > systimestamp)", req.Key).Scan(&value, &binaryYN, &etag, &expireTime)
|
|
if err != nil {
|
|
// If no rows exist, return an empty response, otherwise return the error.
|
|
if err == sql.ErrNoRows {
|
|
return &state.GetResponse{}, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
var metadata map[string]string
|
|
if expireTime.Valid {
|
|
metadata = map[string]string{
|
|
state.GetRespMetaKeyTTLExpireTime: expireTime.Time.UTC().Format(time.RFC3339),
|
|
}
|
|
}
|
|
if binaryYN == "Y" {
|
|
var (
|
|
s string
|
|
data []byte
|
|
)
|
|
if err = json.Unmarshal([]byte(value), &s); err != nil {
|
|
return nil, err
|
|
}
|
|
if data, err = base64.StdEncoding.DecodeString(s); err != nil {
|
|
return nil, err
|
|
}
|
|
return &state.GetResponse{
|
|
Data: data,
|
|
ETag: &etag,
|
|
Metadata: metadata,
|
|
}, nil
|
|
}
|
|
return &state.GetResponse{
|
|
Data: []byte(value),
|
|
ETag: &etag,
|
|
Metadata: metadata,
|
|
}, nil
|
|
}
|
|
|
|
// Delete removes an item from the state store.
|
|
func (o *oracleDatabaseAccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
|
|
return o.doDelete(ctx, o.db, req)
|
|
}
|
|
|
|
func (o *oracleDatabaseAccess) doDelete(ctx context.Context, db querier, req *state.DeleteRequest) (err error) {
|
|
if req.Key == "" {
|
|
return errors.New("missing key in delete operation")
|
|
}
|
|
|
|
var result sql.Result
|
|
if !req.HasETag() {
|
|
result, err = db.ExecContext(ctx, "DELETE FROM "+o.metadata.TableName+" WHERE key = :key", req.Key)
|
|
} else {
|
|
result, err = db.ExecContext(ctx, "DELETE FROM "+o.metadata.TableName+" WHERE key = :key AND etag = :etag", req.Key, *req.ETag)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rows == 0 && req.ETag != nil && *req.ETag != "" {
|
|
return state.NewETagError(state.ETagMismatch, nil)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (o *oracleDatabaseAccess) ExecuteMulti(parentCtx context.Context, reqs []state.TransactionalStateOperation) error {
|
|
tx, err := o.db.BeginTx(parentCtx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
for _, op := range reqs {
|
|
switch req := op.(type) {
|
|
case state.SetRequest:
|
|
err = o.doSet(parentCtx, tx, &req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
case state.DeleteRequest:
|
|
err = o.doDelete(parentCtx, tx, &req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
default:
|
|
return fmt.Errorf("unsupported operation: %s", req.Operation())
|
|
}
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
// Close implements io.Closer.
|
|
func (o *oracleDatabaseAccess) Close() error {
|
|
if o.db != nil {
|
|
return o.db.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (o *oracleDatabaseAccess) ensureStateTable(stateTableName string) error {
|
|
// TODO: This is not atomic and can lead to race conditions if multiple sidecars are starting up at the same time
|
|
exists, err := tableExists(o.db, stateTableName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !exists {
|
|
o.logger.Info("Creating state table")
|
|
createTable := `CREATE TABLE ` + stateTableName + ` (
|
|
key varchar2(100) NOT NULL PRIMARY KEY,
|
|
value clob NOT NULL,
|
|
binary_yn varchar2(1) NOT NULL,
|
|
etag varchar2(50) NOT NULL,
|
|
creation_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL ,
|
|
expiration_time TIMESTAMP WITH TIME ZONE NULL,
|
|
update_time TIMESTAMP WITH TIME ZONE NULL)`
|
|
_, err = o.db.Exec(createTable)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func tableExists(db *sql.DB, tableName string) (bool, error) {
|
|
var tblCount int32
|
|
err := db.QueryRow("SELECT count(table_name) tbl_count FROM user_tables WHERE table_name = upper(:tablename)", tableName).Scan(&tblCount)
|
|
exists := tblCount > 0
|
|
return exists, err
|
|
}
|