feat: add context-value flag (#1448)

- add the `--context-value` command line flag to pass arbitrary key
value pairs to the evaluation context

Signed-off-by: Aleksei Muratov <muratoff.alexey@gmail.com>
This commit is contained in:
Aleksei 2024-12-05 17:05:46 +01:00 committed by GitHub
parent f7dd1eb630
commit 7ca092e478
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 181 additions and 54 deletions

View File

@ -68,6 +68,9 @@ uninstall:
lint: lint:
go install -v github.com/golangci/golangci-lint/cmd/golangci-lint@v1.55.2 go install -v github.com/golangci/golangci-lint/cmd/golangci-lint@v1.55.2
$(foreach module, $(ALL_GO_MOD_DIRS), ${GOPATH}/bin/golangci-lint run --deadline=5m --timeout=5m $(module)/... || exit;) $(foreach module, $(ALL_GO_MOD_DIRS), ${GOPATH}/bin/golangci-lint run --deadline=5m --timeout=5m $(module)/... || exit;)
lint-fix:
go install -v github.com/golangci/golangci-lint/cmd/golangci-lint@v1.55.2
$(foreach module, $(ALL_GO_MOD_DIRS), ${GOPATH}/bin/golangci-lint run --fix --deadline=5m --timeout=5m $(module)/... || exit;)
install-mockgen: install-mockgen:
go install go.uber.org/mock/mockgen@v0.4.0 go install go.uber.org/mock/mockgen@v0.4.0
mockgen: install-mockgen mockgen: install-mockgen

View File

@ -32,6 +32,7 @@ type Configuration struct {
SocketPath string SocketPath string
CORS []string CORS []string
Options []connect.HandlerOption Options []connect.HandlerOption
ContextValues map[string]any
} }
/* /*

View File

@ -184,6 +184,9 @@ For example, when accessing flagd via HTTP, the POST body may look like this:
The evaluation context can be accessed in targeting rules using the `var` operation followed by the evaluation context property name. The evaluation context can be accessed in targeting rules using the `var` operation followed by the evaluation context property name.
The evaluation context can be appended by arbitrary key value pairs
via the `-X` command line flag.
| Description | Example | | Description | Example |
| -------------------------------------------------------------- | ---------------------------------------------------- | | -------------------------------------------------------------- | ---------------------------------------------------- |
| Retrieve property from the evaluation context | `#!json { "var": "email" }` | | Retrieve property from the evaluation context | `#!json { "var": "email" }` |

View File

@ -11,6 +11,7 @@ flagd start [flags]
### Options ### Options
``` ```
-X, --context-value stringToString add arbitrary key value pairs to the flag evaluation context (default [])
-C, --cors-origin strings CORS allowed origins, * will allow all origins -C, --cors-origin strings CORS allowed origins, * will allow all origins
-h, --help help for start -h, --help help for start
-z, --log-format string Set the logging format, e.g. console or json (default "console") -z, --log-format string Set the logging format, e.g. console or json (default "console")

View File

@ -34,11 +34,11 @@ const (
sourcesFlagName = "sources" sourcesFlagName = "sources"
syncPortFlagName = "sync-port" syncPortFlagName = "sync-port"
uriFlagName = "uri" uriFlagName = "uri"
contextValueFlagName = "context-value"
) )
func init() { func init() {
flags := startCmd.Flags() flags := startCmd.Flags()
// allows environment variables to use _ instead of - // allows environment variables to use _ instead of -
viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) // sync-provider-args becomes SYNC_PROVIDER_ARGS viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) // sync-provider-args becomes SYNC_PROVIDER_ARGS
viper.SetEnvPrefix("FLAGD") // port becomes FLAGD_PORT viper.SetEnvPrefix("FLAGD") // port becomes FLAGD_PORT
@ -78,6 +78,8 @@ func init() {
flags.StringP(otelCAPathFlagName, "A", "", "tls certificate authority path to use with OpenTelemetry collector") flags.StringP(otelCAPathFlagName, "A", "", "tls certificate authority path to use with OpenTelemetry collector")
flags.DurationP(otelReloadIntervalFlagName, "I", time.Hour, "how long between reloading the otel tls certificate "+ flags.DurationP(otelReloadIntervalFlagName, "I", time.Hour, "how long between reloading the otel tls certificate "+
"from disk") "from disk")
flags.StringToStringP(contextValueFlagName, "X", map[string]string{}, "add arbitrary key value pairs "+
"to the flag evaluation context")
_ = viper.BindPFlag(corsFlagName, flags.Lookup(corsFlagName)) _ = viper.BindPFlag(corsFlagName, flags.Lookup(corsFlagName))
_ = viper.BindPFlag(logFormatFlagName, flags.Lookup(logFormatFlagName)) _ = viper.BindPFlag(logFormatFlagName, flags.Lookup(logFormatFlagName))
@ -95,6 +97,7 @@ func init() {
_ = viper.BindPFlag(uriFlagName, flags.Lookup(uriFlagName)) _ = viper.BindPFlag(uriFlagName, flags.Lookup(uriFlagName))
_ = viper.BindPFlag(syncPortFlagName, flags.Lookup(syncPortFlagName)) _ = viper.BindPFlag(syncPortFlagName, flags.Lookup(syncPortFlagName))
_ = viper.BindPFlag(ofrepPortFlagName, flags.Lookup(ofrepPortFlagName)) _ = viper.BindPFlag(ofrepPortFlagName, flags.Lookup(ofrepPortFlagName))
_ = viper.BindPFlag(contextValueFlagName, flags.Lookup(contextValueFlagName))
} }
// startCmd represents the start command // startCmd represents the start command
@ -139,6 +142,11 @@ var startCmd = &cobra.Command{
} }
syncProviders = append(syncProviders, syncProvidersFromConfig...) syncProviders = append(syncProviders, syncProvidersFromConfig...)
contextValuesToMap := make(map[string]any)
for k, v := range viper.GetStringMapString(contextValueFlagName) {
contextValuesToMap[k] = v
}
// Build Runtime ----------------------------------------------------------- // Build Runtime -----------------------------------------------------------
rt, err := runtime.FromConfig(logger, Version, runtime.Config{ rt, err := runtime.FromConfig(logger, Version, runtime.Config{
CORS: viper.GetStringSlice(corsFlagName), CORS: viper.GetStringSlice(corsFlagName),
@ -156,6 +164,7 @@ var startCmd = &cobra.Command{
ServiceSocketPath: viper.GetString(socketPathFlagName), ServiceSocketPath: viper.GetString(socketPathFlagName),
SyncServicePort: viper.GetUint16(syncPortFlagName), SyncServicePort: viper.GetUint16(syncPortFlagName),
SyncProviders: syncProviders, SyncProviders: syncProviders,
ContextValues: contextValuesToMap,
}) })
if err != nil { if err != nil {
rtLogger.Fatal(err.Error()) rtLogger.Fatal(err.Error())

View File

@ -40,6 +40,8 @@ type Config struct {
SyncProviders []sync.SourceConfig SyncProviders []sync.SourceConfig
CORS []string CORS []string
ContextValues map[string]any
} }
// FromConfig builds a runtime from startup configurations // FromConfig builds a runtime from startup configurations
@ -101,17 +103,20 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime,
ofrepService, err := ofrep.NewOfrepService(jsonEvaluator, config.CORS, ofrep.SvcConfiguration{ ofrepService, err := ofrep.NewOfrepService(jsonEvaluator, config.CORS, ofrep.SvcConfiguration{
Logger: logger.WithFields(zap.String("component", "OFREPService")), Logger: logger.WithFields(zap.String("component", "OFREPService")),
Port: config.OfrepServicePort, Port: config.OfrepServicePort,
}) },
config.ContextValues,
)
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating ofrep service") return nil, fmt.Errorf("error creating ofrep service")
} }
// flag sync service // flag sync service
flagSyncService, err := flagsync.NewSyncService(flagsync.SvcConfigurations{ flagSyncService, err := flagsync.NewSyncService(flagsync.SvcConfigurations{
Logger: logger.WithFields(zap.String("component", "FlagSyncService")), Logger: logger.WithFields(zap.String("component", "FlagSyncService")),
Port: config.SyncServicePort, Port: config.SyncServicePort,
Sources: sources, Sources: sources,
Store: s, Store: s,
ContextValues: config.ContextValues,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("error creating sync service: %w", err) return nil, fmt.Errorf("error creating sync service: %w", err)
@ -145,6 +150,7 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime,
SocketPath: config.ServiceSocketPath, SocketPath: config.ServiceSocketPath,
CORS: config.CORS, CORS: config.CORS,
Options: options, Options: options,
ContextValues: config.ContextValues,
}, },
SyncImpl: iSyncs, SyncImpl: iSyncs,
}, nil }, nil

View File

@ -154,6 +154,7 @@ func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listene
s.eval, s.eval,
s.eventingConfiguration, s.eventingConfiguration,
s.metrics, s.metrics,
svcConf.ContextValues,
) )
marshalOpts := WithJSON( marshalOpts := WithJSON(
@ -170,6 +171,7 @@ func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listene
s.eval, s.eval,
s.eventingConfiguration, s.eventingConfiguration,
s.metrics, s.metrics,
svcConf.ContextValues,
) )
_, newHandler := evaluationV1.NewServiceHandler(newFes, append(svcConf.Options, marshalOpts)...) _, newHandler := evaluationV1.NewServiceHandler(newFes, append(svcConf.Options, marshalOpts)...)

View File

@ -32,11 +32,16 @@ type OldFlagEvaluationService struct {
metrics telemetry.IMetricsRecorder metrics telemetry.IMetricsRecorder
eventingConfiguration IEvents eventingConfiguration IEvents
flagEvalTracer trace.Tracer flagEvalTracer trace.Tracer
contextValues map[string]any
} }
// NewOldFlagEvaluationService creates a OldFlagEvaluationService with provided parameters // NewOldFlagEvaluationService creates a OldFlagEvaluationService with provided parameters
func NewOldFlagEvaluationService(log *logger.Logger, func NewOldFlagEvaluationService(
eval evaluator.IEvaluator, eventingCfg IEvents, metricsRecorder telemetry.IMetricsRecorder, log *logger.Logger,
eval evaluator.IEvaluator,
eventingCfg IEvents,
metricsRecorder telemetry.IMetricsRecorder,
contextValues map[string]any,
) *OldFlagEvaluationService { ) *OldFlagEvaluationService {
svc := &OldFlagEvaluationService{ svc := &OldFlagEvaluationService{
logger: log, logger: log,
@ -44,6 +49,7 @@ func NewOldFlagEvaluationService(log *logger.Logger,
metrics: &telemetry.NoopMetricsRecorder{}, metrics: &telemetry.NoopMetricsRecorder{},
eventingConfiguration: eventingCfg, eventingConfiguration: eventingCfg,
flagEvalTracer: otel.Tracer("flagEvaluationService"), flagEvalTracer: otel.Tracer("flagEvaluationService"),
contextValues: contextValues,
} }
if metricsRecorder != nil { if metricsRecorder != nil {
@ -65,12 +71,8 @@ func (s *OldFlagEvaluationService) ResolveAll(
res := &schemaV1.ResolveAllResponse{ res := &schemaV1.ResolveAllResponse{
Flags: make(map[string]*schemaV1.AnyFlag), Flags: make(map[string]*schemaV1.AnyFlag),
} }
evalCtx := map[string]any{}
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}
values, err := s.eval.ResolveAllValues(sCtx, reqID, evalCtx) values, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues))
if err != nil { if err != nil {
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err)) s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID) return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)
@ -172,6 +174,7 @@ func (s *OldFlagEvaluationService) ResolveBoolean(
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer)) sCtx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer))
defer span.End() defer span.End()
res := connect.NewResponse(&schemaV1.ResolveBooleanResponse{}) res := connect.NewResponse(&schemaV1.ResolveBooleanResponse{})
err := resolve[bool]( err := resolve[bool](
sCtx, sCtx,
s.logger, s.logger,
@ -180,6 +183,7 @@ func (s *OldFlagEvaluationService) ResolveBoolean(
req.Msg.GetContext(), req.Msg.GetContext(),
&booleanResponse{schemaV1Resp: res}, &booleanResponse{schemaV1Resp: res},
s.metrics, s.metrics,
s.contextValues,
) )
if err != nil { if err != nil {
span.RecordError(err) span.RecordError(err)
@ -206,6 +210,7 @@ func (s *OldFlagEvaluationService) ResolveString(
req.Msg.GetContext(), req.Msg.GetContext(),
&stringResponse{schemaV1Resp: res}, &stringResponse{schemaV1Resp: res},
s.metrics, s.metrics,
s.contextValues,
) )
if err != nil { if err != nil {
span.RecordError(err) span.RecordError(err)
@ -232,6 +237,7 @@ func (s *OldFlagEvaluationService) ResolveInt(
req.Msg.GetContext(), req.Msg.GetContext(),
&intResponse{schemaV1Resp: res}, &intResponse{schemaV1Resp: res},
s.metrics, s.metrics,
s.contextValues,
) )
if err != nil { if err != nil {
span.RecordError(err) span.RecordError(err)
@ -258,6 +264,7 @@ func (s *OldFlagEvaluationService) ResolveFloat(
req.Msg.GetContext(), req.Msg.GetContext(),
&floatResponse{schemaV1Resp: res}, &floatResponse{schemaV1Resp: res},
s.metrics, s.metrics,
s.contextValues,
) )
if err != nil { if err != nil {
span.RecordError(err) span.RecordError(err)
@ -284,6 +291,7 @@ func (s *OldFlagEvaluationService) ResolveObject(
req.Msg.GetContext(), req.Msg.GetContext(),
&objectResponse{schemaV1Resp: res}, &objectResponse{schemaV1Resp: res},
s.metrics, s.metrics,
s.contextValues,
) )
if err != nil { if err != nil {
span.RecordError(err) span.RecordError(err)
@ -293,21 +301,36 @@ func (s *OldFlagEvaluationService) ResolveObject(
return res, err return res, err
} }
// mergeContexts combines values from the request context with the values from the config --context-values flag.
// Request context values have a higher priority.
func mergeContexts(reqCtx, configFlagsCtx map[string]any) map[string]any {
merged := make(map[string]any)
for k, v := range reqCtx {
merged[k] = v
}
for k, v := range configFlagsCtx {
merged[k] = v
}
return merged
}
// resolve is a generic flag resolver // resolve is a generic flag resolver
func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver resolverSignature[T], flagKey string, func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver resolverSignature[T], flagKey string,
evaluationContext *structpb.Struct, resp response[T], metrics telemetry.IMetricsRecorder, evaluationContext *structpb.Struct, resp response[T], metrics telemetry.IMetricsRecorder,
configContextValues map[string]any,
) error { ) error {
reqID := xid.New().String() reqID := xid.New().String()
defer logger.ClearFields(reqID) defer logger.ClearFields(reqID)
mergedContext := mergeContexts(evaluationContext.AsMap(), configContextValues)
logger.WriteFields( logger.WriteFields(
reqID, reqID,
zap.String("flag-key", flagKey), zap.String("flag-key", flagKey),
zap.Strings("context-keys", formatContextKeys(evaluationContext)), zap.Strings("context-keys", formatContextKeys(mergedContext)),
) )
var evalErrFormatted error var evalErrFormatted error
result, variant, reason, metadata, evalErr := resolver(ctx, reqID, flagKey, evaluationContext.AsMap()) result, variant, reason, metadata, evalErr := resolver(ctx, reqID, flagKey, mergedContext)
if evalErr != nil { if evalErr != nil {
logger.WarnWithID(reqID, fmt.Sprintf("returning error response, reason: %v", evalErr)) logger.WarnWithID(reqID, fmt.Sprintf("returning error response, reason: %v", evalErr))
reason = model.ErrorReason reason = model.ErrorReason
@ -329,9 +352,9 @@ func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver
return evalErrFormatted return evalErrFormatted
} }
func formatContextKeys(context *structpb.Struct) []string { func formatContextKeys(context map[string]any) []string {
res := []string{} res := []string{}
for k := range context.AsMap() { for k := range context {
res = append(res, k) res = append(res, k)
} }
return res return res

View File

@ -128,6 +128,7 @@ func TestConnectService_ResolveAll(t *testing.T) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
got, err := s.ResolveAll(context.Background(), connect.NewRequest(tt.req)) got, err := s.ResolveAll(context.Background(), connect.NewRequest(tt.req))
if err != nil && !errors.Is(err, tt.wantErr) { if err != nil && !errors.Is(err, tt.wantErr) {
@ -235,6 +236,7 @@ func TestFlag_Evaluation_ResolveBoolean(t *testing.T) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
got, err := s.ResolveBoolean(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) got, err := s.ResolveBoolean(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) { if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -290,6 +292,7 @@ func BenchmarkFlag_Evaluation_ResolveBoolean(b *testing.B) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -388,6 +391,7 @@ func TestFlag_Evaluation_ResolveString(t *testing.T) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
got, err := s.ResolveString(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) got, err := s.ResolveString(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) { if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -443,6 +447,7 @@ func BenchmarkFlag_Evaluation_ResolveString(b *testing.B) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -540,6 +545,7 @@ func TestFlag_Evaluation_ResolveFloat(t *testing.T) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
got, err := s.ResolveFloat(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) got, err := s.ResolveFloat(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) { if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -595,6 +601,7 @@ func BenchmarkFlag_Evaluation_ResolveFloat(b *testing.B) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -692,6 +699,7 @@ func TestFlag_Evaluation_ResolveInt(t *testing.T) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
got, err := s.ResolveInt(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) got, err := s.ResolveInt(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) { if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -747,6 +755,7 @@ func BenchmarkFlag_Evaluation_ResolveInt(b *testing.B) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -847,6 +856,7 @@ func TestFlag_Evaluation_ResolveObject(t *testing.T) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
outParsed, err := structpb.NewStruct(tt.evalFields.result) outParsed, err := structpb.NewStruct(tt.evalFields.result)
@ -910,6 +920,7 @@ func BenchmarkFlag_Evaluation_ResolveObject(b *testing.B) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
if name != "eval returns error" { if name != "eval returns error" {
outParsed, err := structpb.NewStruct(tt.evalFields.result) outParsed, err := structpb.NewStruct(tt.evalFields.result)

View File

@ -25,6 +25,7 @@ type FlagEvaluationService struct {
metrics telemetry.IMetricsRecorder metrics telemetry.IMetricsRecorder
eventingConfiguration IEvents eventingConfiguration IEvents
flagEvalTracer trace.Tracer flagEvalTracer trace.Tracer
contextValues map[string]any
} }
// NewFlagEvaluationService creates a FlagEvaluationService with provided parameters // NewFlagEvaluationService creates a FlagEvaluationService with provided parameters
@ -32,6 +33,7 @@ func NewFlagEvaluationService(log *logger.Logger,
eval evaluator.IEvaluator, eval evaluator.IEvaluator,
eventingCfg IEvents, eventingCfg IEvents,
metricsRecorder telemetry.IMetricsRecorder, metricsRecorder telemetry.IMetricsRecorder,
contextValues map[string]any,
) *FlagEvaluationService { ) *FlagEvaluationService {
svc := &FlagEvaluationService{ svc := &FlagEvaluationService{
logger: log, logger: log,
@ -39,6 +41,7 @@ func NewFlagEvaluationService(log *logger.Logger,
metrics: &telemetry.NoopMetricsRecorder{}, metrics: &telemetry.NoopMetricsRecorder{},
eventingConfiguration: eventingCfg, eventingConfiguration: eventingCfg,
flagEvalTracer: otel.Tracer("flagd.evaluation.v1"), flagEvalTracer: otel.Tracer("flagd.evaluation.v1"),
contextValues: contextValues,
} }
if metricsRecorder != nil { if metricsRecorder != nil {
@ -63,12 +66,7 @@ func (s *FlagEvaluationService) ResolveAll(
Flags: make(map[string]*evalV1.AnyFlag), Flags: make(map[string]*evalV1.AnyFlag),
} }
evalCtx := map[string]any{} values, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues))
if e := req.Msg.GetContext(); e != nil {
evalCtx = e.AsMap()
}
values, err := s.eval.ResolveAllValues(sCtx, reqID, evalCtx)
if err != nil { if err != nil {
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err)) s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID) return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)
@ -167,8 +165,9 @@ func (s *FlagEvaluationService) ResolveBoolean(
) (*connect.Response[evalV1.ResolveBooleanResponse], error) { ) (*connect.Response[evalV1.ResolveBooleanResponse], error) {
sCtx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer)) sCtx, span := s.flagEvalTracer.Start(ctx, "resolveBoolean", trace.WithSpanKind(trace.SpanKindServer))
defer span.End() defer span.End()
res := connect.NewResponse(&evalV1.ResolveBooleanResponse{}) res := connect.NewResponse(&evalV1.ResolveBooleanResponse{})
err := resolve[bool]( err := resolve(
sCtx, sCtx,
s.logger, s.logger,
s.eval.ResolveBooleanValue, s.eval.ResolveBooleanValue,
@ -176,6 +175,7 @@ func (s *FlagEvaluationService) ResolveBoolean(
req.Msg.GetContext(), req.Msg.GetContext(),
&booleanResponse{evalV1Resp: res}, &booleanResponse{evalV1Resp: res},
s.metrics, s.metrics,
s.contextValues,
) )
if err != nil { if err != nil {
span.RecordError(err) span.RecordError(err)
@ -193,7 +193,7 @@ func (s *FlagEvaluationService) ResolveString(
defer span.End() defer span.End()
res := connect.NewResponse(&evalV1.ResolveStringResponse{}) res := connect.NewResponse(&evalV1.ResolveStringResponse{})
err := resolve[string]( err := resolve(
sCtx, sCtx,
s.logger, s.logger,
s.eval.ResolveStringValue, s.eval.ResolveStringValue,
@ -201,6 +201,7 @@ func (s *FlagEvaluationService) ResolveString(
req.Msg.GetContext(), req.Msg.GetContext(),
&stringResponse{evalV1Resp: res}, &stringResponse{evalV1Resp: res},
s.metrics, s.metrics,
s.contextValues,
) )
if err != nil { if err != nil {
span.RecordError(err) span.RecordError(err)
@ -218,7 +219,7 @@ func (s *FlagEvaluationService) ResolveInt(
defer span.End() defer span.End()
res := connect.NewResponse(&evalV1.ResolveIntResponse{}) res := connect.NewResponse(&evalV1.ResolveIntResponse{})
err := resolve[int64]( err := resolve(
sCtx, sCtx,
s.logger, s.logger,
s.eval.ResolveIntValue, s.eval.ResolveIntValue,
@ -226,6 +227,7 @@ func (s *FlagEvaluationService) ResolveInt(
req.Msg.GetContext(), req.Msg.GetContext(),
&intResponse{evalV1Resp: res}, &intResponse{evalV1Resp: res},
s.metrics, s.metrics,
s.contextValues,
) )
if err != nil { if err != nil {
span.RecordError(err) span.RecordError(err)
@ -243,7 +245,7 @@ func (s *FlagEvaluationService) ResolveFloat(
defer span.End() defer span.End()
res := connect.NewResponse(&evalV1.ResolveFloatResponse{}) res := connect.NewResponse(&evalV1.ResolveFloatResponse{})
err := resolve[float64]( err := resolve(
sCtx, sCtx,
s.logger, s.logger,
s.eval.ResolveFloatValue, s.eval.ResolveFloatValue,
@ -251,6 +253,7 @@ func (s *FlagEvaluationService) ResolveFloat(
req.Msg.GetContext(), req.Msg.GetContext(),
&floatResponse{evalV1Resp: res}, &floatResponse{evalV1Resp: res},
s.metrics, s.metrics,
s.contextValues,
) )
if err != nil { if err != nil {
span.RecordError(err) span.RecordError(err)
@ -268,7 +271,7 @@ func (s *FlagEvaluationService) ResolveObject(
defer span.End() defer span.End()
res := connect.NewResponse(&evalV1.ResolveObjectResponse{}) res := connect.NewResponse(&evalV1.ResolveObjectResponse{})
err := resolve[map[string]any]( err := resolve(
sCtx, sCtx,
s.logger, s.logger,
s.eval.ResolveObjectValue, s.eval.ResolveObjectValue,
@ -276,6 +279,7 @@ func (s *FlagEvaluationService) ResolveObject(
req.Msg.GetContext(), req.Msg.GetContext(),
&objectResponse{evalV1Resp: res}, &objectResponse{evalV1Resp: res},
s.metrics, s.metrics,
s.contextValues,
) )
if err != nil { if err != nil {
span.RecordError(err) span.RecordError(err)

View File

@ -3,6 +3,7 @@ package service
import ( import (
"context" "context"
"errors" "errors"
"reflect"
"testing" "testing"
evalV1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/evaluation/v1" evalV1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/evaluation/v1"
@ -93,7 +94,7 @@ func TestConnectServiceV2_ResolveAll(t *testing.T) {
).AnyTimes() ).AnyTimes()
metrics, exp := getMetricReader() metrics, exp := getMetricReader()
s := NewFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics) s := NewFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil)
// when // when
got, err := s.ResolveAll(context.Background(), connect.NewRequest(tt.req)) got, err := s.ResolveAll(context.Background(), connect.NewRequest(tt.req))
@ -208,6 +209,7 @@ func TestFlag_EvaluationV2_ResolveBoolean(t *testing.T) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
got, err := s.ResolveBoolean(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) got, err := s.ResolveBoolean(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) { if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -263,6 +265,7 @@ func BenchmarkFlag_EvaluationV2_ResolveBoolean(b *testing.B) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -361,6 +364,7 @@ func TestFlag_EvaluationV2_ResolveString(t *testing.T) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
got, err := s.ResolveString(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) got, err := s.ResolveString(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) { if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -416,6 +420,7 @@ func BenchmarkFlag_EvaluationV2_ResolveString(b *testing.B) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -513,6 +518,7 @@ func TestFlag_EvaluationV2_ResolveFloat(t *testing.T) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
got, err := s.ResolveFloat(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) got, err := s.ResolveFloat(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) { if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -568,6 +574,7 @@ func BenchmarkFlag_EvaluationV2_ResolveFloat(b *testing.B) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -665,6 +672,7 @@ func TestFlag_EvaluationV2_ResolveInt(t *testing.T) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
got, err := s.ResolveInt(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req)) got, err := s.ResolveInt(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) { if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -720,6 +728,7 @@ func BenchmarkFlag_EvaluationV2_ResolveInt(b *testing.B) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
b.Run(name, func(b *testing.B) { b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -820,6 +829,7 @@ func TestFlag_EvaluationV2_ResolveObject(t *testing.T) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
outParsed, err := structpb.NewStruct(tt.evalFields.result) outParsed, err := structpb.NewStruct(tt.evalFields.result)
@ -883,6 +893,7 @@ func BenchmarkFlag_EvaluationV2_ResolveObject(b *testing.B) {
eval, eval,
&eventingConfiguration{}, &eventingConfiguration{},
metrics, metrics,
nil,
) )
if name != "eval returns error" { if name != "eval returns error" {
outParsed, err := structpb.NewStruct(tt.evalFields.result) outParsed, err := structpb.NewStruct(tt.evalFields.result)
@ -955,3 +966,35 @@ func TestFlag_EvaluationV2_ErrorCodes(t *testing.T) {
} }
} }
} }
func Test_mergeContexts(t *testing.T) {
type args struct {
clientContext, configContext map[string]any
}
tests := []struct {
name string
args args
want map[string]any
}{
{
name: "merge contexts",
args: args{
clientContext: map[string]any{"k1": "v1", "k2": "v2"},
configContext: map[string]any{"k2": "v22", "k3": "v3"},
},
// static context should "win"
want: map[string]any{"k1": "v1", "k2": "v22", "k3": "v3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := mergeContexts(tt.args.clientContext, tt.args.configContext)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("\ngot: %+v\nwant: %+v", got, tt.want)
}
})
}
}

View File

@ -20,14 +20,16 @@ const (
) )
type handler struct { type handler struct {
Logger *logger.Logger Logger *logger.Logger
evaluator evaluator.IEvaluator evaluator evaluator.IEvaluator
contextValues map[string]any
} }
func NewOfrepHandler(logger *logger.Logger, evaluator evaluator.IEvaluator) http.Handler { func NewOfrepHandler(logger *logger.Logger, evaluator evaluator.IEvaluator, contextValues map[string]any) http.Handler {
h := handler{ h := handler{
logger, Logger: logger,
evaluator, evaluator: evaluator,
contextValues: contextValues,
} }
router := mux.NewRouter() router := mux.NewRouter()
@ -56,7 +58,7 @@ func (h *handler) HandleFlagEvaluation(w http.ResponseWriter, r *http.Request) {
return return
} }
context := flagdContext(h.Logger, requestID, request) context := flagdContext(h.Logger, requestID, request, h.contextValues)
evaluation := h.evaluator.ResolveAsAnyValue(r.Context(), requestID, flagKey, context) evaluation := h.evaluator.ResolveAsAnyValue(r.Context(), requestID, flagKey, context)
if evaluation.Error != nil { if evaluation.Error != nil {
status, evaluationError := ofrep.EvaluationErrorResponseFrom(evaluation) status, evaluationError := ofrep.EvaluationErrorResponseFrom(evaluation)
@ -76,7 +78,7 @@ func (h *handler) HandleBulkEvaluation(w http.ResponseWriter, r *http.Request) {
return return
} }
context := flagdContext(h.Logger, requestID, request) context := flagdContext(h.Logger, requestID, request, h.contextValues)
evaluations, err := h.evaluator.ResolveAllValues(r.Context(), requestID, context) evaluations, err := h.evaluator.ResolveAllValues(r.Context(), requestID, context)
if err != nil { if err != nil {
h.Logger.WarnWithID(requestID, fmt.Sprintf("error from resolver: %v", err)) h.Logger.WarnWithID(requestID, fmt.Sprintf("error from resolver: %v", err))
@ -117,13 +119,21 @@ func extractOfrepRequest(req *http.Request) (ofrep.Request, error) {
return request, nil return request, nil
} }
func flagdContext(log *logger.Logger, requestID string, request ofrep.Request) map[string]any { func flagdContext(
context := map[string]any{} log *logger.Logger, requestID string, request ofrep.Request, staticContextValues map[string]any,
) map[string]any {
context := make(map[string]any)
if res, ok := request.Context.(map[string]any); ok { if res, ok := request.Context.(map[string]any); ok {
context = res for k, v := range res {
context[k] = v
}
} else { } else {
log.WarnWithID(requestID, "provided context does not comply with flagd, continuing ignoring the context") log.WarnWithID(requestID, "provided context does not comply with flagd, continuing ignoring the context")
} }
for k, v := range staticContextValues {
context[k] = v
}
return context return context
} }

View File

@ -29,12 +29,14 @@ type Service struct {
server *http.Server server *http.Server
} }
func NewOfrepService(evaluator evaluator.IEvaluator, origins []string, cfg SvcConfiguration) (*Service, error) { func NewOfrepService(
evaluator evaluator.IEvaluator, origins []string, cfg SvcConfiguration, contextValues map[string]any,
) (*Service, error) {
corsMW := cors.New(cors.Options{ corsMW := cors.New(cors.Options{
AllowedOrigins: origins, AllowedOrigins: origins,
AllowedMethods: []string{http.MethodPost}, AllowedMethods: []string{http.MethodPost},
}) })
h := corsMW.Handler(NewOfrepHandler(cfg.Logger, evaluator)) h := corsMW.Handler(NewOfrepHandler(cfg.Logger, evaluator, contextValues))
server := http.Server{ server := http.Server{
Addr: fmt.Sprintf(":%d", cfg.Port), Addr: fmt.Sprintf(":%d", cfg.Port),

View File

@ -27,7 +27,7 @@ func Test_OfrepServiceStartStop(t *testing.T) {
Port: uint16(port), Port: uint16(port),
} }
service, err := NewOfrepService(eval, []string{"*"}, cfg) service, err := NewOfrepService(eval, []string{"*"}, cfg, nil)
if err != nil { if err != nil {
t.Fatalf("error creating the ofrep service: %v", err) t.Fatalf("error creating the ofrep service: %v", err)
} }

View File

@ -5,15 +5,16 @@ import (
"fmt" "fmt"
"buf.build/gen/go/open-feature/flagd/grpc/go/flagd/sync/v1/syncv1grpc" "buf.build/gen/go/open-feature/flagd/grpc/go/flagd/sync/v1/syncv1grpc"
"buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/sync/v1" syncv1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/flagd/sync/v1"
"github.com/open-feature/flagd/core/pkg/logger" "github.com/open-feature/flagd/core/pkg/logger"
"google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/structpb"
) )
// syncHandler implements the sync contract // syncHandler implements the sync contract
type syncHandler struct { type syncHandler struct {
mux *Multiplexer mux *Multiplexer
log *logger.Logger log *logger.Logger
contextValues map[string]any
} }
func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.FlagSyncService_SyncFlagsServer) error { func (s syncHandler) SyncFlags(req *syncv1.SyncFlagsRequest, server syncv1grpc.FlagSyncService_SyncFlagsServer) error {
@ -59,9 +60,15 @@ func (s syncHandler) FetchAllFlags(_ context.Context, req *syncv1.FetchAllFlagsR
func (s syncHandler) GetMetadata(_ context.Context, _ *syncv1.GetMetadataRequest) ( func (s syncHandler) GetMetadata(_ context.Context, _ *syncv1.GetMetadataRequest) (
*syncv1.GetMetadataResponse, error, *syncv1.GetMetadataResponse, error,
) { ) {
metadata, err := structpb.NewStruct(map[string]interface{}{ metadataSrc := make(map[string]any)
"sources": s.mux.SourcesAsMetadata(), for k, v := range s.contextValues {
}) metadataSrc[k] = v
}
if sources := s.mux.SourcesAsMetadata(); sources != "" {
metadataSrc["sources"] = sources
}
metadata, err := structpb.NewStruct(metadataSrc)
if err != nil { if err != nil {
s.log.Warn(fmt.Sprintf("error from struct creation: %v", err)) s.log.Warn(fmt.Sprintf("error from struct creation: %v", err))
return nil, fmt.Errorf("error constructing metadata response") return nil, fmt.Errorf("error constructing metadata response")

View File

@ -23,10 +23,11 @@ type ISyncService interface {
} }
type SvcConfigurations struct { type SvcConfigurations struct {
Logger *logger.Logger Logger *logger.Logger
Port uint16 Port uint16
Sources []string Sources []string
Store *store.Flags Store *store.Flags
ContextValues map[string]any
} }
type Service struct { type Service struct {
@ -47,8 +48,9 @@ func NewSyncService(cfg SvcConfigurations) (*Service, error) {
server := grpc.NewServer() server := grpc.NewServer()
syncv1grpc.RegisterFlagSyncServiceServer(server, &syncHandler{ syncv1grpc.RegisterFlagSyncServiceServer(server, &syncHandler{
mux: mux, mux: mux,
log: l, log: l,
contextValues: cfg.ContextValues,
}) })
l.Info(fmt.Sprintf("starting flag sync service on port %d", cfg.Port)) l.Info(fmt.Sprintf("starting flag sync service on port %d", cfg.Port))