mirror of https://github.com/grpc/grpc-go.git
239 lines
7.4 KiB
Go
239 lines
7.4 KiB
Go
/*
|
|
*
|
|
* Copyright 2023 gRPC authors.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*
|
|
*/
|
|
|
|
// Binary server demonstrates how to validate authorization credential metadata
|
|
// for incoming RPCs.
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/authz"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/credentials"
|
|
"google.golang.org/grpc/examples/data"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/status"
|
|
|
|
"google.golang.org/grpc/examples/features/authz/token"
|
|
pb "google.golang.org/grpc/examples/features/proto/echo"
|
|
)
|
|
|
|
const (
|
|
unaryEchoWriterRole = "UNARY_ECHO:W"
|
|
streamEchoReadWriterRole = "STREAM_ECHO:RW"
|
|
authzPolicy = `
|
|
{
|
|
"name": "authz",
|
|
"allow_rules": [
|
|
{
|
|
"name": "allow_UnaryEcho",
|
|
"request": {
|
|
"paths": ["/grpc.examples.echo.Echo/UnaryEcho"],
|
|
"headers": [
|
|
{
|
|
"key": "UNARY_ECHO:W",
|
|
"values": ["true"]
|
|
}
|
|
]
|
|
}
|
|
},
|
|
{
|
|
"name": "allow_BidirectionalStreamingEcho",
|
|
"request": {
|
|
"paths": ["/grpc.examples.echo.Echo/BidirectionalStreamingEcho"],
|
|
"headers": [
|
|
{
|
|
"key": "STREAM_ECHO:RW",
|
|
"values": ["true"]
|
|
}
|
|
]
|
|
}
|
|
}
|
|
],
|
|
"deny_rules": []
|
|
}
|
|
`
|
|
authzOptStatic = "static"
|
|
authzOptFileWatcher = "filewatcher"
|
|
)
|
|
|
|
var (
|
|
port = flag.Int("port", 50051, "the port to serve on")
|
|
authzOpt = flag.String("authz-option", authzOptStatic, "the authz option (static or filewatcher)")
|
|
|
|
errMissingMetadata = status.Errorf(codes.InvalidArgument, "missing metadata")
|
|
)
|
|
|
|
func newContextWithRoles(ctx context.Context, username string) context.Context {
|
|
md := metadata.MD{}
|
|
if username == "super-user" {
|
|
md.Set(unaryEchoWriterRole, "true")
|
|
md.Set(streamEchoReadWriterRole, "true")
|
|
}
|
|
return metadata.NewIncomingContext(ctx, md)
|
|
}
|
|
|
|
type server struct {
|
|
pb.UnimplementedEchoServer
|
|
}
|
|
|
|
func (s *server) UnaryEcho(_ context.Context, in *pb.EchoRequest) (*pb.EchoResponse, error) {
|
|
fmt.Printf("unary echoing message %q\n", in.Message)
|
|
return &pb.EchoResponse{Message: in.Message}, nil
|
|
}
|
|
|
|
func (s *server) BidirectionalStreamingEcho(stream pb.Echo_BidirectionalStreamingEchoServer) error {
|
|
for {
|
|
in, err := stream.Recv()
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
return nil
|
|
}
|
|
fmt.Printf("Receiving message from stream: %v\n", err)
|
|
return err
|
|
}
|
|
fmt.Printf("bidi echoing message %q\n", in.Message)
|
|
stream.Send(&pb.EchoResponse{Message: in.Message})
|
|
}
|
|
}
|
|
|
|
// isAuthenticated validates the authorization.
|
|
func isAuthenticated(authorization []string) (username string, err error) {
|
|
if len(authorization) < 1 {
|
|
return "", errors.New("received empty authorization token from client")
|
|
}
|
|
tokenBase64 := strings.TrimPrefix(authorization[0], "Bearer ")
|
|
// Perform the token validation here. For the sake of this example, the code
|
|
// here forgoes any of the usual OAuth2 token validation and instead checks
|
|
// for a token matching an arbitrary string.
|
|
var token token.Token
|
|
err = token.Decode(tokenBase64)
|
|
if err != nil {
|
|
return "", fmt.Errorf("base64 decoding of received token %q: %v", tokenBase64, err)
|
|
}
|
|
if token.Secret != "super-secret" {
|
|
return "", fmt.Errorf("received token %q does not match expected %q", token.Secret, "super-secret")
|
|
}
|
|
return token.Username, nil
|
|
}
|
|
|
|
// authUnaryInterceptor looks up the authorization header from the incoming RPC context,
|
|
// retrieves the username from it and creates a new context with the username before invoking
|
|
// the provided handler.
|
|
func authUnaryInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
|
md, ok := metadata.FromIncomingContext(ctx)
|
|
if !ok {
|
|
return nil, errMissingMetadata
|
|
}
|
|
username, err := isAuthenticated(md["authorization"])
|
|
if err != nil {
|
|
return nil, status.Error(codes.Unauthenticated, err.Error())
|
|
}
|
|
return handler(newContextWithRoles(ctx, username), req)
|
|
}
|
|
|
|
// wrappedStream wraps a grpc.ServerStream associated with an incoming RPC, and
|
|
// a custom context containing the username derived from the authorization header
|
|
// specified in the incoming RPC metadata
|
|
type wrappedStream struct {
|
|
grpc.ServerStream
|
|
ctx context.Context
|
|
}
|
|
|
|
func (w *wrappedStream) Context() context.Context {
|
|
return w.ctx
|
|
}
|
|
|
|
func newWrappedStream(ctx context.Context, s grpc.ServerStream) grpc.ServerStream {
|
|
return &wrappedStream{s, ctx}
|
|
}
|
|
|
|
// authStreamInterceptor looks up the authorization header from the incoming RPC context,
|
|
// retrieves the username from it and creates a new context with the username before invoking
|
|
// the provided handler.
|
|
func authStreamInterceptor(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
|
|
md, ok := metadata.FromIncomingContext(ss.Context())
|
|
if !ok {
|
|
return errMissingMetadata
|
|
}
|
|
username, err := isAuthenticated(md["authorization"])
|
|
if err != nil {
|
|
return status.Error(codes.Unauthenticated, err.Error())
|
|
}
|
|
return handler(srv, newWrappedStream(newContextWithRoles(ss.Context(), username), ss))
|
|
}
|
|
|
|
func main() {
|
|
flag.Parse()
|
|
|
|
if *authzOpt != authzOptStatic && *authzOpt != authzOptFileWatcher {
|
|
log.Fatalf("Invalid authz option: %s", *authzOpt)
|
|
}
|
|
|
|
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port))
|
|
if err != nil {
|
|
log.Fatalf("Listening on local port %q: %v", *port, err)
|
|
}
|
|
|
|
// Create tls based credential.
|
|
creds, err := credentials.NewServerTLSFromFile(data.Path("x509/server_cert.pem"), data.Path("x509/server_key.pem"))
|
|
if err != nil {
|
|
log.Fatalf("Loading credentials: %v", err)
|
|
}
|
|
|
|
// Create authorization interceptors according to the authz-option command-line flag.
|
|
var unaryAuthzInterceptor grpc.UnaryServerInterceptor
|
|
var streamAuthzInterceptor grpc.StreamServerInterceptor
|
|
if *authzOpt == authzOptStatic {
|
|
// Create an authorization interceptor using a static policy.
|
|
staticInterceptor, err := authz.NewStatic(authzPolicy)
|
|
if err != nil {
|
|
log.Fatalf("Creating a static authz interceptor: %v", err)
|
|
}
|
|
unaryAuthzInterceptor, streamAuthzInterceptor = staticInterceptor.UnaryInterceptor, staticInterceptor.StreamInterceptor
|
|
} else if *authzOpt == authzOptFileWatcher {
|
|
// Create an authorization interceptor by watching a policy file.
|
|
fileWatcherInterceptor, err := authz.NewFileWatcher(data.Path("rbac/policy.json"), 100*time.Millisecond)
|
|
if err != nil {
|
|
log.Fatalf("Creating a file watcher authz interceptor: %v", err)
|
|
}
|
|
unaryAuthzInterceptor, streamAuthzInterceptor = fileWatcherInterceptor.UnaryInterceptor, fileWatcherInterceptor.StreamInterceptor
|
|
}
|
|
|
|
unaryInterceptors := grpc.ChainUnaryInterceptor(authUnaryInterceptor, unaryAuthzInterceptor)
|
|
streamInterceptors := grpc.ChainStreamInterceptor(authStreamInterceptor, streamAuthzInterceptor)
|
|
s := grpc.NewServer(grpc.Creds(creds), unaryInterceptors, streamInterceptors)
|
|
|
|
// Register EchoServer on the server.
|
|
pb.RegisterEchoServer(s, &server{})
|
|
|
|
if err := s.Serve(lis); err != nil {
|
|
log.Fatalf("Serving Echo service on local port: %v", err)
|
|
}
|
|
}
|