pipelines/sdk/python/kfp/dsl/pipeline_channel_test.py

389 lines
13 KiB
Python

# Copyright 2021 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for kfp.dsl.pipeline_channel."""
from typing import List
import unittest
from absl.testing import parameterized
from kfp import dsl
from kfp.dsl import Artifact
from kfp.dsl import Dataset
from kfp.dsl import Output
from kfp.dsl import pipeline_channel
class PipelineChannelTest(parameterized.TestCase):
def test_instantiate_pipline_channel(self):
with self.assertRaisesRegex(
TypeError, "Can't instantiate abstract class PipelineChannel"):
p = pipeline_channel.PipelineChannel(
name='channel',
channel_type='String',
)
def test_invalid_name(self):
with self.assertRaisesRegex(
ValueError,
'Only letters, numbers, spaces, "_", and "-" are allowed in the '
'name. Must begin with a letter. Got name: 123_abc'):
p = pipeline_channel.create_pipeline_channel(
name='123_abc',
channel_type='String',
)
def test_task_name_and_value_both_set(self):
with self.assertRaisesRegex(ValueError,
'task_name and value cannot be both set.'):
p = pipeline_channel.create_pipeline_channel(
name='abc',
channel_type='Integer',
task_name='task1',
value=123,
)
def test_invalid_type(self):
with self.assertRaisesRegex(TypeError,
'Artifact is not a parameter type.'):
p = pipeline_channel.PipelineParameterChannel(
name='channel1',
channel_type='Artifact',
)
with self.assertRaisesRegex(TypeError,
'String is not an artifact type.'):
p = pipeline_channel.PipelineArtifactChannel(
name='channel1',
channel_type='String',
task_name='task1',
is_artifact_list=False,
)
@parameterized.parameters(
{
'pipeline_channel':
pipeline_channel.create_pipeline_channel(
name='channel1',
task_name='task1',
channel_type='String',
),
'str_repr':
'{{channel:task=task1;name=channel1;type=String;}}',
},
{
'pipeline_channel':
pipeline_channel.create_pipeline_channel(
name='channel2',
channel_type='Integer',
),
'str_repr':
'{{channel:task=;name=channel2;type=Integer;}}',
},
{
'pipeline_channel':
pipeline_channel.create_pipeline_channel(
name='channel3',
channel_type={'type_a': {
'property_b': 'c'
}},
task_name='task3',
),
'str_repr':
'{{channel:task=task3;name=channel3;type={"type_a": {"property_b": "c"}};}}',
},
{
'pipeline_channel':
pipeline_channel.create_pipeline_channel(
name='channel4',
channel_type='Float',
value=1.23,
),
'str_repr':
'{{channel:task=;name=channel4;type=Float;}}',
},
{
'pipeline_channel':
pipeline_channel.create_pipeline_channel(
name='channel5',
channel_type='system.Artifact@0.0.1',
task_name='task5',
),
'str_repr':
'{{channel:task=task5;name=channel5;type=system.Artifact@0.0.1;}}',
},
)
def test_str_repr(self, pipeline_channel, str_repr):
self.assertEqual(str_repr, str(pipeline_channel))
def test_extract_pipeline_channels(self):
p1 = pipeline_channel.create_pipeline_channel(
name='channel1',
channel_type='String',
value='abc',
)
p2 = pipeline_channel.create_pipeline_channel(
name='channel2',
channel_type='customized_type_b',
task_name='task2',
)
p3 = pipeline_channel.create_pipeline_channel(
name='channel3',
channel_type={'customized_type_c': {
'property_c': 'value_c'
}},
task_name='task3',
)
stuff_chars = ' between '
payload = str(p1) + stuff_chars + str(p2) + stuff_chars + str(p3)
params = pipeline_channel.extract_pipeline_channels_from_string(payload)
self.assertListEqual([p1, p2, p3], params)
# Expecting the extract_pipelineparam_from_any to dedup pipeline channels
# among all the payloads.
payload = [
str(p1) + stuff_chars + str(p2),
str(p2) + stuff_chars + str(p3)
]
params = pipeline_channel.extract_pipeline_channels_from_any(payload)
self.assertListEqual([p1, p2, p3], params)
@dsl.component
def string_comp() -> str:
return 'text'
@dsl.component
def list_comp() -> List[str]:
return ['text']
@dsl.component
def roll_three_sided_die() -> str:
import random
val = random.randint(0, 2)
if val == 0:
return 'heads'
elif val == 1:
return 'tails'
else:
return 'draw'
@dsl.component
def print_and_return(text: str) -> str:
print(text)
return text
class TestCanAccessTask(unittest.TestCase):
def test(self):
@dsl.pipeline
def my_pipeline():
op1 = string_comp()
self.assertEqual(op1.output.task, op1)
class TestOneOfAndCollectedNotComposable(unittest.TestCase):
def test_collected_in_oneof(self):
with self.assertRaisesRegex(
ValueError,
'dsl.Collected cannot be used inside of dsl.OneOf.'):
@dsl.pipeline
def my_pipeline(x: str):
with dsl.If(x == 'foo'):
t1 = list_comp()
with dsl.Else():
with dsl.ParallelFor([1, 2, 3]):
t2 = string_comp()
collected = dsl.Collected(t2.output)
# test cases doesn't return or pass to task to ensure validation is in the OneOf
dsl.OneOf(t1.output, collected)
def test_oneof_in_collected(self):
with self.assertRaisesRegex(
ValueError,
'dsl.OneOf cannot be used inside of dsl.Collected.'):
@dsl.pipeline
def my_pipeline(x: str):
with dsl.ParallelFor([1, 2, 3]):
with dsl.If(x == 'foo'):
t1 = string_comp()
with dsl.Else():
t2 = string_comp()
oneof = dsl.OneOf(t1.output, t2.output)
# test cases doesn't return or pass to task to ensure validation is in the Collected constructor
dsl.Collected(oneof)
class TestOneOfRequiresSameType(unittest.TestCase):
def test_same_parameter_type(self):
@dsl.pipeline
def my_pipeline(x: str) -> str:
with dsl.If(x == 'foo'):
t1 = string_comp()
with dsl.Else():
t2 = string_comp()
return dsl.OneOf(t1.output, t2.output)
self.assertEqual(
my_pipeline.pipeline_spec.components['comp-condition-branches-1']
.output_definitions.parameters[
'pipelinechannel--condition-branches-1-oneof-1'].parameter_type,
3)
def test_different_parameter_types(self):
with self.assertRaisesRegex(
TypeError,
r'Task outputs passed to dsl\.OneOf must be the same type. Got two channels with different types: String at index 0 and typing\.List\[str\] at index 1\.'
):
@dsl.pipeline
def my_pipeline(x: str) -> str:
with dsl.If(x == 'foo'):
t1 = string_comp()
with dsl.Else():
t2 = list_comp()
return dsl.OneOf(t1.output, t2.output)
def test_same_artifact_type(self):
@dsl.component
def artifact_comp(out: Output[Artifact]):
with open(out.path, 'w') as f:
f.write('foo')
@dsl.pipeline
def my_pipeline(x: str) -> Artifact:
with dsl.If(x == 'foo'):
t1 = artifact_comp()
with dsl.Else():
t2 = artifact_comp()
return dsl.OneOf(t1.outputs['out'], t2.outputs['out'])
self.assertEqual(
my_pipeline.pipeline_spec.components['comp-condition-branches-1']
.output_definitions
.artifacts['pipelinechannel--condition-branches-1-oneof-1']
.artifact_type.schema_title,
'system.Artifact',
)
self.assertEqual(
my_pipeline.pipeline_spec.components['comp-condition-branches-1']
.output_definitions
.artifacts['pipelinechannel--condition-branches-1-oneof-1']
.artifact_type.schema_version,
'0.0.1',
)
def test_different_artifact_type(self):
@dsl.component
def artifact_comp_one(out: Output[Artifact]):
with open(out.path, 'w') as f:
f.write('foo')
@dsl.component
def artifact_comp_two(out: Output[Dataset]):
with open(out.path, 'w') as f:
f.write('foo')
with self.assertRaisesRegex(
TypeError,
r'Task outputs passed to dsl\.OneOf must be the same type. Got two channels with different types: system.Artifact@0.0.1 at index 0 and system.Dataset@0.0.1 at index 1\.'
):
@dsl.pipeline
def my_pipeline(x: str) -> Artifact:
with dsl.If(x == 'foo'):
t1 = artifact_comp_one()
with dsl.Else():
t2 = artifact_comp_two()
return dsl.OneOf(t1.outputs['out'], t2.outputs['out'])
def test_different_artifact_type_due_to_list(self):
# if we ever support list of artifact outputs from components, this test will fail, which is good because it needs to be changed
with self.assertRaisesRegex(
ValueError,
r"Output lists of artifacts are only supported for pipelines\. Got output list of artifacts for output parameter 'out' of component 'artifact-comp-two'\."
):
@dsl.component
def artifact_comp_one(out: Output[Artifact]):
with open(out.path, 'w') as f:
f.write('foo')
@dsl.component
def artifact_comp_two(out: Output[List[Artifact]]):
with open(out.path, 'w') as f:
f.write('foo')
@dsl.pipeline
def my_pipeline(x: str) -> Artifact:
with dsl.If(x == 'foo'):
t1 = artifact_comp_one()
with dsl.Else():
t2 = artifact_comp_two()
return dsl.OneOf(t1.outputs['out'], t2.outputs['out'])
def test_parameters_mixed_with_artifacts(self):
@dsl.component
def artifact_comp(out: Output[Artifact]):
with open(out.path, 'w') as f:
f.write('foo')
with self.assertRaisesRegex(
TypeError,
r'Task outputs passed to dsl\.OneOf must be the same type\. Found a mix of parameters and artifacts passed to dsl\.OneOf\.'
):
@dsl.pipeline
def my_pipeline(x: str) -> str:
with dsl.If(x == 'foo'):
t1 = artifact_comp()
with dsl.Else():
t2 = string_comp()
return dsl.OneOf(t1.output, t2.output)
def test_no_else_raises(self):
with self.assertRaisesRegex(
ValueError,
r'dsl\.OneOf must include an output from a task in a dsl\.Else group to ensure at least one output is available at runtime\.'
):
@dsl.pipeline
def roll_die_pipeline():
flip_coin_task = roll_three_sided_die()
with dsl.If(flip_coin_task.output == 'heads'):
t1 = print_and_return(text='Got heads!')
with dsl.Elif(flip_coin_task.output == 'tails'):
t2 = print_and_return(text='Got tails!')
print_and_return(text=dsl.OneOf(t1.output, t2.output))
if __name__ == '__main__':
unittest.main()