* chore(v2): publish output parameters * chore(v2): resolve input parameters from producer task * chore(v2): add implicit parameter dependencies to DAG tasks * test(v2): add verification for v2 tests * fix compiler unit test
This commit is contained in:
parent
2f843f95aa
commit
25958081e6
|
|
@ -22,8 +22,6 @@
|
|||
# The field `path` corresponds to the test's python module path
|
||||
# e.g. if folder path is `samples/test/fail_test.py`, then module path is
|
||||
# `samples.test.fail_test`.
|
||||
- name: hello_world
|
||||
path: samples.v2.hello_world_test
|
||||
- name: condition
|
||||
path: samples.core.condition.condition_test
|
||||
- name: nested_condition
|
||||
|
|
@ -84,3 +82,9 @@
|
|||
# TODO(Bobgy): This is currently passing, should it fail?
|
||||
# - name: fail_parameter_value_missing
|
||||
# path: samples.test.fail_parameter_value_missing_test
|
||||
|
||||
# v2 samples
|
||||
- name: hello_world
|
||||
path: samples.v2.hello_world_test
|
||||
- name: producer_consumer_param
|
||||
path: samples.v2.producer_consumer_param_test
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ import kfp
|
|||
import kfp_server_api
|
||||
|
||||
from .hello_world import pipeline_hello_world
|
||||
from ..test.util import run_pipeline_func, TestCase, KfpMlmdClient
|
||||
from ..test.util import KfpTask, TaskInputs, TaskOutputs, run_pipeline_func, TestCase, KfpMlmdClient
|
||||
from ml_metadata.proto import Execution
|
||||
|
||||
|
||||
def verify(run: kfp_server_api.ApiRun, mlmd_connection_config, **kwargs):
|
||||
|
|
@ -31,6 +32,21 @@ def verify(run: kfp_server_api.ApiRun, mlmd_connection_config, **kwargs):
|
|||
client = KfpMlmdClient(mlmd_connection_config=mlmd_connection_config)
|
||||
tasks = client.get_tasks(run_id=run.id)
|
||||
pprint(tasks)
|
||||
t.assertEqual(
|
||||
{
|
||||
'hello-world':
|
||||
KfpTask(
|
||||
name='hello-world',
|
||||
type='system.ContainerExecution',
|
||||
state=Execution.State.COMPLETE,
|
||||
inputs=TaskInputs(
|
||||
parameters={'text': 'hi there'}, artifacts=[]
|
||||
),
|
||||
outputs=TaskOutputs(parameters={}, artifacts=[])
|
||||
)
|
||||
},
|
||||
tasks,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2021 The Kubeflow Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Hello world v2 engine pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
import unittest
|
||||
from pprint import pprint
|
||||
|
||||
import kfp
|
||||
import kfp_server_api
|
||||
|
||||
from .producer_consumer_param import producer_consumer_param_pipeline
|
||||
from ..test.util import KfpTask, TaskInputs, TaskOutputs, run_pipeline_func, TestCase, KfpMlmdClient
|
||||
from ml_metadata.proto import Execution
|
||||
|
||||
|
||||
def verify(run: kfp_server_api.ApiRun, mlmd_connection_config, **kwargs):
|
||||
t = unittest.TestCase()
|
||||
t.maxDiff = None # we always want to see full diff
|
||||
t.assertEqual(run.status, 'Succeeded')
|
||||
client = KfpMlmdClient(mlmd_connection_config=mlmd_connection_config)
|
||||
tasks = client.get_tasks(run_id=run.id)
|
||||
pprint(tasks)
|
||||
t.assertEqual({
|
||||
'consumer':
|
||||
KfpTask(
|
||||
name='consumer',
|
||||
type='system.ContainerExecution',
|
||||
state=Execution.State.COMPLETE,
|
||||
inputs=TaskInputs(
|
||||
parameters={
|
||||
'input_value':
|
||||
'Hello world, this is an output parameter\n'
|
||||
},
|
||||
artifacts=[]
|
||||
),
|
||||
outputs=TaskOutputs(parameters={}, artifacts=[])
|
||||
),
|
||||
'producer':
|
||||
KfpTask(
|
||||
name='producer',
|
||||
type='system.ContainerExecution',
|
||||
state=Execution.State.COMPLETE,
|
||||
inputs=TaskInputs(
|
||||
parameters={'input_text': 'Hello world'}, artifacts=[]
|
||||
),
|
||||
outputs=TaskOutputs(
|
||||
parameters={
|
||||
'output_value':
|
||||
'Hello world, this is an output parameter\n'
|
||||
},
|
||||
artifacts=[]
|
||||
)
|
||||
)
|
||||
}, tasks)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_pipeline_func([
|
||||
TestCase(
|
||||
pipeline_func=producer_consumer_param_pipeline,
|
||||
verify_func=verify,
|
||||
mode=kfp.dsl.PipelineExecutionMode.V2_ENGINE,
|
||||
),
|
||||
])
|
||||
|
|
@ -27,6 +27,7 @@ var (
|
|||
copy = flag.String("copy", "", "copy this binary to specified destination path")
|
||||
executionID = flag.Int64("execution_id", 0, "Execution ID of this task.")
|
||||
executorInputJSON = flag.String("executor_input", "", "The JSON-encoded ExecutorInput.")
|
||||
componentSpecJSON = flag.String("component_spec", "", "The JSON-encoded ComponentSpec.")
|
||||
namespace = flag.String("namespace", "", "The Kubernetes namespace this Pod belongs to.")
|
||||
podName = flag.String("pod_name", "", "Kubernetes Pod name.")
|
||||
podUID = flag.String("pod_uid", "", "Kubernetes Pod UID.")
|
||||
|
|
@ -61,7 +62,7 @@ func run() error {
|
|||
MLMDServerAddress: *mlmdServerAddress,
|
||||
MLMDServerPort: *mlmdServerPort,
|
||||
}
|
||||
launcher, err := component.NewLauncherV2(*executionID, *executorInputJSON, flag.Args(), opts)
|
||||
launcher, err := component.NewLauncherV2(*executionID, *executorInputJSON, *componentSpecJSON, flag.Args(), opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -133,10 +133,12 @@ func workflowParameter(name string) string {
|
|||
return fmt.Sprintf("{{workflow.parameters.%s}}", name)
|
||||
}
|
||||
|
||||
// In a container template, refer to inputs to the template.
|
||||
func inputValue(parameter string) string {
|
||||
return fmt.Sprintf("{{inputs.parameters.%s}}", parameter)
|
||||
}
|
||||
|
||||
// In a DAG/steps template, refer to inputs to the parent template.
|
||||
func inputParameter(parameter string) string {
|
||||
return fmt.Sprintf("{{inputs.parameters.%s}}", parameter)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -103,6 +103,8 @@ func Test_argo_compiler(t *testing.T) {
|
|||
- '{{inputs.parameters.execution-id}}'
|
||||
- --executor_input
|
||||
- '{{inputs.parameters.executor-input}}'
|
||||
- --component_spec
|
||||
- '{{inputs.parameters.component}}'
|
||||
- --namespace
|
||||
- $(KFP_NAMESPACE)
|
||||
- --pod_name
|
||||
|
|
@ -153,6 +155,7 @@ func Test_argo_compiler(t *testing.T) {
|
|||
parameters:
|
||||
- name: executor-input
|
||||
- name: execution-id
|
||||
- name: component
|
||||
metadata: {}
|
||||
name: comp-hello-world-container
|
||||
outputs: {}
|
||||
|
|
@ -164,7 +167,7 @@ func Test_argo_compiler(t *testing.T) {
|
|||
- arguments:
|
||||
parameters:
|
||||
- name: component
|
||||
value: '{"inputDefinitions":{"parameters":{"text":{"type":"STRING"}}},"executorLabel":"exec-hello-world"}'
|
||||
value: '{{inputs.parameters.component}}'
|
||||
- name: task
|
||||
value: '{{inputs.parameters.task}}'
|
||||
- name: dag-context-id
|
||||
|
|
@ -179,6 +182,8 @@ func Test_argo_compiler(t *testing.T) {
|
|||
value: '{{tasks.driver.outputs.parameters.executor-input}}'
|
||||
- name: execution-id
|
||||
value: '{{tasks.driver.outputs.parameters.execution-id}}'
|
||||
- name: component
|
||||
value: '{{inputs.parameters.component}}'
|
||||
dependencies:
|
||||
- driver
|
||||
name: container
|
||||
|
|
@ -188,6 +193,8 @@ func Test_argo_compiler(t *testing.T) {
|
|||
- name: task
|
||||
- name: dag-context-id
|
||||
- name: dag-execution-id
|
||||
- default: '{"inputDefinitions":{"parameters":{"text":{"type":"STRING"}}},"executorLabel":"exec-hello-world"}'
|
||||
name: component
|
||||
metadata: {}
|
||||
name: comp-hello-world
|
||||
outputs: {}
|
||||
|
|
|
|||
|
|
@ -23,7 +23,13 @@ func (c *workflowCompiler) Container(name string, component *pipelinespec.Compon
|
|||
if err != nil {
|
||||
return fmt.Errorf("workflowCompiler.Container: marlshaling component spec to proto JSON failed: %w", err)
|
||||
}
|
||||
driverTask, driverOutputs := c.containerDriverTask("driver", componentJson, inputParameter(paramTask), inputParameter(paramDAGContextID), inputParameter(paramDAGExecutionID))
|
||||
driverTask, driverOutputs := c.containerDriverTask(
|
||||
"driver",
|
||||
inputParameter(paramComponent),
|
||||
inputParameter(paramTask),
|
||||
inputParameter(paramDAGContextID),
|
||||
inputParameter(paramDAGExecutionID),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -39,6 +45,8 @@ func (c *workflowCompiler) Container(name string, component *pipelinespec.Compon
|
|||
{Name: paramTask},
|
||||
{Name: paramDAGContextID},
|
||||
{Name: paramDAGExecutionID},
|
||||
// TODO(Bobgy): reuse the entire 2-step container template
|
||||
{Name: paramComponent, Default: wfapi.AnyStringPtr(componentJson)},
|
||||
},
|
||||
},
|
||||
DAG: &wfapi.DAGTemplate{
|
||||
|
|
@ -51,6 +59,9 @@ func (c *workflowCompiler) Container(name string, component *pipelinespec.Compon
|
|||
}, {
|
||||
Name: paramExecutionID,
|
||||
Value: wfapi.AnyStringPtr(driverOutputs.executionID),
|
||||
}, {
|
||||
Name: paramComponent,
|
||||
Value: wfapi.AnyStringPtr(inputParameter(paramComponent)),
|
||||
}},
|
||||
}},
|
||||
},
|
||||
|
|
@ -136,6 +147,7 @@ func containerExecutorTemplate(container *pipelinespec.PipelineDeploymentConfig_
|
|||
volumePathKFPLauncher + "/launch",
|
||||
"--execution_id", inputValue(paramExecutionID),
|
||||
"--executor_input", inputValue(paramExecutorInput),
|
||||
"--component_spec", inputValue(paramComponent),
|
||||
"--namespace",
|
||||
"$(KFP_NAMESPACE)",
|
||||
"--pod_name",
|
||||
|
|
@ -154,6 +166,7 @@ func containerExecutorTemplate(container *pipelinespec.PipelineDeploymentConfig_
|
|||
Parameters: []wfapi.Parameter{
|
||||
{Name: paramExecutorInput},
|
||||
{Name: paramExecutionID},
|
||||
{Name: paramComponent},
|
||||
},
|
||||
},
|
||||
Volumes: []k8score.Volume{{
|
||||
|
|
|
|||
|
|
@ -13,6 +13,10 @@ func (c *workflowCompiler) DAG(name string, componentSpec *pipelinespec.Componen
|
|||
if name != "root" {
|
||||
return fmt.Errorf("SubDAG not implemented yet")
|
||||
}
|
||||
err := addImplicitDependencies(dagSpec)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dag := &wfapi.Template{
|
||||
Inputs: wfapi.Inputs{
|
||||
Parameters: []wfapi.Parameter{
|
||||
|
|
@ -29,8 +33,9 @@ func (c *workflowCompiler) DAG(name string, componentSpec *pipelinespec.Componen
|
|||
return fmt.Errorf("DAG: marshaling task spec to proto JSON failed: %w", err)
|
||||
}
|
||||
dag.DAG.Tasks = append(dag.DAG.Tasks, wfapi.DAGTask{
|
||||
Name: kfpTask.GetTaskInfo().GetName(),
|
||||
Template: c.templateName(kfpTask.GetComponentRef().GetName()),
|
||||
Name: kfpTask.GetTaskInfo().GetName(),
|
||||
Template: c.templateName(kfpTask.GetComponentRef().GetName()),
|
||||
Dependencies: kfpTask.GetDependentTasks(),
|
||||
Arguments: wfapi.Arguments{
|
||||
Parameters: []wfapi.Parameter{
|
||||
{
|
||||
|
|
@ -172,3 +177,40 @@ func (c *workflowCompiler) addDAGDriverTemplate() string {
|
|||
c.wf.Spec.Templates = append(c.wf.Spec.Templates, *t)
|
||||
return name
|
||||
}
|
||||
|
||||
func addImplicitDependencies(dagSpec *pipelinespec.DagSpec) error {
|
||||
for _, task := range dagSpec.GetTasks() {
|
||||
// TODO(Bobgy): add implicit dependencies introduced by artifacts
|
||||
for _, input := range task.GetInputs().GetParameters() {
|
||||
wrap := func(err error) error {
|
||||
return fmt.Errorf("failed to add implicit deps: %w", err)
|
||||
}
|
||||
switch input.Kind.(type) {
|
||||
case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter:
|
||||
producer := input.GetTaskOutputParameter().GetProducerTask()
|
||||
_, ok := dagSpec.GetTasks()[producer]
|
||||
if !ok {
|
||||
return wrap(fmt.Errorf("unknown producer task %q in DAG", producer))
|
||||
}
|
||||
if task.DependentTasks == nil {
|
||||
task.DependentTasks = make([]string, 0)
|
||||
}
|
||||
// add the dependency if it's not already added
|
||||
found := false
|
||||
for _, dep := range task.DependentTasks {
|
||||
if dep == producer {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
task.DependentTasks = append(task.DependentTasks, producer)
|
||||
}
|
||||
case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskFinalStatus_:
|
||||
return wrap(fmt.Errorf("task final status not supported yet"))
|
||||
default:
|
||||
// other input types do not introduce implicit dependencies
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -460,8 +460,8 @@ func execute(ctx context.Context, executorInput *pipelinespec.ExecutorInput, cmd
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Collect outputs
|
||||
return getExecutorOutput()
|
||||
// Collect outputs from output metadata file.
|
||||
return getExecutorOutputFile()
|
||||
}
|
||||
|
||||
func (l *Launcher) publish(ctx context.Context, executorInput *pipelinespec.ExecutorInput, executorOutput *pipelinespec.ExecutorOutput, execution *metadata.Execution) error {
|
||||
|
|
@ -802,7 +802,8 @@ func mergeRuntimeArtifacts(src, dst *pipelinespec.RuntimeArtifact) {
|
|||
}
|
||||
}
|
||||
|
||||
func getExecutorOutput() (*pipelinespec.ExecutorOutput, error) {
|
||||
func getExecutorOutputFile() (*pipelinespec.ExecutorOutput, error) {
|
||||
// collect user executor output file
|
||||
executorOutput := &pipelinespec.ExecutorOutput{
|
||||
Parameters: map[string]*pipelinespec.Value{},
|
||||
Artifacts: map[string]*pipelinespec.ArtifactList{},
|
||||
|
|
|
|||
|
|
@ -3,12 +3,16 @@ package component
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/golang/glog"
|
||||
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
|
||||
"github.com/kubeflow/pipelines/v2/metadata"
|
||||
"github.com/kubeflow/pipelines/v2/objectstore"
|
||||
pb "github.com/kubeflow/pipelines/v2/third_party/ml_metadata"
|
||||
"gocloud.dev/blob"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"k8s.io/client-go/kubernetes"
|
||||
"k8s.io/client-go/rest"
|
||||
|
|
@ -26,6 +30,7 @@ type LauncherV2Options struct {
|
|||
type LauncherV2 struct {
|
||||
executionID int64
|
||||
executorInput *pipelinespec.ExecutorInput
|
||||
component *pipelinespec.ComponentSpec
|
||||
command string
|
||||
args []string
|
||||
options LauncherV2Options
|
||||
|
|
@ -35,7 +40,7 @@ type LauncherV2 struct {
|
|||
k8sClient *kubernetes.Clientset
|
||||
}
|
||||
|
||||
func NewLauncherV2(executionID int64, executorInputJSON string, cmdArgs []string, opts *LauncherV2Options) (l *LauncherV2, err error) {
|
||||
func NewLauncherV2(executionID int64, executorInputJSON, componentSpecJSON string, cmdArgs []string, opts *LauncherV2Options) (l *LauncherV2, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to create component launcher v2: %w", err)
|
||||
|
|
@ -49,6 +54,11 @@ func NewLauncherV2(executionID int64, executorInputJSON string, cmdArgs []string
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal executor input: %w", err)
|
||||
}
|
||||
component := &pipelinespec.ComponentSpec{}
|
||||
err = protojson.Unmarshal([]byte(componentSpecJSON), component)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal component spec: %w", err)
|
||||
}
|
||||
if len(cmdArgs) == 0 {
|
||||
return nil, fmt.Errorf("command and arguments are empty")
|
||||
}
|
||||
|
|
@ -76,9 +86,13 @@ func NewLauncherV2(executionID int64, executorInputJSON string, cmdArgs []string
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = addOutputs(executorInput, component.GetOutputDefinitions()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &LauncherV2{
|
||||
executionID: executionID,
|
||||
executorInput: executorInput,
|
||||
component: component,
|
||||
command: cmdArgs[0],
|
||||
args: cmdArgs[1:],
|
||||
options: *opts,
|
||||
|
|
@ -105,11 +119,11 @@ func (l *LauncherV2) Execute(ctx context.Context) (err error) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = execute(ctx, l.executorInput, l.command, l.args, bucket, bucketConfig)
|
||||
executorOutput, err := executeV2(ctx, l.executorInput, l.component, l.command, l.args, bucket, bucketConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return l.publish(ctx, execution)
|
||||
return l.publish(ctx, execution, executorOutput)
|
||||
}
|
||||
|
||||
func (o *LauncherV2Options) validate() error {
|
||||
|
|
@ -153,18 +167,94 @@ func (l *LauncherV2) prePublish(ctx context.Context) (execution *metadata.Execut
|
|||
return l.metadataClient.PrePublishExecution(ctx, execution, ecfg)
|
||||
}
|
||||
|
||||
func (l *LauncherV2) publish(ctx context.Context, execution *metadata.Execution) (err error) {
|
||||
func (l *LauncherV2) publish(ctx context.Context, execution *metadata.Execution, executorOutput *pipelinespec.ExecutorOutput) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to publish results to ML Metadata: %w", err)
|
||||
}
|
||||
}()
|
||||
// TODO(Bobgy): read output parameters from local path, and add them to executorOutput.
|
||||
outputParameters, err := metadata.NewParameters(executorOutput.GetParameters())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO(Bobgy): upload output artifacts.
|
||||
// TODO(Bobgy): when adding artifacts, we will need execution.pipeline to be non-nil, because we need
|
||||
// to publish output artifacts to the context too.
|
||||
if err := l.metadataClient.PublishExecution(ctx, execution, nil, nil, pb.Execution_COMPLETE); err != nil {
|
||||
return fmt.Errorf("unable to publish execution: %w", err)
|
||||
return l.metadataClient.PublishExecution(ctx, execution, outputParameters, nil, pb.Execution_COMPLETE)
|
||||
}
|
||||
|
||||
// Add outputs info from component spec to executor input.
|
||||
func addOutputs(executorInput *pipelinespec.ExecutorInput, outputs *pipelinespec.ComponentOutputsSpec) error {
|
||||
if executorInput == nil {
|
||||
return fmt.Errorf("cannot add outputs to nil executor input")
|
||||
}
|
||||
if executorInput.Outputs == nil {
|
||||
executorInput.Outputs = &pipelinespec.ExecutorInput_Outputs{}
|
||||
}
|
||||
if executorInput.Outputs.Parameters == nil {
|
||||
executorInput.Outputs.Parameters = make(map[string]*pipelinespec.ExecutorInput_OutputParameter)
|
||||
}
|
||||
// TODO(Bobgy): add output artifacts
|
||||
for name, _ := range outputs.GetParameters() {
|
||||
executorInput.Outputs.Parameters[name] = &pipelinespec.ExecutorInput_OutputParameter{
|
||||
OutputFile: fmt.Sprintf("/tmp/kfp/outputs/%s", name),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func executeV2(ctx context.Context, executorInput *pipelinespec.ExecutorInput, component *pipelinespec.ComponentSpec, cmd string, args []string, bucket *blob.Bucket, bucketConfig *objectstore.Config) (*pipelinespec.ExecutorOutput, error) {
|
||||
executorOutput, err := execute(ctx, executorInput, cmd, args, bucket, bucketConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Collect Output Parameters
|
||||
//
|
||||
// These are not added in execute(), because execute() is shared between v2 compatible and v2 engine launcher.
|
||||
// In v2 compatible mode, we get output parameter info from runtimeInfo. In v2 engine, we get it from component spec.
|
||||
// Because of the difference, we cannot put parameter collection logic in one method.
|
||||
if executorOutput.Parameters == nil {
|
||||
executorOutput.Parameters = make(map[string]*pipelinespec.Value)
|
||||
}
|
||||
outputParameters := executorOutput.GetParameters()
|
||||
for name, param := range executorInput.GetOutputs().GetParameters() {
|
||||
_, ok := outputParameters[name]
|
||||
if ok {
|
||||
// If the output parameter was already specified in output metadata file,
|
||||
// we don't need to collect it from file, because output metadata file has
|
||||
// the highest priority.
|
||||
continue
|
||||
}
|
||||
paramSpec, ok := component.GetOutputDefinitions().GetParameters()[name]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to find output parameter name=%q in component spec", name)
|
||||
}
|
||||
msg := func(err error) error {
|
||||
return fmt.Errorf("failed to read output parameter name=%q type=%q path=%q: %w", name, paramSpec.GetType(), param.GetOutputFile(), err)
|
||||
}
|
||||
b, err := ioutil.ReadFile(param.GetOutputFile())
|
||||
if err != nil {
|
||||
return nil, msg(err)
|
||||
}
|
||||
switch paramSpec.GetType() {
|
||||
case pipelinespec.PrimitiveType_STRING:
|
||||
outputParameters[name] = metadata.StringValue(string(b))
|
||||
case pipelinespec.PrimitiveType_INT:
|
||||
i, err := strconv.ParseInt(strings.TrimSpace(string(b)), 10, 0)
|
||||
if err != nil {
|
||||
return nil, msg(err)
|
||||
}
|
||||
outputParameters[name] = metadata.IntValue(i)
|
||||
case pipelinespec.PrimitiveType_DOUBLE:
|
||||
f, err := strconv.ParseFloat(strings.TrimSpace(string(b)), 0)
|
||||
if err != nil {
|
||||
return nil, msg(err)
|
||||
}
|
||||
outputParameters[name] = metadata.DoubleValue(f)
|
||||
default:
|
||||
return nil, msg(fmt.Errorf("unknown type. Expected STRING, INT or DOUBLE"))
|
||||
}
|
||||
}
|
||||
return executorOutput, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -206,6 +206,19 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, task *pipelinespec.Pi
|
|||
inputs := &pipelinespec.ExecutorInput_Inputs{
|
||||
Parameters: make(map[string]*pipelinespec.Value),
|
||||
}
|
||||
// get executions in context on demand
|
||||
var tasksCache map[string]*metadata.Execution
|
||||
getDAGTasks := func() (map[string]*metadata.Execution, error) {
|
||||
if tasksCache != nil {
|
||||
return tasksCache, nil
|
||||
}
|
||||
tasks, err := mlmd.GetExecutionsInDAG(ctx, dag)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tasksCache = tasks
|
||||
return tasks, nil
|
||||
}
|
||||
if task.GetInputs() != nil {
|
||||
for name, paramSpec := range task.GetInputs().Parameters {
|
||||
paramError := func(err error) error {
|
||||
|
|
@ -214,20 +227,55 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, task *pipelinespec.Pi
|
|||
if paramSpec.GetParameterExpressionSelector() != "" {
|
||||
return nil, paramError(fmt.Errorf("parameter expression selector not implemented yet"))
|
||||
}
|
||||
componentInput := paramSpec.GetComponentInputParameter()
|
||||
if componentInput != "" {
|
||||
switch t := paramSpec.Kind.(type) {
|
||||
case *pipelinespec.TaskInputsSpec_InputParameterSpec_ComponentInputParameter:
|
||||
componentInput := paramSpec.GetComponentInputParameter()
|
||||
if componentInput == "" {
|
||||
return nil, paramError(fmt.Errorf("empty component input"))
|
||||
}
|
||||
v, ok := inputParams[componentInput]
|
||||
if !ok {
|
||||
return nil, paramError(fmt.Errorf("parent DAG does not have input parameter %s", componentInput))
|
||||
}
|
||||
inputs.Parameters[name] = v
|
||||
} else {
|
||||
return nil, paramError(fmt.Errorf("parameter spec not implemented yet"))
|
||||
|
||||
case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter:
|
||||
taskOutput := paramSpec.GetTaskOutputParameter()
|
||||
if taskOutput.GetProducerTask() == "" {
|
||||
return nil, paramError(fmt.Errorf("producer task is empty"))
|
||||
}
|
||||
if taskOutput.GetOutputParameterKey() == "" {
|
||||
return nil, paramError(fmt.Errorf("output parameter key is empty"))
|
||||
}
|
||||
tasks, err := getDAGTasks()
|
||||
if err != nil {
|
||||
return nil, paramError(err)
|
||||
}
|
||||
producer, ok := tasks[taskOutput.GetProducerTask()]
|
||||
if !ok {
|
||||
return nil, paramError(fmt.Errorf("cannot find producer task %q", taskOutput.GetProducerTask()))
|
||||
}
|
||||
_, outputs, err := producer.GetParameters()
|
||||
if err != nil {
|
||||
return nil, paramError(fmt.Errorf("get producer output parameters: %w", err))
|
||||
}
|
||||
param, ok := outputs[taskOutput.GetOutputParameterKey()]
|
||||
if !ok {
|
||||
return nil, paramError(fmt.Errorf("cannot find output parameter key %q in producer task %q", taskOutput.GetOutputParameterKey(), taskOutput.GetProducerTask()))
|
||||
}
|
||||
inputs.Parameters[name] = param
|
||||
|
||||
// TODO(Bobgy): implement the following cases
|
||||
// case *pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue:
|
||||
// case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskFinalStatus_:
|
||||
default:
|
||||
return nil, paramError(fmt.Errorf("parameter spec of type %T not implemented yet", t))
|
||||
}
|
||||
}
|
||||
if len(task.GetInputs().GetArtifacts()) > 0 {
|
||||
return nil, fmt.Errorf("failed to resolve inputs: artifact inputs not implemented yet")
|
||||
}
|
||||
}
|
||||
// TODO(Bobgy): validate executor inputs match component inputs definition
|
||||
return inputs, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -97,6 +97,27 @@ type Parameters struct {
|
|||
DoubleParameters map[string]float64
|
||||
}
|
||||
|
||||
func NewParameters(params map[string]*pipelinespec.Value) (*Parameters, error) {
|
||||
result := &Parameters{
|
||||
IntParameters: make(map[string]int64),
|
||||
StringParameters: make(map[string]string),
|
||||
DoubleParameters: make(map[string]float64),
|
||||
}
|
||||
for name, parameter := range params {
|
||||
switch t := parameter.Value.(type) {
|
||||
case *pipelinespec.Value_StringValue:
|
||||
result.StringParameters[name] = parameter.GetStringValue()
|
||||
case *pipelinespec.Value_IntValue:
|
||||
result.IntParameters[name] = parameter.GetIntValue()
|
||||
case *pipelinespec.Value_DoubleValue:
|
||||
result.DoubleParameters[name] = parameter.GetDoubleValue()
|
||||
default:
|
||||
return nil, fmt.Errorf("failed to convert from map[string]*pipelinespec.Value to metadata.Parameters: unknown parameter type for parameter name=%q: %T", name, t)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ExecutionConfig represents the input parameters and artifacts to an Execution.
|
||||
type ExecutionConfig struct {
|
||||
InputParameters *Parameters
|
||||
|
|
@ -161,6 +182,13 @@ func (e *Execution) String() string {
|
|||
return e.execution.String()
|
||||
}
|
||||
|
||||
func (e *Execution) TaskName() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
return e.execution.GetCustomProperties()[keyTaskName].GetStringValue()
|
||||
}
|
||||
|
||||
// GetPipeline returns the current pipeline represented by the specified
|
||||
// pipeline name and run ID.
|
||||
func (c *Client) GetPipeline(ctx context.Context, pipelineName, pipelineRunID, namespace, runResource string) (*Pipeline, error) {
|
||||
|
|
@ -199,6 +227,11 @@ type DAG struct {
|
|||
context *pb.Context
|
||||
}
|
||||
|
||||
// identifier info for error message purposes
|
||||
func (d *DAG) Info() string {
|
||||
return fmt.Sprintf("DAG(executionID=%v, contextID=%v)", d.Execution.GetID(), d.context.GetId())
|
||||
}
|
||||
|
||||
func (c *Client) GetDAG(ctx context.Context, executionID int64, contextID int64) (*DAG, error) {
|
||||
dagError := func(err error) error {
|
||||
return fmt.Errorf("failed to get DAG executionID=%v contextID=%v: %w", executionID, contextID, err)
|
||||
|
|
@ -467,6 +500,38 @@ func (c *Client) GetExecution(ctx context.Context, id int64) (*Execution, error)
|
|||
return &Execution{execution: executions[0]}, nil
|
||||
}
|
||||
|
||||
// GetExecutionsInDAG gets all executions in the DAG context, and organize them
|
||||
// into a map, keyed by task name.
|
||||
func (c *Client) GetExecutionsInDAG(ctx context.Context, dag *DAG) (executionsMap map[string]*Execution, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to get executions in %s: %w", dag.Info(), err)
|
||||
}
|
||||
}()
|
||||
executionsMap = make(map[string]*Execution)
|
||||
res, err := c.svc.GetExecutionsByContext(ctx, &pb.GetExecutionsByContextRequest{
|
||||
ContextId: dag.context.Id,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
execs := res.GetExecutions()
|
||||
for _, e := range execs {
|
||||
execution := &Execution{execution: e}
|
||||
taskName := execution.TaskName()
|
||||
if taskName == "" {
|
||||
return nil, fmt.Errorf("empty task name for execution ID: %v", execution.GetID())
|
||||
}
|
||||
existing, ok := executionsMap[taskName]
|
||||
if ok {
|
||||
// TODO(Bobgy): to support retry, we need to handle multiple tasks with the same task name.
|
||||
return nil, fmt.Errorf("two tasks have the same task name %q, id1=%v id2=%v", taskName, existing.GetID(), execution.GetID())
|
||||
}
|
||||
executionsMap[taskName] = execution
|
||||
}
|
||||
return executionsMap, nil
|
||||
}
|
||||
|
||||
// GetEventsByArtifactIDs ...
|
||||
func (c *Client) GetEventsByArtifactIDs(ctx context.Context, artifactIds []int64) ([]*pb.Event, error) {
|
||||
req := &pb.GetEventsByArtifactIDsRequest{ArtifactIds: artifactIds}
|
||||
|
|
@ -664,11 +729,6 @@ func getOrInsertContext(ctx context.Context, svc pb.MetadataStoreServiceClient,
|
|||
|
||||
func GenerateExecutionConfig(executorInput *pipelinespec.ExecutorInput) (*ExecutionConfig, error) {
|
||||
ecfg := &ExecutionConfig{
|
||||
InputParameters: &Parameters{
|
||||
IntParameters: make(map[string]int64),
|
||||
StringParameters: make(map[string]string),
|
||||
DoubleParameters: make(map[string]float64),
|
||||
},
|
||||
InputArtifactIDs: make(map[string][]int64),
|
||||
}
|
||||
|
||||
|
|
@ -682,18 +742,11 @@ func GenerateExecutionConfig(executorInput *pipelinespec.ExecutorInput) (*Execut
|
|||
}
|
||||
}
|
||||
|
||||
for name, parameter := range executorInput.Inputs.Parameters {
|
||||
switch t := parameter.Value.(type) {
|
||||
case *pipelinespec.Value_StringValue:
|
||||
ecfg.InputParameters.StringParameters[name] = parameter.GetStringValue()
|
||||
case *pipelinespec.Value_IntValue:
|
||||
ecfg.InputParameters.IntParameters[name] = parameter.GetIntValue()
|
||||
case *pipelinespec.Value_DoubleValue:
|
||||
ecfg.InputParameters.DoubleParameters[name] = parameter.GetDoubleValue()
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown parameter type: %T", t)
|
||||
}
|
||||
parameters, err := NewParameters(executorInput.Inputs.Parameters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ecfg.InputParameters = parameters
|
||||
return ecfg, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,29 +13,29 @@ import (
|
|||
func mlmdValueToPipelineSpecValue(v *pb.Value) (*pipelinespec.Value, error) {
|
||||
switch t := v.Value.(type) {
|
||||
case *pb.Value_StringValue:
|
||||
return stringKFPValue(t.StringValue), nil
|
||||
return StringValue(t.StringValue), nil
|
||||
case *pb.Value_DoubleValue:
|
||||
return doubleKFPValue(t.DoubleValue), nil
|
||||
return DoubleValue(t.DoubleValue), nil
|
||||
case *pb.Value_IntValue:
|
||||
return intKFPValue(t.IntValue), nil
|
||||
return IntValue(t.IntValue), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown value type %T", t)
|
||||
}
|
||||
}
|
||||
|
||||
func stringKFPValue(v string) *pipelinespec.Value {
|
||||
func StringValue(v string) *pipelinespec.Value {
|
||||
return &pipelinespec.Value{
|
||||
Value: &pipelinespec.Value_StringValue{StringValue: v},
|
||||
}
|
||||
}
|
||||
|
||||
func doubleKFPValue(v float64) *pipelinespec.Value {
|
||||
func DoubleValue(v float64) *pipelinespec.Value {
|
||||
return &pipelinespec.Value{
|
||||
Value: &pipelinespec.Value_DoubleValue{DoubleValue: v},
|
||||
}
|
||||
}
|
||||
|
||||
func intKFPValue(v int64) *pipelinespec.Value {
|
||||
func IntValue(v int64) *pipelinespec.Value {
|
||||
return &pipelinespec.Value{
|
||||
Value: &pipelinespec.Value_IntValue{IntValue: v},
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue