206 lines
4.3 KiB
Go
206 lines
4.3 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bufio"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
|
|
. "github.com/onsi/ginkgo/v2"
|
|
. "github.com/onsi/gomega"
|
|
"go.uber.org/mock/gomock"
|
|
)
|
|
|
|
var _ = Describe("responseWriter", func() {
|
|
Context("Interface compliance", func() {
|
|
It("implements http.ResponseWriter interface", func() {
|
|
var rw *responseWriter
|
|
var _ http.ResponseWriter = rw
|
|
})
|
|
|
|
It("implements http.Hijacker interface", func() {
|
|
var rw *responseWriter
|
|
var _ http.Hijacker = rw
|
|
})
|
|
})
|
|
|
|
Context("New", func() {
|
|
It("returns new object with expected field values set", func() {
|
|
var (
|
|
w = httptest.NewRecorder()
|
|
)
|
|
|
|
rw := newResponseWriter(w)
|
|
Expect(rw).NotTo(BeNil())
|
|
Expect(rw.downstreamResponseWriter).To(Equal(w))
|
|
Expect(rw.bytesWritten).To(Equal(0))
|
|
Expect(rw.statusCode).To(Equal(0))
|
|
})
|
|
})
|
|
|
|
Context("BytesWritten", func() {
|
|
It("returns the expected value", func() {
|
|
const (
|
|
bw = 128
|
|
)
|
|
|
|
rw := &responseWriter{
|
|
bytesWritten: bw,
|
|
}
|
|
|
|
ret := rw.BytesWritten()
|
|
Expect(ret).To(Equal(bw))
|
|
})
|
|
})
|
|
|
|
Context("StatusCode", func() {
|
|
It("returns the expected value", func() {
|
|
const (
|
|
sc = http.StatusTeapot
|
|
)
|
|
|
|
rw := &responseWriter{
|
|
statusCode: sc,
|
|
}
|
|
|
|
ret := rw.StatusCode()
|
|
Expect(ret).To(Equal(sc))
|
|
})
|
|
})
|
|
|
|
Context("Header", func() {
|
|
It("returns downstream method call", func() {
|
|
var (
|
|
w = httptest.NewRecorder()
|
|
)
|
|
|
|
rw := &responseWriter{
|
|
downstreamResponseWriter: w,
|
|
}
|
|
|
|
h := w.Header()
|
|
h.Set("Content-Type", "application/json")
|
|
|
|
ret := rw.Header()
|
|
Expect(ret).To(Equal(h))
|
|
})
|
|
})
|
|
|
|
Context("Write", func() {
|
|
It("invokes downstream method, increases bytesWritten accordingly, and returns expected values", func() {
|
|
const (
|
|
body = "KEDA"
|
|
bodyLen = len(body)
|
|
initialBW = 60
|
|
)
|
|
|
|
var (
|
|
w = httptest.NewRecorder()
|
|
)
|
|
|
|
rw := &responseWriter{
|
|
bytesWritten: initialBW,
|
|
downstreamResponseWriter: w,
|
|
}
|
|
|
|
n, err := rw.Write([]byte(body))
|
|
Expect(err).To(BeNil())
|
|
Expect(n).To(Equal(bodyLen))
|
|
|
|
Expect(rw.bytesWritten).To(Equal(initialBW + bodyLen))
|
|
|
|
Expect(w.Body.String()).To(Equal(body))
|
|
})
|
|
})
|
|
|
|
Context("WriteHeader", func() {
|
|
It("invokes downstream method and records the value", func() {
|
|
const (
|
|
sc = http.StatusTeapot
|
|
)
|
|
|
|
var (
|
|
w = httptest.NewRecorder()
|
|
)
|
|
|
|
rw := &responseWriter{
|
|
statusCode: http.StatusOK,
|
|
downstreamResponseWriter: w,
|
|
}
|
|
rw.WriteHeader(sc)
|
|
|
|
Expect(rw.statusCode).To(Equal(sc))
|
|
|
|
Expect(w.Code).To(Equal(sc))
|
|
})
|
|
})
|
|
|
|
Context("Hijack", func() {
|
|
var ctrl *gomock.Controller
|
|
|
|
BeforeEach(func() {
|
|
ctrl = gomock.NewController(GinkgoT())
|
|
})
|
|
|
|
AfterEach(func() {
|
|
ctrl.Finish()
|
|
})
|
|
|
|
It("successfully hijacks when downstream ResponseWriter implements http.Hijacker", func() {
|
|
// Create mocks using the generated mocks
|
|
mockConn := NewMockConn(ctrl)
|
|
mockReadWriter := &bufio.ReadWriter{}
|
|
mockHijackerWriter := NewMockHijackerResponseWriter(ctrl)
|
|
|
|
// Set up expectations
|
|
mockHijackerWriter.EXPECT().Hijack().Return(mockConn, mockReadWriter, nil)
|
|
|
|
rw := &responseWriter{
|
|
downstreamResponseWriter: mockHijackerWriter,
|
|
}
|
|
|
|
conn, readWriter, err := rw.Hijack()
|
|
|
|
Expect(err).To(BeNil())
|
|
Expect(conn).To(Equal(mockConn))
|
|
Expect(readWriter).To(Equal(mockReadWriter))
|
|
})
|
|
|
|
It("returns error when downstream ResponseWriter does not implement http.Hijacker", func() {
|
|
var (
|
|
w = httptest.NewRecorder()
|
|
)
|
|
|
|
rw := &responseWriter{
|
|
downstreamResponseWriter: w,
|
|
}
|
|
|
|
conn, readWriter, err := rw.Hijack()
|
|
|
|
Expect(err).NotTo(BeNil())
|
|
Expect(err.Error()).To(Equal("http.Hijacker not implemented"))
|
|
Expect(conn).To(BeNil())
|
|
Expect(readWriter).To(BeNil())
|
|
})
|
|
|
|
It("forwards error when downstream hijacker returns error", func() {
|
|
expectedError := fmt.Errorf("hijack failed")
|
|
mockHijackerWriter := NewMockHijackerResponseWriter(ctrl)
|
|
|
|
// Set up expectations
|
|
mockHijackerWriter.EXPECT().Hijack().Return(nil, nil, expectedError)
|
|
|
|
rw := &responseWriter{
|
|
downstreamResponseWriter: mockHijackerWriter,
|
|
}
|
|
|
|
conn, readWriter, err := rw.Hijack()
|
|
|
|
Expect(err).NotTo(BeNil())
|
|
Expect(err.Error()).To(Equal("hijack failed"))
|
|
Expect(conn).To(BeNil())
|
|
Expect(readWriter).To(BeNil())
|
|
})
|
|
})
|
|
})
|