mirror of https://github.com/grpc/grpc-go.git
fix some bugs
This commit is contained in:
parent
f53faa647d
commit
6205cb25ab
|
@ -73,7 +73,7 @@ func TestCredentialsMisuse(t *testing.T) {
|
||||||
t.Fatalf("Failed to create authenticator %v", err)
|
t.Fatalf("Failed to create authenticator %v", err)
|
||||||
}
|
}
|
||||||
// Two conflicting credential configurations
|
// Two conflicting credential configurations
|
||||||
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errCredentialsConflict {
|
if _, err := Dial("Non-Existent.Server:80", WithTransportCredentials(tlsCreds), WithBlock(), WithInsecure()); err != errCredentialsConflict {
|
||||||
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsConflict)
|
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errCredentialsConflict)
|
||||||
}
|
}
|
||||||
rpcCreds, err := oauth.NewJWTAccessFromKey(nil)
|
rpcCreds, err := oauth.NewJWTAccessFromKey(nil)
|
||||||
|
@ -81,7 +81,7 @@ func TestCredentialsMisuse(t *testing.T) {
|
||||||
t.Fatalf("Failed to create credentials %v", err)
|
t.Fatalf("Failed to create credentials %v", err)
|
||||||
}
|
}
|
||||||
// security info on insecure connection
|
// security info on insecure connection
|
||||||
if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(rpcCreds), WithTimeout(time.Millisecond), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing {
|
if _, err := Dial("Non-Existent.Server:80", WithPerRPCCredentials(rpcCreds), WithBlock(), WithInsecure()); err != errTransportCredentialsMissing {
|
||||||
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errTransportCredentialsMissing)
|
t.Fatalf("Dial(_, _) = _, %v, want _, %v", err, errTransportCredentialsMissing)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -123,4 +123,5 @@ func testBackoffConfigSet(t *testing.T, expected *BackoffConfig, opts ...DialOpt
|
||||||
if actual != *expected {
|
if actual != *expected {
|
||||||
t.Fatalf("unexpected backoff config on connection: %v, want %v", actual, expected)
|
t.Fatalf("unexpected backoff config on connection: %v, want %v", actual, expected)
|
||||||
}
|
}
|
||||||
|
conn.Close()
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,8 +90,8 @@ var (
|
||||||
var raceMode bool // set by race_test.go in race mode
|
var raceMode bool // set by race_test.go in race mode
|
||||||
|
|
||||||
type testServer struct {
|
type testServer struct {
|
||||||
security string // indicate the authentication protocol used by this server.
|
security string // indicate the authentication protocol used by this server.
|
||||||
streamingInputCallErr bool // whether to error out the StreamingInputCall handler prematurely.
|
earlyFail bool // whether to error out the execution of a service handler.
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
func (s *testServer) EmptyCall(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
|
||||||
|
@ -220,7 +220,7 @@ func (s *testServer) StreamingInputCall(stream testpb.TestService_StreamingInput
|
||||||
}
|
}
|
||||||
p := in.GetPayload().GetBody()
|
p := in.GetPayload().GetBody()
|
||||||
sum += len(p)
|
sum += len(p)
|
||||||
if s.streamingInputCallErr {
|
if s.earlyFail {
|
||||||
return grpc.Errorf(codes.NotFound, "not found")
|
return grpc.Errorf(codes.NotFound, "not found")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1515,7 +1515,7 @@ func TestClientStreamingError(t *testing.T) {
|
||||||
|
|
||||||
func testClientStreamingError(t *testing.T, e env) {
|
func testClientStreamingError(t *testing.T, e env) {
|
||||||
te := newTest(t, e)
|
te := newTest(t, e)
|
||||||
te.startServer(&testServer{security: e.security, streamingInputCallErr: true})
|
te.startServer(&testServer{security: e.security, earlyFail: true})
|
||||||
defer te.tearDown()
|
defer te.tearDown()
|
||||||
tc := testpb.NewTestServiceClient(te.clientConn())
|
tc := testpb.NewTestServiceClient(te.clientConn())
|
||||||
|
|
||||||
|
@ -1538,12 +1538,11 @@ func testClientStreamingError(t *testing.T, e env) {
|
||||||
for {
|
for {
|
||||||
if err := stream.Send(req); err == nil {
|
if err := stream.Send(req); err == nil {
|
||||||
continue
|
continue
|
||||||
} else {
|
|
||||||
if grpc.Code(err) != codes.NotFound {
|
|
||||||
t.Fatalf("%v.Send(_) = %v, want error %d", stream, err, codes.NotFound)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
|
if _, err := stream.CloseAndRecv(); grpc.Code(err) != codes.NotFound {
|
||||||
|
t.Fatalf("%v.Send(_) = %v, want error %d", stream, err, codes.NotFound)
|
||||||
|
}
|
||||||
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -498,9 +498,10 @@ func (t *http2Client) GracefulClose() error {
|
||||||
// if it improves the performance.
|
// if it improves the performance.
|
||||||
func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
|
func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
|
// The stream has been done. Return the status directly.
|
||||||
if s.state == streamDone {
|
if s.state == streamDone {
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return StreamErrorf(s.statusCode, "%s", s.statusDesc)
|
return StreamErrorf(s.statusCode, "%v", s.statusDesc)
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
r := bytes.NewBuffer(data)
|
r := bytes.NewBuffer(data)
|
||||||
|
@ -599,11 +600,11 @@ func (t *http2Client) Write(s *Stream, data []byte, opts *Options) error {
|
||||||
}
|
}
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if s.state != streamDone {
|
if s.state != streamDone {
|
||||||
if s.state == streamReadDone {
|
//if s.state == streamReadDone {
|
||||||
s.state = streamDone
|
// s.state = streamDone
|
||||||
} else {
|
//} else {
|
||||||
s.state = streamWriteDone
|
s.state = streamWriteDone
|
||||||
}
|
//}
|
||||||
}
|
}
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
return nil
|
return nil
|
||||||
|
@ -678,11 +679,14 @@ func (t *http2Client) handleData(f *http2.DataFrame) {
|
||||||
// the read direction is closed, and set the status appropriately.
|
// the read direction is closed, and set the status appropriately.
|
||||||
if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
|
if f.FrameHeader.Flags.Has(http2.FlagDataEndStream) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if s.state == streamWriteDone {
|
s.state = streamDone
|
||||||
s.state = streamDone
|
/*
|
||||||
} else {
|
if s.state == streamWriteDone {
|
||||||
s.state = streamReadDone
|
s.state = streamDone
|
||||||
}
|
} else {
|
||||||
|
s.state = streamReadDone
|
||||||
|
}
|
||||||
|
*/
|
||||||
s.statusCode = codes.Internal
|
s.statusCode = codes.Internal
|
||||||
s.statusDesc = "server closed the stream without sending trailers"
|
s.statusDesc = "server closed the stream without sending trailers"
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
@ -786,12 +790,20 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) {
|
||||||
if len(state.mdata) > 0 {
|
if len(state.mdata) > 0 {
|
||||||
s.trailer = state.mdata
|
s.trailer = state.mdata
|
||||||
}
|
}
|
||||||
s.state = streamDone
|
|
||||||
s.statusCode = state.statusCode
|
s.statusCode = state.statusCode
|
||||||
s.statusDesc = state.statusDesc
|
s.statusDesc = state.statusDesc
|
||||||
|
var cancel bool
|
||||||
|
if s.state != streamWriteDone {
|
||||||
|
// s will be canceled. This is required to interrupt any pending
|
||||||
|
// blocking Write calls when the final RPC status has been arrived.
|
||||||
|
cancel = true
|
||||||
|
}
|
||||||
|
s.state = streamDone
|
||||||
s.mu.Unlock()
|
s.mu.Unlock()
|
||||||
|
|
||||||
s.write(recvMsg{err: io.EOF})
|
s.write(recvMsg{err: io.EOF})
|
||||||
|
if cancel {
|
||||||
|
s.cancel()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleMalformedHTTP2(s *Stream, err error) {
|
func handleMalformedHTTP2(s *Stream, err error) {
|
||||||
|
|
|
@ -140,6 +140,14 @@ func (r *recvBufferReader) Read(p []byte) (n int, err error) {
|
||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case <-r.ctx.Done():
|
case <-r.ctx.Done():
|
||||||
|
select {
|
||||||
|
case i := <-r.recv.get():
|
||||||
|
m := i.(*recvMsg)
|
||||||
|
if m.err != nil {
|
||||||
|
return 0, m.err
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
}
|
||||||
return 0, ContextErr(r.ctx.Err())
|
return 0, ContextErr(r.ctx.Err())
|
||||||
case i := <-r.recv.get():
|
case i := <-r.recv.get():
|
||||||
r.recv.load()
|
r.recv.load()
|
||||||
|
|
Loading…
Reference in New Issue