diff --git a/cli/cmd/inject.go b/cli/cmd/inject.go index a3a05cf3c..10dadbe70 100644 --- a/cli/cmd/inject.go +++ b/cli/cmd/inject.go @@ -70,7 +70,7 @@ sub-folders, or coming from stdin.`, kubectl get deploy -o yaml | linkerd inject - | kubectl apply -f - # Injecting a file from a remote URL - linkerd inject http://url.to/yml | kubectl apply -f - + linkerd inject https://url.to/yml | kubectl apply -f - # Inject all the resources inside a folder and its sub-folders. linkerd inject | kubectl apply -f -`, diff --git a/cli/cmd/inject_test.go b/cli/cmd/inject_test.go index 997d53c85..f3b56372d 100644 --- a/cli/cmd/inject_test.go +++ b/cli/cmd/inject_test.go @@ -577,7 +577,7 @@ func TestInjectFilePath(t *testing.T) { }) } -func TestValidURL(t *testing.T) { +func TestToURL(t *testing.T) { // if the string follows a URL pattern, true has to be returned // if not false is returned @@ -593,9 +593,9 @@ func TestValidURL(t *testing.T) { } for url, expectedValue := range tests { - value := isValidURL(url) - if value != expectedValue { - t.Errorf("Result mismatch for %s. expected %v, but got %v", url, expectedValue, value) + _, ok := toURL(url) + if ok != expectedValue { + t.Errorf("Result mismatch for %s. expected %v, but got %v", url, expectedValue, ok) } } diff --git a/cli/cmd/inject_util.go b/cli/cmd/inject_util.go index 60175f80e..2c4508a3f 100644 --- a/cli/cmd/inject_util.go +++ b/cli/cmd/inject_util.go @@ -10,6 +10,7 @@ import ( "net/url" "os" "path/filepath" + "strings" "github.com/linkerd/linkerd2/pkg/inject" corev1 "k8s.io/api/core/v1" @@ -142,17 +143,19 @@ func processList(bytes []byte, rt resourceTransformer) ([]byte, []inject.Report, // Read all the resource files found in path into a slice of readers. // path can be either a file, directory or stdin. func read(path string) ([]io.Reader, error) { - var ( - in []io.Reader - err error - ) if path == "-" { - in = append(in, os.Stdin) - } else if isValidURL(path) { - resp, err := http.Get(path) + return []io.Reader{os.Stdin}, nil + } + + if url, ok := toURL(path); ok { + if strings.ToLower(url.Scheme) != "https" { + return nil, fmt.Errorf("only HTTPS URLs are allowed") + } + resp, err := http.Get(url.String()) if err != nil { return nil, err } + defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("unable to read URL %q, server reported %s, status code=%d", path, resp.Status, resp.StatusCode) @@ -164,26 +167,21 @@ func read(path string) ([]io.Reader, error) { if err != nil { return nil, err } - resp.Body.Close() - in = append(in, buf) - } else { - in, err = walk(path) - if err != nil { - return nil, err - } + + return []io.Reader{buf}, nil } - return in, nil + return walk(path) } // checks if the given string is a valid URL -func isValidURL(path string) bool { +func toURL(path string) (*url.URL, bool) { u, err := url.ParseRequestURI(path) - if err != nil { - return false + if err == nil && u.Host != "" && u.Scheme != "" { + return u, true } - return u.Host != "" && u.Scheme != "" + return nil, false } // walk walks the file tree rooted at path. path may be a file or a directory.