grpc-go/examples/features/authz/server/main.go

239 lines
7.4 KiB
Go

/*
*
* Copyright 2023 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
// Binary server demonstrates how to validate authorization credential metadata
// for incoming RPCs.
package main
import (
"context"
"errors"
"flag"
"fmt"
"io"
"log"
"net"
"strings"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/authz"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/examples/data"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/grpc/examples/features/authz/token"
pb "google.golang.org/grpc/examples/features/proto/echo"
)
const (
unaryEchoWriterRole = "UNARY_ECHO:W"
streamEchoReadWriterRole = "STREAM_ECHO:RW"
authzPolicy = `
{
"name": "authz",
"allow_rules": [
{
"name": "allow_UnaryEcho",
"request": {
"paths": ["/grpc.examples.echo.Echo/UnaryEcho"],
"headers": [
{
"key": "UNARY_ECHO:W",
"values": ["true"]
}
]
}
},
{
"name": "allow_BidirectionalStreamingEcho",
"request": {
"paths": ["/grpc.examples.echo.Echo/BidirectionalStreamingEcho"],
"headers": [
{
"key": "STREAM_ECHO:RW",
"values": ["true"]
}
]
}
}
],
"deny_rules": []
}
`
authzOptStatic = "static"
authzOptFileWatcher = "filewatcher"
)
var (
port = flag.Int("port", 50051, "the port to serve on")
authzOpt = flag.String("authz-option", authzOptStatic, "the authz option (static or filewatcher)")
errMissingMetadata = status.Errorf(codes.InvalidArgument, "missing metadata")
)
func newContextWithRoles(ctx context.Context, username string) context.Context {
md := metadata.MD{}
if username == "super-user" {
md.Set(unaryEchoWriterRole, "true")
md.Set(streamEchoReadWriterRole, "true")
}
return metadata.NewIncomingContext(ctx, md)
}
type server struct {
pb.UnimplementedEchoServer
}
func (s *server) UnaryEcho(_ context.Context, in *pb.EchoRequest) (*pb.EchoResponse, error) {
fmt.Printf("unary echoing message %q\n", in.Message)
return &pb.EchoResponse{Message: in.Message}, nil
}
func (s *server) BidirectionalStreamingEcho(stream pb.Echo_BidirectionalStreamingEchoServer) error {
for {
in, err := stream.Recv()
if err != nil {
if err == io.EOF {
return nil
}
fmt.Printf("Receiving message from stream: %v\n", err)
return err
}
fmt.Printf("bidi echoing message %q\n", in.Message)
stream.Send(&pb.EchoResponse{Message: in.Message})
}
}
// isAuthenticated validates the authorization.
func isAuthenticated(authorization []string) (username string, err error) {
if len(authorization) < 1 {
return "", errors.New("received empty authorization token from client")
}
tokenBase64 := strings.TrimPrefix(authorization[0], "Bearer ")
// Perform the token validation here. For the sake of this example, the code
// here forgoes any of the usual OAuth2 token validation and instead checks
// for a token matching an arbitrary string.
var token token.Token
err = token.Decode(tokenBase64)
if err != nil {
return "", fmt.Errorf("base64 decoding of received token %q: %v", tokenBase64, err)
}
if token.Secret != "super-secret" {
return "", fmt.Errorf("received token %q does not match expected %q", token.Secret, "super-secret")
}
return token.Username, nil
}
// authUnaryInterceptor looks up the authorization header from the incoming RPC context,
// retrieves the username from it and creates a new context with the username before invoking
// the provided handler.
func authUnaryInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, errMissingMetadata
}
username, err := isAuthenticated(md["authorization"])
if err != nil {
return nil, status.Error(codes.Unauthenticated, err.Error())
}
return handler(newContextWithRoles(ctx, username), req)
}
// wrappedStream wraps a grpc.ServerStream associated with an incoming RPC, and
// a custom context containing the username derived from the authorization header
// specified in the incoming RPC metadata
type wrappedStream struct {
grpc.ServerStream
ctx context.Context
}
func (w *wrappedStream) Context() context.Context {
return w.ctx
}
func newWrappedStream(ctx context.Context, s grpc.ServerStream) grpc.ServerStream {
return &wrappedStream{s, ctx}
}
// authStreamInterceptor looks up the authorization header from the incoming RPC context,
// retrieves the username from it and creates a new context with the username before invoking
// the provided handler.
func authStreamInterceptor(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
md, ok := metadata.FromIncomingContext(ss.Context())
if !ok {
return errMissingMetadata
}
username, err := isAuthenticated(md["authorization"])
if err != nil {
return status.Error(codes.Unauthenticated, err.Error())
}
return handler(srv, newWrappedStream(newContextWithRoles(ss.Context(), username), ss))
}
func main() {
flag.Parse()
if *authzOpt != authzOptStatic && *authzOpt != authzOptFileWatcher {
log.Fatalf("Invalid authz option: %s", *authzOpt)
}
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", *port))
if err != nil {
log.Fatalf("Listening on local port %q: %v", *port, err)
}
// Create tls based credential.
creds, err := credentials.NewServerTLSFromFile(data.Path("x509/server_cert.pem"), data.Path("x509/server_key.pem"))
if err != nil {
log.Fatalf("Loading credentials: %v", err)
}
// Create authorization interceptors according to the authz-option command-line flag.
var unaryAuthzInterceptor grpc.UnaryServerInterceptor
var streamAuthzInterceptor grpc.StreamServerInterceptor
if *authzOpt == authzOptStatic {
// Create an authorization interceptor using a static policy.
staticInterceptor, err := authz.NewStatic(authzPolicy)
if err != nil {
log.Fatalf("Creating a static authz interceptor: %v", err)
}
unaryAuthzInterceptor, streamAuthzInterceptor = staticInterceptor.UnaryInterceptor, staticInterceptor.StreamInterceptor
} else if *authzOpt == authzOptFileWatcher {
// Create an authorization interceptor by watching a policy file.
fileWatcherInterceptor, err := authz.NewFileWatcher(data.Path("rbac/policy.json"), 100*time.Millisecond)
if err != nil {
log.Fatalf("Creating a file watcher authz interceptor: %v", err)
}
unaryAuthzInterceptor, streamAuthzInterceptor = fileWatcherInterceptor.UnaryInterceptor, fileWatcherInterceptor.StreamInterceptor
}
unaryInterceptors := grpc.ChainUnaryInterceptor(authUnaryInterceptor, unaryAuthzInterceptor)
streamInterceptors := grpc.ChainStreamInterceptor(authStreamInterceptor, streamAuthzInterceptor)
s := grpc.NewServer(grpc.Creds(creds), unaryInterceptors, streamInterceptors)
// Register EchoServer on the server.
pb.RegisterEchoServer(s, &server{})
if err := s.Serve(lis); err != nil {
log.Fatalf("Serving Echo service on local port: %v", err)
}
}