mirror of https://github.com/grpc/grpc-go.git
Merge pull request #772 from menghanl/percent_encoding
grpc-message status desc encoding fixes
This commit is contained in:
commit
0e0265f78b
|
|
@ -83,7 +83,7 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request) (ServerTr
|
|||
}
|
||||
|
||||
if v := r.Header.Get("grpc-timeout"); v != "" {
|
||||
to, err := timeoutDecode(v)
|
||||
to, err := decodeTimeout(v)
|
||||
if err != nil {
|
||||
return nil, StreamErrorf(codes.Internal, "malformed time-out: %v", err)
|
||||
}
|
||||
|
|
@ -194,7 +194,7 @@ func (ht *serverHandlerTransport) WriteStatus(s *Stream, statusCode codes.Code,
|
|||
h := ht.rw.Header()
|
||||
h.Set("Grpc-Status", fmt.Sprintf("%d", statusCode))
|
||||
if statusDesc != "" {
|
||||
h.Set("Grpc-Message", grpcMessageEncode(statusDesc))
|
||||
h.Set("Grpc-Message", encodeGrpcMessage(statusDesc))
|
||||
}
|
||||
if md := s.Trailer(); len(md) > 0 {
|
||||
for k, vv := range md {
|
||||
|
|
|
|||
|
|
@ -333,7 +333,7 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string)
|
|||
"Content-Type": {"application/grpc"},
|
||||
"Trailer": {"Grpc-Status", "Grpc-Message"},
|
||||
"Grpc-Status": {fmt.Sprint(uint32(statusCode))},
|
||||
"Grpc-Message": {grpcMessageEncode(msg)},
|
||||
"Grpc-Message": {encodeGrpcMessage(msg)},
|
||||
}
|
||||
if !reflect.DeepEqual(st.rw.HeaderMap, wantHeader) {
|
||||
t.Errorf("Header+Trailer mismatch.\n got: %#v\nwant: %#v", st.rw.HeaderMap, wantHeader)
|
||||
|
|
@ -381,7 +381,7 @@ func TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
|
|||
"Content-Type": {"application/grpc"},
|
||||
"Trailer": {"Grpc-Status", "Grpc-Message"},
|
||||
"Grpc-Status": {"4"},
|
||||
"Grpc-Message": {grpcMessageEncode("too slow")},
|
||||
"Grpc-Message": {encodeGrpcMessage("too slow")},
|
||||
}
|
||||
if !reflect.DeepEqual(rw.HeaderMap, wantHeader) {
|
||||
t.Errorf("Header+Trailer Map mismatch.\n got: %#v\nwant: %#v", rw.HeaderMap, wantHeader)
|
||||
|
|
|
|||
|
|
@ -341,7 +341,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea
|
|||
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-encoding", Value: callHdr.SendCompress})
|
||||
}
|
||||
if timeout > 0 {
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: timeoutEncode(timeout)})
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-timeout", Value: encodeTimeout(timeout)})
|
||||
}
|
||||
for k, v := range authData {
|
||||
// Capital header names are illegal in HTTP/2.
|
||||
|
|
|
|||
|
|
@ -507,7 +507,7 @@ func (t *http2Server) WriteStatus(s *Stream, statusCode codes.Code, statusDesc s
|
|||
Name: "grpc-status",
|
||||
Value: strconv.Itoa(int(statusCode)),
|
||||
})
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: grpcMessageEncode(statusDesc)})
|
||||
t.hEnc.WriteField(hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(statusDesc)})
|
||||
// Attach the trailer metadata.
|
||||
for k, v := range s.trailer {
|
||||
// Clients don't tolerate reading restricted headers after some non restricted ones were sent.
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ package transport
|
|||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
|
@ -174,11 +175,11 @@ func (d *decodeState) processHeaderField(f hpack.HeaderField) {
|
|||
}
|
||||
d.statusCode = codes.Code(code)
|
||||
case "grpc-message":
|
||||
d.statusDesc = grpcMessageDecode(f.Value)
|
||||
d.statusDesc = decodeGrpcMessage(f.Value)
|
||||
case "grpc-timeout":
|
||||
d.timeoutSet = true
|
||||
var err error
|
||||
d.timeout, err = timeoutDecode(f.Value)
|
||||
d.timeout, err = decodeTimeout(f.Value)
|
||||
if err != nil {
|
||||
d.setErr(StreamErrorf(codes.Internal, "transport: malformed time-out: %v", err))
|
||||
return
|
||||
|
|
@ -251,7 +252,7 @@ func div(d, r time.Duration) int64 {
|
|||
}
|
||||
|
||||
// TODO(zhaoq): It is the simplistic and not bandwidth efficient. Improve it.
|
||||
func timeoutEncode(t time.Duration) string {
|
||||
func encodeTimeout(t time.Duration) string {
|
||||
if d := div(t, time.Nanosecond); d <= maxTimeoutValue {
|
||||
return strconv.FormatInt(d, 10) + "n"
|
||||
}
|
||||
|
|
@ -271,7 +272,7 @@ func timeoutEncode(t time.Duration) string {
|
|||
return strconv.FormatInt(div(t, time.Hour), 10) + "H"
|
||||
}
|
||||
|
||||
func timeoutDecode(s string) (time.Duration, error) {
|
||||
func decodeTimeout(s string) (time.Duration, error) {
|
||||
size := len(s)
|
||||
if size < 2 {
|
||||
return 0, fmt.Errorf("transport: timeout string is too short: %q", s)
|
||||
|
|
@ -288,6 +289,80 @@ func timeoutDecode(s string) (time.Duration, error) {
|
|||
return d * time.Duration(t), nil
|
||||
}
|
||||
|
||||
const (
|
||||
spaceByte = ' '
|
||||
tildaByte = '~'
|
||||
percentByte = '%'
|
||||
)
|
||||
|
||||
// encodeGrpcMessage is used to encode status code in header field
|
||||
// "grpc-message".
|
||||
// It checks to see if each individual byte in msg is an
|
||||
// allowable byte, and then either percent encoding or passing it through.
|
||||
// When percent encoding, the byte is converted into hexadecimal notation
|
||||
// with a '%' prepended.
|
||||
func encodeGrpcMessage(msg string) string {
|
||||
if msg == "" {
|
||||
return ""
|
||||
}
|
||||
lenMsg := len(msg)
|
||||
for i := 0; i < lenMsg; i++ {
|
||||
c := msg[i]
|
||||
if !(c >= spaceByte && c < tildaByte && c != percentByte) {
|
||||
return encodeGrpcMessageUnchecked(msg)
|
||||
}
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func encodeGrpcMessageUnchecked(msg string) string {
|
||||
var buf bytes.Buffer
|
||||
lenMsg := len(msg)
|
||||
for i := 0; i < lenMsg; i++ {
|
||||
c := msg[i]
|
||||
if c >= spaceByte && c < tildaByte && c != percentByte {
|
||||
buf.WriteByte(c)
|
||||
} else {
|
||||
buf.WriteString(fmt.Sprintf("%%%02X", c))
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// decodeGrpcMessage decodes the msg encoded by encodeGrpcMessage.
|
||||
func decodeGrpcMessage(msg string) string {
|
||||
if msg == "" {
|
||||
return ""
|
||||
}
|
||||
lenMsg := len(msg)
|
||||
for i := 0; i < lenMsg; i++ {
|
||||
if msg[i] == percentByte && i+2 < lenMsg {
|
||||
return decodeGrpcMessageUnchecked(msg)
|
||||
}
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func decodeGrpcMessageUnchecked(msg string) string {
|
||||
var buf bytes.Buffer
|
||||
lenMsg := len(msg)
|
||||
for i := 0; i < lenMsg; i++ {
|
||||
c := msg[i]
|
||||
if c == percentByte && i+2 < lenMsg {
|
||||
parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8)
|
||||
if err != nil {
|
||||
buf.WriteByte(c)
|
||||
} else {
|
||||
buf.WriteByte(byte(parsed))
|
||||
i += 2
|
||||
}
|
||||
} else {
|
||||
buf.WriteByte(c)
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
type framer struct {
|
||||
numWriters int32
|
||||
reader io.Reader
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ func TestTimeoutEncode(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("failed to parse duration string %s: %v", test.in, err)
|
||||
}
|
||||
out := timeoutEncode(d)
|
||||
out := encodeTimeout(d)
|
||||
if out != test.out {
|
||||
t.Fatalf("timeoutEncode(%s) = %s, want %s", test.in, out, test.out)
|
||||
}
|
||||
|
|
@ -79,7 +79,7 @@ func TestTimeoutDecode(t *testing.T) {
|
|||
{"1", 0, fmt.Errorf("transport: timeout string is too short: %q", "1")},
|
||||
{"", 0, fmt.Errorf("transport: timeout string is too short: %q", "")},
|
||||
} {
|
||||
d, err := timeoutDecode(test.s)
|
||||
d, err := decodeTimeout(test.s)
|
||||
if d != test.d || fmt.Sprint(err) != fmt.Sprint(test.err) {
|
||||
t.Fatalf("timeoutDecode(%q) = %d, %v, want %d, %v", test.s, int64(d), err, int64(test.d), test.err)
|
||||
}
|
||||
|
|
@ -107,3 +107,38 @@ func TestValidContentType(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncodeGrpcMessage(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{"Hello", "Hello"},
|
||||
{"my favorite character is \u0000", "my favorite character is %00"},
|
||||
{"my favorite character is %", "my favorite character is %25"},
|
||||
} {
|
||||
actual := encodeGrpcMessage(tt.input)
|
||||
if tt.expected != actual {
|
||||
t.Errorf("encodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeGrpcMessage(t *testing.T) {
|
||||
for _, tt := range []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"", ""},
|
||||
{"Hello", "Hello"},
|
||||
{"H%61o", "Hao"},
|
||||
{"H%6", "H%6"},
|
||||
{"%G0", "%G0"},
|
||||
} {
|
||||
actual := decodeGrpcMessage(tt.input)
|
||||
if tt.expected != actual {
|
||||
t.Errorf("dncodeGrpcMessage(%v) = %v, want %v", tt.input, actual, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -43,7 +43,6 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
|
@ -559,74 +558,3 @@ func wait(ctx context.Context, done, goAway, closing <-chan struct{}, proceed <-
|
|||
return i, nil
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
spaceByte = ' '
|
||||
tildaByte = '~'
|
||||
percentByte = '%'
|
||||
)
|
||||
|
||||
// grpcMessageEncode encodes the grpc-message field in the same
|
||||
// manner as https://github.com/grpc/grpc-java/pull/1517.
|
||||
func grpcMessageEncode(msg string) string {
|
||||
if msg == "" {
|
||||
return ""
|
||||
}
|
||||
lenMsg := len(msg)
|
||||
for i := 0; i < lenMsg; i++ {
|
||||
c := msg[i]
|
||||
if !(c >= spaceByte && c < tildaByte && c != percentByte) {
|
||||
return grpcMessageEncodeUnchecked(msg)
|
||||
}
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func grpcMessageEncodeUnchecked(msg string) string {
|
||||
var buf bytes.Buffer
|
||||
lenMsg := len(msg)
|
||||
for i := 0; i < lenMsg; i++ {
|
||||
c := msg[i]
|
||||
if c >= spaceByte && c < tildaByte && c != percentByte {
|
||||
_ = buf.WriteByte(c)
|
||||
} else {
|
||||
_, _ = buf.WriteString(fmt.Sprintf("%%%02X", c))
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
||||
// grpcMessageDecode decodes the grpc-message field in the same
|
||||
// manner as https://github.com/grpc/grpc-java/pull/1517.
|
||||
func grpcMessageDecode(msg string) string {
|
||||
if msg == "" {
|
||||
return ""
|
||||
}
|
||||
lenMsg := len(msg)
|
||||
for i := 0; i < lenMsg; i++ {
|
||||
if msg[i] == percentByte && i+2 < lenMsg {
|
||||
return grpcMessageDecodeUnchecked(msg)
|
||||
}
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
func grpcMessageDecodeUnchecked(msg string) string {
|
||||
var buf bytes.Buffer
|
||||
lenMsg := len(msg)
|
||||
for i := 0; i < lenMsg; i++ {
|
||||
c := msg[i]
|
||||
if c == percentByte && i+2 < lenMsg {
|
||||
parsed, err := strconv.ParseInt(msg[i+1:i+3], 16, 8)
|
||||
if err != nil {
|
||||
_ = buf.WriteByte(c)
|
||||
} else {
|
||||
_ = buf.WriteByte(byte(parsed))
|
||||
i += 2
|
||||
}
|
||||
} else {
|
||||
_ = buf.WriteByte(c)
|
||||
}
|
||||
}
|
||||
return buf.String()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,7 +37,6 @@ import (
|
|||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net"
|
||||
"reflect"
|
||||
|
|
@ -131,7 +130,7 @@ func (h *testStreamHandler) handleStreamMisbehave(t *testing.T, s *Stream) {
|
|||
|
||||
func (h *testStreamHandler) handleStreamEncodingRequiredStatus(t *testing.T, s *Stream) {
|
||||
// raw newline is not accepted by http2 framer so it must be encoded.
|
||||
h.t.WriteStatus(s, codes.Internal, "\n")
|
||||
h.t.WriteStatus(s, encodingTestStatusCode, encodingTestStatusDesc)
|
||||
}
|
||||
|
||||
// start starts server. Other goroutines should block on s.readyChan for further operations.
|
||||
|
|
@ -714,6 +713,11 @@ func TestClientWithMisbehavedServer(t *testing.T) {
|
|||
server.stop()
|
||||
}
|
||||
|
||||
var (
|
||||
encodingTestStatusCode = codes.Internal
|
||||
encodingTestStatusDesc = "\n"
|
||||
)
|
||||
|
||||
func TestEncodingRequiredStatus(t *testing.T) {
|
||||
server, ct := setUp(t, 0, math.MaxUint32, encodingRequiredStatus)
|
||||
callHdr := &CallHdr{
|
||||
|
|
@ -731,8 +735,12 @@ func TestEncodingRequiredStatus(t *testing.T) {
|
|||
if err := ct.Write(s, expectedRequest, &opts); err != nil {
|
||||
t.Fatalf("Failed to write the request: %v", err)
|
||||
}
|
||||
if _, err = ioutil.ReadAll(s); err != nil {
|
||||
t.Fatal(err)
|
||||
p := make([]byte, http2MaxFrameLen)
|
||||
if _, err := s.dec.Read(p); err != io.EOF {
|
||||
t.Fatalf("Read got error %v, want %v", err, io.EOF)
|
||||
}
|
||||
if s.StatusCode() != encodingTestStatusCode || s.StatusDesc() != encodingTestStatusDesc {
|
||||
t.Fatalf("stream with status code %d, status desc %v, want %d, %v", s.StatusCode(), s.StatusDesc(), encodingTestStatusCode, encodingTestStatusDesc)
|
||||
}
|
||||
ct.Close()
|
||||
server.stop()
|
||||
|
|
@ -769,29 +777,3 @@ func TestIsReservedHeader(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGrpcMessageEncode(t *testing.T) {
|
||||
testGrpcMessageEncode(t, "my favorite character is \u0000", "my favorite character is %00")
|
||||
testGrpcMessageEncode(t, "my favorite character is %", "my favorite character is %25")
|
||||
}
|
||||
|
||||
func TestGrpcMessageDecode(t *testing.T) {
|
||||
testGrpcMessageDecode(t, "Hello", "Hello")
|
||||
testGrpcMessageDecode(t, "H%61o", "Hao")
|
||||
testGrpcMessageDecode(t, "H%6", "H%6")
|
||||
testGrpcMessageDecode(t, "%G0", "%G0")
|
||||
}
|
||||
|
||||
func testGrpcMessageEncode(t *testing.T, input string, expected string) {
|
||||
actual := grpcMessageEncode(input)
|
||||
if expected != actual {
|
||||
t.Errorf("Expected %s from grpcMessageEncode, got %s", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func testGrpcMessageDecode(t *testing.T, input string, expected string) {
|
||||
actual := grpcMessageDecode(input)
|
||||
if expected != actual {
|
||||
t.Errorf("Expected %s from grpcMessageDecode, got %s", expected, actual)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue