363 lines
12 KiB
Go
363 lines
12 KiB
Go
/*
|
|
Copyright 2021 The Dapr Authors
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package oracledatabase
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/url"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"github.com/dapr/components-contrib/metadata"
|
|
"github.com/dapr/components-contrib/state"
|
|
stateutils "github.com/dapr/components-contrib/state/utils"
|
|
"github.com/dapr/kit/logger"
|
|
|
|
// Blank import for the underlying Oracle Database driver.
|
|
_ "github.com/sijms/go-ora/v2"
|
|
)
|
|
|
|
const (
|
|
connectionStringKey = "connectionString"
|
|
oracleWalletLocationKey = "oracleWalletLocation"
|
|
errMissingConnectionString = "missing connection string"
|
|
tableName = "state"
|
|
)
|
|
|
|
// oracleDatabaseAccess implements dbaccess.
|
|
type oracleDatabaseAccess struct {
|
|
logger logger.Logger
|
|
metadata oracleDatabaseMetadata
|
|
db *sql.DB
|
|
connectionString string
|
|
tx *sql.Tx
|
|
}
|
|
|
|
type oracleDatabaseMetadata struct {
|
|
ConnectionString string
|
|
OracleWalletLocation string
|
|
TableName string
|
|
}
|
|
|
|
// newOracleDatabaseAccess creates a new instance of oracleDatabaseAccess.
|
|
func newOracleDatabaseAccess(logger logger.Logger) *oracleDatabaseAccess {
|
|
logger.Debug("Instantiating new Oracle Database state store")
|
|
|
|
return &oracleDatabaseAccess{
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
func (o *oracleDatabaseAccess) Ping() error {
|
|
return o.db.Ping()
|
|
}
|
|
|
|
func parseMetadata(meta map[string]string) (oracleDatabaseMetadata, error) {
|
|
m := oracleDatabaseMetadata{
|
|
TableName: "state",
|
|
}
|
|
err := metadata.DecodeMetadata(meta, &m)
|
|
return m, err
|
|
}
|
|
|
|
// Init sets up OracleDatabase connection and ensures that the state table exists.
|
|
func (o *oracleDatabaseAccess) Init(metadata state.Metadata) error {
|
|
o.logger.Debug("Initializing OracleDatabase state store")
|
|
meta, err := parseMetadata(metadata.Properties)
|
|
o.metadata = meta
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if o.metadata.ConnectionString != "" {
|
|
o.connectionString = meta.ConnectionString
|
|
} else {
|
|
o.logger.Error("Missing Oracle Database connection string")
|
|
|
|
return fmt.Errorf(errMissingConnectionString)
|
|
}
|
|
if o.metadata.OracleWalletLocation != "" {
|
|
o.connectionString += "?TRACE FILE=trace.log&SSL=enable&SSL Verify=false&WALLET=" + url.QueryEscape(o.metadata.OracleWalletLocation)
|
|
}
|
|
db, err := sql.Open("oracle", o.connectionString)
|
|
if err != nil {
|
|
o.logger.Error(err)
|
|
|
|
return err
|
|
}
|
|
|
|
o.db = db
|
|
|
|
if pingErr := db.Ping(); pingErr != nil {
|
|
return pingErr
|
|
}
|
|
err = o.ensureStateTable(tableName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Set makes an insert or update to the database.
|
|
func (o *oracleDatabaseAccess) Set(ctx context.Context, req *state.SetRequest) error {
|
|
o.logger.Debug("Setting state value in OracleDatabase")
|
|
err := state.CheckRequestOptions(req.Options)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if req.Key == "" {
|
|
return fmt.Errorf("missing key in set operation")
|
|
}
|
|
|
|
if v, ok := req.Value.(string); ok && v == "" {
|
|
return fmt.Errorf("empty string is not allowed in set operation")
|
|
}
|
|
if req.Options.Concurrency == state.FirstWrite && (req.ETag == nil || len(*req.ETag) == 0) {
|
|
o.logger.Debugf("when FirstWrite is to be enforced, a value must be provided for the ETag")
|
|
return fmt.Errorf("when FirstWrite is to be enforced, a value must be provided for the ETag")
|
|
}
|
|
var ttlSeconds int
|
|
ttl, ttlerr := stateutils.ParseTTL(req.Metadata)
|
|
if ttlerr != nil {
|
|
return fmt.Errorf("error parsing TTL: %w", ttlerr)
|
|
}
|
|
if ttl != nil {
|
|
ttlSeconds = *ttl
|
|
}
|
|
requestValue := req.Value
|
|
byteArray, isBinary := req.Value.([]uint8)
|
|
binaryYN := "N"
|
|
if isBinary {
|
|
requestValue = base64.StdEncoding.EncodeToString(byteArray)
|
|
binaryYN = "Y"
|
|
}
|
|
|
|
// Convert to json string.
|
|
bt, _ := stateutils.Marshal(requestValue, json.Marshal)
|
|
value := string(bt)
|
|
|
|
var result sql.Result
|
|
var tx *sql.Tx
|
|
if o.tx == nil { // not joining a preexisting transaction.
|
|
tx, err = o.db.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start database transaction : %w", err)
|
|
}
|
|
} else { // join the transaction passed in.
|
|
tx = o.tx
|
|
}
|
|
etag := uuid.New().String()
|
|
// Only check for etag if FirstWrite specified - as per Discord message thread https://discord.com/channels/778680217417809931/901141713089863710/938520959562952735.
|
|
if req.Options.Concurrency != state.FirstWrite {
|
|
// Sprintf is required for table name because sql.DB does not substitute parameters for table names.
|
|
// Other parameters use sql.DB parameter substitution.
|
|
// As per Discord Thread https://discord.com/channels/778680217417809931/901141713089863710/938520959562952735 expiration time is reset in case of an update.
|
|
mergeStatement := fmt.Sprintf(
|
|
`MERGE INTO %s t using (select :key key, :value value, :binary_yn binary_yn, :etag etag , :ttl_in_seconds ttl_in_seconds from dual) new_state_to_store
|
|
ON (t.key = new_state_to_store.key )
|
|
WHEN MATCHED THEN UPDATE SET value = new_state_to_store.value, binary_yn = new_state_to_store.binary_yn, update_time = systimestamp, etag = new_state_to_store.etag, t.expiration_time = case when new_state_to_store.ttl_in_seconds >0 then systimestamp + numtodsinterval(new_state_to_store.ttl_in_seconds, 'SECOND') end
|
|
WHEN NOT MATCHED THEN INSERT (t.key, t.value, t.binary_yn, t.etag, t.expiration_time) values (new_state_to_store.key, new_state_to_store.value, new_state_to_store.binary_yn, new_state_to_store.etag, case when new_state_to_store.ttl_in_seconds >0 then systimestamp + numtodsinterval(new_state_to_store.ttl_in_seconds, 'SECOND') end ) `,
|
|
tableName)
|
|
result, err = tx.ExecContext(ctx, mergeStatement, req.Key, value, binaryYN, etag, ttlSeconds)
|
|
} else {
|
|
// when first write policy is indicated, an existing record has to be updated - one that has the etag provided.
|
|
// TODO: Needs to update ttl_in_seconds
|
|
updateStatement := fmt.Sprintf(
|
|
`UPDATE %s SET value = :value, binary_yn = :binary_yn, etag = :new_etag
|
|
WHERE key = :key AND etag = :etag`,
|
|
tableName)
|
|
result, err = tx.ExecContext(ctx, updateStatement, value, binaryYN, etag, req.Key, *req.ETag)
|
|
}
|
|
if err != nil {
|
|
if req.ETag != nil && *req.ETag != "" {
|
|
return state.NewETagError(state.ETagMismatch, err)
|
|
}
|
|
if o.tx == nil { // not in a preexisting transaction so rollback the local, failed tx.
|
|
tx.Rollback()
|
|
}
|
|
return err
|
|
}
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if o.tx == nil { // local transaction, take responsibility.
|
|
tx.Commit()
|
|
}
|
|
if rows != 1 {
|
|
return fmt.Errorf("no item was updated")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Get returns data from the database. If data does not exist for the key an empty state.GetResponse will be returned.
|
|
func (o *oracleDatabaseAccess) Get(ctx context.Context, req *state.GetRequest) (*state.GetResponse, error) {
|
|
o.logger.Debug("Getting state value from OracleDatabase")
|
|
if req.Key == "" {
|
|
return nil, fmt.Errorf("missing key in get operation")
|
|
}
|
|
var value string
|
|
var binaryYN string
|
|
var etag string
|
|
err := o.db.QueryRowContext(ctx, fmt.Sprintf("SELECT value, binary_yn, etag FROM %s WHERE key = :key and (expiration_time is null or expiration_time > systimestamp)", tableName), req.Key).Scan(&value, &binaryYN, &etag)
|
|
if err != nil {
|
|
// If no rows exist, return an empty response, otherwise return the error.
|
|
if err == sql.ErrNoRows {
|
|
return &state.GetResponse{}, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
if binaryYN == "Y" {
|
|
var s string
|
|
var data []byte
|
|
if err = json.Unmarshal([]byte(value), &s); err != nil {
|
|
return nil, err
|
|
}
|
|
if data, err = base64.StdEncoding.DecodeString(s); err != nil {
|
|
return nil, err
|
|
}
|
|
return &state.GetResponse{
|
|
Data: data,
|
|
ETag: &etag,
|
|
Metadata: req.Metadata,
|
|
}, nil
|
|
}
|
|
return &state.GetResponse{
|
|
Data: []byte(value),
|
|
ETag: &etag,
|
|
Metadata: req.Metadata,
|
|
}, nil
|
|
}
|
|
|
|
// Delete removes an item from the state store.
|
|
func (o *oracleDatabaseAccess) Delete(ctx context.Context, req *state.DeleteRequest) error {
|
|
o.logger.Debug("Deleting state value from OracleDatabase")
|
|
if req.Key == "" {
|
|
return fmt.Errorf("missing key in delete operation")
|
|
}
|
|
if req.Options.Concurrency == state.FirstWrite && (req.ETag == nil || len(*req.ETag) == 0) {
|
|
o.logger.Debugf("when FirstWrite is to be enforced, a value must be provided for the ETag")
|
|
return fmt.Errorf("when FirstWrite is to be enforced, a value must be provided for the ETag")
|
|
}
|
|
var result sql.Result
|
|
var err error
|
|
var tx *sql.Tx
|
|
if o.tx == nil { // not joining a preexisting transaction.
|
|
tx, err = o.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else { // join the transaction passed in.
|
|
tx = o.tx
|
|
}
|
|
// QUESTION: only check for etag if FirstWrite specified - or always when etag is supplied??
|
|
if req.Options.Concurrency != state.FirstWrite {
|
|
result, err = tx.ExecContext(ctx, "DELETE FROM state WHERE key = :key", req.Key)
|
|
} else {
|
|
result, err = tx.ExecContext(ctx, "DELETE FROM state WHERE key = :key and etag = :etag", req.Key, *req.ETag)
|
|
}
|
|
if err != nil {
|
|
if o.tx == nil { // not joining a preexisting transaction.
|
|
tx.Rollback()
|
|
}
|
|
return err
|
|
}
|
|
if o.tx == nil { // not joining a preexisting transaction.
|
|
tx.Commit()
|
|
}
|
|
rows, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rows != 1 && req.ETag != nil && *req.ETag != "" && req.Options.Concurrency == state.FirstWrite {
|
|
return state.NewETagError(state.ETagMismatch, nil)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (o *oracleDatabaseAccess) ExecuteMulti(ctx context.Context, sets []state.SetRequest, deletes []state.DeleteRequest) error {
|
|
o.logger.Debug("Executing multiple OracleDatabase operations, within a single transaction")
|
|
tx, err := o.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
o.tx = tx
|
|
if len(deletes) > 0 {
|
|
for _, d := range deletes {
|
|
da := d // Fix for gosec G601: Implicit memory aliasing in for looo.
|
|
err = o.Delete(ctx, &da)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
if len(sets) > 0 {
|
|
for _, s := range sets {
|
|
sa := s // Fix for gosec G601: Implicit memory aliasing in for looo.
|
|
err = o.Set(ctx, &sa)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
}
|
|
}
|
|
err = tx.Commit()
|
|
o.tx = nil
|
|
return err
|
|
}
|
|
|
|
// Close implements io.Closer.
|
|
func (o *oracleDatabaseAccess) Close() error {
|
|
if o.db != nil {
|
|
return o.db.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (o *oracleDatabaseAccess) ensureStateTable(stateTableName string) error {
|
|
exists, err := tableExists(o.db, stateTableName)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !exists {
|
|
o.logger.Info("Creating OracleDatabase state table")
|
|
createTable := fmt.Sprintf(`CREATE TABLE %s (
|
|
key varchar2(100) NOT NULL PRIMARY KEY,
|
|
value clob NOT NULL,
|
|
binary_yn varchar2(1) NOT NULL,
|
|
etag varchar2(50) NOT NULL,
|
|
creation_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL ,
|
|
expiration_time TIMESTAMP WITH TIME ZONE NULL,
|
|
update_time TIMESTAMP WITH TIME ZONE NULL)`, stateTableName)
|
|
_, err = o.db.Exec(createTable)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func tableExists(db *sql.DB, tableName string) (bool, error) {
|
|
var tblCount int32
|
|
err := db.QueryRow("SELECT count(table_name) tbl_count FROM user_tables where table_name = upper(:tablename)", tableName).Scan(&tblCount)
|
|
exists := tblCount > 0
|
|
return exists, err
|
|
}
|