Merge branch 'master' of https://github.com/dapr/components-contrib into release-1.10

This commit is contained in:
ItalyPaleAle 2023-02-17 01:16:55 +00:00
commit 9c9d96b64e
2 changed files with 80 additions and 13 deletions

View File

@ -30,10 +30,9 @@ import (
"time"
"unicode"
"github.com/mitchellh/mapstructure"
"github.com/dapr/components-contrib/bindings"
"github.com/dapr/components-contrib/internal/utils"
"github.com/dapr/components-contrib/metadata"
"github.com/dapr/kit/logger"
)
@ -60,12 +59,13 @@ type HTTPSource struct {
}
type httpMetadata struct {
URL string `mapstructure:"url"`
MTLSClientCert string `mapstructure:"mtlsClientCert"`
MTLSClientKey string `mapstructure:"mtlsClientKey"`
MTLSRootCA string `mapstructure:"mtlsRootCA"`
SecurityToken string `mapstructure:"securityToken"`
SecurityTokenHeader string `mapstructure:"securityTokenHeader"`
URL string `mapstructure:"url"`
MTLSClientCert string `mapstructure:"mtlsClientCert"`
MTLSClientKey string `mapstructure:"mtlsClientKey"`
MTLSRootCA string `mapstructure:"mtlsRootCA"`
SecurityToken string `mapstructure:"securityToken"`
SecurityTokenHeader string `mapstructure:"securityTokenHeader"`
ResponseTimeout *time.Duration `mapstructure:"responseTimeout"`
}
// NewHTTP returns a new HTTPSource.
@ -74,11 +74,12 @@ func NewHTTP(logger logger.Logger) bindings.OutputBinding {
}
// Init performs metadata parsing.
func (h *HTTPSource) Init(_ context.Context, metadata bindings.Metadata) error {
func (h *HTTPSource) Init(_ context.Context, meta bindings.Metadata) error {
var err error
if err = mapstructure.Decode(metadata.Properties, &h.metadata); err != nil {
if err = metadata.DecodeMetadata(meta.Properties, &h.metadata); err != nil {
return err
}
var tlsConfig *tls.Config
if h.metadata.MTLSClientCert != "" && h.metadata.MTLSClientKey != "" {
tlsConfig, err = h.readMTLSCertificates()
@ -100,11 +101,11 @@ func (h *HTTPSource) Init(_ context.Context, metadata bindings.Metadata) error {
netTransport.TLSClientConfig = tlsConfig
}
h.client = &http.Client{
Timeout: time.Second * 30,
Timeout: 0, // no time out here, we use request timeouts instead
Transport: netTransport,
}
if val := metadata.Properties["errorIfNot2XX"]; val != "" {
if val := meta.Properties["errorIfNot2XX"]; val != "" {
h.errorIfNot2XX = utils.IsTruthy(val)
} else {
// Default behavior
@ -186,7 +187,7 @@ func (h *HTTPSource) Operations() []bindings.OperationKind {
}
// Invoke performs an HTTP request to the configured HTTP endpoint.
func (h *HTTPSource) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
func (h *HTTPSource) Invoke(parentCtx context.Context, req *bindings.InvokeRequest) (*bindings.InvokeResponse, error) {
u := h.metadata.URL
errorIfNot2XX := h.errorIfNot2XX // Default to the component config (default is true)
@ -222,6 +223,15 @@ func (h *HTTPSource) Invoke(ctx context.Context, req *bindings.InvokeRequest) (*
return nil, fmt.Errorf("invalid operation: %s", req.Operation)
}
var ctx context.Context
if h.metadata.ResponseTimeout == nil {
ctx = parentCtx
} else {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(parentCtx, *h.metadata.ResponseTimeout)
defer cancel()
}
request, err := http.NewRequestWithContext(ctx, method, u, body)
if err != nil {
return nil, err

View File

@ -26,6 +26,7 @@ import (
"strconv"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -100,6 +101,12 @@ func (h *HTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
input = inputFromHeader
}
sleepSeconds := req.Header.Get("X-Delay-Seconds")
if sleepSeconds != "" {
seconds, _ := strconv.Atoi(sleepSeconds)
time.Sleep(time.Duration(seconds) * time.Second)
}
w.Header().Set("Content-Type", "text/plain")
statusCode := req.Header.Get("X-Status-Code")
@ -614,3 +621,53 @@ func verifyNon2XXErrorsSuppressed(t *testing.T, hs bindings.OutputBinding, handl
})
}
}
func TestTimeoutHonored(t *testing.T) {
handler := NewHTTPHandler()
s := httptest.NewServer(handler)
defer s.Close()
hs, err := InitBinding(s, map[string]string{"responseTimeout": "1s"})
require.NoError(t, err)
verifyTimeoutBehavior(t, hs, handler)
}
func verifyTimeoutBehavior(t *testing.T, hs bindings.OutputBinding, handler *HTTPHandler) {
tests := map[string]TestCase{
"exceedsResponseTimeout": {
input: "GET",
operation: "get",
metadata: map[string]string{"X-Delay-Seconds": "2"},
path: "/",
err: "context deadline exceeded",
statusCode: 200,
},
"meetsResposneTimeout": {
input: "GET",
operation: "get",
metadata: map[string]string{},
path: "/",
err: "",
statusCode: 200,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
req := tc.ToInvokeRequest()
response, err := hs.Invoke(context.Background(), &req)
if tc.err == "" {
require.NoError(t, err)
assert.Equal(t, tc.path, handler.Path)
if tc.statusCode != 204 {
// 204 will return no content, so we should skip checking
assert.Equal(t, strings.ToUpper(tc.input), string(response.Data))
}
assert.Equal(t, "text/plain", response.Metadata["Content-Type"])
} else {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.err)
}
})
}
}