diff --git a/cmd/sparkctl/app/create.go b/cmd/sparkctl/app/create.go index 9509b9d7..34c68760 100644 --- a/cmd/sparkctl/app/create.go +++ b/cmd/sparkctl/app/create.go @@ -57,6 +57,8 @@ var createCmd = &cobra.Command{ Short: "Create a SparkApplication object", Long: `Create a SparkApplication from a given YAML file storing the application specification.`, Run: func(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + if From != "" && len(args) != 1 { fmt.Fprintln(os.Stderr, "must specify the name of a ScheduledSparkApplication") return @@ -80,11 +82,11 @@ var createCmd = &cobra.Command{ } if From != "" { - if err := createFromScheduledSparkApplication(args[0], kubeClient, crdClient); err != nil { + if err := createFromScheduledSparkApplication(ctx, args[0], kubeClient, crdClient); err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) } } else { - if err := createFromYaml(args[0], kubeClient, crdClient); err != nil { + if err := createFromYaml(ctx, args[0], kubeClient, crdClient); err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) } } @@ -114,20 +116,20 @@ func init() { "the name of ScheduledSparkApplication from which a forced SparkApplication run is created") } -func createFromYaml(yamlFile string, kubeClient clientset.Interface, crdClient crdclientset.Interface) error { +func createFromYaml(ctx context.Context, yamlFile string, kubeClient clientset.Interface, crdClient crdclientset.Interface) error { app, err := loadFromYAML(yamlFile) if err != nil { return fmt.Errorf("failed to read a SparkApplication from %s: %v", yamlFile, err) } - if err := createSparkApplication(app, kubeClient, crdClient); err != nil { + if err := createSparkApplication(ctx, app, kubeClient, crdClient); err != nil { return fmt.Errorf("failed to create SparkApplication %s: %v", app.Name, err) } return nil } -func createFromScheduledSparkApplication(name string, kubeClient clientset.Interface, crdClient crdclientset.Interface) error { +func createFromScheduledSparkApplication(ctx context.Context, name string, kubeClient clientset.Interface, crdClient crdclientset.Interface) error { sapp, err := crdClient.SparkoperatorV1beta2().ScheduledSparkApplications(Namespace).Get(context.TODO(), From, metav1.GetOptions{}) if err != nil { return fmt.Errorf("failed to get ScheduledSparkApplication %s: %v", From, err) @@ -149,14 +151,14 @@ func createFromScheduledSparkApplication(name string, kubeClient clientset.Inter Spec: *sapp.Spec.Template.DeepCopy(), } - if err := createSparkApplication(app, kubeClient, crdClient); err != nil { + if err := createSparkApplication(ctx, app, kubeClient, crdClient); err != nil { return fmt.Errorf("failed to create SparkApplication %s: %v", app.Name, err) } return nil } -func createSparkApplication(app *v1beta2.SparkApplication, kubeClient clientset.Interface, crdClient crdclientset.Interface) error { +func createSparkApplication(ctx context.Context, app *v1beta2.SparkApplication, kubeClient clientset.Interface, crdClient crdclientset.Interface) error { if DeleteIfExists { if err := deleteSparkApplication(app.Name, crdClient); err != nil { return err @@ -190,7 +192,7 @@ func createSparkApplication(app *v1beta2.SparkApplication, kubeClient clientset. fmt.Printf("SparkApplication \"%s\" created\n", app.Name) if LogsEnabled { - if err := doLog(app.Name, true, kubeClient, crdClient); err != nil { + if err := doLog(ctx, app.Name, true, kubeClient, crdClient); err != nil { return nil } } diff --git a/cmd/sparkctl/app/log.go b/cmd/sparkctl/app/log.go index 9ccdf4a8..788ce6b5 100644 --- a/cmd/sparkctl/app/log.go +++ b/cmd/sparkctl/app/log.go @@ -39,6 +39,8 @@ var logCommand = &cobra.Command{ Short: "log is a sub-command of sparkctl that fetches logs of a Spark application.", Long: ``, Run: func(cmd *cobra.Command, args []string) { + ctx := cmd.Context() + if len(args) != 1 { fmt.Fprintln(os.Stderr, "must specify a SparkApplication name") return @@ -56,7 +58,7 @@ var logCommand = &cobra.Command{ return } - if err := doLog(args[0], FollowLogs, kubeClientset, crdClientset); err != nil { + if err := doLog(ctx, args[0], FollowLogs, kubeClientset, crdClientset); err != nil { fmt.Fprintf(os.Stderr, "failed to get driver logs of SparkApplication %s: %v\n", args[0], err) } }, @@ -69,13 +71,14 @@ func init() { } func doLog( + ctx context.Context, name string, followLogs bool, kubeClient clientset.Interface, crdClient crdclientset.Interface) error { timeout := 30 * time.Second - podNameChannel := getPodNameChannel(name, crdClient) + podNameChannel := getPodNameChannel(ctx, name, crdClient) var podName string select { @@ -84,7 +87,7 @@ func doLog( return fmt.Errorf("not found pod name") } - waitLogsChannel := waitForLogsFromPodChannel(podName, kubeClient, crdClient) + waitLogsChannel := waitForLogsFromPodChannel(ctx, podName, kubeClient, crdClient) select { case <-waitLogsChannel: @@ -93,19 +96,20 @@ func doLog( } if followLogs { - return streamLogs(os.Stdout, kubeClient, podName) + return streamLogs(ctx, os.Stdout, kubeClient, podName) } - return printLogs(os.Stdout, kubeClient, podName) + return printLogs(ctx, os.Stdout, kubeClient, podName) } func getPodNameChannel( + ctx context.Context, sparkApplicationName string, crdClient crdclientset.Interface) chan string { channel := make(chan string, 1) go func() { for { app, _ := crdClient.SparkoperatorV1beta2().SparkApplications(Namespace).Get( - context.TODO(), + ctx, sparkApplicationName, metav1.GetOptions{}) @@ -119,13 +123,14 @@ func getPodNameChannel( } func waitForLogsFromPodChannel( + ctx context.Context, podName string, kubeClient clientset.Interface, _ crdclientset.Interface) chan bool { channel := make(chan bool, 1) go func() { for { - _, err := kubeClient.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{}).Do(context.TODO()).Raw() + _, err := kubeClient.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{}).Do(ctx).Raw() if err == nil { channel <- true @@ -137,8 +142,8 @@ func waitForLogsFromPodChannel( } // printLogs is a one time operation that prints the fetched logs of the given pod. -func printLogs(out io.Writer, kubeClientset clientset.Interface, podName string) error { - rawLogs, err := kubeClientset.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{}).Do(context.TODO()).Raw() +func printLogs(ctx context.Context, out io.Writer, kubeClientset clientset.Interface, podName string) error { + rawLogs, err := kubeClientset.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{}).Do(ctx).Raw() if err != nil { return err } @@ -147,9 +152,9 @@ func printLogs(out io.Writer, kubeClientset clientset.Interface, podName string) } // streamLogs streams the logs of the given pod until there are no more logs available. -func streamLogs(out io.Writer, kubeClientset clientset.Interface, podName string) error { +func streamLogs(ctx context.Context, out io.Writer, kubeClientset clientset.Interface, podName string) error { request := kubeClientset.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{Follow: true}) - reader, err := request.Stream(context.TODO()) + reader, err := request.Stream(ctx) if err != nil { return err } diff --git a/cmd/sparkctl/app/root.go b/cmd/sparkctl/app/root.go index e845b8be..41d074ac 100644 --- a/cmd/sparkctl/app/root.go +++ b/cmd/sparkctl/app/root.go @@ -17,6 +17,7 @@ limitations under the License. package app import ( + "context" "fmt" "os" @@ -52,7 +53,9 @@ func init() { } func Execute() { - if err := rootCmd.Execute(); err != nil { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if err := rootCmd.ExecuteContext(ctx); err != nil { fmt.Fprintf(os.Stderr, "%v", err) } }