diff --git a/sparkctl/cmd/client.go b/sparkctl/cmd/client.go index fa368353..b9ccaf07 100644 --- a/sparkctl/cmd/client.go +++ b/sparkctl/cmd/client.go @@ -17,10 +17,12 @@ limitations under the License. package cmd import ( - crdclientset "k8s.io/spark-on-k8s-operator/pkg/client/clientset/versioned" - "k8s.io/client-go/tools/clientcmd" - "k8s.io/client-go/rest" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" clientset "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" + "k8s.io/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1alpha1" + crdclientset "k8s.io/spark-on-k8s-operator/pkg/client/clientset/versioned" ) func buildConfig(kubeConfig string) (*rest.Config, error) { @@ -32,6 +34,10 @@ func getKubeClient() (clientset.Interface, error) { if err != nil { return nil, err } + return getKubeClientForConfig(config) +} + +func getKubeClientForConfig(config *rest.Config) (clientset.Interface, error) { return clientset.NewForConfig(config) } @@ -40,5 +46,17 @@ func getSparkApplicationClient() (crdclientset.Interface, error) { if err != nil { return nil, err } + return getSparkApplicationClientForConfig(config) +} + +func getSparkApplicationClientForConfig(config *rest.Config) (crdclientset.Interface, error) { return crdclientset.NewForConfig(config) } + +func getSparkApplication(name string, crdClientset crdclientset.Interface) (*v1alpha1.SparkApplication, error) { + app, err := crdClientset.SparkoperatorV1alpha1().SparkApplications(Namespace).Get(name, metav1.GetOptions{}) + if err != nil { + return nil, err + } + return app, nil +} diff --git a/sparkctl/cmd/forward.go b/sparkctl/cmd/forward.go new file mode 100644 index 00000000..015c3541 --- /dev/null +++ b/sparkctl/cmd/forward.go @@ -0,0 +1,171 @@ +/* +Copyright 2017 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cmd + +import ( + "fmt" + "net/http" + "net/url" + "os" + "os/signal" + + "github.com/spf13/cobra" + apiv1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + clientset "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + "k8s.io/client-go/tools/portforward" + "k8s.io/client-go/transport/spdy" + crdclientset "k8s.io/spark-on-k8s-operator/pkg/client/clientset/versioned" + "time" +) + +var LocalPort int32 +var RemotePort int32 + +var forwardCmd = &cobra.Command{ + Use: "forward [--local-port ] [--remote-port ]", + Short: "Start to forward a local port to the remote port of the driver UI", + Long: `Start to forward a local port to the remote port of the driver UI so the UI can be accessed locally.`, + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 1 { + fmt.Fprintln(os.Stderr, "must specify a SparkApplication name") + return + } + + config, err := buildConfig(KubeConfig) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to get kubeconfig: %v", err) + return + } + + crdClientset, err := getSparkApplicationClientForConfig(config) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to get SparkApplication client: %v\n", err) + return + } + + kubeClientset, err := getKubeClientForConfig(config) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to get REST client: %v\n", err) + return + } + restClient := kubeClientset.CoreV1().RESTClient() + + driverPodUrl, driverPodName, err := getDriverPodUrlAndName(args[0], restClient, crdClientset) + if err != nil { + fmt.Fprintf(os.Stderr, + "failed to get an API server URL of the driver pod of SparkApplication %s: %v", + args[0], err) + return + } + + stopCh := make(chan struct{}, 1) + readyCh := make(chan struct{}) + + forwarder, err := newPortForwarder(config, driverPodUrl, stopCh, readyCh) + if err != nil { + fmt.Fprintf(os.Stderr, "failed to get a port forwarder: %v", err) + return + } + + fmt.Printf("Forwarding from %d -> %d\n", LocalPort, RemotePort) + if err = runPortForward(driverPodName, stopCh, forwarder, kubeClientset); err != nil { + fmt.Fprintf(os.Stderr, "failed to run port forwarding: %v", err) + } + }, +} + +func init() { + forwardCmd.Flags().Int32VarP(&LocalPort, "local-port", "l", 4040, + "local port to forward from") + forwardCmd.Flags().Int32VarP(&RemotePort, "remote-port", "r", 4040, + "remote port to forward to") +} + +func newPortForwarder( + config *rest.Config, + url *url.URL, + stopCh chan struct{}, + readyCh chan struct{}) (*portforward.PortForwarder, error) { + transport, upgrader, err := spdy.RoundTripperFor(config) + if err != nil { + return nil, err + } + + dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url) + ports := []string{fmt.Sprintf("%d:%d", LocalPort, RemotePort)} + fw, err := portforward.New(dialer, ports, stopCh, readyCh, nil, os.Stderr) + if err != nil { + return nil, err + } + + return fw, nil +} + +func getDriverPodUrlAndName( + name string, + restClient rest.Interface, + crdClientset crdclientset.Interface) (*url.URL, string, error) { + app, err := getSparkApplication(name, crdClientset) + if err != nil { + return nil, "", fmt.Errorf("failed to get SparkApplication %s: %v", name, err) + } + + if app.Status.DriverInfo.PodName != "" { + request := restClient.Post(). + Resource("pods"). + Namespace(Namespace). + Name(app.Status.DriverInfo.PodName). + SubResource("portforward") + return request.URL(), app.Status.DriverInfo.PodName, nil + } + + return nil, "", fmt.Errorf("driver pod name of SparkApplication %s is not available yet", name) +} + +func runPortForward( + driverPodName string, + stopCh chan struct{}, + forwarder *portforward.PortForwarder, + kubeClientset clientset.Interface) error { + signals := make(chan os.Signal, 1) + signal.Notify(signals, os.Interrupt) + defer signal.Stop(signals) + + go func() { + defer close(stopCh) + for { + pod, err := kubeClientset.CoreV1().Pods(Namespace).Get(driverPodName, metav1.GetOptions{}) + if err != nil { + break + } + if pod.Status.Phase == apiv1.PodSucceeded || pod.Status.Phase == apiv1.PodFailed { + break + } + time.Sleep(1 * time.Second) + } + fmt.Println("Driver pod has terminated, stopping forwarding") + }() + + go func() { + <-signals + close(stopCh) + }() + + return forwarder.ForwardPorts() +} \ No newline at end of file diff --git a/sparkctl/cmd/root.go b/sparkctl/cmd/root.go index be750e9a..302cdac5 100644 --- a/sparkctl/cmd/root.go +++ b/sparkctl/cmd/root.go @@ -17,10 +17,10 @@ limitations under the License. package cmd import ( + "fmt" "os" "github.com/spf13/cobra" - "fmt" ) var defaultKubeConfig = os.Getenv("HOME") + "/.kube/config" @@ -40,7 +40,7 @@ func init() { "The namespace in which the SparkApplication is to be created") rootCmd.PersistentFlags().StringVarP(&KubeConfig, "kubeconfig", "c", defaultKubeConfig, "The namespace in which the SparkApplication is to be created") - rootCmd.AddCommand(createCmd, deleteCmd, statusCmd, logCommand, listCmd) + rootCmd.AddCommand(createCmd, deleteCmd, statusCmd, logCommand, listCmd, forwardCmd) } func Execute() { diff --git a/sparkctl/cmd/status.go b/sparkctl/cmd/status.go index 2a44dac7..f26b489e 100644 --- a/sparkctl/cmd/status.go +++ b/sparkctl/cmd/status.go @@ -21,7 +21,6 @@ import ( "os" "github.com/spf13/cobra" - metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1alpha1" crdclientset "k8s.io/spark-on-k8s-operator/pkg/client/clientset/versioned" ) @@ -49,7 +48,7 @@ var statusCmd = &cobra.Command{ } func doStatus(name string, crdClientset crdclientset.Interface) error { - app, err := crdClientset.SparkoperatorV1alpha1().SparkApplications(Namespace).Get(name, metav1.GetOptions{}) + app, err := getSparkApplication(name, crdClientset) if err != nil { return fmt.Errorf("failed to get SparkApplication %s: %v", name, err) }