mirror of https://github.com/grpc/grpc-go.git
				
				
				
			allow access of some info of client certificate
This commit is contained in:
		
							parent
							
								
									69288679b3
								
							
						
					
					
						commit
						d12ff72146
					
				| 
						 | 
				
			
			@ -82,15 +82,24 @@ type ProtocolInfo struct {
 | 
			
		|||
// protocols and supported transport security protocols (e.g., TLS, SSL).
 | 
			
		||||
type TransportAuthenticator interface {
 | 
			
		||||
	// ClientHandshake does the authentication handshake specified by the corresponding
 | 
			
		||||
	// authentication protocol on rawConn for clients.
 | 
			
		||||
	ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, error)
 | 
			
		||||
	// ServerHandshake does the authentication handshake for servers.
 | 
			
		||||
	ServerHandshake(rawConn net.Conn) (net.Conn, error)
 | 
			
		||||
	// authentication protocol on rawConn for clients. It returns the authenticated
 | 
			
		||||
	// connection and the corresponding auth information about the connection.
 | 
			
		||||
	ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (net.Conn, map[string][]string, error)
 | 
			
		||||
	// ServerHandshake does the authentication handshake for servers. It returns
 | 
			
		||||
	// the authenticated connection and the corresponding auth information about
 | 
			
		||||
	// the connection.
 | 
			
		||||
	ServerHandshake(rawConn net.Conn) (net.Conn, map[string][]string, error)
 | 
			
		||||
	// Info provides the ProtocolInfo of this TransportAuthenticator.
 | 
			
		||||
	Info() ProtocolInfo
 | 
			
		||||
	Credentials
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	transportSecurityType = "transport_security_type"
 | 
			
		||||
	x509CN = "x509_common_name"
 | 
			
		||||
	x509SAN = "x509_suject_alternative_name"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// tlsCreds is the credentials required for authenticating a connection using TLS.
 | 
			
		||||
type tlsCreds struct {
 | 
			
		||||
	// TLS configuration
 | 
			
		||||
| 
						 | 
				
			
			@ -116,7 +125,7 @@ func (timeoutError) Error() string   { return "credentials: Dial timed out" }
 | 
			
		|||
func (timeoutError) Timeout() bool   { return true }
 | 
			
		||||
func (timeoutError) Temporary() bool { return true }
 | 
			
		||||
 | 
			
		||||
func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, err error) {
 | 
			
		||||
func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.Duration) (_ net.Conn, _ map[string][]string, err error) {
 | 
			
		||||
	// borrow some code from tls.DialWithDialer
 | 
			
		||||
	var errChannel chan error
 | 
			
		||||
	if timeout != 0 {
 | 
			
		||||
| 
						 | 
				
			
			@ -143,18 +152,32 @@ func (c *tlsCreds) ClientHandshake(addr string, rawConn net.Conn, timeout time.D
 | 
			
		|||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		rawConn.Close()
 | 
			
		||||
		return nil, err
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return conn, nil
 | 
			
		||||
	// TODO(zhaoq): Omit the auth info for client now. It is more for
 | 
			
		||||
	// information than anything else.
 | 
			
		||||
	return conn, nil, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, error) {
 | 
			
		||||
func (c *tlsCreds) ServerHandshake(rawConn net.Conn) (net.Conn, map[string][]string, error) {
 | 
			
		||||
	conn := tls.Server(rawConn, &c.config)
 | 
			
		||||
	if err := conn.Handshake(); err != nil {
 | 
			
		||||
		rawConn.Close()
 | 
			
		||||
		return nil, err
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return conn, nil
 | 
			
		||||
	state := conn.ConnectionState()
 | 
			
		||||
	info := make(map[string][]string)
 | 
			
		||||
	info[transportSecurityType] = []string{"tls"}
 | 
			
		||||
	for _, certs := range state.VerifiedChains {
 | 
			
		||||
		fmt.Println("DEBUG: reach here")
 | 
			
		||||
		for _, cert := range certs {
 | 
			
		||||
			info[x509CN] = append(info[x509CN], cert.Subject.CommonName)
 | 
			
		||||
			for _, san := range cert.DNSNames {
 | 
			
		||||
				info[x509SAN] = append(info[x509SAN], san)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return conn, info, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewTLS uses c to construct a TransportAuthenticator based on TLS.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -566,7 +566,7 @@ func main() {
 | 
			
		|||
		doPerRPCCreds(tc)
 | 
			
		||||
	case "oauth2_auth_token":
 | 
			
		||||
		if !*useTLS {
 | 
			
		||||
			grpclog.Fatalf("TLS is not enabled. TLS is required to execute oauth2_token_creds test case.")
 | 
			
		||||
			grpclog.Fatalf("TLS is not enabled. TLS is required to execute oauth2_auth_token test case.")
 | 
			
		||||
		}
 | 
			
		||||
		doOauth2TokenCreds(tc)
 | 
			
		||||
	case "cancel_after_begin":
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -199,8 +199,9 @@ func (s *Server) Serve(lis net.Listener) error {
 | 
			
		|||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		var authInfo map[string][]string
 | 
			
		||||
		if creds, ok := s.opts.creds.(credentials.TransportAuthenticator); ok {
 | 
			
		||||
			c, err = creds.ServerHandshake(c)
 | 
			
		||||
			c, authInfo, err = creds.ServerHandshake(c)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				grpclog.Println("grpc: Server.Serve failed to complete security handshake.")
 | 
			
		||||
				continue
 | 
			
		||||
| 
						 | 
				
			
			@ -212,7 +213,7 @@ func (s *Server) Serve(lis net.Listener) error {
 | 
			
		|||
			c.Close()
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams)
 | 
			
		||||
		st, err := transport.NewServerTransport("http2", c, s.opts.maxConcurrentStreams, authInfo)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			s.mu.Unlock()
 | 
			
		||||
			c.Close()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -124,7 +124,7 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
 | 
			
		|||
			if timeout > 0 {
 | 
			
		||||
				timeout -= time.Since(startT)
 | 
			
		||||
			}
 | 
			
		||||
			conn, connErr = ccreds.ClientHandshake(addr, conn, timeout)
 | 
			
		||||
			conn, _, connErr = ccreds.ClientHandshake(addr, conn, timeout)
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -58,6 +58,7 @@ var ErrIllegalHeaderWrite = errors.New("transport: the stream is done or WriteHe
 | 
			
		|||
type http2Server struct {
 | 
			
		||||
	conn        net.Conn
 | 
			
		||||
	maxStreamID uint32 // max stream ID ever seen
 | 
			
		||||
	authInfo    map[string][]string  // basic auth info about the connection
 | 
			
		||||
	// writableChan synchronizes write access to the transport.
 | 
			
		||||
	// A writer acquires the write lock by sending a value on writableChan
 | 
			
		||||
	// and releases it by receiving from writableChan.
 | 
			
		||||
| 
						 | 
				
			
			@ -88,7 +89,7 @@ type http2Server struct {
 | 
			
		|||
 | 
			
		||||
// newHTTP2Server constructs a ServerTransport based on HTTP2. ConnectionError is
 | 
			
		||||
// returned if something goes wrong.
 | 
			
		||||
func newHTTP2Server(conn net.Conn, maxStreams uint32) (_ ServerTransport, err error) {
 | 
			
		||||
func newHTTP2Server(conn net.Conn, maxStreams uint32, authInfo map[string][]string) (_ ServerTransport, err error) {
 | 
			
		||||
	framer := newFramer(conn)
 | 
			
		||||
	// Send initial settings as connection preface to client.
 | 
			
		||||
	var settings []http2.Setting
 | 
			
		||||
| 
						 | 
				
			
			@ -114,6 +115,7 @@ func newHTTP2Server(conn net.Conn, maxStreams uint32) (_ ServerTransport, err er
 | 
			
		|||
	var buf bytes.Buffer
 | 
			
		||||
	t := &http2Server{
 | 
			
		||||
		conn:            conn,
 | 
			
		||||
		authInfo:        authInfo,
 | 
			
		||||
		framer:          framer,
 | 
			
		||||
		hBuf:            &buf,
 | 
			
		||||
		hEnc:            hpack.NewEncoder(&buf),
 | 
			
		||||
| 
						 | 
				
			
			@ -235,6 +237,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
 | 
			
		|||
	t.handleSettings(sf)
 | 
			
		||||
 | 
			
		||||
	hDec := newHPACKDecoder()
 | 
			
		||||
	hDec.state.mdata = t.authInfo
 | 
			
		||||
	var curStream *Stream
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
	defer wg.Wait()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -308,8 +308,8 @@ const (
 | 
			
		|||
 | 
			
		||||
// NewServerTransport creates a ServerTransport with conn or non-nil error
 | 
			
		||||
// if it fails.
 | 
			
		||||
func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (ServerTransport, error) {
 | 
			
		||||
	return newHTTP2Server(conn, maxStreams)
 | 
			
		||||
func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32, authInfo map[string][]string) (ServerTransport, error) {
 | 
			
		||||
	return newHTTP2Server(conn, maxStreams, authInfo)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConnectOptions covers all relevant options for dialing a server.
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue