feat: add context-value flag (#1448)

- add the `--context-value` command line flag to pass arbitrary key
value pairs to the evaluation context

Signed-off-by: Aleksei Muratov <muratoff.alexey@gmail.com>
This commit is contained in:
Aleksei 2024-12-05 17:05:46 +01:00 committed by GitHub
parent f7dd1eb630
commit 7ca092e478
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 181 additions and 54 deletions

View File

@ -68,6 +68,9 @@ uninstall:
lint:
go install -v github.com/golangci/golangci-lint/cmd/golangci-lint@v1.55.2
$(foreach module, $(ALL_GO_MOD_DIRS), ${GOPATH}/bin/golangci-lint run --deadline=5m --timeout=5m $(module)/... || exit;)
lint-fix:
go install -v github.com/golangci/golangci-lint/cmd/golangci-lint@v1.55.2
$(foreach module, $(ALL_GO_MOD_DIRS), ${GOPATH}/bin/golangci-lint run --fix --deadline=5m --timeout=5m $(module)/... || exit;)
install-mockgen:
go install go.uber.org/mock/mockgen@v0.4.0
mockgen: install-mockgen

View File

@ -32,6 +32,7 @@ type Configuration struct {
SocketPath string
CORS []string
Options []connect.HandlerOption
ContextValues map[string]any
}
/*

View File

@ -184,6 +184,9 @@ For example, when accessing flagd via HTTP, the POST body may look like this:
The evaluation context can be accessed in targeting rules using the `var` operation followed by the evaluation context property name.
The evaluation context can be appended by arbitrary key value pairs
via the `-X` command line flag.
| Description | Example |
| -------------------------------------------------------------- | ---------------------------------------------------- |
| Retrieve property from the evaluation context | `#!json { "var": "email" }` |

View File

@ -11,6 +11,7 @@ flagd start [flags]
### Options
```
-X, --context-value stringToString add arbitrary key value pairs to the flag evaluation context (default [])
-C, --cors-origin strings CORS allowed origins, * will allow all origins
-h, --help help for start
-z, --log-format string Set the logging format, e.g. console or json (default "console")

View File

@ -34,11 +34,11 @@ const (
sourcesFlagName = "sources"
syncPortFlagName = "sync-port"
uriFlagName = "uri"
contextValueFlagName = "context-value"
)
func init() {
flags := startCmd.Flags()
// allows environment variables to use _ instead of -
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) // sync-provider-args becomes SYNC_PROVIDER_ARGS
viper.SetEnvPrefix("FLAGD") // port becomes FLAGD_PORT
@ -78,6 +78,8 @@ func init() {
flags.StringP(otelCAPathFlagName, "A", "", "tls certificate authority path to use with OpenTelemetry collector")
flags.DurationP(otelReloadIntervalFlagName, "I", time.Hour, "how long between reloading the otel tls certificate "+
"from disk")
flags.StringToStringP(contextValueFlagName, "X", map[string]string{}, "add arbitrary key value pairs "+
"to the flag evaluation context")
_ = viper.BindPFlag(corsFlagName, flags.Lookup(corsFlagName))
_ = viper.BindPFlag(logFormatFlagName, flags.Lookup(logFormatFlagName))
@ -95,6 +97,7 @@ func init() {
_ = viper.BindPFlag(uriFlagName, flags.Lookup(uriFlagName))
_ = viper.BindPFlag(syncPortFlagName, flags.Lookup(syncPortFlagName))
_ = viper.BindPFlag(ofrepPortFlagName, flags.Lookup(ofrepPortFlagName))
_ = viper.BindPFlag(contextValueFlagName, flags.Lookup(contextValueFlagName))
}
// startCmd represents the start command
@ -139,6 +142,11 @@ var startCmd = &cobra.Command{
}
syncProviders = append(syncProviders, syncProvidersFromConfig...)
contextValuesToMap := make(map[string]any)
for k, v := range viper.GetStringMapString(contextValueFlagName) {
contextValuesToMap[k] = v
}
// Build Runtime -----------------------------------------------------------
rt, err := runtime.FromConfig(logger, Version, runtime.Config{
CORS: viper.GetStringSlice(corsFlagName),
@ -156,6 +164,7 @@ var startCmd = &cobra.Command{
ServiceSocketPath: viper.GetString(socketPathFlagName),
SyncServicePort: viper.GetUint16(syncPortFlagName),
SyncProviders: syncProviders,
ContextValues: contextValuesToMap,
})
if err != nil {
rtLogger.Fatal(err.Error())

View File

@ -40,6 +40,8 @@ type Config struct {
SyncProviders []sync.SourceConfig
CORS []string
ContextValues map[string]any
}
// FromConfig builds a runtime from startup configurations
@ -101,7 +103,9 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime,
ofrepService, err := ofrep.NewOfrepService(jsonEvaluator, config.CORS, ofrep.SvcConfiguration{
Logger: logger.WithFields(zap.String("component", "OFREPService")),
Port: config.OfrepServicePort,
})
},
config.ContextValues,
)
if err != nil {
return nil, fmt.Errorf("error creating ofrep service")
}
@ -112,6 +116,7 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime,
Port: config.SyncServicePort,
Sources: sources,
Store: s,
ContextValues: config.ContextValues,
})
if err != nil {
return nil, fmt.Errorf("error creating sync service: %w", err)
@ -145,6 +150,7 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime,
SocketPath: config.ServiceSocketPath,
CORS: config.CORS,
Options: options,
ContextValues: config.ContextValues,
},
SyncImpl: iSyncs,
}, nil

View File

@ -154,6 +154,7 @@ func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listene
s.eval,
s.eventingConfiguration,
s.metrics,
svcConf.ContextValues,
)
marshalOpts := WithJSON(
@ -170,6 +171,7 @@ func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listene
s.eval,
s.eventingConfiguration,
s.metrics,
svcConf.ContextValues,
)
_, newHandler := evaluationV1.NewServiceHandler(newFes, append(svcConf.Options, marshalOpts)...)

View File

@ -32,11 +32,16 @@ type OldFlagEvaluationService struct {
metrics telemetry.IMetricsRecorder
eventingConfiguration IEvents
flagEvalTracer trace.Tracer
contextValues map[string]any
}
// NewOldFlagEvaluationService creates a OldFlagEvaluationService with provided parameters
func NewOldFlagEvaluationService(log *logger.Logger,
eval evaluator.IEvaluator, eventingCfg IEvents, metricsRecorder telemetry.IMetricsRecorder,
func NewOldFlagEvaluationService(
log *logger.Logger,
eval evaluator.IEvaluator,
eventingCfg IEvents,
metricsRecorder telemetry.IMetricsRecorder,
contextValues map[string]any,
) *OldFlagEvaluationService {
svc := &OldFlagEvaluationService{
logger: log,
@ -44,6 +49,7 @@ func NewOldFlagEvaluationService(log *logger.Logger,
metrics: &telemetry.NoopMetricsRecorder{},
eventingConfiguration: eventingCfg,
flagEvalTracer: otel.Tracer("flagEvaluationService"),
contextValues: contextValues,
}
if metricsRecorder != nil {
@ -65,12 +71,8 @@ func (s *OldFlagEvaluationService) ResolveAll(
res := &schemaV1.ResolveAllResponse{
Flags: make(map[string]*schemaV1.AnyFlag),
}
evalCtx := map[string]any{}
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}
values, err := s.eval.ResolveAllValues(sCtx, reqID, evalCtx)
values, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues))
if err != nil {
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)
@ -172,6 +174,7 @@ func (s *OldFlagEvaluationService) ResolveBoolean(
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()
res := connect.NewResponse(&schemaV1.ResolveBooleanResponse{})
err := resolve[bool](
sCtx,
s.logger,
@ -180,6 +183,7 @@ func (s *OldFlagEvaluationService) ResolveBoolean(
req.Msg.GetContext(),
&booleanResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
@ -206,6 +210,7 @@ func (s *OldFlagEvaluationService) ResolveString(
req.Msg.GetContext(),
&stringResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
@ -232,6 +237,7 @@ func (s *OldFlagEvaluationService) ResolveInt(
req.Msg.GetContext(),
&intResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
@ -258,6 +264,7 @@ func (s *OldFlagEvaluationService) ResolveFloat(
req.Msg.GetContext(),
&floatResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
@ -284,6 +291,7 @@ func (s *OldFlagEvaluationService) ResolveObject(
req.Msg.GetContext(),
&objectResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
@ -293,21 +301,36 @@ func (s *OldFlagEvaluationService) ResolveObject(
return res, err
}
// mergeContexts combines values from the request context with the values from the config --context-values flag.
// Request context values have a higher priority.
func mergeContexts(reqCtx, configFlagsCtx map[string]any) map[string]any {
merged := make(map[string]any)
for k, v := range reqCtx {
merged[k] = v
}
for k, v := range configFlagsCtx {
merged[k] = v
}
return merged
}
// resolve is a generic flag resolver
func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver resolverSignature[T], flagKey string,
evaluationContext *structpb.Struct, resp response[T], metrics telemetry.IMetricsRecorder,
configContextValues map[string]any,
) error {
reqID := xid.New().String()
defer logger.ClearFields(reqID)
mergedContext := mergeContexts(evaluationContext.AsMap(), configContextValues)
logger.WriteFields(
reqID,
zap.String("flag-key", flagKey),
zap.Strings("context-keys", formatContextKeys(evaluationContext)),
zap.Strings("context-keys", formatContextKeys(mergedContext)),
)
var evalErrFormatted error
result, variant, reason, metadata, evalErr := resolver(ctx, reqID, flagKey, evaluationContext.AsMap())
result, variant, reason, metadata, evalErr := resolver(ctx, reqID, flagKey, mergedContext)
if evalErr != nil {
logger.WarnWithID(reqID, fmt.Sprintf("returning error response, reason: %v", evalErr))
reason = model.ErrorReason
@ -329,9 +352,9 @@ func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver
return evalErrFormatted
}
func formatContextKeys(context *structpb.Struct) []string {
func formatContextKeys(context map[string]any) []string {
res := []string{}
for k := range context.AsMap() {
for k := range context {
res = append(res, k)
}
return res

View File

@ -128,6 +128,7 @@ func TestConnectService_ResolveAll(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveAll(context.Background(), connect.NewRequest(tt.req))
if err != nil && !errors.Is(err, tt.wantErr) {
@ -235,6 +236,7 @@ func TestFlag_Evaluation_ResolveBoolean(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveBoolean(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -290,6 +292,7 @@ func BenchmarkFlag_Evaluation_ResolveBoolean(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -388,6 +391,7 @@ func TestFlag_Evaluation_ResolveString(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveString(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -443,6 +447,7 @@ func BenchmarkFlag_Evaluation_ResolveString(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -540,6 +545,7 @@ func TestFlag_Evaluation_ResolveFloat(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveFloat(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -595,6 +601,7 @@ func BenchmarkFlag_Evaluation_ResolveFloat(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -692,6 +699,7 @@ func TestFlag_Evaluation_ResolveInt(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveInt(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -747,6 +755,7 @@ func BenchmarkFlag_Evaluation_ResolveInt(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -847,6 +856,7 @@ func TestFlag_Evaluation_ResolveObject(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
outParsed, err := structpb.NewStruct(tt.evalFields.result)
@ -910,6 +920,7 @@ func BenchmarkFlag_Evaluation_ResolveObject(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
if name != "eval returns error" {
outParsed, err := structpb.NewStruct(tt.evalFields.result)

View File

@ -25,6 +25,7 @@ type FlagEvaluationService struct {
metrics telemetry.IMetricsRecorder
eventingConfiguration IEvents
flagEvalTracer trace.Tracer
contextValues map[string]any
}
// NewFlagEvaluationService creates a FlagEvaluationService with provided parameters
@ -32,6 +33,7 @@ func NewFlagEvaluationService(log *logger.Logger,
eval evaluator.IEvaluator,
eventingCfg IEvents,
metricsRecorder telemetry.IMetricsRecorder,
contextValues map[string]any,
) *FlagEvaluationService {
svc := &FlagEvaluationService{
logger: log,
@ -39,6 +41,7 @@ func NewFlagEvaluationService(log *logger.Logger,
metrics: &telemetry.NoopMetricsRecorder{},
eventingConfiguration: eventingCfg,
flagEvalTracer: otel.Tracer("flagd.evaluation.v1"),
contextValues: contextValues,
}
if metricsRecorder != nil {
@ -63,12 +66,7 @@ func (s *FlagEvaluationService) ResolveAll(
Flags: make(map[string]*evalV1.AnyFlag),
}
evalCtx := map[string]any{}
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}
values, err := s.eval.ResolveAllValues(sCtx, reqID, evalCtx)
values, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues))
if err != nil {
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)
@ -167,8 +165,9 @@ func (s *FlagEvaluationService) ResolveBoolean(
) (*connect.Response[evalV1.ResolveBooleanResponse], error) {
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer))
defer span.End()
res := connect.NewResponse(&evalV1.ResolveBooleanResponse{})
err := resolve[bool](
err := resolve(
sCtx,
s.logger,
s.eval.ResolveBooleanValue,
@ -176,6 +175,7 @@ func (s *FlagEvaluationService) ResolveBoolean(
req.Msg.GetContext(),
&booleanResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
@ -193,7 +193,7 @@ func (s *FlagEvaluationService) ResolveString(
defer span.End()
res := connect.NewResponse(&evalV1.ResolveStringResponse{})
err := resolve[string](
err := resolve(
sCtx,
s.logger,
s.eval.ResolveStringValue,
@ -201,6 +201,7 @@ func (s *FlagEvaluationService) ResolveString(
req.Msg.GetContext(),
&stringResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
@ -218,7 +219,7 @@ func (s *FlagEvaluationService) ResolveInt(
defer span.End()
res := connect.NewResponse(&evalV1.ResolveIntResponse{})
err := resolve[int64](
err := resolve(
sCtx,
s.logger,
s.eval.ResolveIntValue,
@ -226,6 +227,7 @@ func (s *FlagEvaluationService) ResolveInt(
req.Msg.GetContext(),
&intResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
@ -243,7 +245,7 @@ func (s *FlagEvaluationService) ResolveFloat(
defer span.End()
res := connect.NewResponse(&evalV1.ResolveFloatResponse{})
err := resolve[float64](
err := resolve(
sCtx,
s.logger,
s.eval.ResolveFloatValue,
@ -251,6 +253,7 @@ func (s *FlagEvaluationService) ResolveFloat(
req.Msg.GetContext(),
&floatResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)
@ -268,7 +271,7 @@ func (s *FlagEvaluationService) ResolveObject(
defer span.End()
res := connect.NewResponse(&evalV1.ResolveObjectResponse{})
err := resolve[map[string]any](
err := resolve(
sCtx,
s.logger,
s.eval.ResolveObjectValue,
@ -276,6 +279,7 @@ func (s *FlagEvaluationService) ResolveObject(
req.Msg.GetContext(),
&objectResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
)
if err != nil {
span.RecordError(err)

View File

@ -3,6 +3,7 @@ package service
import (
"context"
"errors"
"reflect"
"testing"
evalV1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/evaluation/v1"
@ -93,7 +94,7 @@ func TestConnectServiceV2_ResolveAll(t *testing.T) {
).AnyTimes()
metrics, exp := getMetricReader()
s := NewFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics)
s := NewFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil)
// when
got, err := s.ResolveAll(context.Background(), connect.NewRequest(tt.req))
@ -208,6 +209,7 @@ func TestFlag_EvaluationV2_ResolveBoolean(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveBoolean(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -263,6 +265,7 @@ func BenchmarkFlag_EvaluationV2_ResolveBoolean(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -361,6 +364,7 @@ func TestFlag_EvaluationV2_ResolveString(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveString(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -416,6 +420,7 @@ func BenchmarkFlag_EvaluationV2_ResolveString(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -513,6 +518,7 @@ func TestFlag_EvaluationV2_ResolveFloat(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveFloat(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -568,6 +574,7 @@ func BenchmarkFlag_EvaluationV2_ResolveFloat(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -665,6 +672,7 @@ func TestFlag_EvaluationV2_ResolveInt(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
got, err := s.ResolveInt(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -720,6 +728,7 @@ func BenchmarkFlag_EvaluationV2_ResolveInt(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -820,6 +829,7 @@ func TestFlag_EvaluationV2_ResolveObject(t *testing.T) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
outParsed, err := structpb.NewStruct(tt.evalFields.result)
@ -883,6 +893,7 @@ func BenchmarkFlag_EvaluationV2_ResolveObject(b *testing.B) {
eval,
&eventingConfiguration{},
metrics,
nil,
)
if name != "eval returns error" {
outParsed, err := structpb.NewStruct(tt.evalFields.result)
@ -955,3 +966,35 @@ func TestFlag_EvaluationV2_ErrorCodes(t *testing.T) {
}
}
}
func Test_mergeContexts(t *testing.T) {
type args struct {
clientContext, configContext map[string]any
}
tests := []struct {
name string
args args
want map[string]any
}{
{
name: "merge contexts",
args: args{
clientContext: map[string]any{"k1": "v1", "k2": "v2"},
configContext: map[string]any{"k2": "v22", "k3": "v3"},
},
// static context should "win"
want: map[string]any{"k1": "v1", "k2": "v22", "k3": "v3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := mergeContexts(tt.args.clientContext, tt.args.configContext)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("\ngot: %+v\nwant: %+v", got, tt.want)
}
})
}
}

View File

@ -22,12 +22,14 @@ const (
type handler struct {
Logger *logger.Logger
evaluator evaluator.IEvaluator
contextValues map[string]any
}
func NewOfrepHandler(logger *logger.Logger, evaluator evaluator.IEvaluator) http.Handler {
func NewOfrepHandler(logger *logger.Logger, evaluator evaluator.IEvaluator, contextValues map[string]any) http.Handler {
h := handler{
logger,
evaluator,
Logger: logger,
evaluator: evaluator,
contextValues: contextValues,
}
router := mux.NewRouter()
@ -56,7 +58,7 @@ func (h *handler) HandleFlagEvaluation(w http.ResponseWriter, r *http.Request) {
return
}
context := flagdContext(h.Logger, requestID, request)
context := flagdContext(h.Logger, requestID, request, h.contextValues)
evaluation := h.evaluator.ResolveAsAnyValue(r.Context(), requestID, flagKey, context)
if evaluation.Error != nil {
status, evaluationError := ofrep.EvaluationErrorResponseFrom(evaluation)
@ -76,7 +78,7 @@ func (h *handler) HandleBulkEvaluation(w http.ResponseWriter, r *http.Request) {
return
}
context := flagdContext(h.Logger, requestID, request)
context := flagdContext(h.Logger, requestID, request, h.contextValues)
evaluations, err := h.evaluator.ResolveAllValues(r.Context(), requestID, context)
if err != nil {
h.Logger.WarnWithID(requestID, fmt.Sprintf("error from resolver: %v", err))
@ -117,13 +119,21 @@ func extractOfrepRequest(req *http.Request) (ofrep.Request, error) {
return request, nil
}
func flagdContext(log *logger.Logger, requestID string, request ofrep.Request) map[string]any {
context := map[string]any{}
func flagdContext(
log *logger.Logger, requestID string, request ofrep.Request, staticContextValues map[string]any,
) map[string]any {
context := make(map[string]any)
if res, ok := request.Context.(map[string]any); ok {
context = res
for k, v := range res {
context[k] = v
}
} else {
log.WarnWithID(requestID, "provided context does not comply with flagd, continuing ignoring the context")
}
for k, v := range staticContextValues {
context[k] = v
}
return context
}

View File

@ -29,12 +29,14 @@ type Service struct {
server *http.Server
}
func NewOfrepService(evaluator evaluator.IEvaluator, origins []string, cfg SvcConfiguration) (*Service, error) {
func NewOfrepService(
evaluator evaluator.IEvaluator, origins []string, cfg SvcConfiguration, contextValues map[string]any,
) (*Service, error) {
corsMW := cors.New(cors.Options{
AllowedOrigins: origins,
AllowedMethods: []string{http.MethodPost},
})
h := corsMW.Handler(NewOfrepHandler(cfg.Logger, evaluator))
h := corsMW.Handler(NewOfrepHandler(cfg.Logger, evaluator, contextValues))
server := http.Server{
Addr: fmt.Sprintf(":%d", cfg.Port),

View File

@ -27,7 +27,7 @@ func Test_OfrepServiceStartStop(t *testing.T) {
Port: uint16(port),
}
service, err := NewOfrepService(eval, []string{"*"}, cfg)
service, err := NewOfrepService(eval, []string{"*"}, cfg, nil)
if err != nil {
t.Fatalf("error creating the ofrep service: %v", err)
}

View File

@ -5,7 +5,7 @@ import (
"fmt"
"buf.build/gen/go/open-feature/flagd/grpc/go/flagd/sync/v1/syncv1grpc"
"buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/sync/v1"
syncv1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/sync/v1"
"github.com/open-feature/flagd/core/pkg/logger"
"google.golang.org/protobuf/types/known/structpb"
)
@ -14,6 +14,7 @@ import (
type syncHandler struct {
mux *Multiplexer
log *logger.Logger
contextValues map[string]any
}
func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.FlagSyncService_SyncFlagsServer) error {
@ -59,9 +60,15 @@ func (s syncHandler) FetchAllFlags(_ context.Context, req *syncv1.FetchAllFlagsR
func (s syncHandler) GetMetadata(_ context.Context, _ *syncv1.GetMetadataRequest) (
*syncv1.GetMetadataResponse, error,
) {
metadata, err := structpb.NewStruct(map[string]interface{}{
"sources": s.mux.SourcesAsMetadata(),
})
metadataSrc := make(map[string]any)
for k, v := range s.contextValues {
metadataSrc[k] = v
}
if sources := s.mux.SourcesAsMetadata(); sources != "" {
metadataSrc["sources"] = sources
}
metadata, err := structpb.NewStruct(metadataSrc)
if err != nil {
s.log.Warn(fmt.Sprintf("error from struct creation: %v", err))
return nil, fmt.Errorf("error constructing metadata response")

View File

@ -27,6 +27,7 @@ type SvcConfigurations struct {
Port uint16
Sources []string
Store *store.Flags
ContextValues map[string]any
}
type Service struct {
@ -49,6 +50,7 @@ func NewSyncService(cfg SvcConfigurations) (*Service, error) {
syncv1grpc.RegisterFlagSyncServiceServer(server, &syncHandler{
mux: mux,
log: l,
contextValues: cfg.ContextValues,
})
l.Info(fmt.Sprintf("starting flag sync service on port %d", cfg.Port))