fix: immediately flush data to client for event-stream response (#3375)

This commit is contained in:
mchtech 2024-07-29 12:20:36 +08:00 committed by GitHub
parent d2453b93d5
commit 83450d5924
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 212 additions and 1 deletions

View File

@ -22,6 +22,7 @@ import (
"encoding/base64"
"errors"
"io"
"mime"
"net"
"net/http"
"net/http/httputil"
@ -401,6 +402,25 @@ func parseBasicAuth(auth string) (username, password string, ok bool) {
return cs[:s], cs[s+1:], true
}
// flushInterval returns zero, conditionally
// overriding its value for a specific request/response.
func (proxy *Proxy) flushInterval(res *http.Response) time.Duration {
resCT := res.Header.Get("Content-Type")
// For Server-Sent Events responses, flush immediately.
// The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream
if baseCT, _, _ := mime.ParseMediaType(resCT); baseCT == "text/event-stream" {
return -1 // negative means immediately
}
// We might have the case of streaming for which Content-Length might be unset.
if res.ContentLength == -1 {
return -1
}
return 0
}
func (proxy *Proxy) handleHTTP(span trace.Span, w http.ResponseWriter, req *http.Request) {
resp, err := proxy.transport.RoundTrip(req)
if err != nil {
@ -412,7 +432,26 @@ func (proxy *Proxy) handleHTTP(span trace.Span, w http.ResponseWriter, req *http
copyHeader(w.Header(), resp.Header)
w.WriteHeader(resp.StatusCode)
span.SetAttributes(semconv.HTTPStatusCodeKey.Int(resp.StatusCode))
if n, err := io.Copy(w, resp.Body); err != nil && err != io.EOF {
// support event stream responses, see: https://github.com/golang/go/issues/2012
var lw io.Writer = w
if flushInterval := proxy.flushInterval(resp); flushInterval != 0 {
mlw := &maxLatencyWriter{
dst: w,
flush: http.NewResponseController(w).Flush,
latency: flushInterval,
}
defer mlw.stop()
// set up initial timer so headers get flushed even if body writes are delayed
mlw.flushPending = true
mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
lw = mlw
logger.Debugf("handle event stream response: %s, url%s", req.Host, req.URL.String())
}
if n, err := io.Copy(lw, resp.Body); err != nil && err != io.EOF {
if peerID := resp.Header.Get(config.HeaderDragonflyPeer); peerID != "" {
logger.Errorf("failed to write http body: %v, peer: %s, task: %s, written bytes: %d",
err, peerID, resp.Header.Get(config.HeaderDragonflyTask), n)

View File

@ -18,9 +18,13 @@ package proxy
import (
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
@ -133,6 +137,56 @@ func (tc *testCase) TestMirror(t *testing.T) {
}
}
func (tc *testCase) TestEventStream(t *testing.T) {
a := assert.New(t)
if !a.Nil(tc.Error) {
return
}
tp, err := NewProxy(WithRules(tc.Rules))
if !a.Nil(err) {
return
}
tp.transport = &mockTransport{}
for _, item := range tc.Items {
req, err := http.NewRequest("GET", item.URL, nil)
if !a.Nil(err) {
continue
}
if !a.Equal(tp.shouldUseDragonfly(req), !item.Direct) {
fmt.Println(item.URL)
}
if item.UseHTTPS {
a.Equal(req.URL.Scheme, "https")
} else {
a.Equal(req.URL.Scheme, "http")
}
if item.Redirect != "" {
a.Equal(item.Redirect, req.URL.String())
}
if strings.Contains(req.URL.Path, "event-stream") {
batch := 10
_, span := tp.tracer.Start(req.Context(), config.SpanProxy)
w := &mockResponseWriter{}
req.Header.Set("X-Response-Batch", strconv.Itoa(batch))
if req.URL.Path == "/event-stream" {
req.Header.Set("X-Event-Stream", "true")
req.Header.Set("X-Response-Content-Length", "-1")
req.Header.Set("X-Response-Content-Encoding", "chunked")
req.Header.Set("X-Response-Content-Type", "text/event-stream")
tp.handleHTTP(span, w, req)
a.GreaterOrEqual(w.flushCount, batch)
} else {
req.Header.Set("X-Event-Stream", "false")
req.Header.Set("X-Response-Content-Length", strconv.Itoa(batch))
req.Header.Set("X-Response-Content-Encoding", "")
req.Header.Set("X-Response-Content-Type", "application/octet-stream")
tp.handleHTTP(span, w, req)
a.Less(w.flushCount, batch)
}
}
}
}
func TestMatch(t *testing.T) {
newTestCase().
WithRule("/blobs/sha256/", false, false, "").
@ -235,3 +289,64 @@ func TestMatchWithRedirect(t *testing.T) {
TestMirror(t)
}
func TestProxyEventStream(t *testing.T) {
newTestCase().
WithRule("/blobs/sha256/", false, false, "").
WithTest("http://h/event-stream", true, false, "").
WithTest("http://h/not-event-stream", true, false, "").
TestEventStream(t)
}
type mockResponseWriter struct {
flushCount int
}
func (w *mockResponseWriter) Header() http.Header {
return http.Header{}
}
func (w *mockResponseWriter) Write(p []byte) (int, error) {
return len(string(p)), nil
}
func (w *mockResponseWriter) WriteHeader(int) {}
func (w *mockResponseWriter) Flush() {
w.flushCount++
}
type mockTransport struct{}
func (rt *mockTransport) RoundTrip(r *http.Request) (*http.Response, error) {
batch, _ := strconv.Atoi(r.Header.Get("X-Response-Batch"))
return &http.Response{
StatusCode: http.StatusOK,
Body: &mockReadCloser{batch: batch},
Header: http.Header{
"Content-Length": []string{r.Header.Get("X-Response-Content-Length")},
"Content-Encoding": []string{r.Header.Get("X-Response-Content-Encoding")},
"Content-Type": []string{r.Header.Get("X-Response-Content-Type")},
},
}, nil
}
type mockReadCloser struct {
batch int
count int
}
func (rc *mockReadCloser) Read(p []byte) (n int, err error) {
if rc.count == rc.batch {
return 0, io.EOF
}
time.Sleep(100 * time.Millisecond)
p[0] = '0'
p = p[:1]
rc.count++
return len(p), nil
}
func (rc *mockReadCloser) Close() error {
return nil
}

View File

@ -0,0 +1,57 @@
package proxy
import (
"io"
"sync"
"time"
)
// copy from golang library, see https://github.com/golang/go/blob/master/src/net/http/httputil/reverseproxy.go
type maxLatencyWriter struct {
dst io.Writer
flush func() error
latency time.Duration // non-zero; negative means to flush immediately
mu sync.Mutex // protects t, flushPending, and dst.Flush
t *time.Timer
flushPending bool
}
func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
m.mu.Lock()
defer m.mu.Unlock()
n, err = m.dst.Write(p)
if m.latency < 0 {
m.flush() // nolint: errcheck
return
}
if m.flushPending {
return
}
if m.t == nil {
m.t = time.AfterFunc(m.latency, m.delayedFlush)
} else {
m.t.Reset(m.latency)
}
m.flushPending = true
return
}
func (m *maxLatencyWriter) delayedFlush() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
return
}
m.flush() // nolint: errcheck
m.flushPending = false
}
func (m *maxLatencyWriter) stop() {
m.mu.Lock()
defer m.mu.Unlock()
m.flushPending = false
if m.t != nil {
m.t.Stop()
}
}