Add read limits in various places (#7968)

Prevent servers from buffering unbounded data.

Signed-off-by: Alex Leong <alex@buoyant.io>
This commit is contained in:
Alex Leong 2022-03-03 09:59:43 -08:00 committed by GitHub
parent 67bcd8f642
commit 9314191ab2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 34 additions and 9 deletions

View File

@ -24,6 +24,7 @@ RUN bin/scurl -o linkerd-await https://github.com/linkerd/linkerd-await/releases
## compile proxy-identity agent
FROM go-deps as golang
WORKDIR /linkerd-build
COPY pkg/util pkg/util
COPY pkg/flags pkg/flags
COPY pkg/tls pkg/tls
COPY pkg/version pkg/version

View File

@ -3,7 +3,6 @@ package heartbeat
import (
"context"
"fmt"
"io/ioutil"
"math"
"net/http"
"net/url"
@ -13,6 +12,7 @@ import (
pkgK8s "github.com/linkerd/linkerd2/controller/k8s"
"github.com/linkerd/linkerd2/pkg/config"
"github.com/linkerd/linkerd2/pkg/k8s"
"github.com/linkerd/linkerd2/pkg/util"
"github.com/linkerd/linkerd2/pkg/version"
promv1 "github.com/prometheus/client_golang/api/prometheus/v1"
"github.com/prometheus/common/model"
@ -235,7 +235,7 @@ func send(client *http.Client, baseURL string, v url.Values) error {
return fmt.Errorf("check URL [%s] request failed with: %w", req.URL.String(), err)
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
body, err := util.ReadAllLimit(resp.Body, util.MB)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}

View File

@ -6,13 +6,13 @@ import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"sync/atomic"
"github.com/linkerd/linkerd2/controller/k8s"
pkgk8s "github.com/linkerd/linkerd2/pkg/k8s"
pkgTls "github.com/linkerd/linkerd2/pkg/tls"
"github.com/linkerd/linkerd2/pkg/util"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
admissionv1beta1 "k8s.io/api/admission/v1beta1"
@ -125,7 +125,7 @@ func (s *Server) serve(res http.ResponseWriter, req *http.Request) {
err error
)
if req.Body != nil {
data, err = ioutil.ReadAll(req.Body)
data, err = util.ReadAllLimit(req.Body, 10*util.MB)
if err != nil {
http.Error(res, err.Error(), http.StatusInternalServerError)
return

View File

@ -6,11 +6,11 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"github.com/golang/protobuf/proto"
"github.com/linkerd/linkerd2/pkg/k8s"
"github.com/linkerd/linkerd2/pkg/util"
metricsPb "github.com/linkerd/linkerd2/viz/metrics-api/gen/viz"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/status"
@ -45,7 +45,7 @@ func (e HTTPError) Error() string {
// HTTPRequestToProto converts an HTTP Request to a protobuf request.
func HTTPRequestToProto(req *http.Request, protoRequestOut proto.Message) error {
bytes, err := ioutil.ReadAll(req.Body)
bytes, err := util.ReadAllLimit(req.Body, 100*util.MB)
if err != nil {
return HTTPError{
Code: http.StatusBadRequest,
@ -168,7 +168,7 @@ func CheckIfResponseHasError(rsp *http.Response) error {
// check for JSON-encoded error
if rsp.StatusCode != http.StatusOK {
if rsp.Body != nil {
bytes, err := ioutil.ReadAll(rsp.Body)
bytes, err := util.ReadAllLimit(rsp.Body, 100*util.MB)
if err == nil && len(bytes) > 0 {
body := string(bytes)
obj, err := k8s.ToRuntimeObject(body)

View File

@ -1,11 +1,20 @@
package util
import (
"fmt"
"io"
"io/ioutil"
"strings"
httpPb "github.com/linkerd/linkerd2-proxy-api/go/http_types"
)
// KB = Kilobyte
const KB = 1024
// MB = Megabyte
const MB = KB * 1024
// ParseScheme converts a scheme string to protobuf
// TODO: validate scheme
func ParseScheme(scheme string) *httpPb.Scheme {
@ -41,3 +50,17 @@ func ParseMethod(method string) *httpPb.HttpMethod {
},
}
}
// ReadAllLimit reads from r until EOF or until limit bytes are read. If EOF is
// reached, the full bytes are returned. If the limit is reached, an error is
// returned.
func ReadAllLimit(r io.Reader, limit int) ([]byte, error) {
bytes, err := ioutil.ReadAll(io.LimitReader(r, int64(limit)))
if err != nil {
return nil, err
}
if len(bytes) == limit {
return nil, fmt.Errorf("limit reached while reading: %d", limit)
}
return bytes, nil
}

View File

@ -5,8 +5,9 @@ import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"net/http"
"github.com/linkerd/linkerd2/pkg/util"
)
// Channels provides an interface to interact with a set of release channels.
@ -80,7 +81,7 @@ func getLatestVersions(ctx context.Context, client *http.Client, url string) (Ch
return Channels{}, fmt.Errorf("unexpected versioncheck response: %s", rsp.Status)
}
bytes, err := ioutil.ReadAll(rsp.Body)
bytes, err := util.ReadAllLimit(rsp.Body, util.MB)
if err != nil {
return Channels{}, err
}