chore(v2): parameter passing. Fixes #6151 (#6208)

* chore(v2): publish output parameters

* chore(v2): resolve input parameters from producer task

* chore(v2): add implicit parameter dependencies to DAG tasks

* test(v2): add verification for v2 tests

* fix compiler unit test
This commit is contained in:
Yuan (Bob) Gong 2021-08-06 09:25:50 +08:00 committed by GitHub
parent 2f843f95aa
commit 25958081e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 397 additions and 44 deletions

View File

@ -22,8 +22,6 @@
# The field `path` corresponds to the test's python module path
# e.g. if folder path is `samples/test/fail_test.py`, then module path is
# `samples.test.fail_test`.
- name: hello_world
path: samples.v2.hello_world_test
- name: condition
path: samples.core.condition.condition_test
- name: nested_condition
@ -84,3 +82,9 @@
# TODO(Bobgy): This is currently passing, should it fail?
# - name: fail_parameter_value_missing
# path: samples.test.fail_parameter_value_missing_test
# v2 samples
- name: hello_world
path: samples.v2.hello_world_test
- name: producer_consumer_param
path: samples.v2.producer_consumer_param_test

View File

@ -21,7 +21,8 @@ import kfp
import kfp_server_api
from .hello_world import pipeline_hello_world
from ..test.util import run_pipeline_func, TestCase, KfpMlmdClient
from ..test.util import KfpTask, TaskInputs, TaskOutputs, run_pipeline_func, TestCase, KfpMlmdClient
from ml_metadata.proto import Execution
def verify(run: kfp_server_api.ApiRun, mlmd_connection_config, **kwargs):
@ -31,6 +32,21 @@ def verify(run: kfp_server_api.ApiRun, mlmd_connection_config, **kwargs):
client = KfpMlmdClient(mlmd_connection_config=mlmd_connection_config)
tasks = client.get_tasks(run_id=run.id)
pprint(tasks)
t.assertEqual(
{
'hello-world':
KfpTask(
name='hello-world',
type='system.ContainerExecution',
state=Execution.State.COMPLETE,
inputs=TaskInputs(
parameters={'text': 'hi there'}, artifacts=[]
),
outputs=TaskOutputs(parameters={}, artifacts=[])
)
},
tasks,
)
if __name__ == '__main__':

View File

@ -0,0 +1,76 @@
# Copyright 2021 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Hello world v2 engine pipeline."""
from __future__ import annotations
import unittest
from pprint import pprint
import kfp
import kfp_server_api
from .producer_consumer_param import producer_consumer_param_pipeline
from ..test.util import KfpTask, TaskInputs, TaskOutputs, run_pipeline_func, TestCase, KfpMlmdClient
from ml_metadata.proto import Execution
def verify(run: kfp_server_api.ApiRun, mlmd_connection_config, **kwargs):
t = unittest.TestCase()
t.maxDiff = None # we always want to see full diff
t.assertEqual(run.status, 'Succeeded')
client = KfpMlmdClient(mlmd_connection_config=mlmd_connection_config)
tasks = client.get_tasks(run_id=run.id)
pprint(tasks)
t.assertEqual({
'consumer':
KfpTask(
name='consumer',
type='system.ContainerExecution',
state=Execution.State.COMPLETE,
inputs=TaskInputs(
parameters={
'input_value':
'Hello world, this is an output parameter\n'
},
artifacts=[]
),
outputs=TaskOutputs(parameters={}, artifacts=[])
),
'producer':
KfpTask(
name='producer',
type='system.ContainerExecution',
state=Execution.State.COMPLETE,
inputs=TaskInputs(
parameters={'input_text': 'Hello world'}, artifacts=[]
),
outputs=TaskOutputs(
parameters={
'output_value':
'Hello world, this is an output parameter\n'
},
artifacts=[]
)
)
}, tasks)
if __name__ == '__main__':
run_pipeline_func([
TestCase(
pipeline_func=producer_consumer_param_pipeline,
verify_func=verify,
mode=kfp.dsl.PipelineExecutionMode.V2_ENGINE,
),
])

View File

