389 lines
13 KiB
Python
389 lines
13 KiB
Python
# Copyright 2021 The Kubeflow Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tests for kfp.dsl.pipeline_channel."""
|
|
|
|
from typing import List
|
|
import unittest
|
|
|
|
from absl.testing import parameterized
|
|
from kfp import dsl
|
|
from kfp.dsl import Artifact
|
|
from kfp.dsl import Dataset
|
|
from kfp.dsl import Output
|
|
from kfp.dsl import pipeline_channel
|
|
|
|
|
|
class PipelineChannelTest(parameterized.TestCase):
|
|
|
|
def test_instantiate_pipline_channel(self):
|
|
with self.assertRaisesRegex(
|
|
TypeError, "Can't instantiate abstract class PipelineChannel"):
|
|
p = pipeline_channel.PipelineChannel(
|
|
name='channel',
|
|
channel_type='String',
|
|
)
|
|
|
|
def test_invalid_name(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'Only letters, numbers, spaces, "_", and "-" are allowed in the '
|
|
'name. Must begin with a letter. Got name: 123_abc'):
|
|
p = pipeline_channel.create_pipeline_channel(
|
|
name='123_abc',
|
|
channel_type='String',
|
|
)
|
|
|
|
def test_task_name_and_value_both_set(self):
|
|
with self.assertRaisesRegex(ValueError,
|
|
'task_name and value cannot be both set.'):
|
|
p = pipeline_channel.create_pipeline_channel(
|
|
name='abc',
|
|
channel_type='Integer',
|
|
task_name='task1',
|
|
value=123,
|
|
)
|
|
|
|
def test_invalid_type(self):
|
|
with self.assertRaisesRegex(TypeError,
|
|
'Artifact is not a parameter type.'):
|
|
p = pipeline_channel.PipelineParameterChannel(
|
|
name='channel1',
|
|
channel_type='Artifact',
|
|
)
|
|
|
|
with self.assertRaisesRegex(TypeError,
|
|
'String is not an artifact type.'):
|
|
p = pipeline_channel.PipelineArtifactChannel(
|
|
name='channel1',
|
|
channel_type='String',
|
|
task_name='task1',
|
|
is_artifact_list=False,
|
|
)
|
|
|
|
@parameterized.parameters(
|
|
{
|
|
'pipeline_channel':
|
|
pipeline_channel.create_pipeline_channel(
|
|
name='channel1',
|
|
task_name='task1',
|
|
channel_type='String',
|
|
),
|
|
'str_repr':
|
|
'{{channel:task=task1;name=channel1;type=String;}}',
|
|
},
|
|
{
|
|
'pipeline_channel':
|
|
pipeline_channel.create_pipeline_channel(
|
|
name='channel2',
|
|
channel_type='Integer',
|
|
),
|
|
'str_repr':
|
|
'{{channel:task=;name=channel2;type=Integer;}}',
|
|
},
|
|
{
|
|
'pipeline_channel':
|
|
pipeline_channel.create_pipeline_channel(
|
|
name='channel3',
|
|
channel_type={'type_a': {
|
|
'property_b': 'c'
|
|
}},
|
|
task_name='task3',
|
|
),
|
|
'str_repr':
|
|
'{{channel:task=task3;name=channel3;type={"type_a": {"property_b": "c"}};}}',
|
|
},
|
|
{
|
|
'pipeline_channel':
|
|
pipeline_channel.create_pipeline_channel(
|
|
name='channel4',
|
|
channel_type='Float',
|
|
value=1.23,
|
|
),
|
|
'str_repr':
|
|
'{{channel:task=;name=channel4;type=Float;}}',
|
|
},
|
|
{
|
|
'pipeline_channel':
|
|
pipeline_channel.create_pipeline_channel(
|
|
name='channel5',
|
|
channel_type='system.Artifact@0.0.1',
|
|
task_name='task5',
|
|
),
|
|
'str_repr':
|
|
'{{channel:task=task5;name=channel5;type=system.Artifact@0.0.1;}}',
|
|
},
|
|
)
|
|
def test_str_repr(self, pipeline_channel, str_repr):
|
|
self.assertEqual(str_repr, str(pipeline_channel))
|
|
|
|
def test_extract_pipeline_channels(self):
|
|
p1 = pipeline_channel.create_pipeline_channel(
|
|
name='channel1',
|
|
channel_type='String',
|
|
value='abc',
|
|
)
|
|
p2 = pipeline_channel.create_pipeline_channel(
|
|
name='channel2',
|
|
channel_type='customized_type_b',
|
|
task_name='task2',
|
|
)
|
|
p3 = pipeline_channel.create_pipeline_channel(
|
|
name='channel3',
|
|
channel_type={'customized_type_c': {
|
|
'property_c': 'value_c'
|
|
}},
|
|
task_name='task3',
|
|
)
|
|
stuff_chars = ' between '
|
|
payload = str(p1) + stuff_chars + str(p2) + stuff_chars + str(p3)
|
|
params = pipeline_channel.extract_pipeline_channels_from_string(payload)
|
|
self.assertListEqual([p1, p2, p3], params)
|
|
|
|
# Expecting the extract_pipelineparam_from_any to dedup pipeline channels
|
|
# among all the payloads.
|
|
payload = [
|
|
str(p1) + stuff_chars + str(p2),
|
|
str(p2) + stuff_chars + str(p3)
|
|
]
|
|
params = pipeline_channel.extract_pipeline_channels_from_any(payload)
|
|
self.assertListEqual([p1, p2, p3], params)
|
|
|
|
|
|
@dsl.component
|
|
def string_comp() -> str:
|
|
return 'text'
|
|
|
|
|
|
@dsl.component
|
|
def list_comp() -> List[str]:
|
|
return ['text']
|
|
|
|
|
|
@dsl.component
|
|
def roll_three_sided_die() -> str:
|
|
import random
|
|
val = random.randint(0, 2)
|
|
|
|
if val == 0:
|
|
return 'heads'
|
|
elif val == 1:
|
|
return 'tails'
|
|
else:
|
|
return 'draw'
|
|
|
|
|
|
@dsl.component
|
|
def print_and_return(text: str) -> str:
|
|
print(text)
|
|
return text
|
|
|
|
|
|
class TestCanAccessTask(unittest.TestCase):
|
|
|
|
def test(self):
|
|
|
|
@dsl.pipeline
|
|
def my_pipeline():
|
|
op1 = string_comp()
|
|
self.assertEqual(op1.output.task, op1)
|
|
|
|
|
|
class TestOneOfAndCollectedNotComposable(unittest.TestCase):
|
|
|
|
def test_collected_in_oneof(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'dsl.Collected cannot be used inside of dsl.OneOf.'):
|
|
|
|
@dsl.pipeline
|
|
def my_pipeline(x: str):
|
|
with dsl.If(x == 'foo'):
|
|
t1 = list_comp()
|
|
with dsl.Else():
|
|
with dsl.ParallelFor([1, 2, 3]):
|
|
t2 = string_comp()
|
|
collected = dsl.Collected(t2.output)
|
|
# test cases doesn't return or pass to task to ensure validation is in the OneOf
|
|
dsl.OneOf(t1.output, collected)
|
|
|
|
def test_oneof_in_collected(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'dsl.OneOf cannot be used inside of dsl.Collected.'):
|
|
|
|
@dsl.pipeline
|
|
def my_pipeline(x: str):
|
|
with dsl.ParallelFor([1, 2, 3]):
|
|
with dsl.If(x == 'foo'):
|
|
t1 = string_comp()
|
|
with dsl.Else():
|
|
t2 = string_comp()
|
|
oneof = dsl.OneOf(t1.output, t2.output)
|
|
# test cases doesn't return or pass to task to ensure validation is in the Collected constructor
|
|
dsl.Collected(oneof)
|
|
|
|
|
|
class TestOneOfRequiresSameType(unittest.TestCase):
|
|
|
|
def test_same_parameter_type(self):
|
|
|
|
@dsl.pipeline
|
|
def my_pipeline(x: str) -> str:
|
|
with dsl.If(x == 'foo'):
|
|
t1 = string_comp()
|
|
with dsl.Else():
|
|
t2 = string_comp()
|
|
return dsl.OneOf(t1.output, t2.output)
|
|
|
|
self.assertEqual(
|
|
my_pipeline.pipeline_spec.components['comp-condition-branches-1']
|
|
.output_definitions.parameters[
|
|
'pipelinechannel--condition-branches-1-oneof-1'].parameter_type,
|
|
3)
|
|
|
|
def test_different_parameter_types(self):
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
r'Task outputs passed to dsl\.OneOf must be the same type. Got two channels with different types: String at index 0 and typing\.List\[str\] at index 1\.'
|
|
):
|
|
|
|
@dsl.pipeline
|
|
def my_pipeline(x: str) -> str:
|
|
with dsl.If(x == 'foo'):
|
|
t1 = string_comp()
|
|
with dsl.Else():
|
|
t2 = list_comp()
|
|
return dsl.OneOf(t1.output, t2.output)
|
|
|
|
def test_same_artifact_type(self):
|
|
|
|
@dsl.component
|
|
def artifact_comp(out: Output[Artifact]):
|
|
with open(out.path, 'w') as f:
|
|
f.write('foo')
|
|
|
|
@dsl.pipeline
|
|
def my_pipeline(x: str) -> Artifact:
|
|
with dsl.If(x == 'foo'):
|
|
t1 = artifact_comp()
|
|
with dsl.Else():
|
|
t2 = artifact_comp()
|
|
return dsl.OneOf(t1.outputs['out'], t2.outputs['out'])
|
|
|
|
self.assertEqual(
|
|
my_pipeline.pipeline_spec.components['comp-condition-branches-1']
|
|
.output_definitions
|
|
.artifacts['pipelinechannel--condition-branches-1-oneof-1']
|
|
.artifact_type.schema_title,
|
|
'system.Artifact',
|
|
)
|
|
self.assertEqual(
|
|
my_pipeline.pipeline_spec.components['comp-condition-branches-1']
|
|
.output_definitions
|
|
.artifacts['pipelinechannel--condition-branches-1-oneof-1']
|
|
.artifact_type.schema_version,
|
|
'0.0.1',
|
|
)
|
|
|
|
def test_different_artifact_type(self):
|
|
|
|
@dsl.component
|
|
def artifact_comp_one(out: Output[Artifact]):
|
|
with open(out.path, 'w') as f:
|
|
f.write('foo')
|
|
|
|
@dsl.component
|
|
def artifact_comp_two(out: Output[Dataset]):
|
|
with open(out.path, 'w') as f:
|
|
f.write('foo')
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
r'Task outputs passed to dsl\.OneOf must be the same type. Got two channels with different types: system.Artifact@0.0.1 at index 0 and system.Dataset@0.0.1 at index 1\.'
|
|
):
|
|
|
|
@dsl.pipeline
|
|
def my_pipeline(x: str) -> Artifact:
|
|
with dsl.If(x == 'foo'):
|
|
t1 = artifact_comp_one()
|
|
with dsl.Else():
|
|
t2 = artifact_comp_two()
|
|
return dsl.OneOf(t1.outputs['out'], t2.outputs['out'])
|
|
|
|
def test_different_artifact_type_due_to_list(self):
|
|
# if we ever support list of artifact outputs from components, this test will fail, which is good because it needs to be changed
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"Output lists of artifacts are only supported for pipelines\. Got output list of artifacts for output parameter 'out' of component 'artifact-comp-two'\."
|
|
):
|
|
|
|
@dsl.component
|
|
def artifact_comp_one(out: Output[Artifact]):
|
|
with open(out.path, 'w') as f:
|
|
f.write('foo')
|
|
|
|
@dsl.component
|
|
def artifact_comp_two(out: Output[List[Artifact]]):
|
|
with open(out.path, 'w') as f:
|
|
f.write('foo')
|
|
|
|
@dsl.pipeline
|
|
def my_pipeline(x: str) -> Artifact:
|
|
with dsl.If(x == 'foo'):
|
|
t1 = artifact_comp_one()
|
|
with dsl.Else():
|
|
t2 = artifact_comp_two()
|
|
return dsl.OneOf(t1.outputs['out'], t2.outputs['out'])
|
|
|
|
def test_parameters_mixed_with_artifacts(self):
|
|
|
|
@dsl.component
|
|
def artifact_comp(out: Output[Artifact]):
|
|
with open(out.path, 'w') as f:
|
|
f.write('foo')
|
|
|
|
with self.assertRaisesRegex(
|
|
TypeError,
|
|
r'Task outputs passed to dsl\.OneOf must be the same type\. Found a mix of parameters and artifacts passed to dsl\.OneOf\.'
|
|
):
|
|
|
|
@dsl.pipeline
|
|
def my_pipeline(x: str) -> str:
|
|
with dsl.If(x == 'foo'):
|
|
t1 = artifact_comp()
|
|
with dsl.Else():
|
|
t2 = string_comp()
|
|
return dsl.OneOf(t1.output, t2.output)
|
|
|
|
def test_no_else_raises(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r'dsl\.OneOf must include an output from a task in a dsl\.Else group to ensure at least one output is available at runtime\.'
|
|
):
|
|
|
|
@dsl.pipeline
|
|
def roll_die_pipeline():
|
|
flip_coin_task = roll_three_sided_die()
|
|
with dsl.If(flip_coin_task.output == 'heads'):
|
|
t1 = print_and_return(text='Got heads!')
|
|
with dsl.Elif(flip_coin_task.output == 'tails'):
|
|
t2 = print_and_return(text='Got tails!')
|
|
print_and_return(text=dsl.OneOf(t1.output, t2.output))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|