diff --git a/client/client.go b/client/client.go index 35b1c14..11e88ad 100644 --- a/client/client.go +++ b/client/client.go @@ -37,7 +37,7 @@ import ( const ( daprPortDefault = "50001" - daprPortEnvVarName = "DAPR_GRPC_PORT" + daprPortEnvVarName = "DAPR_GRPC_PORT" /* #nosec */ traceparentKey = "traceparent" apiTokenKey = "dapr-api-token" /* #nosec */ apiTokenEnvVarName = "DAPR_API_TOKEN" /* #nosec */ diff --git a/service/common/service.go b/service/common/service.go index 2c9d251..1b4eda9 100644 --- a/service/common/service.go +++ b/service/common/service.go @@ -20,6 +20,12 @@ import ( "github.com/dapr/go-sdk/actor/config" ) +const ( + // AppAPITokenEnvVar is the environment variable for app api token. + AppAPITokenEnvVar = "APP_API_TOKEN" /* #nosec */ + APITokenKey = "dapr-api-token" /* #nosec */ +) + // Service represents Dapr callback service. type Service interface { // AddServiceInvocationHandler appends provided service invocation handler with its name to the service. diff --git a/service/grpc/invoke.go b/service/grpc/invoke.go index 3ec0bc2..ef08e6d 100644 --- a/service/grpc/invoke.go +++ b/service/grpc/invoke.go @@ -19,6 +19,7 @@ import ( "github.com/golang/protobuf/ptypes/any" "github.com/pkg/errors" + "google.golang.org/grpc/metadata" cpb "github.com/dapr/dapr/pkg/proto/common/v1" cc "github.com/dapr/go-sdk/service/common" @@ -41,6 +42,17 @@ func (s *Server) OnInvoke(ctx context.Context, in *cpb.InvokeRequest) (*cpb.Invo if in == nil { return nil, errors.New("nil invoke request") } + if s.authToken != "" { + if md, ok := metadata.FromIncomingContext(ctx); !ok { + return nil, errors.New("authentication failed") + } else if vals := md.Get(cc.APITokenKey); len(vals) > 0 { + if vals[0] != s.authToken { + return nil, errors.New("authentication failed: app token mismatch") + } + } else { + return nil, errors.New("authentication failed. app token key not exist") + } + } if fn, ok := s.invokeHandlers[in.Method]; ok { e := &cc.InvocationEvent{} e.ContentType = in.ContentType diff --git a/service/grpc/invoke_test.go b/service/grpc/invoke_test.go index 47af77d..da79db8 100644 --- a/service/grpc/invoke_test.go +++ b/service/grpc/invoke_test.go @@ -15,10 +15,12 @@ package grpc import ( "context" + "os" "testing" "github.com/pkg/errors" "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" "google.golang.org/protobuf/types/known/anypb" "github.com/dapr/dapr/pkg/proto/common/v1" @@ -48,6 +50,39 @@ func TestInvokeErrors(t *testing.T) { assert.Error(t, err) } +func TestInvokeWithToken(t *testing.T) { + _ = os.Setenv(cc.AppAPITokenEnvVar, "app-dapr-token") + server := getTestServer() + startTestServer(server) + methodName := "test" + err := server.AddServiceInvocationHandler(methodName, testInvokeHandler) + assert.Nil(t, err) + t.Run("invoke with token, return success", func(t *testing.T) { + grpcMetadata := metadata.New(map[string]string{ + cc.APITokenKey: os.Getenv(cc.AppAPITokenEnvVar), + }) + ctx := metadata.NewIncomingContext(context.Background(), grpcMetadata) + in := &common.InvokeRequest{Method: methodName} + _, err := server.OnInvoke(ctx, in) + assert.Nil(t, err) + }) + t.Run("invoke with empty token, return failed", func(t *testing.T) { + in := &common.InvokeRequest{Method: methodName} + _, err := server.OnInvoke(context.Background(), in) + assert.Error(t, err) + }) + t.Run("invoke with mismatch token, return failed", func(t *testing.T) { + grpcMetadata := metadata.New(map[string]string{ + cc.APITokenKey: "mismatch-token", + }) + ctx := metadata.NewOutgoingContext(context.Background(), grpcMetadata) + in := &common.InvokeRequest{Method: methodName} + _, err := server.OnInvoke(ctx, in) + assert.Error(t, err) + }) + _ = os.Unsetenv(cc.AppAPITokenEnvVar) +} + // go test -timeout 30s ./service/grpc -count 1 -run ^TestInvoke$ func TestInvoke(t *testing.T) { methodName := "test" diff --git a/service/grpc/service.go b/service/grpc/service.go index 056ebda..008ca24 100644 --- a/service/grpc/service.go +++ b/service/grpc/service.go @@ -16,6 +16,7 @@ package grpc import ( "context" "net" + "os" pb "github.com/dapr/dapr/pkg/proto/runtime/v1" "github.com/dapr/go-sdk/actor" @@ -52,6 +53,7 @@ func newService(lis net.Listener) *Server { invokeHandlers: make(map[string]func(ctx context.Context, in *common.InvocationEvent) (out *common.Content, err error)), topicSubscriptions: make(map[string]*topicEventHandler), bindingHandlers: make(map[string]func(ctx context.Context, in *common.BindingEvent) (out []byte, err error)), + authToken: os.Getenv(common.AppAPITokenEnvVar), } } @@ -62,6 +64,7 @@ type Server struct { invokeHandlers map[string]func(ctx context.Context, in *common.InvocationEvent) (out *common.Content, err error) topicSubscriptions map[string]*topicEventHandler bindingHandlers map[string]func(ctx context.Context, in *common.BindingEvent) (out []byte, err error) + authToken string } func (s *Server) RegisterActorImplFactory(f actor.Factory, opts ...config.Option) { diff --git a/service/http/invoke.go b/service/http/invoke.go index aacdabf..d9af42b 100644 --- a/service/http/invoke.go +++ b/service/http/invoke.go @@ -38,6 +38,13 @@ func (s *Server) AddServiceInvocationHandler(route string, fn func(ctx context.C s.mux.Handle(route, optionsHandler(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { + if s.authToken != "" { + token := r.Header.Get(common.APITokenKey) + if token == "" || token != s.authToken { + http.Error(w, "authentication failed.", http.StatusNonAuthoritativeInfo) + return + } + } // capture http args e := &common.InvocationEvent{ Verb: r.Method, diff --git a/service/http/invoke_test.go b/service/http/invoke_test.go index bd0989b..213af89 100644 --- a/service/http/invoke_test.go +++ b/service/http/invoke_test.go @@ -19,6 +19,7 @@ import ( "io/ioutil" "net/http" "net/http/httptest" + "os" "strings" "testing" @@ -33,6 +34,41 @@ func TestInvocationHandlerWithoutHandler(t *testing.T) { assert.Errorf(t, err, "expected error adding event handler") } +func TestInvocationHandlerWithToken(t *testing.T) { + data := `{"name": "test", "data": hellow}` + _ = os.Setenv(common.AppAPITokenEnvVar, "app-dapr-token") + s := newServer("", nil) + err := s.AddServiceInvocationHandler("/", func(ctx context.Context, in *common.InvocationEvent) (out *common.Content, err error) { + if in == nil || in.Data == nil || in.ContentType == "" { + err = errors.New("nil input") + return + } + out = &common.Content{ + Data: in.Data, + ContentType: in.ContentType, + DataTypeURL: in.DataTypeURL, + } + return + }) + assert.NoErrorf(t, err, "error adding event handler") + + // forbbiden. + req, err := http.NewRequest(http.MethodPost, "/", strings.NewReader(data)) + assert.NoErrorf(t, err, "error creating request") + req.Header.Set("Content-Type", "application/json") + + resp := httptest.NewRecorder() + s.mux.ServeHTTP(resp, req) + assert.Equal(t, http.StatusNonAuthoritativeInfo, resp.Code) + + // pass. + req.Header.Set(common.APITokenKey, os.Getenv(common.AppAPITokenEnvVar)) + resp = httptest.NewRecorder() + s.mux.ServeHTTP(resp, req) + assert.Equal(t, http.StatusOK, resp.Code) + _ = os.Unsetenv(common.AppAPITokenEnvVar) +} + func TestInvocationHandlerWithData(t *testing.T) { data := `{"name": "test", "data": hellow}` s := newServer("", nil) diff --git a/service/http/service.go b/service/http/service.go index 51a760e..a6c2f68 100644 --- a/service/http/service.go +++ b/service/http/service.go @@ -16,6 +16,7 @@ package http import ( "context" "net/http" + "os" "time" "github.com/gorilla/mux" @@ -49,6 +50,7 @@ func newServer(address string, router *mux.Router) *Server { }, mux: router, topicSubscriptions: make([]*common.Subscription, 0), + authToken: os.Getenv(common.AppAPITokenEnvVar), } } @@ -58,6 +60,7 @@ type Server struct { mux *mux.Router httpServer *http.Server topicSubscriptions []*common.Subscription + authToken string } func (s *Server) RegisterActorImplFactory(f actor.Factory, opts ...config.Option) {