feat: updating context using headers (#1641)
<!-- Please use this template for your pull request. --> <!-- Please use the sections that you need and delete other sections --> Able to update context map using headers present in - OFREP requests - Connect Requests (via Flag Evaluator V2 service) ### Related Issues Fixes #1583 ### Notes Context values passed via headers is high priority If same context key is updated via - Headers - Request Body - Static Config _Context via Headers will be considered_ ### Usage ``` flagd start --port 8013 --uri file:./samples/example_flags.flagd.json -H Header=contextKey ``` or ``` flagd start --port 8013 --uri file:./samples/example_flags.flagd.json --context-from-header Header=contextKey ``` --------- Signed-off-by: Rahul Baradol <rahul.baradol.14@gmail.com>
This commit is contained in:
parent
bf10ff30dc
commit
ba348152b6
|
|
@ -33,6 +33,7 @@ type Configuration struct {
|
|||
CORS []string
|
||||
Options []connect.HandlerOption
|
||||
ContextValues map[string]any
|
||||
HeaderToContextKeyMappings map[string]string
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ flagd start [flags]
|
|||
### Options
|
||||
|
||||
```
|
||||
-H, --context-from-header stringToString add key-value pairs to map header values to context values, where key is Header name, value is context key (default [])
|
||||
-X, --context-value stringToString add arbitrary key value pairs to the flag evaluation context (default [])
|
||||
-C, --cors-origin strings CORS allowed origins, * will allow all origins
|
||||
-h, --help help for start
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ const (
|
|||
syncSocketPathFlagName = "sync-socket-path"
|
||||
uriFlagName = "uri"
|
||||
contextValueFlagName = "context-value"
|
||||
headerToContextKeyFlagName = "context-from-header"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
|
@ -84,6 +85,8 @@ func init() {
|
|||
"from disk")
|
||||
flags.StringToStringP(contextValueFlagName, "X", map[string]string{}, "add arbitrary key value pairs "+
|
||||
"to the flag evaluation context")
|
||||
flags.StringToStringP(headerToContextKeyFlagName, "H", map[string]string{}, "add key-value pairs to map " +
|
||||
"header values to context values, where key is Header name, value is context key")
|
||||
|
||||
bindFlags(flags)
|
||||
}
|
||||
|
|
@ -107,6 +110,7 @@ func bindFlags(flags *pflag.FlagSet) {
|
|||
_ = viper.BindPFlag(syncSocketPathFlagName, flags.Lookup(syncSocketPathFlagName))
|
||||
_ = viper.BindPFlag(ofrepPortFlagName, flags.Lookup(ofrepPortFlagName))
|
||||
_ = viper.BindPFlag(contextValueFlagName, flags.Lookup(contextValueFlagName))
|
||||
_ = viper.BindPFlag(headerToContextKeyFlagName, flags.Lookup(headerToContextKeyFlagName))
|
||||
}
|
||||
|
||||
// startCmd represents the start command
|
||||
|
|
@ -156,6 +160,11 @@ var startCmd = &cobra.Command{
|
|||
contextValuesToMap[k] = v
|
||||
}
|
||||
|
||||
headerToContextKeyMappings := make(map[string]string)
|
||||
for k, v := range viper.GetStringMapString(headerToContextKeyFlagName) {
|
||||
headerToContextKeyMappings[k] = v
|
||||
}
|
||||
|
||||
// Build Runtime -----------------------------------------------------------
|
||||
rt, err := runtime.FromConfig(logger, Version, runtime.Config{
|
||||
CORS: viper.GetStringSlice(corsFlagName),
|
||||
|
|
@ -175,6 +184,7 @@ var startCmd = &cobra.Command{
|
|||
SyncServiceSocketPath: viper.GetString(syncSocketPathFlagName),
|
||||
SyncProviders: syncProviders,
|
||||
ContextValues: contextValuesToMap,
|
||||
HeaderToContextKeyMappings: headerToContextKeyMappings,
|
||||
})
|
||||
if err != nil {
|
||||
rtLogger.Fatal(err.Error())
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ type Config struct {
|
|||
CORS []string
|
||||
|
||||
ContextValues map[string]any
|
||||
HeaderToContextKeyMappings map[string]string
|
||||
}
|
||||
|
||||
// FromConfig builds a runtime from startup configurations
|
||||
|
|
@ -106,6 +107,7 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime,
|
|||
Port: config.OfrepServicePort,
|
||||
},
|
||||
config.ContextValues,
|
||||
config.HeaderToContextKeyMappings,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating ofrep service")
|
||||
|
|
@ -155,6 +157,7 @@ func FromConfig(logger *logger.Logger, version string, config Config) (*Runtime,
|
|||
CORS: config.CORS,
|
||||
Options: options,
|
||||
ContextValues: config.ContextValues,
|
||||
HeaderToContextKeyMappings: config.HeaderToContextKeyMappings,
|
||||
},
|
||||
SyncImpl: iSyncs,
|
||||
}, nil
|
||||
|
|
|
|||
|
|
@ -172,6 +172,7 @@ func (s *ConnectService) setupServer(svcConf service.Configuration) (net.Listene
|
|||
s.eventingConfiguration,
|
||||
s.metrics,
|
||||
svcConf.ContextValues,
|
||||
svcConf.HeaderToContextKeyMappings,
|
||||
)
|
||||
|
||||
_, newHandler := evaluationV1.NewServiceHandler(newFes, append(svcConf.Options, marshalOpts)...)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package service
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
schemaV1 "buf.build/gen/go/open-feature/flagd/protocolbuffers/go/schema/v1"
|
||||
|
|
@ -72,7 +73,7 @@ func (s *OldFlagEvaluationService) ResolveAll(
|
|||
Flags: make(map[string]*schemaV1.AnyFlag),
|
||||
}
|
||||
|
||||
values, _, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues))
|
||||
values, _, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), make(map[string]string)))
|
||||
if err != nil {
|
||||
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
|
||||
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)
|
||||
|
|
@ -179,11 +180,13 @@ func (s *OldFlagEvaluationService) ResolveBoolean(
|
|||
sCtx,
|
||||
s.logger,
|
||||
s.eval.ResolveBooleanValue,
|
||||
req.Header(),
|
||||
req.Msg.GetFlagKey(),
|
||||
req.Msg.GetContext(),
|
||||
&booleanResponse{schemaV1Resp: res},
|
||||
s.metrics,
|
||||
s.contextValues,
|
||||
make(map[string]string),
|
||||
)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
|
|
@ -206,11 +209,13 @@ func (s *OldFlagEvaluationService) ResolveString(
|
|||
sCtx,
|
||||
s.logger,
|
||||
s.eval.ResolveStringValue,
|
||||
req.Header(),
|
||||
req.Msg.GetFlagKey(),
|
||||
req.Msg.GetContext(),
|
||||
&stringResponse{schemaV1Resp: res},
|
||||
s.metrics,
|
||||
s.contextValues,
|
||||
make(map[string]string),
|
||||
)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
|
|
@ -233,11 +238,13 @@ func (s *OldFlagEvaluationService) ResolveInt(
|
|||
sCtx,
|
||||
s.logger,
|
||||
s.eval.ResolveIntValue,
|
||||
req.Header(),
|
||||
req.Msg.GetFlagKey(),
|
||||
req.Msg.GetContext(),
|
||||
&intResponse{schemaV1Resp: res},
|
||||
s.metrics,
|
||||
s.contextValues,
|
||||
make(map[string]string),
|
||||
)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
|
|
@ -260,11 +267,13 @@ func (s *OldFlagEvaluationService) ResolveFloat(
|
|||
sCtx,
|
||||
s.logger,
|
||||
s.eval.ResolveFloatValue,
|
||||
req.Header(),
|
||||
req.Msg.GetFlagKey(),
|
||||
req.Msg.GetContext(),
|
||||
&floatResponse{schemaV1Resp: res},
|
||||
s.metrics,
|
||||
s.contextValues,
|
||||
make(map[string]string),
|
||||
)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
|
|
@ -287,11 +296,13 @@ func (s *OldFlagEvaluationService) ResolveObject(
|
|||
sCtx,
|
||||
s.logger,
|
||||
s.eval.ResolveObjectValue,
|
||||
req.Header(),
|
||||
req.Msg.GetFlagKey(),
|
||||
req.Msg.GetContext(),
|
||||
&objectResponse{schemaV1Resp: res},
|
||||
s.metrics,
|
||||
s.contextValues,
|
||||
make(map[string]string),
|
||||
)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
|
|
@ -301,9 +312,9 @@ func (s *OldFlagEvaluationService) ResolveObject(
|
|||
return res, err
|
||||
}
|
||||
|
||||
// mergeContexts combines values from the request context with the values from the config --context-values flag.
|
||||
// Request context values have a higher priority.
|
||||
func mergeContexts(reqCtx, configFlagsCtx map[string]any) map[string]any {
|
||||
// mergeContexts combines context values from headers, static context (from cli) and request context.
|
||||
// highest priority > header-context-from-cli > static-context-from-cli > request-context > lowest priority
|
||||
func mergeContexts(reqCtx, configFlagsCtx map[string]any, headers http.Header, headerToContextKeyMappings map[string]string) map[string]any {
|
||||
merged := make(map[string]any)
|
||||
for k, v := range reqCtx {
|
||||
merged[k] = v
|
||||
|
|
@ -311,18 +322,24 @@ func mergeContexts(reqCtx, configFlagsCtx map[string]any) map[string]any {
|
|||
for k, v := range configFlagsCtx {
|
||||
merged[k] = v
|
||||
}
|
||||
for header, contextKey := range headerToContextKeyMappings {
|
||||
if values, ok := headers[header]; ok {
|
||||
merged[contextKey] = values[0]
|
||||
}
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
// resolve is a generic flag resolver
|
||||
func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver resolverSignature[T], flagKey string,
|
||||
func resolve[T constraints](ctx context.Context, logger *logger.Logger, resolver resolverSignature[T], header http.Header, flagKey string,
|
||||
evaluationContext *structpb.Struct, resp response[T], metrics telemetry.IMetricsRecorder,
|
||||
configContextValues map[string]any,
|
||||
configContextValues map[string]any, configHeaderToContextKeyMappings map[string]string,
|
||||
) error {
|
||||
reqID := xid.New().String()
|
||||
defer logger.ClearFields(reqID)
|
||||
|
||||
mergedContext := mergeContexts(evaluationContext.AsMap(), configContextValues)
|
||||
mergedContext := mergeContexts(evaluationContext.AsMap(), configContextValues, header, configHeaderToContextKeyMappings)
|
||||
|
||||
logger.WriteFields(
|
||||
reqID,
|
||||
zap.String("flag-key", flagKey),
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ type FlagEvaluationService struct {
|
|||
eventingConfiguration IEvents
|
||||
flagEvalTracer trace.Tracer
|
||||
contextValues map[string]any
|
||||
headerToContextKeyMappings map[string]string
|
||||
}
|
||||
|
||||
// NewFlagEvaluationService creates a FlagEvaluationService with provided parameters
|
||||
|
|
@ -34,6 +35,7 @@ func NewFlagEvaluationService(log *logger.Logger,
|
|||
eventingCfg IEvents,
|
||||
metricsRecorder telemetry.IMetricsRecorder,
|
||||
contextValues map[string]any,
|
||||
headerToContextKeyMappings map[string]string,
|
||||
) *FlagEvaluationService {
|
||||
svc := &FlagEvaluationService{
|
||||
logger: log,
|
||||
|
|
@ -42,6 +44,7 @@ func NewFlagEvaluationService(log *logger.Logger,
|
|||
eventingConfiguration: eventingCfg,
|
||||
flagEvalTracer: otel.Tracer("flagd.evaluation.v1"),
|
||||
contextValues: contextValues,
|
||||
headerToContextKeyMappings: headerToContextKeyMappings,
|
||||
}
|
||||
|
||||
if metricsRecorder != nil {
|
||||
|
|
@ -66,8 +69,9 @@ func (s *FlagEvaluationService) ResolveAll(
|
|||
Flags: make(map[string]*evalV1.AnyFlag),
|
||||
}
|
||||
|
||||
resolutions, flagSetMetadata, err := s.eval.ResolveAllValues(sCtx, reqID, mergeContexts(req.Msg.GetContext().AsMap(),
|
||||
s.contextValues))
|
||||
context := mergeContexts(req.Msg.GetContext().AsMap(), s.contextValues, req.Header(), s.headerToContextKeyMappings)
|
||||
|
||||
resolutions, flagSetMetadata, err := s.eval.ResolveAllValues(sCtx, reqID, context)
|
||||
if err != nil {
|
||||
s.logger.WarnWithID(reqID, fmt.Sprintf("error resolving all flags: %v", err))
|
||||
return nil, fmt.Errorf("error resolving flags. Tracking ID: %s", reqID)
|
||||
|
|
@ -185,11 +189,13 @@ func (s *FlagEvaluationService) ResolveBoolean(
|
|||
sCtx,
|
||||
s.logger,
|
||||
s.eval.ResolveBooleanValue,
|
||||
req.Header(),
|
||||
req.Msg.GetFlagKey(),
|
||||
req.Msg.GetContext(),
|
||||
&booleanResponse{evalV1Resp: res},
|
||||
s.metrics,
|
||||
s.contextValues,
|
||||
s.headerToContextKeyMappings,
|
||||
)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
|
|
@ -211,11 +217,13 @@ func (s *FlagEvaluationService) ResolveString(
|
|||
sCtx,
|
||||
s.logger,
|
||||
s.eval.ResolveStringValue,
|
||||
req.Header(),
|
||||
req.Msg.GetFlagKey(),
|
||||
req.Msg.GetContext(),
|
||||
&stringResponse{evalV1Resp: res},
|
||||
s.metrics,
|
||||
s.contextValues,
|
||||
s.headerToContextKeyMappings,
|
||||
)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
|
|
@ -237,11 +245,13 @@ func (s *FlagEvaluationService) ResolveInt(
|
|||
sCtx,
|
||||
s.logger,
|
||||
s.eval.ResolveIntValue,
|
||||
req.Header(),
|
||||
req.Msg.GetFlagKey(),
|
||||
req.Msg.GetContext(),
|
||||
&intResponse{evalV1Resp: res},
|
||||
s.metrics,
|
||||
s.contextValues,
|
||||
s.headerToContextKeyMappings,
|
||||
)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
|
|
@ -263,11 +273,13 @@ func (s *FlagEvaluationService) ResolveFloat(
|
|||
sCtx,
|
||||
s.logger,
|
||||
s.eval.ResolveFloatValue,
|
||||
req.Header(),
|
||||
req.Msg.GetFlagKey(),
|
||||
req.Msg.GetContext(),
|
||||
&floatResponse{evalV1Resp: res},
|
||||
s.metrics,
|
||||
s.contextValues,
|
||||
s.headerToContextKeyMappings,
|
||||
)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
|
|
@ -289,11 +301,13 @@ func (s *FlagEvaluationService) ResolveObject(
|
|||
sCtx,
|
||||
s.logger,
|
||||
s.eval.ResolveObjectValue,
|
||||
req.Header(),
|
||||
req.Msg.GetFlagKey(),
|
||||
req.Msg.GetContext(),
|
||||
&objectResponse{evalV1Resp: res},
|
||||
s.metrics,
|
||||
s.contextValues,
|
||||
s.headerToContextKeyMappings,
|
||||
)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package service
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
|
|
@ -103,7 +104,7 @@ func TestConnectServiceV2_ResolveAll(t *testing.T) {
|
|||
).AnyTimes()
|
||||
|
||||
metrics, exp := getMetricReader()
|
||||
s := NewFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil)
|
||||
s := NewFlagEvaluationService(logger.NewLogger(nil, false), eval, &eventingConfiguration{}, metrics, nil, nil)
|
||||
|
||||
// when
|
||||
got, err := s.ResolveAll(context.Background(), connect.NewRequest(tt.req))
|
||||
|
|
@ -220,6 +221,7 @@ func TestFlag_EvaluationV2_ResolveBoolean(t *testing.T) {
|
|||
&eventingConfiguration{},
|
||||
metrics,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
got, err := s.ResolveBoolean(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
|
||||
if (err != nil) && !errors.Is(err, tt.wantErr) {
|
||||
|
|
@ -276,6 +278,7 @@ func BenchmarkFlag_EvaluationV2_ResolveBoolean(b *testing.B) {
|
|||
&eventingConfiguration{},
|
||||
metrics,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
b.Run(name, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
|
@ -375,6 +378,7 @@ func TestFlag_EvaluationV2_ResolveString(t *testing.T) {
|
|||
&eventingConfiguration{},
|
||||
metrics,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
got, err := s.ResolveString(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
|
||||
if (err != nil) && !errors.Is(err, tt.wantErr) {
|
||||
|
|
@ -431,6 +435,7 @@ func BenchmarkFlag_EvaluationV2_ResolveString(b *testing.B) {
|
|||
&eventingConfiguration{},
|
||||
metrics,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
b.Run(name, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
|
@ -529,6 +534,7 @@ func TestFlag_EvaluationV2_ResolveFloat(t *testing.T) {
|
|||
&eventingConfiguration{},
|
||||
metrics,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
got, err := s.ResolveFloat(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
|
||||
if (err != nil) && !errors.Is(err, tt.wantErr) {
|
||||
|
|
@ -585,6 +591,7 @@ func BenchmarkFlag_EvaluationV2_ResolveFloat(b *testing.B) {
|
|||
&eventingConfiguration{},
|
||||
metrics,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
b.Run(name, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
|
@ -683,6 +690,7 @@ func TestFlag_EvaluationV2_ResolveInt(t *testing.T) {
|
|||
&eventingConfiguration{},
|
||||
metrics,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
got, err := s.ResolveInt(tt.functionArgs.ctx, connect.NewRequest(tt.functionArgs.req))
|
||||
if (err != nil) && !errors.Is(err, tt.wantErr) {
|
||||
|
|
@ -739,6 +747,7 @@ func BenchmarkFlag_EvaluationV2_ResolveInt(b *testing.B) {
|
|||
&eventingConfiguration{},
|
||||
metrics,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
b.Run(name, func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
|
|
@ -840,6 +849,7 @@ func TestFlag_EvaluationV2_ResolveObject(t *testing.T) {
|
|||
&eventingConfiguration{},
|
||||
metrics,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
outParsed, err := structpb.NewStruct(tt.evalFields.result)
|
||||
|
|
@ -904,6 +914,7 @@ func BenchmarkFlag_EvaluationV2_ResolveObject(b *testing.B) {
|
|||
&eventingConfiguration{},
|
||||
metrics,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
if name != "eval returns error" {
|
||||
outParsed, err := structpb.NewStruct(tt.evalFields.result)
|
||||
|
|
@ -979,7 +990,10 @@ func TestFlag_EvaluationV2_ErrorCodes(t *testing.T) {
|
|||
|
||||
func Test_mergeContexts(t *testing.T) {
|
||||
type args struct {
|
||||
clientContext, configContext map[string]any
|
||||
headers http.Header
|
||||
headerToContextKeyMappings map[string]string
|
||||
clientContext map[string]any
|
||||
configContext map[string]any
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
|
@ -988,19 +1002,54 @@ func Test_mergeContexts(t *testing.T) {
|
|||
want map[string]any
|
||||
}{
|
||||
{
|
||||
name: "merge contexts",
|
||||
name: "merge contexts with no headers, with no header-context mappings",
|
||||
args: args{
|
||||
clientContext: map[string]any{"k1": "v1", "k2": "v2"},
|
||||
configContext: map[string]any{"k2": "v22", "k3": "v3"},
|
||||
headers: http.Header{},
|
||||
headerToContextKeyMappings: map[string]string{},
|
||||
},
|
||||
// static context should "win"
|
||||
want: map[string]any{"k1": "v1", "k2": "v22", "k3": "v3"},
|
||||
},
|
||||
{
|
||||
name: "merge contexts with headers, with no header-context mappings",
|
||||
args: args{
|
||||
clientContext: map[string]any{"k1": "v1", "k2": "v2"},
|
||||
configContext: map[string]any{"k2": "v22", "k3": "v3"},
|
||||
headers: http.Header{"X-key": []string{"value"}, "X-token": []string{"token"}},
|
||||
headerToContextKeyMappings: map[string]string{},
|
||||
},
|
||||
// static context should "win"
|
||||
want: map[string]any{"k1": "v1", "k2": "v22", "k3": "v3"},
|
||||
},
|
||||
{
|
||||
name: "merge contexts with no headers, with header-context mappings",
|
||||
args: args{
|
||||
clientContext: map[string]any{"k1": "v1", "k2": "v2"},
|
||||
configContext: map[string]any{"k2": "v22", "k3": "v3"},
|
||||
headers: http.Header{},
|
||||
headerToContextKeyMappings: map[string]string{"X-key": "k2"},
|
||||
},
|
||||
// static context should "win"
|
||||
want: map[string]any{"k1": "v1", "k2": "v22", "k3": "v3"},
|
||||
},
|
||||
{
|
||||
name: "merge contexts with headers, with header-context mappings",
|
||||
args: args{
|
||||
clientContext: map[string]any{"k1": "v1", "k2": "v2"},
|
||||
configContext: map[string]any{"k2": "v22", "k3": "v3"},
|
||||
headers: http.Header{"X-key": []string{"value"}, "X-token": []string{"token"}},
|
||||
headerToContextKeyMappings: map[string]string{"X-key": "k2"},
|
||||
},
|
||||
// header context should "win"
|
||||
want: map[string]any{"k1": "v1", "k2": "value", "k3": "v3"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := mergeContexts(tt.args.clientContext, tt.args.configContext)
|
||||
got := mergeContexts(tt.args.clientContext, tt.args.configContext, tt.args.headers, tt.args.headerToContextKeyMappings)
|
||||
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("\ngot: %+v\nwant: %+v", got, tt.want)
|
||||
|
|
|
|||
|
|
@ -26,14 +26,16 @@ type handler struct {
|
|||
Logger *logger.Logger
|
||||
evaluator evaluator.IEvaluator
|
||||
contextValues map[string]any
|
||||
headerToContextKeyMappings map[string]string
|
||||
tracer trace.Tracer
|
||||
}
|
||||
|
||||
func NewOfrepHandler(logger *logger.Logger, evaluator evaluator.IEvaluator, contextValues map[string]any) http.Handler {
|
||||
func NewOfrepHandler(logger *logger.Logger, evaluator evaluator.IEvaluator, contextValues map[string]any, headerToContextKeyMappings map[string]string) http.Handler {
|
||||
h := handler{
|
||||
Logger: logger,
|
||||
evaluator: evaluator,
|
||||
contextValues: contextValues,
|
||||
headerToContextKeyMappings: headerToContextKeyMappings,
|
||||
tracer: otel.Tracer("flagd.ofrep.v1"),
|
||||
}
|
||||
|
||||
|
|
@ -62,7 +64,8 @@ func (h *handler) HandleFlagEvaluation(w http.ResponseWriter, r *http.Request) {
|
|||
h.writeJSONToResponse(http.StatusBadRequest, ofrep.ContextErrorResponseFrom(flagKey), w)
|
||||
return
|
||||
}
|
||||
context := flagdContext(h.Logger, requestID, request, h.contextValues)
|
||||
context := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings)
|
||||
|
||||
evaluation := h.evaluator.ResolveAsAnyValue(r.Context(), requestID, flagKey, context)
|
||||
if evaluation.Error != nil {
|
||||
status, evaluationError := ofrep.EvaluationErrorResponseFrom(evaluation)
|
||||
|
|
@ -82,7 +85,8 @@ func (h *handler) HandleBulkEvaluation(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
context := flagdContext(h.Logger, requestID, request, h.contextValues)
|
||||
context := flagdContext(h.Logger, requestID, request, h.contextValues, r.Header, h.headerToContextKeyMappings)
|
||||
|
||||
evaluations, metadata, err := h.evaluator.ResolveAllValues(r.Context(), requestID, context)
|
||||
if err != nil {
|
||||
h.Logger.WarnWithID(requestID, fmt.Sprintf("error from resolver: %v", err))
|
||||
|
|
@ -123,8 +127,10 @@ func extractOfrepRequest(req *http.Request) (ofrep.Request, error) {
|
|||
return request, nil
|
||||
}
|
||||
|
||||
// flagdContext returns combined context values from headers, static context (from cli) and request context.
|
||||
// highest priority > header-context-from-cli > static-context-from-cli > request-context > lowest priority
|
||||
func flagdContext(
|
||||
log *logger.Logger, requestID string, request ofrep.Request, staticContextValues map[string]any,
|
||||
log *logger.Logger, requestID string, request ofrep.Request, staticContextValues map[string]any, headers http.Header, headerToContextKeyMappings map[string]string,
|
||||
) map[string]any {
|
||||
context := make(map[string]any)
|
||||
if res, ok := request.Context.(map[string]any); ok {
|
||||
|
|
@ -139,5 +145,11 @@ func flagdContext(
|
|||
context[k] = v
|
||||
}
|
||||
|
||||
for header, contextKey := range headerToContextKeyMappings {
|
||||
if values, ok := headers[header]; ok {
|
||||
context[contextKey] = values[0]
|
||||
}
|
||||
}
|
||||
|
||||
return context
|
||||
}
|
||||
|
|
@ -30,13 +30,13 @@ type Service struct {
|
|||
}
|
||||
|
||||
func NewOfrepService(
|
||||
evaluator evaluator.IEvaluator, origins []string, cfg SvcConfiguration, contextValues map[string]any,
|
||||
evaluator evaluator.IEvaluator, origins []string, cfg SvcConfiguration, contextValues map[string]any, headerToContextKeyMappings map[string]string,
|
||||
) (*Service, error) {
|
||||
corsMW := cors.New(cors.Options{
|
||||
AllowedOrigins: origins,
|
||||
AllowedMethods: []string{http.MethodPost},
|
||||
})
|
||||
h := corsMW.Handler(NewOfrepHandler(cfg.Logger, evaluator, contextValues))
|
||||
h := corsMW.Handler(NewOfrepHandler(cfg.Logger, evaluator, contextValues, headerToContextKeyMappings))
|
||||
|
||||
server := http.Server{
|
||||
Addr: fmt.Sprintf(":%d", cfg.Port),
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ func Test_OfrepServiceStartStop(t *testing.T) {
|
|||
Port: uint16(port),
|
||||
}
|
||||
|
||||
service, err := NewOfrepService(eval, []string{"*"}, cfg, nil)
|
||||
service, err := NewOfrepService(eval, []string{"*"}, cfg, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("error creating the ofrep service: %v", err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue