Merge pull request #610 from peter-edge/percent_encoding

Do percent encoding matching grpc-java for the statusDesc
This commit is contained in:
Menghan Li 2016-07-21 11:36:23 -07:00 committed by GitHub
commit 2350c4144c
6 changed files with 113 additions and 17 deletions

View File

@ -194,7 +194,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code,
h := ht.rw.Header() h := ht.rw.Header()
h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode)) h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
if statusDesc != "" { if statusDesc != "" {
h.Set("Grpc-Message", statusDesc) h.Set("Grpc-Message", grpcMessageEncode(statusDesc))
} }
if md := s.Trailer(); len(md) > 0 { if md := s.Trailer(); len(md) > 0 {
for k, vv := range md { for k, vv := range md {

View File

@ -333,7 +333,7 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string)
"Content-Type": {"application/grpc"}, "Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message"}, "Trailer": {"Grpc-Status", "Grpc-Message"},
"Grpc-Status": {fmt.Sprint(uint32(statusCode))}, "Grpc-Status": {fmt.Sprint(uint32(statusCode))},
"Grpc-Message": {msg}, "Grpc-Message": {grpcMessageEncode(msg)},
} }
if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) { if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader) t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader)
@ -381,7 +381,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
"Content-Type": {"application/grpc"}, "Content-Type": {"application/grpc"},
"Trailer": {"Grpc-Status", "Grpc-Message"}, "Trailer": {"Grpc-Status", "Grpc-Message"},
"Grpc-Status": {"4"}, "Grpc-Status": {"4"},
"Grpc-Message": {"too slow"}, "Grpc-Message": {grpcMessageEncode("too slow")},
} }
if !reflect.DeepEqual(rw.HeaderMap, wantHeader) { if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader) t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader)

View File

@ -504,7 +504,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
Name: "grpc-status", Name: "grpc-status",
Value: strconv.Itoa(int(statusCode)), Value: strconv.Itoa(int(statusCode)),
}) })
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: statusDesc}) t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: grpcMessageEncode(statusDesc)})
// Attach the trailer metadata. // Attach the trailer metadata.
for k, v := range s.trailer { for k, v := range s.trailer {
// Clients don't tolerate reading restricted headers after some non restricted ones were sent. // Clients don't tolerate reading restricted headers after some non restricted ones were sent.

View File

@ -174,7 +174,7 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
} }
d.statusCode = codes.Code(code) d.statusCode = codes.Code(code)
case "grpc-message": case "grpc-message":
d.statusDesc = f.Value d.statusDesc = grpcMessageDecode(f.Value)
case "grpc-timeout": case "grpc-timeout":
d.timeoutSet = true d.timeoutSet = true
var err error var err error

View File

@ -43,6 +43,7 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"strconv"
"sync" "sync"
"time" "time"
@ -529,3 +530,74 @@ func wait(ctx context.Context, done, closing <-chan struct{}, proceed <-chan int
return i, nil return i, nil
} }
} }
const (
spaceByte = ' '
tildaByte = '~'
percentByte = '%'
)
// grpcMessageEncode encodes the grpc-message field in the same
// manner as https://github.com/grpc/grpc-java/pull/1517.
func grpcMessageEncode(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if !(c >= spaceByte && c < tildaByte && c != percentByte) {
return grpcMessageEncodeUnchecked(msg)
}
}
return msg
}
func grpcMessageEncodeUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c >= spaceByte && c < tildaByte && c != percentByte {
_ = buf.WriteByte(c)
} else {
_, _ = buf.WriteString(fmt.Sprintf("%%%02X", c))
}
}
return buf.String()
}
// grpcMessageDecode decodes the grpc-message field in the same
// manner as https://github.com/grpc/grpc-java/pull/1517.
func grpcMessageDecode(msg string) string {
if msg == "" {
return ""
}
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
if msg[i] == percentByte && i+2 < lenMsg {
return grpcMessageDecodeUnchecked(msg)
}
}
return msg
}
func grpcMessageDecodeUnchecked(msg string) string {
var buf bytes.Buffer
lenMsg := len(msg)
for i := 0; i < lenMsg; i++ {
c := msg[i]
if c == percentByte && i+2 < lenMsg {
parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8)
if err != nil {
_ = buf.WriteByte(c)
} else {
_ = buf.WriteByte(byte(parsed))
i += 2
}
} else {
_ = buf.WriteByte(c)
}
}
return buf.String()
}

View File

@ -37,6 +37,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"math" "math"
"net" "net"
"reflect" "reflect"
@ -75,7 +76,7 @@ const (
normal hType = iota normal hType = iota
suspended suspended
misbehaved misbehaved
malformedStatus encodingRequiredStatus
) )
func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) { func (h *testStreamHandler) handleStream(t *testing.T, s *Stream) {
@ -128,9 +129,8 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
} }
} }
func (h *testStreamHandler) handleStreamMalformedStatus(t *testing.T, s *Stream) { func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *Stream) {
// raw newline is not accepted by http2 framer and a http2.StreamError is // raw newline is not accepted by http2 framer so it must be encoded.
// generated.
h.t.WriteStatus(s, codes.Internal, "\n") h.t.WriteStatus(s, codes.Internal, "\n")
} }
@ -179,9 +179,9 @@ func (s *server) start(t *testing.T, port int, maxStreams uint32, ht hType) {
go transport.HandleStreams(func(s *Stream) { go transport.HandleStreams(func(s *Stream) {
go h.handleStreamMisbehave(t, s) go h.handleStreamMisbehave(t, s)
}) })
case malformedStatus: case encodingRequiredStatus:
go transport.HandleStreams(func(s *Stream) { go transport.HandleStreams(func(s *Stream) {
go h.handleStreamMalformedStatus(t, s) go h.handleStreamEncodingRequiredStatus(t, s)
}) })
default: default:
go transport.HandleStreams(func(s *Stream) { go transport.HandleStreams(func(s *Stream) {
@ -714,8 +714,8 @@ func TestClientWithMisbehavedServer(t *testing.T) {
server.stop() server.stop()
} }
func TestMalformedStatus(t *testing.T) { func TestEncodingRequiredStatus(t *testing.T) {
server, ct := setUp(t, 0, math.MaxUint32, malformedStatus) server, ct := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
callHdr := &CallHdr{ callHdr := &CallHdr{
Host: "localhost", Host: "localhost",
Method: "foo", Method: "foo",
@ -731,10 +731,8 @@ func TestMalformedStatus(t *testing.T) {
if err := ct.Write(s, expectedRequest, &opts); err != nil { if err := ct.Write(s, expectedRequest, &opts); err != nil {
t.Fatalf("Failed to write the request: %v", err) t.Fatalf("Failed to write the request: %v", err)
} }
p := make([]byte, http2MaxFrameLen) if _, err = ioutil.ReadAll(s); err != nil {
expectedErr := StreamErrorf(codes.Internal, "invalid header field value \"\\n\"") t.Fatal(err)
if _, err = s.dec.Read(p); err != expectedErr {
t.Fatalf("Read the err %v, want %v", err, expectedErr)
} }
ct.Close() ct.Close()
server.stop() server.stop()
@ -771,3 +769,29 @@ func TestIsReservedHeader(t *testing.T) {
} }
} }
} }
func TestGrpcMessageEncode(t *testing.T) {
testGrpcMessageEncode(t, "my favorite character is \u0000", "my favorite character is %00")
testGrpcMessageEncode(t, "my favorite character is %", "my favorite character is %25")
}
func TestGrpcMessageDecode(t *testing.T) {
testGrpcMessageDecode(t, "Hello", "Hello")
testGrpcMessageDecode(t, "H%61o", "Hao")
testGrpcMessageDecode(t, "H%6", "H%6")
testGrpcMessageDecode(t, "%G0", "%G0")
}
func testGrpcMessageEncode(t *testing.T, input string, expected string) {
actual := grpcMessageEncode(input)
if expected != actual {
t.Errorf("Expected %s from grpcMessageEncode, got %s", expected, actual)
}
}
func testGrpcMessageDecode(t *testing.T, input string, expected string) {
actual := grpcMessageDecode(input)
if expected != actual {
t.Errorf("Expected %s from grpcMessageDecode, got %s", expected, actual)
}
}