use cmd context in sparkctl (#2447)

Signed-off-by: Manabu McCloskey <manabu.mccloskey@gmail.com>
This commit is contained in:
Manabu McCloskey 2025-02-19 19:52:42 -08:00 committed by GitHub
parent 405ae51de4
commit bd197c6f8c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 30 additions and 20 deletions

View File

@ -57,6 +57,8 @@ var createCmd = &cobra.Command{
Short: "Create a SparkApplication object", Short: "Create a SparkApplication object",
Long: `Create a SparkApplication from a given YAML file storing the application specification.`, Long: `Create a SparkApplication from a given YAML file storing the application specification.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
if From != "" && len(args) != 1 { if From != "" && len(args) != 1 {
fmt.Fprintln(os.Stderr, "must specify the name of a ScheduledSparkApplication") fmt.Fprintln(os.Stderr, "must specify the name of a ScheduledSparkApplication")
return return
@ -80,11 +82,11 @@ var createCmd = &cobra.Command{
} }
if From != "" { if From != "" {
if err := createFromScheduledSparkApplication(args[0], kubeClient, crdClient); err != nil { if err := createFromScheduledSparkApplication(ctx, args[0], kubeClient, crdClient); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err) fmt.Fprintf(os.Stderr, "%v\n", err)
} }
} else { } else {
if err := createFromYaml(args[0], kubeClient, crdClient); err != nil { if err := createFromYaml(ctx, args[0], kubeClient, crdClient); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err) fmt.Fprintf(os.Stderr, "%v\n", err)
} }
} }
@ -114,20 +116,20 @@ func init() {
"the name of ScheduledSparkApplication from which a forced SparkApplication run is created") "the name of ScheduledSparkApplication from which a forced SparkApplication run is created")
} }
func createFromYaml(yamlFile string, kubeClient clientset.Interface, crdClient crdclientset.Interface) error { func createFromYaml(ctx context.Context, yamlFile string, kubeClient clientset.Interface, crdClient crdclientset.Interface) error {
app, err := loadFromYAML(yamlFile) app, err := loadFromYAML(yamlFile)
if err != nil { if err != nil {
return fmt.Errorf("failed to read a SparkApplication from %s: %v", yamlFile, err) return fmt.Errorf("failed to read a SparkApplication from %s: %v", yamlFile, err)
} }
if err := createSparkApplication(app, kubeClient, crdClient); err != nil { if err := createSparkApplication(ctx, app, kubeClient, crdClient); err != nil {
return fmt.Errorf("failed to create SparkApplication %s: %v", app.Name, err) return fmt.Errorf("failed to create SparkApplication %s: %v", app.Name, err)
} }
return nil return nil
} }
func createFromScheduledSparkApplication(name string, kubeClient clientset.Interface, crdClient crdclientset.Interface) error { func createFromScheduledSparkApplication(ctx context.Context, name string, kubeClient clientset.Interface, crdClient crdclientset.Interface) error {
sapp, err := crdClient.SparkoperatorV1beta2().ScheduledSparkApplications(Namespace).Get(context.TODO(), From, metav1.GetOptions{}) sapp, err := crdClient.SparkoperatorV1beta2().ScheduledSparkApplications(Namespace).Get(context.TODO(), From, metav1.GetOptions{})
if err != nil { if err != nil {
return fmt.Errorf("failed to get ScheduledSparkApplication %s: %v", From, err) return fmt.Errorf("failed to get ScheduledSparkApplication %s: %v", From, err)
@ -149,14 +151,14 @@ func createFromScheduledSparkApplication(name string, kubeClient clientset.Inter
Spec: *sapp.Spec.Template.DeepCopy(), Spec: *sapp.Spec.Template.DeepCopy(),
} }
if err := createSparkApplication(app, kubeClient, crdClient); err != nil { if err := createSparkApplication(ctx, app, kubeClient, crdClient); err != nil {
return fmt.Errorf("failed to create SparkApplication %s: %v", app.Name, err) return fmt.Errorf("failed to create SparkApplication %s: %v", app.Name, err)
} }
return nil return nil
} }
func createSparkApplication(app *v1beta2.SparkApplication, kubeClient clientset.Interface, crdClient crdclientset.Interface) error { func createSparkApplication(ctx context.Context, app *v1beta2.SparkApplication, kubeClient clientset.Interface, crdClient crdclientset.Interface) error {
if DeleteIfExists { if DeleteIfExists {
if err := deleteSparkApplication(app.Name, crdClient); err != nil { if err := deleteSparkApplication(app.Name, crdClient); err != nil {
return err return err
@ -190,7 +192,7 @@ func createSparkApplication(app *v1beta2.SparkApplication, kubeClient clientset.
fmt.Printf("SparkApplication \"%s\" created\n", app.Name) fmt.Printf("SparkApplication \"%s\" created\n", app.Name)
if LogsEnabled { if LogsEnabled {
if err := doLog(app.Name, true, kubeClient, crdClient); err != nil { if err := doLog(ctx, app.Name, true, kubeClient, crdClient); err != nil {
return nil return nil
} }
} }

View File

@ -39,6 +39,8 @@ var logCommand = &cobra.Command{
Short: "log is a sub-command of sparkctl that fetches logs of a Spark application.", Short: "log is a sub-command of sparkctl that fetches logs of a Spark application.",
Long: ``, Long: ``,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
if len(args) != 1 { if len(args) != 1 {
fmt.Fprintln(os.Stderr, "must specify a SparkApplication name") fmt.Fprintln(os.Stderr, "must specify a SparkApplication name")
return return
@ -56,7 +58,7 @@ var logCommand = &cobra.Command{
return return
} }
if err := doLog(args[0], FollowLogs, kubeClientset, crdClientset); err != nil { if err := doLog(ctx, args[0], FollowLogs, kubeClientset, crdClientset); err != nil {
fmt.Fprintf(os.Stderr, "failed to get driver logs of SparkApplication %s: %v\n", args[0], err) fmt.Fprintf(os.Stderr, "failed to get driver logs of SparkApplication %s: %v\n", args[0], err)
} }
}, },
@ -69,13 +71,14 @@ func init() {
} }
func doLog( func doLog(
ctx context.Context,
name string, name string,
followLogs bool, followLogs bool,
kubeClient clientset.Interface, kubeClient clientset.Interface,
crdClient crdclientset.Interface) error { crdClient crdclientset.Interface) error {
timeout := 30 * time.Second timeout := 30 * time.Second
podNameChannel := getPodNameChannel(name, crdClient) podNameChannel := getPodNameChannel(ctx, name, crdClient)
var podName string var podName string
select { select {
@ -84,7 +87,7 @@ func doLog(
return fmt.Errorf("not found pod name") return fmt.Errorf("not found pod name")
} }
waitLogsChannel := waitForLogsFromPodChannel(podName, kubeClient, crdClient) waitLogsChannel := waitForLogsFromPodChannel(ctx, podName, kubeClient, crdClient)
select { select {
case <-waitLogsChannel: case <-waitLogsChannel:
@ -93,19 +96,20 @@ func doLog(
} }
if followLogs { if followLogs {
return streamLogs(os.Stdout, kubeClient, podName) return streamLogs(ctx, os.Stdout, kubeClient, podName)
} }
return printLogs(os.Stdout, kubeClient, podName) return printLogs(ctx, os.Stdout, kubeClient, podName)
} }
func getPodNameChannel( func getPodNameChannel(
ctx context.Context,
sparkApplicationName string, sparkApplicationName string,
crdClient crdclientset.Interface) chan string { crdClient crdclientset.Interface) chan string {
channel := make(chan string, 1) channel := make(chan string, 1)
go func() { go func() {
for { for {
app, _ := crdClient.SparkoperatorV1beta2().SparkApplications(Namespace).Get( app, _ := crdClient.SparkoperatorV1beta2().SparkApplications(Namespace).Get(
context.TODO(), ctx,
sparkApplicationName, sparkApplicationName,
metav1.GetOptions{}) metav1.GetOptions{})
@ -119,13 +123,14 @@ func getPodNameChannel(
} }
func waitForLogsFromPodChannel( func waitForLogsFromPodChannel(
ctx context.Context,
podName string, podName string,
kubeClient clientset.Interface, kubeClient clientset.Interface,
_ crdclientset.Interface) chan bool { _ crdclientset.Interface) chan bool {
channel := make(chan bool, 1) channel := make(chan bool, 1)
go func() { go func() {
for { for {
_, err := kubeClient.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{}).Do(context.TODO()).Raw() _, err := kubeClient.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{}).Do(ctx).Raw()
if err == nil { if err == nil {
channel <- true channel <- true
@ -137,8 +142,8 @@ func waitForLogsFromPodChannel(
} }
// printLogs is a one time operation that prints the fetched logs of the given pod. // printLogs is a one time operation that prints the fetched logs of the given pod.
func printLogs(out io.Writer, kubeClientset clientset.Interface, podName string) error { func printLogs(ctx context.Context, out io.Writer, kubeClientset clientset.Interface, podName string) error {
rawLogs, err := kubeClientset.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{}).Do(context.TODO()).Raw() rawLogs, err := kubeClientset.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{}).Do(ctx).Raw()
if err != nil { if err != nil {
return err return err
} }
@ -147,9 +152,9 @@ func printLogs(out io.Writer, kubeClientset clientset.Interface, podName string)
} }
// streamLogs streams the logs of the given pod until there are no more logs available. // streamLogs streams the logs of the given pod until there are no more logs available.
func streamLogs(out io.Writer, kubeClientset clientset.Interface, podName string) error { func streamLogs(ctx context.Context, out io.Writer, kubeClientset clientset.Interface, podName string) error {
request := kubeClientset.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{Follow: true}) request := kubeClientset.CoreV1().Pods(Namespace).GetLogs(podName, &corev1.PodLogOptions{Follow: true})
reader, err := request.Stream(context.TODO()) reader, err := request.Stream(ctx)
if err != nil { if err != nil {
return err return err
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
package app package app
import ( import (
"context"
"fmt" "fmt"
"os" "os"
@ -52,7 +53,9 @@ func init() {
} }
func Execute() { func Execute() {
if err := rootCmd.Execute(); err != nil { ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := rootCmd.ExecuteContext(ctx); err != nil {
fmt.Fprintf(os.Stderr, "%v", err) fmt.Fprintf(os.Stderr, "%v", err)
} }
} }