175 lines
5.8 KiB
Python
175 lines
5.8 KiB
Python
# Copyright 2021 The Kubeflow Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tests for kfp.components.pipeline_channel."""
|
|
|
|
import unittest
|
|
|
|
from absl.testing import parameterized
|
|
from kfp import dsl
|
|
from kfp.components import pipeline_channel
|
|
|
|
|
|
class PipelineChannelTest(parameterized.TestCase):
|
|
|
|
def test_instantiate_pipline_channel(self):
|
|
with self.assertRaisesRegex(
|
|
TypeError, "Can't instantiate abstract class PipelineChannel"):
|
|
p = pipeline_channel.PipelineChannel(
|
|
name='channel',
|
|
channel_type='String',
|
|
)
|
|
|
|
def test_invalid_name(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
'Only letters, numbers, spaces, "_", and "-" are allowed in the '
|
|
'name. Must begin with a letter. Got name: 123_abc'):
|
|
p = pipeline_channel.create_pipeline_channel(
|
|
name='123_abc',
|
|
channel_type='String',
|
|
)
|
|
|
|
def test_task_name_and_value_both_set(self):
|
|
with self.assertRaisesRegex(ValueError,
|
|
'task_name and value cannot be both set.'):
|
|
p = pipeline_channel.create_pipeline_channel(
|
|
name='abc',
|
|
channel_type='Integer',
|
|
task_name='task1',
|
|
value=123,
|
|
)
|
|
|
|
def test_invalid_type(self):
|
|
with self.assertRaisesRegex(TypeError,
|
|
'Artifact is not a parameter type.'):
|
|
p = pipeline_channel.PipelineParameterChannel(
|
|
name='channel1',
|
|
channel_type='Artifact',
|
|
)
|
|
|
|
with self.assertRaisesRegex(TypeError,
|
|
'String is not an artifact type.'):
|
|
p = pipeline_channel.PipelineArtifactChannel(
|
|
name='channel1',
|
|
channel_type='String',
|
|
task_name='task1',
|
|
is_artifact_list=False,
|
|
)
|
|
|
|
@parameterized.parameters(
|
|
{
|
|
'pipeline_channel':
|
|
pipeline_channel.create_pipeline_channel(
|
|
name='channel1',
|
|
task_name='task1',
|
|
channel_type='String',
|
|
),
|
|
'str_repr':
|
|
'{{channel:task=task1;name=channel1;type=String;}}',
|
|
},
|
|
{
|
|
'pipeline_channel':
|
|
pipeline_channel.create_pipeline_channel(
|
|
name='channel2',
|
|
channel_type='Integer',
|
|
),
|
|
'str_repr':
|
|
'{{channel:task=;name=channel2;type=Integer;}}',
|
|
},
|
|
{
|
|
'pipeline_channel':
|
|
pipeline_channel.create_pipeline_channel(
|
|
name='channel3',
|
|
channel_type={'type_a': {
|
|
'property_b': 'c'
|
|
}},
|
|
task_name='task3',
|
|
),
|
|
'str_repr':
|
|
'{{channel:task=task3;name=channel3;type={"type_a": {"property_b": "c"}};}}',
|
|
},
|
|
{
|
|
'pipeline_channel':
|
|
pipeline_channel.create_pipeline_channel(
|
|
name='channel4',
|
|
channel_type='Float',
|
|
value=1.23,
|
|
),
|
|
'str_repr':
|
|
'{{channel:task=;name=channel4;type=Float;}}',
|
|
},
|
|
{
|
|
'pipeline_channel':
|
|
pipeline_channel.create_pipeline_channel(
|
|
name='channel5',
|
|
channel_type='system.Artifact@0.0.1',
|
|
task_name='task5',
|
|
),
|
|
'str_repr':
|
|
'{{channel:task=task5;name=channel5;type=system.Artifact@0.0.1;}}',
|
|
},
|
|
)
|
|
def test_str_repr(self, pipeline_channel, str_repr):
|
|
self.assertEqual(str_repr, str(pipeline_channel))
|
|
|
|
def test_extract_pipeline_channels(self):
|
|
p1 = pipeline_channel.create_pipeline_channel(
|
|
name='channel1',
|
|
channel_type='String',
|
|
value='abc',
|
|
)
|
|
p2 = pipeline_channel.create_pipeline_channel(
|
|
name='channel2',
|
|
channel_type='customized_type_b',
|
|
task_name='task2',
|
|
)
|
|
p3 = pipeline_channel.create_pipeline_channel(
|
|
name='channel3',
|
|
channel_type={'customized_type_c': {
|
|
'property_c': 'value_c'
|
|
}},
|
|
task_name='task3',
|
|
)
|
|
stuff_chars = ' between '
|
|
payload = str(p1) + stuff_chars + str(p2) + stuff_chars + str(p3)
|
|
params = pipeline_channel.extract_pipeline_channels_from_string(payload)
|
|
self.assertListEqual([p1, p2, p3], params)
|
|
|
|
# Expecting the extract_pipelineparam_from_any to dedup pipeline channels
|
|
# among all the payloads.
|
|
payload = [
|
|
str(p1) + stuff_chars + str(p2),
|
|
str(p2) + stuff_chars + str(p3)
|
|
]
|
|
params = pipeline_channel.extract_pipeline_channels_from_any(payload)
|
|
self.assertListEqual([p1, p2, p3], params)
|
|
|
|
|
|
class TestCanAccessTask(unittest.TestCase):
|
|
|
|
def test(self):
|
|
|
|
@dsl.component
|
|
def comp() -> str:
|
|
return 'text'
|
|
|
|
@dsl.pipeline
|
|
def my_pipeline():
|
|
op1 = comp()
|
|
self.assertEqual(op1.output.task, op1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|