feat: updating context using headers (#1641)

<!-- Please use this template for your pull request. -->
<!-- Please use the sections that you need and delete other sections -->

Able to update context map using headers present in 
- OFREP requests
- Connect Requests (via Flag Evaluator V2 service)

### Related Issues
Fixes #1583 

### Notes
Context values passed via headers is high priority

If same context key is updated via
- Headers
- Request Body
- Static Config

_Context via Headers will be considered_

### Usage 
```
flagd start --port 8013 --uri file:./samples/example_flags.flagd.json -H Header=contextKey
```
or
```
flagd start --port 8013 --uri file:./samples/example_flags.flagd.json --context-from-header Header=contextKey
```

---------

Signed-off-by: Rahul Baradol <rahul.baradol.14@gmail.com>
This commit is contained in:
Rahul Baradol 2025-06-13 19:32:35 +05:30 committed by GitHub
parent bf10ff30dc
commit ba348152b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 206 additions and 98 deletions

View File

@ -33,6 +33,7 @@ type Configuration struct {
CORS []string
Options []connect.HandlerOption
ContextValues map[string]any
HeaderToContextKeyMappings map[string]string
}
/*

View File

@ -11,6 +11,7 @@ flagd start [flags]
### Options
```
-H, --context-from-header stringToString add key-value pairs to map header values to context values, where key is Header name, value is context key (default [])
-X, --context-value stringToString add arbitrary key value pairs to the flag evaluation context (default [])
-C, --cors-origin strings CORS allowed origins, * will allow all origins
-h, --help help for start

View File

@ -37,6 +37,7 @@ const (
syncSocketPathFlagName = "sync-socket-path"
uriFlagName = "uri"
contextValueFlagName = "context-value"
headerToContextKeyFlagName = "context-from-header"
)
func init() {
@ -84,6 +85,8 @@ func init() {
"from disk")
flags.StringToStringP(contextValueFlagName, "X", map[string]string{}, "add arbitrary key value pairs "+
"to the flag evaluation context")
flags.StringToStringP(headerToContextKeyFlagName, "H", map[string]string{}, "add key-value pairs to map " +
"header values to context values, where key is Header name, value is context key")
bindFlags(flags)
}
@ -107,6 +110,7 @@ func bindFlags(flags *pflag.FlagSet) {
_ = viper.BindPFlag(syncSocketPathFlagName, flags.Lookup(syncSocketPathFlagName))
_ = viper.BindPFlag(ofrepPortFlagName, flags.Lookup(ofrepPortFlagName))
_ = viper.BindPFlag(contextValueFlagName, flags.Lookup(contextValueFlagName))
_ = viper.BindPFlag(headerToContextKeyFlagName, flags.Lookup(headerToContextKeyFlagName))
}
// startCmd represents the start command
@ -156,6 +160,11 @@ var startCmd = &cobra.Command{
contextValuesToMap[k] = v
}
headerToContextKeyMappings := make(map[string]string)
for k, v := range viper.GetStringMapString(headerToContextKeyFlagName) {
headerToContextKeyMappings[k] = v
}
// Build Runtime -----------------------------------------------------------
rt, err := runtime.FromConfig(logger, Version, runtime.Config{
CORS: viper.GetStringSlice(corsFlagName),
@ -175,6 +184,7 @@ var startCmd = &cobra.Command{
SyncServiceSocketPath: viper.GetString(syncSocketPathFlagName),
SyncProviders: syncProviders,
ContextValues: contextValuesToMap,
HeaderToContextKeyMappings: headerToContextKeyMappings,
})
if err != nil {
rtLogger.Fatal(err.Error())

View File

@ -43,6 +43,7 @@ type Config struct {
CORS []string
ContextValues map[string]any
HeaderToContextKeyMappings map[string]string
}
// FromConfig builds a runtime from startup configurations
@ -106,6 +107,7 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime,
Port: config.OfrepServicePort,
},
config.ContextValues,
config.HeaderToContextKeyMappings,
)
if err != nil {
return nil, fmt.Errorf("error creating ofrep service")
@ -155,6 +157,7 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime,
CORS: config.CORS,
Options: options,
ContextValues: config.ContextValues,
HeaderToContextKeyMappings: config.HeaderToContextKeyMappings,
},
SyncImpl: iSyncs,
}, nil

View File

@ -172,6 +172,7 @@ func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listene
s.eventingConfiguration,
s.metrics,
svcConf.ContextValues,
svcConf.HeaderToContextKeyMappings,
)
_, newHandler := evaluationV1.NewServiceHandler(newFes, append(svcConf.Options, marshalOpts)...)

View File

@ -3,6 +3,7 @@ package service
import (
"context"
"fmt"
"net/http"
"time"
schemaV1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/schema/v1"
@ -72,7 +73,7 @@ func (s *OldFlagEvaluationService) ResolveAll(
Flags: make(map[string]*schemaV1.AnyFlag),
}
values, _, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues))
values, _, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), make(map[string]string)))
if err != nil {
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)
@ -179,11 +180,13 @@ func (s *OldFlagEvaluationService) ResolveBoolean(
sCtx,
s.logger,
s.eval.ResolveBooleanValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&booleanResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
make(map[string]string),
)
if err != nil {
span.RecordError(err)
@ -206,11 +209,13 @@ func (s *OldFlagEvaluationService) ResolveString(
sCtx,
s.logger,
s.eval.ResolveStringValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&stringResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
make(map[string]string),
)
if err != nil {
span.RecordError(err)
@ -233,11 +238,13 @@ func (s *OldFlagEvaluationService) ResolveInt(
sCtx,
s.logger,
s.eval.ResolveIntValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&intResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
make(map[string]string),
)
if err != nil {
span.RecordError(err)
@ -260,11 +267,13 @@ func (s *OldFlagEvaluationService) ResolveFloat(
sCtx,
s.logger,
s.eval.ResolveFloatValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&floatResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
make(map[string]string),
)
if err != nil {
span.RecordError(err)
@ -287,11 +296,13 @@ func (s *OldFlagEvaluationService) ResolveObject(
sCtx,
s.logger,
s.eval.ResolveObjectValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&objectResponse{schemaV1Resp: res},
s.metrics,
s.contextValues,
make(map[string]string),
)
if err != nil {
span.RecordError(err)
@ -301,9 +312,9 @@ func (s *OldFlagEvaluationService) ResolveObject(
return res, err
}
// mergeContexts combines values from the request context with the values from the config --context-values flag.
// Request context values have a higher priority.
func mergeContexts(reqCtx, configFlagsCtx map[string]any) map[string]any {
// mergeContexts combines context values from headers, static context (from cli) and request context.
// highest priority > header-context-from-cli > static-context-from-cli > request-context > lowest priority
func mergeContexts(reqCtx, configFlagsCtx map[string]any, headers http.Header, headerToContextKeyMappings map[string]string) map[string]any {
merged := make(map[string]any)
for k, v := range reqCtx {
merged[k] = v
@ -311,18 +322,24 @@ func mergeContexts(reqCtx, configFlagsCtx map[string]any) map[string]any {
for k, v := range configFlagsCtx {
merged[k] = v
}
for header, contextKey := range headerToContextKeyMappings {
if values, ok := headers[header]; ok {
merged[contextKey] = values[0]
}
}
return merged
}
// resolve is a generic flag resolver
func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver resolverSignature[T], flagKey string,
func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver resolverSignature[T], header http.Header, flagKey string,
evaluationContext *structpb.Struct, resp response[T], metrics telemetry.IMetricsRecorder,
configContextValues map[string]any,
configContextValues map[string]any, configHeaderToContextKeyMappings map[string]string,
) error {
reqID := xid.New().String()
defer logger.ClearFields(reqID)
mergedContext := mergeContexts(evaluationContext.AsMap(), configContextValues)
mergedContext := mergeContexts(evaluationContext.AsMap(), configContextValues, header, configHeaderToContextKeyMappings)
logger.WriteFields(
reqID,
zap.String("flag-key", flagKey),

View File

@ -26,6 +26,7 @@ type FlagEvaluationService struct {
eventingConfiguration IEvents
flagEvalTracer trace.Tracer
contextValues map[string]any
headerToContextKeyMappings map[string]string
}
// NewFlagEvaluationService creates a FlagEvaluationService with provided parameters
@ -34,6 +35,7 @@ func NewFlagEvaluationService(log *logger.Logger,
eventingCfg IEvents,
metricsRecorder telemetry.IMetricsRecorder,
contextValues map[string]any,
headerToContextKeyMappings map[string]string,
) *FlagEvaluationService {
svc := &FlagEvaluationService{
logger: log,
@ -42,6 +44,7 @@ func NewFlagEvaluationService(log *logger.Logger,
eventingConfiguration: eventingCfg,
flagEvalTracer: otel.Tracer("flagd.evaluation.v1"),
contextValues: contextValues,
headerToContextKeyMappings: headerToContextKeyMappings,
}
if metricsRecorder != nil {
@ -66,8 +69,9 @@ func (s *FlagEvaluationService) ResolveAll(
Flags: make(map[string]*evalV1.AnyFlag),
}
resolutions, flagSetMetadata, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(),
s.contextValues))
context := mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), s.headerToContextKeyMappings)
resolutions, flagSetMetadata, err := s.eval.ResolveAllValues(sCtx, reqID, context)
if err != nil {
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)
@ -185,11 +189,13 @@ func (s *FlagEvaluationService) ResolveBoolean(
sCtx,
s.logger,
s.eval.ResolveBooleanValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&booleanResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
s.headerToContextKeyMappings,
)
if err != nil {
span.RecordError(err)
@ -211,11 +217,13 @@ func (s *FlagEvaluationService) ResolveString(
sCtx,
s.logger,
s.eval.ResolveStringValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&stringResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
s.headerToContextKeyMappings,
)
if err != nil {
span.RecordError(err)
@ -237,11 +245,13 @@ func (s *FlagEvaluationService) ResolveInt(
sCtx,
s.logger,
s.eval.ResolveIntValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&intResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
s.headerToContextKeyMappings,
)
if err != nil {
span.RecordError(err)
@ -263,11 +273,13 @@ func (s *FlagEvaluationService) ResolveFloat(
sCtx,
s.logger,
s.eval.ResolveFloatValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&floatResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
s.headerToContextKeyMappings,
)
if err != nil {
span.RecordError(err)
@ -289,11 +301,13 @@ func (s *FlagEvaluationService) ResolveObject(
sCtx,
s.logger,
s.eval.ResolveObjectValue,
req.Header(),
req.Msg.GetFlagKey(),
req.Msg.GetContext(),
&objectResponse{evalV1Resp: res},
s.metrics,
s.contextValues,
s.headerToContextKeyMappings,
)
if err != nil {
span.RecordError(err)

View File

@ -3,6 +3,7 @@ package service
import (
"context"
"errors"
"net/http"
"reflect"
"testing"
@ -103,7 +104,7 @@ func TestConnectServiceV2_ResolveAll(t *testing.T) {
).AnyTimes()
metrics, exp := getMetricReader()
s := NewFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil)
s := NewFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil, nil)
// when
got, err := s.ResolveAll(context.Background(), connect.NewRequest(tt.req))
@ -220,6 +221,7 @@ func TestFlag_EvaluationV2_ResolveBoolean(t *testing.T) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
got, err := s.ResolveBoolean(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -276,6 +278,7 @@ func BenchmarkFlag_EvaluationV2_ResolveBoolean(b *testing.B) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -375,6 +378,7 @@ func TestFlag_EvaluationV2_ResolveString(t *testing.T) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
got, err := s.ResolveString(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -431,6 +435,7 @@ func BenchmarkFlag_EvaluationV2_ResolveString(b *testing.B) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -529,6 +534,7 @@ func TestFlag_EvaluationV2_ResolveFloat(t *testing.T) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
got, err := s.ResolveFloat(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -585,6 +591,7 @@ func BenchmarkFlag_EvaluationV2_ResolveFloat(b *testing.B) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -683,6 +690,7 @@ func TestFlag_EvaluationV2_ResolveInt(t *testing.T) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
got, err := s.ResolveInt(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
if (err != nil) && !errors.Is(err, tt.wantErr) {
@ -739,6 +747,7 @@ func BenchmarkFlag_EvaluationV2_ResolveInt(b *testing.B) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
b.Run(name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
@ -840,6 +849,7 @@ func TestFlag_EvaluationV2_ResolveObject(t *testing.T) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
outParsed, err := structpb.NewStruct(tt.evalFields.result)
@ -904,6 +914,7 @@ func BenchmarkFlag_EvaluationV2_ResolveObject(b *testing.B) {
&eventingConfiguration{},
metrics,
nil,
nil,
)
if name != "eval returns error" {
outParsed, err := structpb.NewStruct(tt.evalFields.result)
@ -979,7 +990,10 @@ func TestFlag_EvaluationV2_ErrorCodes(t *testing.T) {
func Test_mergeContexts(t *testing.T) {
type args struct {
clientContext, configContext map[string]any
headers http.Header
headerToContextKeyMappings map[string]string
clientContext map[string]any
configContext map[string]any
}
tests := []struct {
@ -988,19 +1002,54 @@ func Test_mergeContexts(t *testing.T) {
want map[string]any
}{
{
name: "merge contexts",
name: "merge contexts with no headers, with no header-context mappings",
args: args{
clientContext: map[string]any{"k1": "v1", "k2": "v2"},
configContext: map[string]any{"k2": "v22", "k3": "v3"},
headers: http.Header{},
headerToContextKeyMappings: map[string]string{},
},
// static context should "win"
want: map[string]any{"k1": "v1", "k2": "v22", "k3": "v3"},
},
{
name: "merge contexts with headers, with no header-context mappings",
args: args{
clientContext: map[string]any{"k1": "v1", "k2": "v2"},
configContext: map[string]any{"k2": "v22", "k3": "v3"},
headers: http.Header{"X-key": []string{"value"}, "X-token": []string{"token"}},
headerToContextKeyMappings: map[string]string{},
},
// static context should "win"
want: map[string]any{"k1": "v1", "k2": "v22", "k3": "v3"},
},
{
name: "merge contexts with no headers, with header-context mappings",
args: args{
clientContext: map[string]any{"k1": "v1", "k2": "v2"},
configContext: map[string]any{"k2": "v22", "k3": "v3"},
headers: http.Header{},
headerToContextKeyMappings: map[string]string{"X-key": "k2"},
},
// static context should "win"
want: map[string]any{"k1": "v1", "k2": "v22", "k3": "v3"},
},
{
name: "merge contexts with headers, with header-context mappings",
args: args{
clientContext: map[string]any{"k1": "v1", "k2": "v2"},
configContext: map[string]any{"k2": "v22", "k3": "v3"},
headers: http.Header{"X-key": []string{"value"}, "X-token": []string{"token"}},
headerToContextKeyMappings: map[string]string{"X-key": "k2"},
},
// header context should "win"
want: map[string]any{"k1": "v1", "k2": "value", "k3": "v3"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := mergeContexts(tt.args.clientContext, tt.args.configContext)
got := mergeContexts(tt.args.clientContext, tt.args.configContext, tt.args.headers, tt.args.headerToContextKeyMappings)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("\ngot: %+v\nwant: %+v", got, tt.want)

View File

@ -26,14 +26,16 @@ type handler struct {
Logger *logger.Logger
evaluator evaluator.IEvaluator
contextValues map[string]any
headerToContextKeyMappings map[string]string
tracer trace.Tracer
}
func NewOfrepHandler(logger *logger.Logger, evaluator evaluator.IEvaluator, contextValues map[string]any) http.Handler {
func NewOfrepHandler(logger *logger.Logger, evaluator evaluator.IEvaluator, contextValues map[string]any, headerToContextKeyMappings map[string]string) http.Handler {
h := handler{
Logger: logger,
evaluator: evaluator,
contextValues: contextValues,
headerToContextKeyMappings: headerToContextKeyMappings,
tracer: otel.Tracer("flagd.ofrep.v1"),
}
@ -62,7 +64,8 @@ func (h *handler) HandleFlagEvaluation(w http.ResponseWriter, r *http.Request) {
h.writeJSONToResponse(http.StatusBadRequest, ofrep.ContextErrorResponseFrom(flagKey), w)
return
}
context := flagdContext(h.Logger, requestID, request, h.contextValues)
context := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings)
evaluation := h.evaluator.ResolveAsAnyValue(r.Context(), requestID, flagKey, context)
if evaluation.Error != nil {
status, evaluationError := ofrep.EvaluationErrorResponseFrom(evaluation)
@ -82,7 +85,8 @@ func (h *handler) HandleBulkEvaluation(w http.ResponseWriter, r *http.Request) {
return
}
context := flagdContext(h.Logger, requestID, request, h.contextValues)
context := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings)
evaluations, metadata, err := h.evaluator.ResolveAllValues(r.Context(), requestID, context)
if err != nil {
h.Logger.WarnWithID(requestID, fmt.Sprintf("error from resolver: %v", err))
@ -123,8 +127,10 @@ func extractOfrepRequest(req *http.Request) (ofrep.Request, error) {
return request, nil
}
// flagdContext returns combined context values from headers, static context (from cli) and request context.
// highest priority > header-context-from-cli > static-context-from-cli > request-context > lowest priority
func flagdContext(
log *logger.Logger, requestID string, request ofrep.Request, staticContextValues map[string]any,
log *logger.Logger, requestID string, request ofrep.Request, staticContextValues map[string]any, headers http.Header, headerToContextKeyMappings map[string]string,
) map[string]any {
context := make(map[string]any)
if res, ok := request.Context.(map[string]any); ok {
@ -139,5 +145,11 @@ func flagdContext(
context[k] = v
}
for header, contextKey := range headerToContextKeyMappings {
if values, ok := headers[header]; ok {
context[contextKey] = values[0]
}
}
return context
}

View File

@ -30,13 +30,13 @@ type Service struct {
}
func NewOfrepService(
evaluator evaluator.IEvaluator, origins []string, cfg SvcConfiguration, contextValues map[string]any,
evaluator evaluator.IEvaluator, origins []string, cfg SvcConfiguration, contextValues map[string]any, headerToContextKeyMappings map[string]string,
) (*Service, error) {
corsMW := cors.New(cors.Options{
AllowedOrigins: origins,
AllowedMethods: []string{http.MethodPost},
})
h := corsMW.Handler(NewOfrepHandler(cfg.Logger, evaluator, contextValues))
h := corsMW.Handler(NewOfrepHandler(cfg.Logger, evaluator, contextValues, headerToContextKeyMappings))
server := http.Server{
Addr: fmt.Sprintf(":%d", cfg.Port),

View File

@ -28,7 +28,7 @@ func Test_OfrepServiceStartStop(t *testing.T) {
Port: uint16(port),
}
service, err := NewOfrepService(eval, []string{"*"}, cfg, nil)
service, err := NewOfrepService(eval, []string{"*"}, cfg, nil, nil)
if err != nil {
t.Fatalf("error creating the ofrep service: %v", err)
}