From 25958081e65d7976c4e523295712b0c8f155734a Mon Sep 17 00:00:00 2001 From: "Yuan (Bob) Gong" <4957653+Bobgy@users.noreply.github.com> Date: Fri, 6 Aug 2021 09:25:50 +0800 Subject: [PATCH] chore(v2): parameter passing. Fixes #6151 (#6208) * chore(v2): publish output parameters * chore(v2): resolve input parameters from producer task * chore(v2): add implicit parameter dependencies to DAG tasks * test(v2): add verification for v2 tests * fix compiler unit test --- samples/test/config.yaml | 8 +- samples/v2/hello_world_test.py | 18 +++- samples/v2/producer_consumer_param_test.py | 76 +++++++++++++++ v2/cmd/launcher-v2/main.go | 3 +- v2/compiler/argo.go | 2 + v2/compiler/argo_test.go | 9 +- v2/compiler/container.go | 15 ++- v2/compiler/dag.go | 46 ++++++++- v2/component/launcher.go | 7 +- v2/component/launcher_v2.go | 104 +++++++++++++++++++-- v2/driver/driver.go | 56 ++++++++++- v2/metadata/client.go | 85 +++++++++++++---- v2/metadata/converter.go | 12 +-- 13 files changed, 397 insertions(+), 44 deletions(-) create mode 100644 samples/v2/producer_consumer_param_test.py diff --git a/samples/test/config.yaml b/samples/test/config.yaml index 98a5a33265..1011c4f304 100644 --- a/samples/test/config.yaml +++ b/samples/test/config.yaml @@ -22,8 +22,6 @@ # The field `path` corresponds to the test's python module path # e.g. if folder path is `samples/test/fail_test.py`, then module path is # `samples.test.fail_test`. -- name: hello_world - path: samples.v2.hello_world_test - name: condition path: samples.core.condition.condition_test - name: nested_condition @@ -84,3 +82,9 @@ # TODO(Bobgy): This is currently passing, should it fail? # - name: fail_parameter_value_missing # path: samples.test.fail_parameter_value_missing_test + +# v2 samples +- name: hello_world + path: samples.v2.hello_world_test +- name: producer_consumer_param + path: samples.v2.producer_consumer_param_test diff --git a/samples/v2/hello_world_test.py b/samples/v2/hello_world_test.py index 43644013cb..a81bc789e5 100644 --- a/samples/v2/hello_world_test.py +++ b/samples/v2/hello_world_test.py @@ -21,7 +21,8 @@ import kfp import kfp_server_api from .hello_world import pipeline_hello_world -from ..test.util import run_pipeline_func, TestCase, KfpMlmdClient +from ..test.util import KfpTask, TaskInputs, TaskOutputs, run_pipeline_func, TestCase, KfpMlmdClient +from ml_metadata.proto import Execution def verify(run: kfp_server_api.ApiRun, mlmd_connection_config, **kwargs): @@ -31,6 +32,21 @@ def verify(run: kfp_server_api.ApiRun, mlmd_connection_config, **kwargs): client = KfpMlmdClient(mlmd_connection_config=mlmd_connection_config) tasks = client.get_tasks(run_id=run.id) pprint(tasks) + t.assertEqual( + { + 'hello-world': + KfpTask( + name='hello-world', + type='system.ContainerExecution', + state=Execution.State.COMPLETE, + inputs=TaskInputs( + parameters={'text': 'hi there'}, artifacts=[] + ), + outputs=TaskOutputs(parameters={}, artifacts=[]) + ) + }, + tasks, + ) if __name__ == '__main__': diff --git a/samples/v2/producer_consumer_param_test.py b/samples/v2/producer_consumer_param_test.py new file mode 100644 index 0000000000..1a3c51c4e1 --- /dev/null +++ b/samples/v2/producer_consumer_param_test.py @@ -0,0 +1,76 @@ +# Copyright 2021 The Kubeflow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Hello world v2 engine pipeline.""" + +from __future__ import annotations +import unittest +from pprint import pprint + +import kfp +import kfp_server_api + +from .producer_consumer_param import producer_consumer_param_pipeline +from ..test.util import KfpTask, TaskInputs, TaskOutputs, run_pipeline_func, TestCase, KfpMlmdClient +from ml_metadata.proto import Execution + + +def verify(run: kfp_server_api.ApiRun, mlmd_connection_config, **kwargs): + t = unittest.TestCase() + t.maxDiff = None # we always want to see full diff + t.assertEqual(run.status, 'Succeeded') + client = KfpMlmdClient(mlmd_connection_config=mlmd_connection_config) + tasks = client.get_tasks(run_id=run.id) + pprint(tasks) + t.assertEqual({ + 'consumer': + KfpTask( + name='consumer', + type='system.ContainerExecution', + state=Execution.State.COMPLETE, + inputs=TaskInputs( + parameters={ + 'input_value': + 'Hello world, this is an output parameter\n' + }, + artifacts=[] + ), + outputs=TaskOutputs(parameters={}, artifacts=[]) + ), + 'producer': + KfpTask( + name='producer', + type='system.ContainerExecution', + state=Execution.State.COMPLETE, + inputs=TaskInputs( + parameters={'input_text': 'Hello world'}, artifacts=[] + ), + outputs=TaskOutputs( + parameters={ + 'output_value': + 'Hello world, this is an output parameter\n' + }, + artifacts=[] + ) + ) + }, tasks) + + +if __name__ == '__main__': + run_pipeline_func([ + TestCase( + pipeline_func=producer_consumer_param_pipeline, + verify_func=verify, + mode=kfp.dsl.PipelineExecutionMode.V2_ENGINE, + ), + ]) diff --git a/v2/cmd/launcher-v2/main.go b/v2/cmd/launcher-v2/main.go index e39b33755b..07a5e9c53c 100644 --- a/v2/cmd/launcher-v2/main.go +++ b/v2/cmd/launcher-v2/main.go @@ -27,6 +27,7 @@ var ( copy = flag.String("copy", "", "copy this binary to specified destination path") executionID = flag.Int64("execution_id", 0, "Execution ID of this task.") executorInputJSON = flag.String("executor_input", "", "The JSON-encoded ExecutorInput.") + componentSpecJSON = flag.String("component_spec", "", "The JSON-encoded ComponentSpec.") namespace = flag.String("namespace", "", "The Kubernetes namespace this Pod belongs to.") podName = flag.String("pod_name", "", "Kubernetes Pod name.") podUID = flag.String("pod_uid", "", "Kubernetes Pod UID.") @@ -61,7 +62,7 @@ func run() error { MLMDServerAddress: *mlmdServerAddress, MLMDServerPort: *mlmdServerPort, } - launcher, err := component.NewLauncherV2(*executionID, *executorInputJSON, flag.Args(), opts) + launcher, err := component.NewLauncherV2(*executionID, *executorInputJSON, *componentSpecJSON, flag.Args(), opts) if err != nil { return err } diff --git a/v2/compiler/argo.go b/v2/compiler/argo.go index 17ead18a57..3a040d3b6b 100644 --- a/v2/compiler/argo.go +++ b/v2/compiler/argo.go @@ -133,10 +133,12 @@ func workflowParameter(name string) string { return fmt.Sprintf("{{workflow.parameters.%s}}", name) } +// In a container template, refer to inputs to the template. func inputValue(parameter string) string { return fmt.Sprintf("{{inputs.parameters.%s}}", parameter) } +// In a DAG/steps template, refer to inputs to the parent template. func inputParameter(parameter string) string { return fmt.Sprintf("{{inputs.parameters.%s}}", parameter) } diff --git a/v2/compiler/argo_test.go b/v2/compiler/argo_test.go index 05e304fc94..8e8d4ff479 100644 --- a/v2/compiler/argo_test.go +++ b/v2/compiler/argo_test.go @@ -103,6 +103,8 @@ func Test_argo_compiler(t *testing.T) { - '{{inputs.parameters.execution-id}}' - --executor_input - '{{inputs.parameters.executor-input}}' + - --component_spec + - '{{inputs.parameters.component}}' - --namespace - $(KFP_NAMESPACE) - --pod_name @@ -153,6 +155,7 @@ func Test_argo_compiler(t *testing.T) { parameters: - name: executor-input - name: execution-id + - name: component metadata: {} name: comp-hello-world-container outputs: {} @@ -164,7 +167,7 @@ func Test_argo_compiler(t *testing.T) { - arguments: parameters: - name: component - value: '{"inputDefinitions":{"parameters":{"text":{"type":"STRING"}}},"executorLabel":"exec-hello-world"}' + value: '{{inputs.parameters.component}}' - name: task value: '{{inputs.parameters.task}}' - name: dag-context-id @@ -179,6 +182,8 @@ func Test_argo_compiler(t *testing.T) { value: '{{tasks.driver.outputs.parameters.executor-input}}' - name: execution-id value: '{{tasks.driver.outputs.parameters.execution-id}}' + - name: component + value: '{{inputs.parameters.component}}' dependencies: - driver name: container @@ -188,6 +193,8 @@ func Test_argo_compiler(t *testing.T) { - name: task - name: dag-context-id - name: dag-execution-id + - default: '{"inputDefinitions":{"parameters":{"text":{"type":"STRING"}}},"executorLabel":"exec-hello-world"}' + name: component metadata: {} name: comp-hello-world outputs: {} diff --git a/v2/compiler/container.go b/v2/compiler/container.go index 3acbba71b5..6929529958 100644 --- a/v2/compiler/container.go +++ b/v2/compiler/container.go @@ -23,7 +23,13 @@ func (c *workflowCompiler) Container(name string, component *pipelinespec.Compon if err != nil { return fmt.Errorf("workflowCompiler.Container: marlshaling component spec to proto JSON failed: %w", err) } - driverTask, driverOutputs := c.containerDriverTask("driver", componentJson, inputParameter(paramTask), inputParameter(paramDAGContextID), inputParameter(paramDAGExecutionID)) + driverTask, driverOutputs := c.containerDriverTask( + "driver", + inputParameter(paramComponent), + inputParameter(paramTask), + inputParameter(paramDAGContextID), + inputParameter(paramDAGExecutionID), + ) if err != nil { return err } @@ -39,6 +45,8 @@ func (c *workflowCompiler) Container(name string, component *pipelinespec.Compon {Name: paramTask}, {Name: paramDAGContextID}, {Name: paramDAGExecutionID}, + // TODO(Bobgy): reuse the entire 2-step container template + {Name: paramComponent, Default: wfapi.AnyStringPtr(componentJson)}, }, }, DAG: &wfapi.DAGTemplate{ @@ -51,6 +59,9 @@ func (c *workflowCompiler) Container(name string, component *pipelinespec.Compon }, { Name: paramExecutionID, Value: wfapi.AnyStringPtr(driverOutputs.executionID), + }, { + Name: paramComponent, + Value: wfapi.AnyStringPtr(inputParameter(paramComponent)), }}, }}, }, @@ -136,6 +147,7 @@ func containerExecutorTemplate(container *pipelinespec.PipelineDeploymentConfig_ volumePathKFPLauncher + "/launch", "--execution_id", inputValue(paramExecutionID), "--executor_input", inputValue(paramExecutorInput), + "--component_spec", inputValue(paramComponent), "--namespace", "$(KFP_NAMESPACE)", "--pod_name", @@ -154,6 +166,7 @@ func containerExecutorTemplate(container *pipelinespec.PipelineDeploymentConfig_ Parameters: []wfapi.Parameter{ {Name: paramExecutorInput}, {Name: paramExecutionID}, + {Name: paramComponent}, }, }, Volumes: []k8score.Volume{{ diff --git a/v2/compiler/dag.go b/v2/compiler/dag.go index 22341e6e12..9793a4b873 100644 --- a/v2/compiler/dag.go +++ b/v2/compiler/dag.go @@ -13,6 +13,10 @@ func (c *workflowCompiler) DAG(name string, componentSpec *pipelinespec.Componen if name != "root" { return fmt.Errorf("SubDAG not implemented yet") } + err := addImplicitDependencies(dagSpec) + if err != nil { + return err + } dag := &wfapi.Template{ Inputs: wfapi.Inputs{ Parameters: []wfapi.Parameter{ @@ -29,8 +33,9 @@ func (c *workflowCompiler) DAG(name string, componentSpec *pipelinespec.Componen return fmt.Errorf("DAG: marshaling task spec to proto JSON failed: %w", err) } dag.DAG.Tasks = append(dag.DAG.Tasks, wfapi.DAGTask{ - Name: kfpTask.GetTaskInfo().GetName(), - Template: c.templateName(kfpTask.GetComponentRef().GetName()), + Name: kfpTask.GetTaskInfo().GetName(), + Template: c.templateName(kfpTask.GetComponentRef().GetName()), + Dependencies: kfpTask.GetDependentTasks(), Arguments: wfapi.Arguments{ Parameters: []wfapi.Parameter{ { @@ -172,3 +177,40 @@ func (c *workflowCompiler) addDAGDriverTemplate() string { c.wf.Spec.Templates = append(c.wf.Spec.Templates, *t) return name } + +func addImplicitDependencies(dagSpec *pipelinespec.DagSpec) error { + for _, task := range dagSpec.GetTasks() { + // TODO(Bobgy): add implicit dependencies introduced by artifacts + for _, input := range task.GetInputs().GetParameters() { + wrap := func(err error) error { + return fmt.Errorf("failed to add implicit deps: %w", err) + } + switch input.Kind.(type) { + case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter: + producer := input.GetTaskOutputParameter().GetProducerTask() + _, ok := dagSpec.GetTasks()[producer] + if !ok { + return wrap(fmt.Errorf("unknown producer task %q in DAG", producer)) + } + if task.DependentTasks == nil { + task.DependentTasks = make([]string, 0) + } + // add the dependency if it's not already added + found := false + for _, dep := range task.DependentTasks { + if dep == producer { + found = true + } + } + if !found { + task.DependentTasks = append(task.DependentTasks, producer) + } + case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskFinalStatus_: + return wrap(fmt.Errorf("task final status not supported yet")) + default: + // other input types do not introduce implicit dependencies + } + } + } + return nil +} diff --git a/v2/component/launcher.go b/v2/component/launcher.go index d87e704db0..763a7a7e04 100644 --- a/v2/component/launcher.go +++ b/v2/component/launcher.go @@ -460,8 +460,8 @@ func execute(ctx context.Context, executorInput *pipelinespec.ExecutorInput, cmd return nil, err } - // Collect outputs - return getExecutorOutput() + // Collect outputs from output metadata file. + return getExecutorOutputFile() } func (l *Launcher) publish(ctx context.Context, executorInput *pipelinespec.ExecutorInput, executorOutput *pipelinespec.ExecutorOutput, execution *metadata.Execution) error { @@ -802,7 +802,8 @@ func mergeRuntimeArtifacts(src, dst *pipelinespec.RuntimeArtifact) { } } -func getExecutorOutput() (*pipelinespec.ExecutorOutput, error) { +func getExecutorOutputFile() (*pipelinespec.ExecutorOutput, error) { + // collect user executor output file executorOutput := &pipelinespec.ExecutorOutput{ Parameters: map[string]*pipelinespec.Value{}, Artifacts: map[string]*pipelinespec.ArtifactList{}, diff --git a/v2/component/launcher_v2.go b/v2/component/launcher_v2.go index eb53139bfe..d547c70dee 100644 --- a/v2/component/launcher_v2.go +++ b/v2/component/launcher_v2.go @@ -3,12 +3,16 @@ package component import ( "context" "fmt" + "io/ioutil" + "strconv" + "strings" "github.com/golang/glog" "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" "github.com/kubeflow/pipelines/v2/metadata" "github.com/kubeflow/pipelines/v2/objectstore" pb "github.com/kubeflow/pipelines/v2/third_party/ml_metadata" + "gocloud.dev/blob" "google.golang.org/protobuf/encoding/protojson" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" @@ -26,6 +30,7 @@ type LauncherV2Options struct { type LauncherV2 struct { executionID int64 executorInput *pipelinespec.ExecutorInput + component *pipelinespec.ComponentSpec command string args []string options LauncherV2Options @@ -35,7 +40,7 @@ type LauncherV2 struct { k8sClient *kubernetes.Clientset } -func NewLauncherV2(executionID int64, executorInputJSON string, cmdArgs []string, opts *LauncherV2Options) (l *LauncherV2, err error) { +func NewLauncherV2(executionID int64, executorInputJSON, componentSpecJSON string, cmdArgs []string, opts *LauncherV2Options) (l *LauncherV2, err error) { defer func() { if err != nil { err = fmt.Errorf("failed to create component launcher v2: %w", err) @@ -49,6 +54,11 @@ func NewLauncherV2(executionID int64, executorInputJSON string, cmdArgs []string if err != nil { return nil, fmt.Errorf("failed to unmarshal executor input: %w", err) } + component := &pipelinespec.ComponentSpec{} + err = protojson.Unmarshal([]byte(componentSpecJSON), component) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal component spec: %w", err) + } if len(cmdArgs) == 0 { return nil, fmt.Errorf("command and arguments are empty") } @@ -76,9 +86,13 @@ func NewLauncherV2(executionID int64, executorInputJSON string, cmdArgs []string if err != nil { return nil, err } + if err = addOutputs(executorInput, component.GetOutputDefinitions()); err != nil { + return nil, err + } return &LauncherV2{ executionID: executionID, executorInput: executorInput, + component: component, command: cmdArgs[0], args: cmdArgs[1:], options: *opts, @@ -105,11 +119,11 @@ func (l *LauncherV2) Execute(ctx context.Context) (err error) { if err != nil { return err } - _, err = execute(ctx, l.executorInput, l.command, l.args, bucket, bucketConfig) + executorOutput, err := executeV2(ctx, l.executorInput, l.component, l.command, l.args, bucket, bucketConfig) if err != nil { return err } - return l.publish(ctx, execution) + return l.publish(ctx, execution, executorOutput) } func (o *LauncherV2Options) validate() error { @@ -153,18 +167,94 @@ func (l *LauncherV2) prePublish(ctx context.Context) (execution *metadata.Execut return l.metadataClient.PrePublishExecution(ctx, execution, ecfg) } -func (l *LauncherV2) publish(ctx context.Context, execution *metadata.Execution) (err error) { +func (l *LauncherV2) publish(ctx context.Context, execution *metadata.Execution, executorOutput *pipelinespec.ExecutorOutput) (err error) { defer func() { if err != nil { err = fmt.Errorf("failed to publish results to ML Metadata: %w", err) } }() - // TODO(Bobgy): read output parameters from local path, and add them to executorOutput. + outputParameters, err := metadata.NewParameters(executorOutput.GetParameters()) + if err != nil { + return err + } // TODO(Bobgy): upload output artifacts. // TODO(Bobgy): when adding artifacts, we will need execution.pipeline to be non-nil, because we need // to publish output artifacts to the context too. - if err := l.metadataClient.PublishExecution(ctx, execution, nil, nil, pb.Execution_COMPLETE); err != nil { - return fmt.Errorf("unable to publish execution: %w", err) + return l.metadataClient.PublishExecution(ctx, execution, outputParameters, nil, pb.Execution_COMPLETE) +} + +// Add outputs info from component spec to executor input. +func addOutputs(executorInput *pipelinespec.ExecutorInput, outputs *pipelinespec.ComponentOutputsSpec) error { + if executorInput == nil { + return fmt.Errorf("cannot add outputs to nil executor input") + } + if executorInput.Outputs == nil { + executorInput.Outputs = &pipelinespec.ExecutorInput_Outputs{} + } + if executorInput.Outputs.Parameters == nil { + executorInput.Outputs.Parameters = make(map[string]*pipelinespec.ExecutorInput_OutputParameter) + } + // TODO(Bobgy): add output artifacts + for name, _ := range outputs.GetParameters() { + executorInput.Outputs.Parameters[name] = &pipelinespec.ExecutorInput_OutputParameter{ + OutputFile: fmt.Sprintf("/tmp/kfp/outputs/%s", name), + } } return nil } + +func executeV2(ctx context.Context, executorInput *pipelinespec.ExecutorInput, component *pipelinespec.ComponentSpec, cmd string, args []string, bucket *blob.Bucket, bucketConfig *objectstore.Config) (*pipelinespec.ExecutorOutput, error) { + executorOutput, err := execute(ctx, executorInput, cmd, args, bucket, bucketConfig) + if err != nil { + return nil, err + } + + // Collect Output Parameters + // + // These are not added in execute(), because execute() is shared between v2 compatible and v2 engine launcher. + // In v2 compatible mode, we get output parameter info from runtimeInfo. In v2 engine, we get it from component spec. + // Because of the difference, we cannot put parameter collection logic in one method. + if executorOutput.Parameters == nil { + executorOutput.Parameters = make(map[string]*pipelinespec.Value) + } + outputParameters := executorOutput.GetParameters() + for name, param := range executorInput.GetOutputs().GetParameters() { + _, ok := outputParameters[name] + if ok { + // If the output parameter was already specified in output metadata file, + // we don't need to collect it from file, because output metadata file has + // the highest priority. + continue + } + paramSpec, ok := component.GetOutputDefinitions().GetParameters()[name] + if !ok { + return nil, fmt.Errorf("failed to find output parameter name=%q in component spec", name) + } + msg := func(err error) error { + return fmt.Errorf("failed to read output parameter name=%q type=%q path=%q: %w", name, paramSpec.GetType(), param.GetOutputFile(), err) + } + b, err := ioutil.ReadFile(param.GetOutputFile()) + if err != nil { + return nil, msg(err) + } + switch paramSpec.GetType() { + case pipelinespec.PrimitiveType_STRING: + outputParameters[name] = metadata.StringValue(string(b)) + case pipelinespec.PrimitiveType_INT: + i, err := strconv.ParseInt(strings.TrimSpace(string(b)), 10, 0) + if err != nil { + return nil, msg(err) + } + outputParameters[name] = metadata.IntValue(i) + case pipelinespec.PrimitiveType_DOUBLE: + f, err := strconv.ParseFloat(strings.TrimSpace(string(b)), 0) + if err != nil { + return nil, msg(err) + } + outputParameters[name] = metadata.DoubleValue(f) + default: + return nil, msg(fmt.Errorf("unknown type. Expected STRING, INT or DOUBLE")) + } + } + return executorOutput, nil +} diff --git a/v2/driver/driver.go b/v2/driver/driver.go index 6538fa6159..dbc1137ca8 100644 --- a/v2/driver/driver.go +++ b/v2/driver/driver.go @@ -206,6 +206,19 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, task *pipelinespec.Pi inputs := &pipelinespec.ExecutorInput_Inputs{ Parameters: make(map[string]*pipelinespec.Value), } + // get executions in context on demand + var tasksCache map[string]*metadata.Execution + getDAGTasks := func() (map[string]*metadata.Execution, error) { + if tasksCache != nil { + return tasksCache, nil + } + tasks, err := mlmd.GetExecutionsInDAG(ctx, dag) + if err != nil { + return nil, err + } + tasksCache = tasks + return tasks, nil + } if task.GetInputs() != nil { for name, paramSpec := range task.GetInputs().Parameters { paramError := func(err error) error { @@ -214,20 +227,55 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, task *pipelinespec.Pi if paramSpec.GetParameterExpressionSelector() != "" { return nil, paramError(fmt.Errorf("parameter expression selector not implemented yet")) } - componentInput := paramSpec.GetComponentInputParameter() - if componentInput != "" { + switch t := paramSpec.Kind.(type) { + case *pipelinespec.TaskInputsSpec_InputParameterSpec_ComponentInputParameter: + componentInput := paramSpec.GetComponentInputParameter() + if componentInput == "" { + return nil, paramError(fmt.Errorf("empty component input")) + } v, ok := inputParams[componentInput] if !ok { return nil, paramError(fmt.Errorf("parent DAG does not have input parameter %s", componentInput)) } inputs.Parameters[name] = v - } else { - return nil, paramError(fmt.Errorf("parameter spec not implemented yet")) + + case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter: + taskOutput := paramSpec.GetTaskOutputParameter() + if taskOutput.GetProducerTask() == "" { + return nil, paramError(fmt.Errorf("producer task is empty")) + } + if taskOutput.GetOutputParameterKey() == "" { + return nil, paramError(fmt.Errorf("output parameter key is empty")) + } + tasks, err := getDAGTasks() + if err != nil { + return nil, paramError(err) + } + producer, ok := tasks[taskOutput.GetProducerTask()] + if !ok { + return nil, paramError(fmt.Errorf("cannot find producer task %q", taskOutput.GetProducerTask())) + } + _, outputs, err := producer.GetParameters() + if err != nil { + return nil, paramError(fmt.Errorf("get producer output parameters: %w", err)) + } + param, ok := outputs[taskOutput.GetOutputParameterKey()] + if !ok { + return nil, paramError(fmt.Errorf("cannot find output parameter key %q in producer task %q", taskOutput.GetOutputParameterKey(), taskOutput.GetProducerTask())) + } + inputs.Parameters[name] = param + + // TODO(Bobgy): implement the following cases + // case *pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue: + // case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskFinalStatus_: + default: + return nil, paramError(fmt.Errorf("parameter spec of type %T not implemented yet", t)) } } if len(task.GetInputs().GetArtifacts()) > 0 { return nil, fmt.Errorf("failed to resolve inputs: artifact inputs not implemented yet") } } + // TODO(Bobgy): validate executor inputs match component inputs definition return inputs, nil } diff --git a/v2/metadata/client.go b/v2/metadata/client.go index 494e261601..40c9e65da0 100644 --- a/v2/metadata/client.go +++ b/v2/metadata/client.go @@ -97,6 +97,27 @@ type Parameters struct { DoubleParameters map[string]float64 } +func NewParameters(params map[string]*pipelinespec.Value) (*Parameters, error) { + result := &Parameters{ + IntParameters: make(map[string]int64), + StringParameters: make(map[string]string), + DoubleParameters: make(map[string]float64), + } + for name, parameter := range params { + switch t := parameter.Value.(type) { + case *pipelinespec.Value_StringValue: + result.StringParameters[name] = parameter.GetStringValue() + case *pipelinespec.Value_IntValue: + result.IntParameters[name] = parameter.GetIntValue() + case *pipelinespec.Value_DoubleValue: + result.DoubleParameters[name] = parameter.GetDoubleValue() + default: + return nil, fmt.Errorf("failed to convert from map[string]*pipelinespec.Value to metadata.Parameters: unknown parameter type for parameter name=%q: %T", name, t) + } + } + return result, nil +} + // ExecutionConfig represents the input parameters and artifacts to an Execution. type ExecutionConfig struct { InputParameters *Parameters @@ -161,6 +182,13 @@ func (e *Execution) String() string { return e.execution.String() } +func (e *Execution) TaskName() string { + if e == nil { + return "" + } + return e.execution.GetCustomProperties()[keyTaskName].GetStringValue() +} + // GetPipeline returns the current pipeline represented by the specified // pipeline name and run ID. func (c *Client) GetPipeline(ctx context.Context, pipelineName, pipelineRunID, namespace, runResource string) (*Pipeline, error) { @@ -199,6 +227,11 @@ type DAG struct { context *pb.Context } +// identifier info for error message purposes +func (d *DAG) Info() string { + return fmt.Sprintf("DAG(executionID=%v, contextID=%v)", d.Execution.GetID(), d.context.GetId()) +} + func (c *Client) GetDAG(ctx context.Context, executionID int64, contextID int64) (*DAG, error) { dagError := func(err error) error { return fmt.Errorf("failed to get DAG executionID=%v contextID=%v: %w", executionID, contextID, err) @@ -467,6 +500,38 @@ func (c *Client) GetExecution(ctx context.Context, id int64) (*Execution, error) return &Execution{execution: executions[0]}, nil } +// GetExecutionsInDAG gets all executions in the DAG context, and organize them +// into a map, keyed by task name. +func (c *Client) GetExecutionsInDAG(ctx context.Context, dag *DAG) (executionsMap map[string]*Execution, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("failed to get executions in %s: %w", dag.Info(), err) + } + }() + executionsMap = make(map[string]*Execution) + res, err := c.svc.GetExecutionsByContext(ctx, &pb.GetExecutionsByContextRequest{ + ContextId: dag.context.Id, + }) + if err != nil { + return nil, err + } + execs := res.GetExecutions() + for _, e := range execs { + execution := &Execution{execution: e} + taskName := execution.TaskName() + if taskName == "" { + return nil, fmt.Errorf("empty task name for execution ID: %v", execution.GetID()) + } + existing, ok := executionsMap[taskName] + if ok { + // TODO(Bobgy): to support retry, we need to handle multiple tasks with the same task name. + return nil, fmt.Errorf("two tasks have the same task name %q, id1=%v id2=%v", taskName, existing.GetID(), execution.GetID()) + } + executionsMap[taskName] = execution + } + return executionsMap, nil +} + // GetEventsByArtifactIDs ... func (c *Client) GetEventsByArtifactIDs(ctx context.Context, artifactIds []int64) ([]*pb.Event, error) { req := &pb.GetEventsByArtifactIDsRequest{ArtifactIds: artifactIds} @@ -664,11 +729,6 @@ func getOrInsertContext(ctx context.Context, svc pb.MetadataStoreServiceClient, func GenerateExecutionConfig(executorInput *pipelinespec.ExecutorInput) (*ExecutionConfig, error) { ecfg := &ExecutionConfig{ - InputParameters: &Parameters{ - IntParameters: make(map[string]int64), - StringParameters: make(map[string]string), - DoubleParameters: make(map[string]float64), - }, InputArtifactIDs: make(map[string][]int64), } @@ -682,18 +742,11 @@ func GenerateExecutionConfig(executorInput *pipelinespec.ExecutorInput) (*Execut } } - for name, parameter := range executorInput.Inputs.Parameters { - switch t := parameter.Value.(type) { - case *pipelinespec.Value_StringValue: - ecfg.InputParameters.StringParameters[name] = parameter.GetStringValue() - case *pipelinespec.Value_IntValue: - ecfg.InputParameters.IntParameters[name] = parameter.GetIntValue() - case *pipelinespec.Value_DoubleValue: - ecfg.InputParameters.DoubleParameters[name] = parameter.GetDoubleValue() - default: - return nil, fmt.Errorf("unknown parameter type: %T", t) - } + parameters, err := NewParameters(executorInput.Inputs.Parameters) + if err != nil { + return nil, err } + ecfg.InputParameters = parameters return ecfg, nil } diff --git a/v2/metadata/converter.go b/v2/metadata/converter.go index b239630787..953e74fe04 100644 --- a/v2/metadata/converter.go +++ b/v2/metadata/converter.go @@ -13,29 +13,29 @@ import ( func mlmdValueToPipelineSpecValue(v *pb.Value) (*pipelinespec.Value, error) { switch t := v.Value.(type) { case *pb.Value_StringValue: - return stringKFPValue(t.StringValue), nil + return StringValue(t.StringValue), nil case *pb.Value_DoubleValue: - return doubleKFPValue(t.DoubleValue), nil + return DoubleValue(t.DoubleValue), nil case *pb.Value_IntValue: - return intKFPValue(t.IntValue), nil + return IntValue(t.IntValue), nil default: return nil, fmt.Errorf("unknown value type %T", t) } } -func stringKFPValue(v string) *pipelinespec.Value { +func StringValue(v string) *pipelinespec.Value { return &pipelinespec.Value{ Value: &pipelinespec.Value_StringValue{StringValue: v}, } } -func doubleKFPValue(v float64) *pipelinespec.Value { +func DoubleValue(v float64) *pipelinespec.Value { return &pipelinespec.Value{ Value: &pipelinespec.Value_DoubleValue{DoubleValue: v}, } } -func intKFPValue(v int64) *pipelinespec.Value { +func IntValue(v int64) *pipelinespec.Value { return &pipelinespec.Value{ Value: &pipelinespec.Value_IntValue{IntValue: v}, }