fix some typos and run gofmt

This commit is contained in:
iamqizhao 2015-04-17 13:50:18 -07:00
parent 94a47542e0
commit 3259049490
8 changed files with 189 additions and 45 deletions

View File

@ -35,7 +35,6 @@ package grpc
import (
"io"
"net"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
@ -114,12 +113,8 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
o.after(&c)
}
}()
host, _, err := net.SplitHostPort(cc.target)
if err != nil {
return toRPCErr(err)
}
callHdr := &transport.CallHdr{
Host: host,
Host: cc.authority,
Method: method,
}
topts := &transport.Options{

View File

@ -36,6 +36,7 @@ package grpc
import (
"errors"
"log"
"net"
"sync"
"time"
@ -95,6 +96,14 @@ func WithTimeout(d time.Duration) DialOption {
}
}
// WithNetwork returns a DialOption that specifies the network on which
// the connection will be established.
func WithNetwork(network string) DialOption {
return func(o *dialOptions) {
o.copts.Network = network
}
}
// Dial creates a client connection the given target.
// TODO(zhaoq): Have an option to make Dial return immediately without waiting
// for connection to complete.
@ -108,6 +117,24 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
for _, opt := range opts {
opt(&cc.dopts)
}
// Validate the network type
switch cc.dopts.copts.Network {
case "":
cc.dopts.copts.Network = "tcp" // Set the default
case "tcp", "tcp4", "tcp6", "unix":
default:
return nil, net.UnknownNetworkError(cc.dopts.copts.Network)
}
cc.authority = target
// Format target for tcp.
if cc.dopts.copts.Network != "unix" {
// format target for tcp.
var err error
cc.authority, _, err = net.SplitHostPort(target)
if err != nil {
return nil, err
}
}
if cc.dopts.codec == nil {
// Set the default codec.
cc.dopts.codec = protoCodec{}
@ -124,6 +151,7 @@ func Dial(target string, opts ...DialOption) (*ClientConn, error) {
// ClientConn represents a client connection to an RPC service.
type ClientConn struct {
target string
authority string
dopts dialOptions
shutdownChan chan struct{}

View File

@ -105,7 +105,7 @@ func (c *tlsCreds) DialWithDialer(dialer *net.Dialer, network, addr string) (_ n
return nil, fmt.Errorf("credentials: failed to parse server address %v", err)
}
}
return tls.DialWithDialer(dialer, "tcp", addr, &c.config)
return tls.DialWithDialer(dialer, network, addr, &c.config)
}
// Dial connects to addr and performs TLS handshake.

View File

@ -371,8 +371,8 @@ func (s *Server) TestingCloseConns() {
s.mu.Lock()
for c := range s.conns {
c.Close()
delete(s.conns, c)
}
s.conns = make(map[transport.ServerTransport]bool)
s.mu.Unlock()
}

View File

@ -36,7 +36,6 @@ package grpc
import (
"errors"
"io"
"net"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
@ -95,12 +94,8 @@ type ClientStream interface {
// by generated code.
func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (ClientStream, error) {
// TODO(zhaoq): CallOption is omitted. Add support when it is needed.
host, _, err := net.SplitHostPort(cc.target)
if err != nil {
return nil, toRPCErr(err)
}
callHdr := &transport.CallHdr{
Host: host,
Host: cc.authority,
Method: method,
}
t, _, err := cc.wait(ctx, 0)

View File

@ -34,12 +34,15 @@
package grpc_test
import (
"fmt"
"io"
"log"
"math"
"net"
"reflect"
"runtime"
"sync"
"syscall"
"testing"
"time"
@ -263,18 +266,32 @@ func TestReconnectTimeout(t *testing.T) {
}
}
func setUp(useTLS bool, maxStream uint32) (s *grpc.Server, cc *grpc.ClientConn) {
lis, err := net.Listen("tcp", ":0")
type env struct {
network string // The type of network such as tcp, unix, etc.
security string // The security protocol such as TLS, SSH, etc.
}
func listTestEnv() []env {
if runtime.GOOS == "windows" {
return []env{env{"tcp", ""}, env{"tcp", "tls"}}
}
return []env{env{"tcp", ""}, env{"tcp", "tls"}, env{"unix", ""}, env{"unix", "tls"}}
}
func setUp(maxStream uint32, e env) (s *grpc.Server, cc *grpc.ClientConn) {
s = grpc.NewServer(grpc.MaxConcurrentStreams(maxStream))
la := ":0"
switch e.network {
case "unix":
la = "/tmp/testsock" + fmt.Sprintf("%p", s)
syscall.Unlink(la)
}
lis, err := net.Listen(e.network, la)
if err != nil {
log.Fatalf("Failed to listen: %v", err)
}
_, port, err := net.SplitHostPort(lis.Addr().String())
if err != nil {
log.Fatalf("Failed to parse listener address: %v", err)
}
s = grpc.NewServer(grpc.MaxConcurrentStreams(maxStream))
testpb.RegisterTestServiceServer(s, &testServer{})
if useTLS {
if e.security == "tls" {
creds, err := credentials.NewServerTLSFromFile(tlsDir+"server1.pem", tlsDir+"server1.key")
if err != nil {
log.Fatalf("Failed to generate credentials %v", err)
@ -283,15 +300,24 @@ func setUp(useTLS bool, maxStream uint32) (s *grpc.Server, cc *grpc.ClientConn)
} else {
go s.Serve(lis)
}
addr := "localhost:" + port
if useTLS {
addr := la
switch e.network {
case "unix":
default:
_, port, err := net.SplitHostPort(lis.Addr().String())
if err != nil {
log.Fatalf("Failed to parse listener address: %v", err)
}
addr = "localhost:" + port
}
if e.security == "tls" {
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
if err != nil {
log.Fatalf("Failed to create credentials %v", err)
}
cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds))
cc, err = grpc.Dial(addr, grpc.WithTransportCredentials(creds), grpc.WithNetwork(e.network))
} else {
cc, err = grpc.Dial(addr)
cc, err = grpc.Dial(addr, grpc.WithNetwork(e.network))
}
if err != nil {
log.Fatalf("Dial(%q) = %v", addr, err)
@ -305,7 +331,14 @@ func tearDown(s *grpc.Server, cc *grpc.ClientConn) {
}
func TestTimeoutOnDeadServer(t *testing.T) {
s, cc := setUp(false, math.MaxUint32)
for _, e := range listTestEnv() {
log.Println("Testing in the env: ", e)
testTimeoutOnDeadServer(t, e)
}
}
func testTimeoutOnDeadServer(t *testing.T, e env) {
s, cc := setUp(math.MaxUint32, e)
tc := testpb.NewTestServiceClient(cc)
s.Stop()
// Set -1 as the timeout to make sure if transportMonitor gets error
@ -319,7 +352,14 @@ func TestTimeoutOnDeadServer(t *testing.T) {
}
func TestEmptyUnary(t *testing.T) {
s, cc := setUp(true, math.MaxUint32)
for _, e := range listTestEnv() {
log.Println("Testing in the env: ", e)
testEmptyUnary(t, e)
}
}
func testEmptyUnary(t *testing.T, e env) {
s, cc := setUp(math.MaxUint32, e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
reply, err := tc.EmptyCall(context.Background(), &testpb.Empty{})
@ -329,7 +369,14 @@ func TestEmptyUnary(t *testing.T) {
}
func TestFailedEmptyUnary(t *testing.T) {
s, cc := setUp(true, math.MaxUint32)
for _, e := range listTestEnv() {
log.Println("Testing in the env: ", e)
testFailedEmptyUnary(t, e)
}
}
func testFailedEmptyUnary(t *testing.T, e env) {
s, cc := setUp(math.MaxUint32, e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
ctx := metadata.NewContext(context.Background(), testMetadata)
@ -339,7 +386,14 @@ func TestFailedEmptyUnary(t *testing.T) {
}
func TestLargeUnary(t *testing.T) {
s, cc := setUp(true, math.MaxUint32)
for _, e := range listTestEnv() {
log.Println("Testing in the env: ", e)
testLargeUnary(t, e)
}
}
func testLargeUnary(t *testing.T, e env) {
s, cc := setUp(math.MaxUint32, e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 271828
@ -361,7 +415,14 @@ func TestLargeUnary(t *testing.T) {
}
func TestMetadataUnaryRPC(t *testing.T) {
s, cc := setUp(true, math.MaxUint32)
for _, e := range listTestEnv() {
log.Println("Testing in the env: ", e)
testMetadataUnaryRPC(t, e)
}
}
func testMetadataUnaryRPC(t *testing.T, e env) {
s, cc := setUp(math.MaxUint32, e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 2718
@ -405,11 +466,18 @@ func performOneRPC(t *testing.T, tc testpb.TestServiceClient, wg *sync.WaitGroup
wg.Done()
}
func TestRetry(t *testing.T) {
for _, e := range listTestEnv() {
log.Println("Testing in the env: ", e)
testRetry(t, e)
}
}
// This test mimics a user who sends 1000 RPCs concurrently on a faulty transport.
// TODO(zhaoq): Refactor to make this clearer and add more cases to test racy
// and error-prone paths.
func TestRetry(t *testing.T) {
s, cc := setUp(true, math.MaxUint32)
func testRetry(t *testing.T, e env) {
s, cc := setUp(math.MaxUint32, e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
var wg sync.WaitGroup
@ -431,9 +499,16 @@ func TestRetry(t *testing.T) {
wg.Wait()
}
// TODO(zhaoq): Have a better test coverage of timeout and cancellation mechanism.
func TestRPCTimeout(t *testing.T) {
s, cc := setUp(true, math.MaxUint32)
for _, e := range listTestEnv() {
log.Println("Testing in the env: ", e)
testRPCTimeout(t, e)
}
}
// TODO(zhaoq): Have a better test coverage of timeout and cancellation mechanism.
func testRPCTimeout(t *testing.T, e env) {
s, cc := setUp(math.MaxUint32, e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 2718
@ -456,7 +531,14 @@ func TestRPCTimeout(t *testing.T) {
}
func TestCancel(t *testing.T) {
s, cc := setUp(true, math.MaxUint32)
for _, e := range listTestEnv() {
log.Println("Testing in the env: ", e)
testCancel(t, e)
}
}
func testCancel(t *testing.T, e env) {
s, cc := setUp(math.MaxUint32, e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
argSize := 2718
@ -482,7 +564,14 @@ var (
)
func TestPingPong(t *testing.T) {
s, cc := setUp(true, math.MaxUint32)
for _, e := range listTestEnv() {
log.Println("Testing in the env: ", e)
testPingPong(t, e)
}
}
func testPingPong(t *testing.T, e env) {
s, cc := setUp(math.MaxUint32, e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
stream, err := tc.FullDuplexCall(context.Background())
@ -527,7 +616,14 @@ func TestPingPong(t *testing.T) {
}
func TestMetadataStreamingRPC(t *testing.T) {
s, cc := setUp(true, math.MaxUint32)
for _, e := range listTestEnv() {
log.Println("Testing in the env: ", e)
testMetadataStreamingRPC(t, e)
}
}
func testMetadataStreamingRPC(t *testing.T, e env) {
s, cc := setUp(math.MaxUint32, e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
ctx := metadata.NewContext(context.Background(), testMetadata)
@ -578,7 +674,14 @@ func TestMetadataStreamingRPC(t *testing.T) {
}
func TestServerStreaming(t *testing.T) {
s, cc := setUp(true, math.MaxUint32)
for _, e := range listTestEnv() {
log.Println("Testing in the env: ", e)
testServerStreaming(t, e)
}
}
func testServerStreaming(t *testing.T, e env) {
s, cc := setUp(math.MaxUint32, e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
respParam := make([]*testpb.ResponseParameters, len(respSizes))
@ -624,7 +727,14 @@ func TestServerStreaming(t *testing.T) {
}
func TestFailedServerStreaming(t *testing.T) {
s, cc := setUp(true, math.MaxUint32)
for _, e := range listTestEnv() {
log.Println("Testing in the env: ", e)
testFailedServerStreaming(t, e)
}
}
func testFailedServerStreaming(t *testing.T, e env) {
s, cc := setUp(math.MaxUint32, e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
respParam := make([]*testpb.ResponseParameters, len(respSizes))
@ -648,7 +758,14 @@ func TestFailedServerStreaming(t *testing.T) {
}
func TestClientStreaming(t *testing.T) {
s, cc := setUp(true, math.MaxUint32)
for _, e := range listTestEnv() {
log.Println("Testing in the env: ", e)
testClientStreaming(t, e)
}
}
func testClientStreaming(t *testing.T, e env) {
s, cc := setUp(math.MaxUint32, e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
stream, err := tc.StreamingInputCall(context.Background())
@ -676,8 +793,15 @@ func TestClientStreaming(t *testing.T) {
}
func TestExceedMaxStreamsLimit(t *testing.T) {
for _, e := range listTestEnv() {
log.Println("Testing in the env: ", e)
testExceedMaxStreamsLimit(t, e)
}
}
func testExceedMaxStreamsLimit(t *testing.T, e env) {
// Only allows 1 live stream per server transport.
s, cc := setUp(true, 1)
s, cc := setUp(1, e)
tc := testpb.NewTestServiceClient(cc)
defer tearDown(s, cc)
var err error

View File

@ -110,12 +110,12 @@ func newHTTP2Client(addr string, opts *ConnectOptions) (_ ClientTransport, err e
// multiple ones provided. Revisit this if it is not appropriate. Probably
// place the ClientTransport construction into a separate function to make
// things clear.
conn, connErr = ccreds.DialWithDialer(&net.Dialer{Timeout: opts.Timeout}, "tcp", addr)
conn, connErr = ccreds.DialWithDialer(&net.Dialer{Timeout: opts.Timeout}, opts.Network, addr)
break
}
}
if scheme == "http" {
conn, connErr = net.DialTimeout("tcp", addr, opts.Timeout)
conn, connErr = net.DialTimeout(opts.Network, addr, opts.Timeout)
}
if connErr != nil {
return nil, ConnectionErrorf("transport: %v", connErr)

View File

@ -315,7 +315,9 @@ func NewServerTransport(protocol string, conn net.Conn, maxStreams uint32) (Serv
// ConnectOptions covers all relevant options for dialing a server.
type ConnectOptions struct {
Protocol string
// Network indicates the type of network where the connection is established.
// Known networks are "tcp", "tcp4", "tcp6", "unix"
Network string
AuthOptions []credentials.Credentials
Timeout time.Duration
}