http-add-on/interceptor/middleware/responsewriter_test.go

206 lines
4.3 KiB
Go

package middleware
import (
"bufio"
"fmt"
"net/http"
"net/http/httptest"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"go.uber.org/mock/gomock"
)
var _ = Describe("responseWriter", func() {
Context("Interface compliance", func() {
It("implements http.ResponseWriter interface", func() {
var rw *responseWriter
var _ http.ResponseWriter = rw
})
It("implements http.Hijacker interface", func() {
var rw *responseWriter
var _ http.Hijacker = rw
})
})
Context("New", func() {
It("returns new object with expected field values set", func() {
var (
w = httptest.NewRecorder()
)
rw := newResponseWriter(w)
Expect(rw).NotTo(BeNil())
Expect(rw.downstreamResponseWriter).To(Equal(w))
Expect(rw.bytesWritten).To(Equal(0))
Expect(rw.statusCode).To(Equal(0))
})
})
Context("BytesWritten", func() {
It("returns the expected value", func() {
const (
bw = 128
)
rw := &responseWriter{
bytesWritten: bw,
}
ret := rw.BytesWritten()
Expect(ret).To(Equal(bw))
})
})
Context("StatusCode", func() {
It("returns the expected value", func() {
const (
sc = http.StatusTeapot
)
rw := &responseWriter{
statusCode: sc,
}
ret := rw.StatusCode()
Expect(ret).To(Equal(sc))
})
})
Context("Header", func() {
It("returns downstream method call", func() {
var (
w = httptest.NewRecorder()
)
rw := &responseWriter{
downstreamResponseWriter: w,
}
h := w.Header()
h.Set("Content-Type", "application/json")
ret := rw.Header()
Expect(ret).To(Equal(h))
})
})
Context("Write", func() {
It("invokes downstream method, increases bytesWritten accordingly, and returns expected values", func() {
const (
body = "KEDA"
bodyLen = len(body)
initialBW = 60
)
var (
w = httptest.NewRecorder()
)
rw := &responseWriter{
bytesWritten: initialBW,
downstreamResponseWriter: w,
}
n, err := rw.Write([]byte(body))
Expect(err).To(BeNil())
Expect(n).To(Equal(bodyLen))
Expect(rw.bytesWritten).To(Equal(initialBW + bodyLen))
Expect(w.Body.String()).To(Equal(body))
})
})
Context("WriteHeader", func() {
It("invokes downstream method and records the value", func() {
const (
sc = http.StatusTeapot
)
var (
w = httptest.NewRecorder()
)
rw := &responseWriter{
statusCode: http.StatusOK,
downstreamResponseWriter: w,
}
rw.WriteHeader(sc)
Expect(rw.statusCode).To(Equal(sc))
Expect(w.Code).To(Equal(sc))
})
})
Context("Hijack", func() {
var ctrl *gomock.Controller
BeforeEach(func() {
ctrl = gomock.NewController(GinkgoT())
})
AfterEach(func() {
ctrl.Finish()
})
It("successfully hijacks when downstream ResponseWriter implements http.Hijacker", func() {
// Create mocks using the generated mocks
mockConn := NewMockConn(ctrl)
mockReadWriter := &bufio.ReadWriter{}
mockHijackerWriter := NewMockHijackerResponseWriter(ctrl)
// Set up expectations
mockHijackerWriter.EXPECT().Hijack().Return(mockConn, mockReadWriter, nil)
rw := &responseWriter{
downstreamResponseWriter: mockHijackerWriter,
}
conn, readWriter, err := rw.Hijack()
Expect(err).To(BeNil())
Expect(conn).To(Equal(mockConn))
Expect(readWriter).To(Equal(mockReadWriter))
})
It("returns error when downstream ResponseWriter does not implement http.Hijacker", func() {
var (
w = httptest.NewRecorder()
)
rw := &responseWriter{
downstreamResponseWriter: w,
}
conn, readWriter, err := rw.Hijack()
Expect(err).NotTo(BeNil())
Expect(err.Error()).To(Equal("http.Hijacker not implemented"))
Expect(conn).To(BeNil())
Expect(readWriter).To(BeNil())
})
It("forwards error when downstream hijacker returns error", func() {
expectedError := fmt.Errorf("hijack failed")
mockHijackerWriter := NewMockHijackerResponseWriter(ctrl)
// Set up expectations
mockHijackerWriter.EXPECT().Hijack().Return(nil, nil, expectedError)
rw := &responseWriter{
downstreamResponseWriter: mockHijackerWriter,
}
conn, readWriter, err := rw.Hijack()
Expect(err).NotTo(BeNil())
Expect(err.Error()).To(Equal("hijack failed"))
Expect(conn).To(BeNil())
Expect(readWriter).To(BeNil())
})
})
})