feat: updating context using headers (#1641)

<!-- Please use this template for your pull request. -->
<!-- Please use the sections that you need and delete other sections -->

Able to update context map using headers present in 
- OFREP requests
- Connect Requests (via Flag Evaluator V2 service)

### Related Issues
Fixes #1583 

### Notes
Context values passed via headers is high priority

If same context key is updated via
- Headers
- Request Body
- Static Config

_Context via Headers will be considered_

### Usage 
```
flagd start --port 8013 --uri file:./samples/example_flags.flagd.json -H Header=contextKey
```
or
```
flagd start --port 8013 --uri file:./samples/example_flags.flagd.json --context-from-header Header=contextKey
```

---------

Signed-off-by: Rahul Baradol <rahul.baradol.14@gmail.com>
This commit is contained in:
Rahul Baradol 2025-06-13 19:32:35 +05:30 committed by GitHub
parent bf10ff30dc
commit ba348152b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 206 additions and 98 deletions

View File

@ -23,16 +23,17 @@ type Notification struct {
type ReadinessProbe func() bool
type Configuration struct {
ReadinessProbe ReadinessProbe
Port uint16
ManagementPort uint16
ServiceName string
CertPath string
KeyPath string
SocketPath string
CORS []string
Options []connect.HandlerOption
ContextValues map[string]any
ReadinessProbe ReadinessProbe
Port uint16
ManagementPort uint16
ServiceName string
CertPath string
KeyPath string
SocketPath string
CORS []string
Options []connect.HandlerOption
ContextValues map[string]any
HeaderToContextKeyMappings map[string]string
}
/*

View File

@ -11,26 +11,27 @@ flagd start [flags]
### Options
```
-X, --context-value stringToString add arbitrary key value pairs to the flag evaluation context (default [])
-C, --cors-origin strings CORS allowed origins, * will allow all origins
-h, --help help for start
-z, --log-format string Set the logging format, e.g. console or json (default "console")
-m, --management-port int32 Port for management operations (default 8014)
-t, --metrics-exporter string Set the metrics exporter. Default(if unset) is Prometheus. Can be override to otel - OpenTelemetry metric exporter. Overriding to otel require otelCollectorURI to be present
-r, --ofrep-port int32 ofrep service port (default 8016)
-A, --otel-ca-path string tls certificate authority path to use with OpenTelemetry collector
-D, --otel-cert-path string tls certificate path to use with OpenTelemetry collector
-o, --otel-collector-uri string Set the grpc URI of the OpenTelemetry collector for flagd runtime. If unset, the collector setup will be ignored and traces will not be exported.
-K, --otel-key-path string tls key path to use with OpenTelemetry collector
-I, --otel-reload-interval duration how long between reloading the otel tls certificate from disk (default 1h0m0s)
-p, --port int32 Port to listen on (default 8013)
-c, --server-cert-path string Server side tls certificate path
-k, --server-key-path string Server side tls key path
-d, --socket-path string Flagd unix socket path. With grpc the evaluations service will become available on this address. With http(s) the grpc-gateway proxy will use this address internally.
-s, --sources string JSON representation of an array of SourceConfig objects. This object contains 2 required fields, uri (string) and provider (string). Documentation for this object: https://flagd.dev/reference/sync-configuration/#source-configuration
-g, --sync-port int32 gRPC Sync port (default 8015)
-e, --sync-socket-path string Flagd sync service socket path. With grpc the sync service will be available on this address.
-f, --uri .yaml/.yml/.json Set a sync provider uri to read data from, this can be a filepath, URL (HTTP and gRPC), FeatureFlag custom resource, or GCS or Azure Blob. When flag keys are duplicated across multiple providers the merge priority follows the index of the flag arguments, as such flags from the uri at index 0 take the lowest precedence, with duplicated keys being overwritten by those from the uri at index 1. Please note that if you are using filepath, flagd only supports files with .yaml/.yml/.json extension.
-H, --context-from-header stringToString add key-value pairs to map header values to context values, where key is Header name, value is context key (default [])
-X, --context-value stringToString add arbitrary key value pairs to the flag evaluation context (default [])
-C, --cors-origin strings CORS allowed origins, * will allow all origins
-h, --help help for start
-z, --log-format string Set the logging format, e.g. console or json (default "console")
-m, --management-port int32 Port for management operations (default 8014)
-t, --metrics-exporter string Set the metrics exporter. Default(if unset) is Prometheus. Can be override to otel - OpenTelemetry metric exporter. Overriding to otel require otelCollectorURI to be present
-r, --ofrep-port int32 ofrep service port (default 8016)
-A, --otel-ca-path string tls certificate authority path to use with OpenTelemetry collector
-D, --otel-cert-path string tls certificate path to use with OpenTelemetry collector
-o, --otel-collector-uri string Set the grpc URI of the OpenTelemetry collector for flagd runtime. If unset, the collector setup will be ignored and traces will not be exported.
-K, --otel-key-path string tls key path to use with OpenTelemetry collector
-I, --otel-reload-interval duration how long between reloading the otel tls certificate from disk (default 1h0m0s)
-p, --port int32 Port to listen on (default 8013)
-c, --server-cert-path string Server side tls certificate path
-k, --server-key-path string Server side tls key path
-d, --socket-path string Flagd unix socket path. With grpc the evaluations service will become available on this address. With http(s) the grpc-gateway proxy will use this address internally.
-s, --sources string JSON representation of an array of SourceConfig objects. This object contains 2 required fields, uri (string) and provider (string). Documentation for this object: https://flagd.dev/reference/sync-configuration/#source-configuration
-g, --sync-port int32 gRPC Sync port (default 8015)
-e, --sync-socket-path string Flagd sync service socket path. With grpc the sync service will be available on this address.
-f, --uri .yaml/.yml/.json Set a sync provider uri to read data from, this can be a filepath, URL (HTTP and gRPC), FeatureFlag custom resource, or GCS or Azure Blob. When flag keys are duplicated across multiple providers the merge priority follows the index of the flag arguments, as such flags from the uri at index 0 take the lowest precedence, with duplicated keys being overwritten by those from the uri at index 1. Please note that if you are using filepath, flagd only supports files with .yaml/.yml/.json extension.
```
### Options inherited from parent commands

View File

@ -37,6 +37,7 @@ const (
syncSocketPathFlagName = "sync-socket-path"
uriFlagName = "uri"
contextValueFlagName = "context-value"
headerToContextKeyFlagName = "context-from-header"
)
func init() {
@ -84,6 +85,8 @@ func init() {
"from disk")
flags.StringToStringP(contextValueFlagName, "X", map[string]string{}, "add arbitrary key value pairs "+
"to the flag evaluation context")
flags.StringToStringP(headerToContextKeyFlagName, "H", map[string]string{}, "add key-value pairs to map " +
"header values to context values, where key is Header name, value is context key")
bindFlags(flags)
}
@ -107,6 +110,7 @@ func bindFlags(flags *pflag.FlagSet) {
_ = viper.BindPFlag(syncSocketPathFlagName, flags.Lookup(syncSocketPathFlagName))
_ = viper.BindPFlag(ofrepPortFlagName, flags.Lookup(ofrepPortFlagName))
_ = viper.BindPFlag(contextValueFlagName, flags.Lookup(contextValueFlagName))
_ = viper.BindPFlag(headerToContextKeyFlagName, flags.Lookup(headerToContextKeyFlagName))
}
// startCmd represents the start command
@ -156,25 +160,31 @@ var startCmd = &cobra.Command{
contextValuesToMap[k] = v
}
headerToContextKeyMappings := make(map[string]string)
for k, v := range viper.GetStringMapString(headerToContextKeyFlagName) {
headerToContextKeyMappings[k] = v
}
// Build Runtime -----------------------------------------------------------
rt, err := runtime.FromConfig(logger, Version, runtime.Config{
CORS: viper.GetStringSlice(corsFlagName),
MetricExporter: viper.GetString(metricsExporter),
ManagementPort: viper.GetUint16(managementPortFlagName),
OfrepServicePort: viper.GetUint16(ofrepPortFlagName),
OtelCollectorURI: viper.GetString(otelCollectorURI),
OtelCertPath: viper.GetString(otelCertPathFlagName),
OtelKeyPath: viper.GetString(otelKeyPathFlagName),
OtelReloadInterval: viper.GetDuration(otelReloadIntervalFlagName),
OtelCAPath: viper.GetString(otelCAPathFlagName),
ServiceCertPath: viper.GetString(serverCertPathFlagName),
ServiceKeyPath: viper.GetString(serverKeyPathFlagName),
ServicePort: viper.GetUint16(portFlagName),
ServiceSocketPath: viper.GetString(socketPathFlagName),
SyncServicePort: viper.GetUint16(syncPortFlagName),
SyncServiceSocketPath: viper.GetString(syncSocketPathFlagName),
SyncProviders: syncProviders,
ContextValues: contextValuesToMap,
CORS: viper.GetStringSlice(corsFlagName),
MetricExporter: viper.GetString(metricsExporter),
ManagementPort: viper.GetUint16(managementPortFlagName),
OfrepServicePort: viper.GetUint16(ofrepPortFlagName),
OtelCollectorURI: viper.GetString(otelCollectorURI),
OtelCertPath: viper.GetString(otelCertPathFlagName),
OtelKeyPath: viper.GetString(otelKeyPathFlagName),
OtelReloadInterval: viper.GetDuration(otelReloadIntervalFlagName),
OtelCAPath: viper.GetString(otelCAPathFlagName),
ServiceCertPath: viper.GetString(serverCertPathFlagName),
ServiceKeyPath: viper.GetString(serverKeyPathFlagName),
ServicePort: viper.GetUint16(portFlagName),
ServiceSocketPath: viper.GetString(socketPathFlagName),
SyncServicePort: viper.GetUint16(syncPortFlagName),
SyncServiceSocketPath: viper.GetString(syncSocketPathFlagName),
SyncProviders: syncProviders,
ContextValues: contextValuesToMap,
HeaderToContextKeyMappings: headerToContextKeyMappings,
})
if err != nil {
rtLogger.Fatal(err.Error())

View File

@ -42,7 +42,8 @@ type Config struct {
SyncProviders []sync.SourceConfig
CORS []string
ContextValues map[string]any
ContextValues map[string]any
HeaderToContextKeyMappings map[string]string
}
// FromConfig builds a runtime from startup configurations
@ -106,6 +107,7 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime,
Port: config.OfrepServicePort,
},
config.ContextValues,
config.HeaderToContextKeyMappings,
)
if err != nil {
return nil, fmt.Errorf("error creating ofrep service")
@ -146,15 +148,16 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime,
OfrepService: ofrepService,
Service: connectService,
ServiceConfig: service.Configuration{
Port: config.ServicePort,
ManagementPort: config.ManagementPort,
ServiceName: svcName,
KeyPath: config.ServiceKeyPath,
CertPath: config.ServiceCertPath,
SocketPath: config.ServiceSocketPath,
CORS: config.CORS,
Options: options,
ContextValues: config.ContextValues,
Port: config.ServicePort,
ManagementPort: config.ManagementPort,
ServiceName: svcName,
KeyPath: config.ServiceKeyPath,
CertPath: config.ServiceCertPath,
SocketPath: config.ServiceSocketPath,
CORS: config.CORS,
Options: options,
ContextValues: config.ContextValues,
HeaderToContextKeyMappings: config.HeaderToContextKeyMappings,
},
SyncImpl: iSyncs,
}, nil

View File

@ -172,6 +172,7 @@ func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listene
s.eventingConfiguration,
s.metrics,
svcConf.ContextValues,
svcConf.HeaderToContextKeyMappings,
)
_, newHandler := evaluationV1.NewServiceHandler(newFes, append(svcConf.Options, marshalOpts)...)

View File

@ -3,6 +3,7 @@ package service
import (
"context"
"fmt"
"net/http"
"time"
schemaV1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/schema/v1"
@ -72,7 +73,7 @@ func (s *OldFlagEvaluationService) ResolveAll(
Flags: make(map[string]*schemaV1.AnyFlag),
}
values, _, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues))
values, _, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), make(map[string]string)))
if err != nil {
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)
@ -179,11 +180,13 @@ func (s *OldFlagEvaluationService) ResolveBoolean(
sCtx,
s.logger,
s.eval.ResolveBooleanValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&booleanResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
make(map[string]string),
)
if err != nil {
span.RecordError(err)
@ -206,11 +209,13 @@ func (s *OldFlagEvaluationService) ResolveString(
sCtx,
s.logger,
s.eval.ResolveStringValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&stringResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
make(map[string]string),
)
if err != nil {
span.RecordError(err)
@ -233,11 +238,13 @@ func (s *OldFlagEvaluationService) ResolveInt(
sCtx,
s.logger,
s.eval.ResolveIntValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&intResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
make(map[string]string),
)
if err != nil {
span.RecordError(err)
@ -260,11 +267,13 @@ func (s *OldFlagEvaluationService) ResolveFloat(
sCtx,
s.logger,
s.eval.ResolveFloatValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&floatResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
make(map[string]string),
)
if err != nil {
span.RecordError(err)
@ -287,11 +296,13 @@ func (s *OldFlagEvaluationService) ResolveObject(
sCtx,
s.logger,
s.eval.ResolveObjectValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&objectResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
make(map[string]string),
)
if err != nil {
span.RecordError(err)
@ -301,9 +312,9 @@ func (s *OldFlagEvaluationService) ResolveObject(
return res, err
}
// mergeContexts combines values from the request context with the values from the config --context-values flag.
// Request context values have a higher priority.
func mergeContexts(reqCtx, configFlagsCtx map[string]any) map[string]any {
// mergeContexts combines context values from headers, static context (from cli) and request context.
// highest priority > header-context-from-cli > static-context-from-cli > request-context > lowest priority
func mergeContexts(reqCtx, configFlagsCtx map[string]any, headers http.Header, headerToContextKeyMappings map[string]string) map[string]any {
merged := make(map[string]any)
for k, v := range reqCtx {
merged[k] = v
@ -311,18 +322,24 @@ func mergeContexts(reqCtx, configFlagsCtx map[string]any) map[string]any {
for k, v := range configFlagsCtx {
merged[k] = v
}
for header, contextKey := range headerToContextKeyMappings {
if values, ok := headers[header]; ok {
merged[contextKey] = values[0]
}
}
return merged
}
// resolve is a generic flag resolver
func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver resolverSignature[T], flagKey string,
func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver resolverSignature[T], header http.Header, flagKey string,
evaluationContext *structpb.Struct, resp response[T], metrics telemetry.IMetricsRecorder,
configContextValues map[string]any,
configContextValues map[string]any, configHeaderToContextKeyMappings map[string]string,
) error {
reqID := xid.New().String()
defer logger.ClearFields(reqID)
mergedContext := mergeContexts(evaluationContext.AsMap(), configContextValues)
mergedContext := mergeContexts(evaluationContext.AsMap(), configContextValues, header, configHeaderToContextKeyMappings)
logger.WriteFields(
reqID,
zap.String("flag-key", flagKey),

View File

@ -20,12 +20,13 @@ import (
)
type FlagEvaluationService struct {
logger *logger.Logger
eval evaluator.IEvaluator
metrics telemetry.IMetricsRecorder
eventingConfiguration IEvents
flagEvalTracer trace.Tracer
contextValues map[string]any
logger *logger.Logger
eval evaluator.IEvaluator
metrics telemetry.IMetricsRecorder
eventingConfiguration IEvents
flagEvalTracer trace.Tracer
contextValues map[string]any
headerToContextKeyMappings map[string]string
}
// NewFlagEvaluationService creates a FlagEvaluationService with provided parameters
@ -34,14 +35,16 @@ func NewFlagEvaluationService(log *logger.Logger,
eventingCfg IEvents,
metricsRecorder telemetry.IMetricsRecorder,
contextValues map[string]any,
headerToContextKeyMappings map[string]string,
) *FlagEvaluationService {
svc := &FlagEvaluationService{
logger: log,
eval: eval,
metrics: &telemetry.NoopMetricsRecorder{},
eventingConfiguration: eventingCfg,
flagEvalTracer: otel.Tracer("flagd.evaluation.v1"),
contextValues: contextValues,
logger: log,
eval: eval,
metrics: &telemetry.NoopMetricsRecorder{},
eventingConfiguration: eventingCfg,
flagEvalTracer: otel.Tracer("flagd.evaluation.v1"),
contextValues: contextValues,
headerToContextKeyMappings: headerToContextKeyMappings,
}
if metricsRecorder != nil {
@ -66,8 +69,9 @@ func (s *FlagEvaluationService) ResolveAll(
Flags: make(map[string]*evalV1.AnyFlag),
}
resolutions, flagSetMetadata, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(),
s.contextValues))
context := mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), s.headerToContextKeyMappings)
resolutions, flagSetMetadata, err := s.eval.ResolveAllValues(sCtx, reqID, context)
if err != nil {
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)
@ -185,11 +189,13 @@ func (s *FlagEvaluationService) ResolveBoolean(
sCtx,
s.logger,
s.eval.ResolveBooleanValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&booleanResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
s.headerToContextKeyMappings,
)
if err != nil {
span.RecordError(err)
@ -211,11 +217,13 @@ func (s *FlagEvaluationService) ResolveString(
sCtx,
s.logger,
s.eval.ResolveStringValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&stringResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
s.headerToContextKeyMappings,
)
if err != nil {
span.RecordError(err)
@ -237,11 +245,13 @@ func (s *FlagEvaluationService) ResolveInt(
sCtx,
s.logger,
s.eval.ResolveIntValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&intResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
s.headerToContextKeyMappings,
)
if err != nil {
span.RecordError(err)
@ -263,11 +273,13 @@ func (s *FlagEvaluationService) ResolveFloat(
sCtx,
s.logger,
s.eval.ResolveFloatValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&floatResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
s.headerToContextKeyMappings,
)
if err != nil {
span.RecordError(err)
@ -289,11 +301,13 @@ func (s *FlagEvaluationService) ResolveObject(
sCtx,
s.logger,
s.eval.ResolveObjectValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&objectResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
s.headerToContextKeyMappings,
)
if err != nil {
span.RecordError(err)

View File

@ -3,6 +3,7 @@ package service
import (
"context"
"errors"
"net/http"
"reflect"
"testing"
@ -103,7 +104,7 @@ func TestConnectServiceV2_ResolveAll(t *testing.T) {
).AnyTimes()
metrics, exp := getMetricReader()
s := NewFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil)
s := NewFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil, nil)
// when
got, err := s.ResolveAll(context.Background(), connect.NewRequest(tt.req))
@ -220,6 +221,7 @@ func TestFlag_EvaluationV2_ResolveBoolean(t *testing.T) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
got, err := s.ResolveBoolean(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -276,6 +278,7 @@ func BenchmarkFlag_EvaluationV2_ResolveBoolean(b *testing.B) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -375,6 +378,7 @@ func TestFlag_EvaluationV2_ResolveString(t *testing.T) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
got, err := s.ResolveString(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -431,6 +435,7 @@ func BenchmarkFlag_EvaluationV2_ResolveString(b *testing.B) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -529,6 +534,7 @@ func TestFlag_EvaluationV2_ResolveFloat(t *testing.T) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
got, err := s.ResolveFloat(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -585,6 +591,7 @@ func BenchmarkFlag_EvaluationV2_ResolveFloat(b *testing.B) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -683,6 +690,7 @@ func TestFlag_EvaluationV2_ResolveInt(t *testing.T) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
got, err := s.ResolveInt(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -739,6 +747,7 @@ func BenchmarkFlag_EvaluationV2_ResolveInt(b *testing.B) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -840,6 +849,7 @@ func TestFlag_EvaluationV2_ResolveObject(t *testing.T) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
outParsed, err := structpb.NewStruct(tt.evalFields.result)
@ -904,6 +914,7 @@ func BenchmarkFlag_EvaluationV2_ResolveObject(b *testing.B) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
if name != "eval returns error" {
outParsed, err := structpb.NewStruct(tt.evalFields.result)
@ -979,7 +990,10 @@ func TestFlag_EvaluationV2_ErrorCodes(t *testing.T) {
func Test_mergeContexts(t *testing.T) {
type args struct {
clientContext, configContext map[string]any
headers http.Header
headerToContextKeyMappings map[string]string
clientContext map[string]any
configContext map[string]any
}
tests := []struct {
@ -988,19 +1002,54 @@ func Test_mergeContexts(t *testing.T) {
want map[string]any
}{
{
name: "merge contexts",
name: "merge contexts with no headers, with no header-context mappings",
args: args{
clientContext: map[string]any{"k1": "v1", "k2": "v2"},
configContext: map[string]any{"k2": "v22", "k3": "v3"},
headers: http.Header{},
headerToContextKeyMappings: map[string]string{},
},
// static context should "win"
want: map[string]any{"k1": "v1", "k2": "v22", "k3": "v3"},
},
{
name: "merge contexts with headers, with no header-context mappings",
args: args{
clientContext: map[string]any{"k1": "v1", "k2": "v2"},
configContext: map[string]any{"k2": "v22", "k3": "v3"},
headers: http.Header{"X-key": []string{"value"}, "X-token": []string{"token"}},
headerToContextKeyMappings: map[string]string{},
},
// static context should "win"
want: map[string]any{"k1": "v1", "k2": "v22", "k3": "v3"},
},
{
name: "merge contexts with no headers, with header-context mappings",
args: args{
clientContext: map[string]any{"k1": "v1", "k2": "v2"},
configContext: map[string]any{"k2": "v22", "k3": "v3"},
headers: http.Header{},
headerToContextKeyMappings: map[string]string{"X-key": "k2"},
},
// static context should "win"
want: map[string]any{"k1": "v1", "k2": "v22", "k3": "v3"},
},
{
name: "merge contexts with headers, with header-context mappings",
args: args{
clientContext: map[string]any{"k1": "v1", "k2": "v2"},
configContext: map[string]any{"k2": "v22", "k3": "v3"},
headers: http.Header{"X-key": []string{"value"}, "X-token": []string{"token"}},
headerToContextKeyMappings: map[string]string{"X-key": "k2"},
},
// header context should "win"
want: map[string]any{"k1": "v1", "k2": "value", "k3": "v3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := mergeContexts(tt.args.clientContext, tt.args.configContext)
got := mergeContexts(tt.args.clientContext, tt.args.configContext, tt.args.headers, tt.args.headerToContextKeyMappings)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("\ngot: %+v\nwant: %+v", got, tt.want)

View File

@ -23,18 +23,20 @@ const (
)
type handler struct {
Logger *logger.Logger
evaluator evaluator.IEvaluator
contextValues map[string]any
tracer trace.Tracer
Logger *logger.Logger
evaluator evaluator.IEvaluator
contextValues map[string]any
headerToContextKeyMappings map[string]string
tracer trace.Tracer
}
func NewOfrepHandler(logger *logger.Logger, evaluator evaluator.IEvaluator, contextValues map[string]any) http.Handler {
func NewOfrepHandler(logger *logger.Logger, evaluator evaluator.IEvaluator, contextValues map[string]any, headerToContextKeyMappings map[string]string) http.Handler {
h := handler{
Logger: logger,
evaluator: evaluator,
contextValues: contextValues,
tracer: otel.Tracer("flagd.ofrep.v1"),
Logger: logger,
evaluator: evaluator,
contextValues: contextValues,
headerToContextKeyMappings: headerToContextKeyMappings,
tracer: otel.Tracer("flagd.ofrep.v1"),
}
router := mux.NewRouter()
@ -62,7 +64,8 @@ func (h *handler) HandleFlagEvaluation(w http.ResponseWriter, r *http.Request) {
h.writeJSONToResponse(http.StatusBadRequest, ofrep.ContextErrorResponseFrom(flagKey), w)
return
}
context := flagdContext(h.Logger, requestID, request, h.contextValues)
context := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings)
evaluation := h.evaluator.ResolveAsAnyValue(r.Context(), requestID, flagKey, context)
if evaluation.Error != nil {
status, evaluationError := ofrep.EvaluationErrorResponseFrom(evaluation)
@ -82,7 +85,8 @@ func (h *handler) HandleBulkEvaluation(w http.ResponseWriter, r *http.Request) {
return
}
context := flagdContext(h.Logger, requestID, request, h.contextValues)
context := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings)
evaluations, metadata, err := h.evaluator.ResolveAllValues(r.Context(), requestID, context)
if err != nil {
h.Logger.WarnWithID(requestID, fmt.Sprintf("error from resolver: %v", err))
@ -123,8 +127,10 @@ func extractOfrepRequest(req *http.Request) (ofrep.Request, error) {
return request, nil
}
// flagdContext returns combined context values from headers, static context (from cli) and request context.
// highest priority > header-context-from-cli > static-context-from-cli > request-context > lowest priority
func flagdContext(
log *logger.Logger, requestID string, request ofrep.Request, staticContextValues map[string]any,
log *logger.Logger, requestID string, request ofrep.Request, staticContextValues map[string]any, headers http.Header, headerToContextKeyMappings map[string]string,
) map[string]any {
context := make(map[string]any)
if res, ok := request.Context.(map[string]any); ok {
@ -139,5 +145,11 @@ func flagdContext(
context[k] = v
}
for header, contextKey := range headerToContextKeyMappings {
if values, ok := headers[header]; ok {
context[contextKey] = values[0]
}
}
return context
}

View File

@ -30,13 +30,13 @@ type Service struct {
}
func NewOfrepService(
evaluator evaluator.IEvaluator, origins []string, cfg SvcConfiguration, contextValues map[string]any,
evaluator evaluator.IEvaluator, origins []string, cfg SvcConfiguration, contextValues map[string]any, headerToContextKeyMappings map[string]string,
) (*Service, error) {
corsMW := cors.New(cors.Options{
AllowedOrigins: origins,
AllowedMethods: []string{http.MethodPost},
})
h := corsMW.Handler(NewOfrepHandler(cfg.Logger, evaluator, contextValues))
h := corsMW.Handler(NewOfrepHandler(cfg.Logger, evaluator, contextValues, headerToContextKeyMappings))
server := http.Server{
Addr: fmt.Sprintf(":%d", cfg.Port),

View File

@ -28,7 +28,7 @@ func Test_OfrepServiceStartStop(t *testing.T) {
Port: uint16(port),
}
service, err := NewOfrepService(eval, []string{"*"}, cfg, nil)
service, err := NewOfrepService(eval, []string{"*"}, cfg, nil, nil)
if err != nil {
t.Fatalf("error creating the ofrep service: %v", err)
}