530 lines
14 KiB
Go
530 lines
14 KiB
Go
/*
|
|
Copyright 2021 The Dapr Authors
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
*/
|
|
|
|
package cosmosdb
|
|
|
|
import (
|
|
"context"
|
|
_ "embed"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
|
|
"github.com/Azure/azure-sdk-for-go/sdk/data/azcosmos"
|
|
"github.com/agrea/ptr"
|
|
"github.com/google/uuid"
|
|
jsoniter "github.com/json-iterator/go"
|
|
|
|
"github.com/dapr/components-contrib/contenttype"
|
|
"github.com/dapr/components-contrib/internal/authentication/azure"
|
|
"github.com/dapr/components-contrib/state"
|
|
"github.com/dapr/components-contrib/state/query"
|
|
"github.com/dapr/kit/logger"
|
|
)
|
|
|
|
// StateStore is a CosmosDB state store.
|
|
type StateStore struct {
|
|
state.DefaultBulkStore
|
|
client *azcosmos.ContainerClient
|
|
metadata metadata
|
|
contentType string
|
|
logger logger.Logger
|
|
}
|
|
|
|
type metadata struct {
|
|
URL string `json:"url"`
|
|
MasterKey string `json:"masterKey"`
|
|
Database string `json:"database"`
|
|
Collection string `json:"collection"`
|
|
ContentType string `json:"contentType"`
|
|
}
|
|
|
|
type cosmosOperationType string
|
|
|
|
// CosmosOperation is a wrapper around a CosmosDB operation.
|
|
type CosmosOperation struct {
|
|
Item CosmosItem `json:"item"`
|
|
Type cosmosOperationType `json:"type"`
|
|
}
|
|
|
|
// CosmosItem is a wrapper around a CosmosDB document.
|
|
type CosmosItem struct {
|
|
ID string `json:"id"`
|
|
Value interface{} `json:"value"`
|
|
IsBinary bool `json:"isBinary"`
|
|
PartitionKey string `json:"partitionKey"`
|
|
TTL *int `json:"ttl,omitempty"`
|
|
Etag string
|
|
}
|
|
|
|
const (
|
|
metadataPartitionKey = "partitionKey"
|
|
unknownPartitionKey = "__UNKNOWN__"
|
|
metadataTTLKey = "ttlInSeconds"
|
|
statusTooManyRequests = "429" // RFC 6585, 4
|
|
defaultTimeout = 20 * time.Second
|
|
)
|
|
|
|
// policy that tracks the number of times it was invoked
|
|
type crossPartitionQueryPolicy struct{}
|
|
|
|
func (p *crossPartitionQueryPolicy) Do(req *policy.Request) (*http.Response, error) {
|
|
raw := req.Raw()
|
|
hdr := raw.Header
|
|
if strings.ToLower(hdr.Get("x-ms-documentdb-query")) == "true" {
|
|
// modify req here since we know it is a query
|
|
hdr.Add("x-ms-documentdb-query-enablecrosspartition", "true")
|
|
hdr.Del("x-ms-documentdb-partitionkey")
|
|
raw.Header = hdr
|
|
|
|
}
|
|
return req.Next()
|
|
}
|
|
|
|
// NewCosmosDBStateStore returns a new CosmosDB state store.
|
|
func NewCosmosDBStateStore(logger logger.Logger) state.Store {
|
|
s := &StateStore{
|
|
logger: logger,
|
|
}
|
|
s.DefaultBulkStore = state.NewDefaultBulkStore(s)
|
|
return s
|
|
}
|
|
|
|
// Init does metadata and connection parsing.
|
|
func (c *StateStore) Init(meta state.Metadata) error {
|
|
c.logger.Debugf("CosmosDB init start")
|
|
|
|
connInfo := meta.Properties
|
|
b, err := json.Marshal(connInfo)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
m := metadata{
|
|
ContentType: "application/json",
|
|
}
|
|
|
|
err = json.Unmarshal(b, &m)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if m.URL == "" {
|
|
return errors.New("url is required")
|
|
}
|
|
if m.Database == "" {
|
|
return errors.New("database is required")
|
|
}
|
|
if m.Collection == "" {
|
|
return errors.New("collection is required")
|
|
}
|
|
if m.ContentType == "" {
|
|
return errors.New("contentType is required")
|
|
}
|
|
|
|
// Internal query policy was created due to lack of cross partition query capability in go sdk
|
|
queryPolicy := &crossPartitionQueryPolicy{}
|
|
opts := azcosmos.ClientOptions{
|
|
ClientOptions: policy.ClientOptions{
|
|
PerCallPolicies: []policy.Policy{queryPolicy},
|
|
},
|
|
}
|
|
|
|
// Create the client; first, try authenticating with a master key, if present
|
|
var client *azcosmos.Client
|
|
if m.MasterKey != "" {
|
|
var cred azcosmos.KeyCredential
|
|
cred, err = azcosmos.NewKeyCredential(m.MasterKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
client, err = azcosmos.NewClientWithKey(m.URL, cred, &opts)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
// Fallback to using Azure AD
|
|
var env azure.EnvironmentSettings
|
|
env, err = azure.NewEnvironmentSettings("cosmosdb", meta.Properties)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
token, tokenErr := env.GetTokenCredential()
|
|
if tokenErr != nil {
|
|
return tokenErr
|
|
}
|
|
client, err = azcosmos.NewClient(m.URL, token, &opts)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
// Create a container client
|
|
dbClient, err := client.NewDatabase(m.Database)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Container is synonymous with collection.
|
|
dbContainer, err := dbClient.NewContainer(m.Collection)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
c.client = dbContainer
|
|
|
|
c.metadata = m
|
|
c.contentType = m.ContentType
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
|
|
_, err = c.client.Read(ctx, nil)
|
|
cancel()
|
|
return err
|
|
}
|
|
|
|
// Features returns the features available in this state store.
|
|
func (c *StateStore) Features() []state.Feature {
|
|
return []state.Feature{state.FeatureETag, state.FeatureTransactional, state.FeatureQueryAPI}
|
|
}
|
|
|
|
// Get retrieves a CosmosDB item.
|
|
func (c *StateStore) Get(req *state.GetRequest) (*state.GetResponse, error) {
|
|
partitionKey := populatePartitionMetadata(req.Key, req.Metadata)
|
|
|
|
options := azcosmos.ItemOptions{}
|
|
if req.Options.Consistency == state.Strong {
|
|
options.ConsistencyLevel = azcosmos.ConsistencyLevelStrong.ToPtr()
|
|
}
|
|
if req.Options.Consistency == state.Eventual {
|
|
options.ConsistencyLevel = azcosmos.ConsistencyLevelEventual.ToPtr()
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
|
|
readItem, err := c.client.ReadItem(ctx, azcosmos.NewPartitionKeyString(partitionKey), req.Key, &options)
|
|
cancel()
|
|
if err != nil {
|
|
var responseErr *azcore.ResponseError
|
|
errors.As(err, &responseErr)
|
|
if responseErr.ErrorCode == "NotFound" {
|
|
return &state.GetResponse{}, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
item := CosmosItem{}
|
|
err = jsoniter.ConfigFastest.Unmarshal(readItem.Value, &item)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if item.IsBinary {
|
|
if item.Value == nil {
|
|
return &state.GetResponse{
|
|
Data: make([]byte, 0),
|
|
ETag: ptr.String(item.Etag),
|
|
}, nil
|
|
}
|
|
|
|
bytes, decodeErr := base64.StdEncoding.DecodeString(item.Value.(string))
|
|
if decodeErr != nil {
|
|
c.logger.Warnf("CosmosDB state store Get request could not decode binary string: %v. Returning raw string instead.", decodeErr)
|
|
bytes = []byte(item.Value.(string))
|
|
}
|
|
|
|
return &state.GetResponse{
|
|
Data: bytes,
|
|
ETag: ptr.String(item.Etag),
|
|
}, nil
|
|
}
|
|
|
|
b, err := jsoniter.ConfigFastest.Marshal(&item.Value)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &state.GetResponse{
|
|
Data: b,
|
|
ETag: ptr.String(item.Etag),
|
|
}, nil
|
|
}
|
|
|
|
// Set saves a CosmosDB item.
|
|
func (c *StateStore) Set(req *state.SetRequest) error {
|
|
err := state.CheckRequestOptions(req.Options)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
partitionKey := populatePartitionMetadata(req.Key, req.Metadata)
|
|
options := azcosmos.ItemOptions{}
|
|
|
|
if req.ETag != nil {
|
|
etag := azcore.ETag(*req.ETag)
|
|
options.IfMatchEtag = &etag
|
|
}
|
|
if req.Options.Concurrency == state.FirstWrite && (req.ETag == nil || *req.ETag == "") {
|
|
newTag := azcore.ETag(uuid.NewString())
|
|
options.IfMatchEtag = &newTag
|
|
}
|
|
if req.Options.Consistency == state.Strong {
|
|
options.ConsistencyLevel = azcosmos.ConsistencyLevelStrong.ToPtr()
|
|
}
|
|
if req.Options.Consistency == state.Eventual {
|
|
options.ConsistencyLevel = azcosmos.ConsistencyLevelEventual.ToPtr()
|
|
}
|
|
|
|
doc, err := createUpsertItem(c.contentType, *req, partitionKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
|
|
pk := azcosmos.NewPartitionKeyString(partitionKey)
|
|
_, err = c.client.UpsertItem(ctx, pk, doc, &options)
|
|
cancel()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Delete performs a delete operation.
|
|
func (c *StateStore) Delete(req *state.DeleteRequest) error {
|
|
err := state.CheckRequestOptions(req.Options)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
partitionKey := populatePartitionMetadata(req.Key, req.Metadata)
|
|
options := azcosmos.ItemOptions{}
|
|
|
|
if req.ETag != nil {
|
|
etag := azcore.ETag(*req.ETag)
|
|
options.IfMatchEtag = &etag
|
|
}
|
|
if req.Options.Consistency == state.Strong {
|
|
options.ConsistencyLevel = azcosmos.ConsistencyLevelStrong.ToPtr()
|
|
}
|
|
if req.Options.Consistency == state.Eventual {
|
|
options.ConsistencyLevel = azcosmos.ConsistencyLevelEventual.ToPtr()
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
|
|
pk := azcosmos.NewPartitionKeyString(partitionKey)
|
|
_, err = c.client.DeleteItem(ctx, pk, req.Key, &options)
|
|
cancel()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Multi performs a transactional operation. succeeds only if all operations succeed, and fails if one or more operations fail.
|
|
func (c *StateStore) Multi(request *state.TransactionalStateRequest) error {
|
|
if len(request.Operations) == 0 {
|
|
c.logger.Debugf("No Operations Provided")
|
|
return nil
|
|
}
|
|
partitionKey := unknownPartitionKey
|
|
|
|
switch request.Operations[0].Operation {
|
|
case state.Upsert:
|
|
stateItem := request.Operations[0].Request.(*state.SetRequest)
|
|
partitionKey = populatePartitionMetadata(stateItem.Key, stateItem.Metadata)
|
|
case state.Delete:
|
|
stateItem := request.Operations[0].Request.(*state.DeleteRequest)
|
|
partitionKey = populatePartitionMetadata(stateItem.Key, stateItem.Metadata)
|
|
}
|
|
|
|
batch := c.client.NewTransactionalBatch(azcosmos.NewPartitionKeyString(partitionKey))
|
|
|
|
numOperations := 0
|
|
// Loop through the list of operations. Create and add the operation to the batch
|
|
for _, o := range request.Operations {
|
|
var options *azcosmos.TransactionalBatchItemOptions
|
|
|
|
if o.Operation == state.Upsert {
|
|
req := o.Request.(state.SetRequest)
|
|
doc, err := createUpsertItem(c.contentType, req, partitionKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if req.ETag != nil && *req.ETag != "" {
|
|
etag := azcore.ETag(*req.ETag)
|
|
options.IfMatchETag = &etag
|
|
}
|
|
if req.Options.Concurrency == state.FirstWrite && (req.ETag == nil || *req.ETag == "") {
|
|
newTag := azcore.ETag(uuid.NewString())
|
|
options.IfMatchETag = &newTag
|
|
}
|
|
|
|
batch.UpsertItem(doc, nil)
|
|
numOperations++
|
|
} else if o.Operation == state.Delete {
|
|
req := o.Request.(state.DeleteRequest)
|
|
|
|
if req.ETag != nil && *req.ETag != "" {
|
|
etag := azcore.ETag(*req.ETag)
|
|
options.IfMatchETag = &etag
|
|
}
|
|
if req.Options.Concurrency == state.FirstWrite && (req.ETag == nil || *req.ETag == "") {
|
|
newTag := azcore.ETag(uuid.NewString())
|
|
options.IfMatchETag = &newTag
|
|
}
|
|
|
|
batch.DeleteItem(req.Key, options)
|
|
numOperations++
|
|
}
|
|
}
|
|
|
|
c.logger.Debugf("#operations=%d,partitionkey=%s", numOperations, partitionKey)
|
|
|
|
var itemResponseBody map[string]string
|
|
|
|
batchResponse, err := c.client.ExecuteTransactionalBatch(context.Background(), batch, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if batchResponse.Success {
|
|
// Transaction succeeded
|
|
// We can inspect the individual operation results
|
|
for index, operation := range batchResponse.OperationResults {
|
|
c.logger.Debugf("Operation %v completed with status code %v", index, operation.StatusCode)
|
|
err = json.Unmarshal(operation.ResourceBody, &itemResponseBody)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
} else {
|
|
// Transaction failed, look for the offending operation
|
|
for index, operation := range batchResponse.OperationResults {
|
|
if string(operation.StatusCode) != statusTooManyRequests {
|
|
c.logger.Debugf("Transaction failed due to operation %v which failed with status code %v", index, operation.StatusCode)
|
|
return nil
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *StateStore) Query(req *state.QueryRequest) (*state.QueryResponse, error) {
|
|
q := &Query{}
|
|
|
|
qbuilder := query.NewQueryBuilder(q)
|
|
if err := qbuilder.BuildQuery(&req.Query); err != nil {
|
|
return &state.QueryResponse{}, err
|
|
}
|
|
var data []state.QueryItem
|
|
var token string
|
|
|
|
var innerErr error
|
|
data, token, innerErr = q.execute(c.client)
|
|
if innerErr != nil {
|
|
return nil, innerErr
|
|
}
|
|
|
|
return &state.QueryResponse{
|
|
Results: data,
|
|
Token: token,
|
|
}, nil
|
|
}
|
|
|
|
func (c *StateStore) Ping() error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
|
|
_, err := c.client.Read(ctx, nil)
|
|
cancel()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func createUpsertItem(contentType string, req state.SetRequest, partitionKey string) ([]byte, error) {
|
|
byteArray, isBinary := req.Value.([]uint8)
|
|
if len(byteArray) == 0 {
|
|
isBinary = false
|
|
}
|
|
|
|
ttl, err := parseTTL(req.Metadata)
|
|
if err != nil {
|
|
return []byte{}, fmt.Errorf("error parsing TTL from metadata: %s", err)
|
|
}
|
|
|
|
if isBinary {
|
|
if contenttype.IsJSONContentType(contentType) {
|
|
var value map[string]interface{}
|
|
err := json.Unmarshal(byteArray, &value)
|
|
// if byte array is not a valid JSON, so keep it as-is to be Base64 encoded in CosmosDB.
|
|
// otherwise, we save it as JSON
|
|
if err == nil {
|
|
item := CosmosItem{
|
|
ID: req.Key,
|
|
Value: value,
|
|
PartitionKey: partitionKey,
|
|
IsBinary: false,
|
|
TTL: ttl,
|
|
}
|
|
return json.Marshal(&item)
|
|
|
|
}
|
|
} else if contenttype.IsStringContentType(contentType) {
|
|
item := CosmosItem{
|
|
ID: req.Key,
|
|
Value: string(byteArray),
|
|
PartitionKey: partitionKey,
|
|
IsBinary: false,
|
|
TTL: ttl,
|
|
}
|
|
return json.Marshal(&item)
|
|
}
|
|
}
|
|
|
|
item := CosmosItem{
|
|
ID: req.Key,
|
|
Value: req.Value,
|
|
PartitionKey: partitionKey,
|
|
IsBinary: isBinary,
|
|
TTL: ttl,
|
|
}
|
|
return json.Marshal(&item)
|
|
}
|
|
|
|
// This is a helper to return the partition key to use. If if metadata["partitionkey"] is present,
|
|
// use that, otherwise use what's in "key".
|
|
func populatePartitionMetadata(key string, requestMetadata map[string]string) string {
|
|
if val, found := requestMetadata[metadataPartitionKey]; found {
|
|
return val
|
|
}
|
|
|
|
return key
|
|
}
|
|
|
|
func parseTTL(requestMetadata map[string]string) (*int, error) {
|
|
if val, found := requestMetadata[metadataTTLKey]; found && val != "" {
|
|
parsedVal, err := strconv.ParseInt(val, 10, 0)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
i := int(parsedVal)
|
|
|
|
return &i, nil
|
|
}
|
|
|
|
return nil, nil
|
|
}
|