@ -27,6 +27,7 @@ var (
copy = flag.String("copy", "", "copy this binary to specified destination path")
executionID = flag.Int64("execution_id", 0, "Execution ID of this task.")
executorInputJSON = flag.String("executor_input", "", "The JSON-encoded ExecutorInput.")
componentSpecJSON = flag.String("component_spec", "", "The JSON-encoded ComponentSpec.")
namespace = flag.String("namespace", "", "The Kubernetes namespace this Pod belongs to.")
podName = flag.String("pod_name", "", "Kubernetes Pod name.")
podUID = flag.String("pod_uid", "", "Kubernetes Pod UID.")
@ -61,7 +62,7 @@ func run() error {
MLMDServerAddress: *mlmdServerAddress,
MLMDServerPort: *mlmdServerPort,
}
launcher, err := component.NewLauncherV2(*executionID, *executorInputJSON, flag.Args(), opts)
launcher, err := component.NewLauncherV2(*executionID, *executorInputJSON, *componentSpecJSON, flag.Args(), opts)
if err != nil {
return err
}

View File

@ -133,10 +133,12 @@ func workflowParameter(name string) string {
return fmt.Sprintf("{{workflow.parameters.%s}}", name)
}
// In a container template, refer to inputs to the template.
func inputValue(parameter string) string {
return fmt.Sprintf("{{inputs.parameters.%s}}", parameter)
}
// In a DAG/steps template, refer to inputs to the parent template.
func inputParameter(parameter string) string {
return fmt.Sprintf("{{inputs.parameters.%s}}", parameter)
}

View File

@ -103,6 +103,8 @@ func Test_argo_compiler(t *testing.T) {
- '{{inputs.parameters.execution-id}}'
- --executor_input
- '{{inputs.parameters.executor-input}}'
- --component_spec
- '{{inputs.parameters.component}}'
- --namespace
- $(KFP_NAMESPACE)
- --pod_name
@ -153,6 +155,7 @@ func Test_argo_compiler(t *testing.T) {
parameters:
- name: executor-input
- name: execution-id
- name: component
metadata: {}
name: comp-hello-world-container
outputs: {}
@ -164,7 +167,7 @@ func Test_argo_compiler(t *testing.T) {
- arguments:
parameters:
- name: component
value: '{"inputDefinitions":{"parameters":{"text":{"type":"STRING"}}},"executorLabel":"exec-hello-world"}'
value: '{{inputs.parameters.component}}'
- name: task
value: '{{inputs.parameters.task}}'
- name: dag-context-id
@ -179,6 +182,8 @@ func Test_argo_compiler(t *testing.T) {
value: '{{tasks.driver.outputs.parameters.executor-input}}'
- name: execution-id
value: '{{tasks.driver.outputs.parameters.execution-id}}'
- name: component
value: '{{inputs.parameters.component}}'
dependencies:
- driver
name: container
@ -188,6 +193,8 @@ func Test_argo_compiler(t *testing.T) {
- name: task
- name: dag-context-id
- name: dag-execution-id
- default: '{"inputDefinitions":{"parameters":{"text":{"type":"STRING"}}},"executorLabel":"exec-hello-world"}'
name: component
metadata: {}
name: comp-hello-world
outputs: {}

View File

@ -23,7 +23,13 @@ func (c *workflowCompiler) Container(name string, component *pipelinespec.Compon
if err != nil {
return fmt.Errorf("workflowCompiler.Container: marlshaling component spec to proto JSON failed: %w", err)
}
driverTask, driverOutputs := c.containerDriverTask("driver", componentJson, inputParameter(paramTask), inputParameter(paramDAGContextID), inputParameter(paramDAGExecutionID))
driverTask, driverOutputs := c.containerDriverTask(
"driver",
inputParameter(paramComponent),
inputParameter(paramTask),
inputParameter(paramDAGContextID),
inputParameter(paramDAGExecutionID),
)
if err != nil {
return err
}
@ -39,6 +45,8 @@ func (c *workflowCompiler) Container(name string, component *pipelinespec.Compon
{Name: paramTask},
{Name: paramDAGContextID},
{Name: paramDAGExecutionID},
// TODO(Bobgy): reuse the entire 2-step container template
{Name: paramComponent, Default: wfapi.AnyStringPtr(componentJson)},
},
},
DAG: &wfapi.DAGTemplate{
@ -51,6 +59,9 @@ func (c *workflowCompiler) Container(name string, component *pipelinespec.Compon
}, {
Name: paramExecutionID,
Value: wfapi.AnyStringPtr(driverOutputs.executionID),
}, {
Name: paramComponent,
Value: wfapi.AnyStringPtr(inputParameter(paramComponent)),
}},
}},
},
@ -136,6 +147,7 @@ func containerExecutorTemplate(container *pipelinespec.PipelineDeploymentConfig_
volumePathKFPLauncher + "/launch",
"--execution_id", inputValue(paramExecutionID),
"--executor_input", inputValue(paramExecutorInput),
"--component_spec", inputValue(paramComponent),
"--namespace",
"$(KFP_NAMESPACE)",
"--pod_name",
@ -154,6 +166,7 @@ func containerExecutorTemplate(container *pipelinespec.PipelineDeploymentConfig_
Parameters: []wfapi.Parameter{
{Name: paramExecutorInput},
{Name: paramExecutionID},
{Name: paramComponent},
},
},
Volumes: []k8score.Volume{{

View File

@ -13,6 +13,10 @@ func (c *workflowCompiler) DAG(name string, componentSpec *pipelinespec.Componen
if name != "root" {
return fmt.Errorf("SubDAG not implemented yet")
}
err := addImplicitDependencies(dagSpec)
if err != nil {
return err
}
dag := &wfapi.Template{
Inputs: wfapi.Inputs{
Parameters: []wfapi.Parameter{
@ -29,8 +33,9 @@ func (c *workflowCompiler) DAG(name string, componentSpec *pipelinespec.Componen
return fmt.Errorf("DAG: marshaling task spec to proto JSON failed: %w", err)
}
dag.DAG.Tasks = append(dag.DAG.Tasks, wfapi.DAGTask{
Name: kfpTask.GetTaskInfo().GetName(),
Template: c.templateName(kfpTask.GetComponentRef().GetName()),
Name: kfpTask.GetTaskInfo().GetName(),
Template: c.templateName(kfpTask.GetComponentRef().GetName()),
Dependencies: kfpTask.GetDependentTasks(),
Arguments: wfapi.Arguments{
Parameters: []wfapi.Parameter{
{
@ -172,3 +177,40 @@ func (c *workflowCompiler) addDAGDriverTemplate() string {
c.wf.Spec.Templates = append(c.wf.Spec.Templates, *t)
return name
}
func addImplicitDependencies(dagSpec *pipelinespec.DagSpec) error {
for _, task := range dagSpec.GetTasks() {
// TODO(Bobgy): add implicit dependencies introduced by artifacts
for _, input := range task.GetInputs().GetParameters() {
wrap := func(err error) error {
return fmt.Errorf("failed to add implicit deps: %w", err)
}
switch input.Kind.(type) {
case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter:
producer := input.GetTaskOutputParameter().GetProducerTask()
_, ok := dagSpec.GetTasks()[producer]
if !ok {
return wrap(fmt.Errorf("unknown producer task %q in DAG", producer))
}
if task.DependentTasks == nil {
task.DependentTasks = make([]string, 0)
}
// add the dependency if it's not already added
found := false
for _, dep := range task.DependentTasks {
if dep == producer {
found = true
}
}
if !found {
task.DependentTasks = append(task.DependentTasks, producer)
}
case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskFinalStatus_:
return wrap(fmt.Errorf("task final status not supported yet"))
default:
// other input types do not introduce implicit dependencies
}
}
}
return nil
}

View File

@ -460,8 +460,8 @@ func execute(ctx context.Context, executorInput *pipelinespec.ExecutorInput, cmd
return nil, err
}
// Collect outputs
return getExecutorOutput()
// Collect outputs from output metadata file.
return getExecutorOutputFile()
}
func (l *Launcher) publish(ctx context.Context, executorInput *pipelinespec.ExecutorInput, executorOutput *pipelinespec.ExecutorOutput, execution *metadata.Execution) error {
@ -802,7 +802,8 @@ func mergeRuntimeArtifacts(src, dst *pipelinespec.RuntimeArtifact) {
}
}
func getExecutorOutput() (*pipelinespec.ExecutorOutput, error) {
func getExecutorOutputFile() (*pipelinespec.ExecutorOutput, error) {
// collect user executor output file
executorOutput := &pipelinespec.ExecutorOutput{
Parameters: map[string]*pipelinespec.Value{},
Artifacts: map[string]*pipelinespec.ArtifactList{},

View File

@ -3,12 +3,16 @@ package component
import (
"context"
"fmt"
"io/ioutil"
"strconv"
"strings"
"github.com/golang/glog"
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"github.com/kubeflow/pipelines/v2/metadata"
"github.com/kubeflow/pipelines/v2/objectstore"
pb "github.com/kubeflow/pipelines/v2/third_party/ml_metadata"
"gocloud.dev/blob"
"google.golang.org/protobuf/encoding/protojson"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
@ -26,6 +30,7 @@ type LauncherV2Options struct {
type LauncherV2 struct {
executionID int64
executorInput *pipelinespec.ExecutorInput
component *pipelinespec.ComponentSpec
command string
args []string
options LauncherV2Options
@ -35,7 +40,7 @@ type LauncherV2 struct {
k8sClient *kubernetes.Clientset
}
func NewLauncherV2(executionID int64, executorInputJSON string, cmdArgs []string, opts *LauncherV2Options) (l *LauncherV2, err error) {
func NewLauncherV2(executionID int64, executorInputJSON, componentSpecJSON string, cmdArgs []string, opts *LauncherV2Options) (l *LauncherV2, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("failed to create component launcher v2: %w", err)
@ -49,6 +54,11 @@ func NewLauncherV2(executionID int64, executorInputJSON string, cmdArgs []string
if err != nil {
return nil, fmt.Errorf("failed to unmarshal executor input: %w", err)
}
component := &pipelinespec.ComponentSpec{}
err = protojson.Unmarshal([]byte(componentSpecJSON), component)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal component spec: %w", err)
}
if len(cmdArgs) == 0 {
return nil, fmt.Errorf("command and arguments are empty")
}
@ -76,9 +86,13 @@ func NewLauncherV2(executionID int64, executorInputJSON string, cmdArgs []string
if err != nil {
return nil, err
}
if err = addOutputs(executorInput, component.GetOutputDefinitions()); err != nil {
return nil, err
}
return &LauncherV2{
executionID: executionID,
executorInput: executorInput,
component: component,
command: cmdArgs[0],
args: cmdArgs[1:],
options: *opts,
@ -105,11 +119,11 @@ func (l *LauncherV2) Execute(ctx context.Context) (err error) {
if err != nil {
return err
}
_, err = execute(ctx, l.executorInput, l.command, l.args, bucket, bucketConfig)
executorOutput, err := executeV2(ctx, l.executorInput, l.component, l.command, l.args, bucket, bucketConfig)
if err != nil {
return err
}
return l.publish(ctx, execution)
return l.publish(ctx, execution, executorOutput)
}
func (o *LauncherV2Options) validate() error {
@ -153,18 +167,94 @@ func (l *LauncherV2) prePublish(ctx context.Context) (execution *metadata.Execut
return l.metadataClient.PrePublishExecution(ctx, execution, ecfg)
}
func (l *LauncherV2) publish(ctx context.Context, execution *metadata.Execution) (err error) {
func (l *LauncherV2) publish(ctx context.Context, execution *metadata.Execution, executorOutput *pipelinespec.ExecutorOutput) (err error) {
defer func() {
if err != nil {
err = fmt.Errorf("failed to publish results to ML Metadata: %w", err)
}
}()
// TODO(Bobgy): read output parameters from local path, and add them to executorOutput.
outputParameters, err := metadata.NewParameters(executorOutput.GetParameters())
if err != nil {
return err
}
// TODO(Bobgy): upload output artifacts.
// TODO(Bobgy): when adding artifacts, we will need execution.pipeline to be non-nil, because we need
// to publish output artifacts to the context too.
if err := l.metadataClient.PublishExecution(ctx, execution, nil, nil, pb.Execution_COMPLETE); err != nil {
return fmt.Errorf("unable to publish execution: %w", err)
return l.metadataClient.PublishExecution(ctx, execution, outputParameters, nil, pb.Execution_COMPLETE)
}
// Add outputs info from component spec to executor input.
func addOutputs(executorInput *pipelinespec.ExecutorInput, outputs *pipelinespec.ComponentOutputsSpec) error {
if executorInput == nil {
return fmt.Errorf("cannot add outputs to nil executor input")
}
if executorInput.Outputs == nil {
executorInput.Outputs = &pipelinespec.ExecutorInput_Outputs{}
}
if executorInput.Outputs.Parameters == nil {
executorInput.Outputs.Parameters = make(map[string]*pipelinespec.ExecutorInput_OutputParameter)
}
// TODO(Bobgy): add output artifacts
for name, _ := range outputs.GetParameters() {
executorInput.Outputs.Parameters[name] = &pipelinespec.ExecutorInput_OutputParameter{
OutputFile: fmt.Sprintf("/tmp/kfp/outputs/%s", name),
}
}
return nil
}
func executeV2(ctx context.Context, executorInput *pipelinespec.ExecutorInput, component *pipelinespec.ComponentSpec, cmd string, args []string, bucket *blob.Bucket, bucketConfig *objectstore.Config) (*pipelinespec.ExecutorOutput, error) {
executorOutput, err := execute(ctx, executorInput, cmd, args, bucket, bucketConfig)
if err != nil {
return nil, err
}
// Collect Output Parameters
//
// These are not added in execute(), because execute() is shared between v2 compatible and v2 engine launcher.
// In v2 compatible mode, we get output parameter info from runtimeInfo. In v2 engine, we get it from component spec.
// Because of the difference, we cannot put parameter collection logic in one method.
if executorOutput.Parameters == nil {
executorOutput.Parameters = make(map[string]*pipelinespec.Value)
}
outputParameters := executorOutput.GetParameters()
for name, param := range executorInput.GetOutputs().GetParameters() {
_, ok := outputParameters[name]
if ok {
// If the output parameter was already specified in output metadata file,
// we don't need to collect it from file, because output metadata file has
// the highest priority.
continue
}
paramSpec, ok := component.GetOutputDefinitions().GetParameters()[name]
if !ok {
return nil, fmt.Errorf("failed to find output parameter name=%q in component spec", name)
}
msg := func(err error) error {
return fmt.Errorf("failed to read output parameter name=%q type=%q path=%q: %w", name, paramSpec.GetType(), param.GetOutputFile(), err)
}
b, err := ioutil.ReadFile(param.GetOutputFile())
if err != nil {
return nil, msg(err)
}
switch paramSpec.GetType() {
case pipelinespec.PrimitiveType_STRING:
outputParameters[name] = metadata.StringValue(string(b))
case pipelinespec.PrimitiveType_INT:
i, err := strconv.ParseInt(strings.TrimSpace(string(b)), 10, 0)
if err != nil {
return nil, msg(err)
}
outputParameters[name] = metadata.IntValue(i)
case pipelinespec.PrimitiveType_DOUBLE:
f, err := strconv.ParseFloat(strings.TrimSpace(string(b)), 0)
if err != nil {
return nil, msg(err)
}
outputParameters[name] = metadata.DoubleValue(f)
default:
return nil, msg(fmt.Errorf("unknown type. Expected STRING, INT or DOUBLE"))
}
}
return executorOutput, nil
}

View File

@ -206,6 +206,19 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, task *pipelinespec.Pi
inputs := &pipelinespec.ExecutorInput_Inputs{
Parameters: make(map[string]*pipelinespec.Value),
}
// get executions in context on demand
var tasksCache map[string]*metadata.Execution
getDAGTasks := func() (map[string]*metadata.Execution, error) {
if tasksCache != nil {
return tasksCache, nil
}
tasks, err := mlmd.GetExecutionsInDAG(ctx, dag)
if err != nil {
return nil, err
}
tasksCache = tasks
return tasks, nil
}
if task.GetInputs() != nil {
for name, paramSpec := range task.GetInputs().Parameters {
paramError := func(err error) error {
@ -214,20 +227,55 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, task *pipelinespec.Pi
if paramSpec.GetParameterExpressionSelector() != "" {
return nil, paramError(fmt.Errorf("parameter expression selector not implemented yet"))
}
componentInput := paramSpec.GetComponentInputParameter()
if componentInput != "" {
switch t := paramSpec.Kind.(type) {
case *pipelinespec.TaskInputsSpec_InputParameterSpec_ComponentInputParameter:
componentInput := paramSpec.GetComponentInputParameter()
if componentInput == "" {
return nil, paramError(fmt.Errorf("empty component input"))
}
v, ok := inputParams[componentInput]
if !ok {
return nil, paramError(fmt.Errorf("parent DAG does not have input parameter %s", componentInput))
}
inputs.Parameters[name] = v
} else {
return nil, paramError(fmt.Errorf("parameter spec not implemented yet"))
case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter:
taskOutput := paramSpec.GetTaskOutputParameter()
if taskOutput.GetProducerTask() == "" {
return nil, paramError(fmt.Errorf("producer task is empty"))
}
if taskOutput.GetOutputParameterKey() == "" {
return nil, paramError(fmt.Errorf("output parameter key is empty"))
}
tasks, err := getDAGTasks()
if err != nil {
return nil, paramError(err)
}
producer, ok := tasks[taskOutput.GetProducerTask()]
if !ok {
return nil, paramError(fmt.Errorf("cannot find producer task %q", taskOutput.GetProducerTask()))
}
_, outputs, err := producer.GetParameters()
if err != nil {
return nil, paramError(fmt.Errorf("get producer output parameters: %w", err))
}
param, ok := outputs[taskOutput.GetOutputParameterKey()]
if !ok {
return nil, paramError(fmt.Errorf("cannot find output parameter key %q in producer task %q", taskOutput.GetOutputParameterKey(), taskOutput.GetProducerTask()))
}
inputs.Parameters[name] = param
// TODO(Bobgy): implement the following cases
// case *pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue:
// case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskFinalStatus_:
default:
return nil, paramError(fmt.Errorf("parameter spec of type %T not implemented yet", t))
}
}
if len(task.GetInputs().GetArtifacts()) > 0 {
return nil, fmt.Errorf("failed to resolve inputs: artifact inputs not implemented yet")
}
}
// TODO(Bobgy): validate executor inputs match component inputs definition
return inputs, nil
}

View File

@ -97,6 +97,27 @@ type Parameters struct {
DoubleParameters map[string]float64
}
func NewParameters(params map[string]*pipelinespec.Value) (*Parameters, error) {
result := &Parameters{
IntParameters: make(map[string]int64),
StringParameters: make(map[string]string),
DoubleParameters: make(map[string]float64),
}
for name, parameter := range params {
switch t := parameter.Value.(type) {
case *pipelinespec.Value_StringValue:
result.StringParameters[name] = parameter.GetStringValue()
case *pipelinespec.Value_IntValue:
result.IntParameters[name] = parameter.GetIntValue()
case *pipelinespec.Value_DoubleValue:
result.DoubleParameters[name] = parameter.GetDoubleValue()
default:
return nil, fmt.Errorf("failed to convert from map[string]*pipelinespec.Value to metadata.Parameters: unknown parameter type for parameter name=%q: %T", name, t)
}
}
return result, nil
}
// ExecutionConfig represents the input parameters and artifacts to an Execution.
type ExecutionConfig struct {
InputParameters *Parameters
@ -161,6 +182,13 @@ func (e *Execution) String() string {
return e.execution.String()
}
func (e *Execution) TaskName() string {
if e == nil {
return ""
}
return e.execution.GetCustomProperties()[keyTaskName].GetStringValue()
}
// GetPipeline returns the current pipeline represented by the specified
// pipeline name and run ID.
func (c *Client) GetPipeline(ctx context.Context, pipelineName, pipelineRunID, namespace, runResource string) (*Pipeline, error) {
@ -199,6 +227,11 @@ type DAG struct {
context *pb.Context
}
// identifier info for error message purposes
func (d *DAG) Info() string {
return fmt.Sprintf("DAG(executionID=%v, contextID=%v)", d.Execution.GetID(), d.context.GetId())
}
func (c *Client) GetDAG(ctx context.Context, executionID int64, contextID int64) (*DAG, error) {
dagError := func(err error) error {
return fmt.Errorf("failed to get DAG executionID=%v contextID=%v: %w", executionID, contextID, err)
@ -467,6 +500,38 @@ func (c *Client) GetExecution(ctx context.Context, id int64) (*Execution, error)
return &Execution{execution: executions[0]}, nil
}
// GetExecutionsInDAG gets all executions in the DAG context, and organize them
// into a map, keyed by task name.
func (c *Client) GetExecutionsInDAG(ctx context.Context, dag *DAG) (executionsMap map[string]*Execution, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("failed to get executions in %s: %w", dag.Info(), err)
}
}()
executionsMap = make(map[string]*Execution)
res, err := c.svc.GetExecutionsByContext(ctx, &pb.GetExecutionsByContextRequest{
ContextId: dag.context.Id,
})
if err != nil {
return nil, err
}
execs := res.GetExecutions()
for _, e := range execs {
execution := &Execution{execution: e}
taskName := execution.TaskName()
if taskName == "" {
return nil, fmt.Errorf("empty task name for execution ID: %v", execution.GetID())
}
existing, ok := executionsMap[taskName]
if ok {
// TODO(Bobgy): to support retry, we need to handle multiple tasks with the same task name.
return nil, fmt.Errorf("two tasks have the same task name %q, id1=%v id2=%v", taskName, existing.GetID(), execution.GetID())
}
executionsMap[taskName] = execution
}
return executionsMap, nil
}
// GetEventsByArtifactIDs ...
func (c *Client) GetEventsByArtifactIDs(ctx context.Context, artifactIds []int64) ([]*pb.Event, error) {
req := &pb.GetEventsByArtifactIDsRequest{ArtifactIds: artifactIds}
@ -664,11 +729,6 @@ func getOrInsertContext(ctx context.Context, svc pb.MetadataStoreServiceClient,
func GenerateExecutionConfig(executorInput *pipelinespec.ExecutorInput) (*ExecutionConfig, error) {
ecfg := &ExecutionConfig{
InputParameters: &Parameters{
IntParameters: make(map[string]int64),
StringParameters: make(map[string]string),
DoubleParameters: make(map[string]float64),
},
InputArtifactIDs: make(map[string][]int64),
}
@ -682,18 +742,11 @@ func GenerateExecutionConfig(executorInput *pipelinespec.ExecutorInput) (*Execut
}
}
for name, parameter := range executorInput.Inputs.Parameters {
switch t := parameter.Value.(type) {
case *pipelinespec.Value_StringValue:
ecfg.InputParameters.StringParameters[name] = parameter.GetStringValue()
case *pipelinespec.Value_IntValue:
ecfg.InputParameters.IntParameters[name] = parameter.GetIntValue()
case *pipelinespec.Value_DoubleValue:
ecfg.InputParameters.DoubleParameters[name] = parameter.GetDoubleValue()
default:
return nil, fmt.Errorf("unknown parameter type: %T", t)
}
parameters, err := NewParameters(executorInput.Inputs.Parameters)
if err != nil {
return nil, err
}
ecfg.InputParameters = parameters
return ecfg, nil
}

View File

@ -13,29 +13,29 @@ import (
func mlmdValueToPipelineSpecValue(v *pb.Value) (*pipelinespec.Value, error) {
switch t := v.Value.(type) {
case *pb.Value_StringValue:
return stringKFPValue(t.StringValue), nil
return StringValue(t.StringValue), nil
case *pb.Value_DoubleValue:
return doubleKFPValue(t.DoubleValue), nil
return DoubleValue(t.DoubleValue), nil
case *pb.Value_IntValue:
return intKFPValue(t.IntValue), nil
return IntValue(t.IntValue), nil
default:
return nil, fmt.Errorf("unknown value type %T", t)
}
}
func stringKFPValue(v string) *pipelinespec.Value {
func StringValue(v string) *pipelinespec.Value {
return &pipelinespec.Value{
Value: &pipelinespec.Value_StringValue{StringValue: v},
}
}
func doubleKFPValue(v float64) *pipelinespec.Value {
func DoubleValue(v float64) *pipelinespec.Value {
return &pipelinespec.Value{
Value: &pipelinespec.Value_DoubleValue{DoubleValue: v},
}
}
func intKFPValue(v int64) *pipelinespec.Value {
func IntValue(v int64) *pipelinespec.Value {
return &pipelinespec.Value{
Value: &pipelinespec.Value_IntValue{IntValue: v},
}