pipelines/sdk/python/kfp/local/io_test.py

96 lines
3.0 KiB
Python

# Copyright 2024 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for io.py."""
import unittest
from kfp import dsl
from kfp.local import io
class IOStoreTest(unittest.TestCase):
def test_task_not_found(self):
store = io.IOStore()
with self.assertRaisesRegex(
ValueError,
r"Tried to get output 'foo' from task 'my-task', but task 'my-task' not found\."
):
store.get_task_output('my-task', 'foo')
def test_output_not_found(self):
store = io.IOStore()
store.put_task_output('my-task', 'bar', 'baz')
with self.assertRaisesRegex(
ValueError,
r"Tried to get output 'foo' from task 'my-task', but task 'my-task' has no output named 'foo'\."
):
store.get_task_output('my-task', 'foo')
def test_parent_input_not_found(self):
store = io.IOStore()
with self.assertRaisesRegex(
ValueError, r"Parent pipeline input argument 'foo' not found."):
store.get_parent_input('foo')
def test_put_and_get_task_output(self):
store = io.IOStore()
store.put_task_output('my-task', 'foo', 'bar')
store.put_task_output('my-task', 'baz', 'bat')
self.assertEqual(
store.get_task_output('my-task', 'foo'),
'bar',
)
self.assertEqual(
store.get_task_output('my-task', 'baz'),
'bat',
)
# test getting doesn't remove by getting twice
self.assertEqual(
store.get_task_output('my-task', 'baz'),
'bat',
)
def test_put_and_get_parent_input(self):
store = io.IOStore()
store.put_parent_input('foo', 'bar')
store.put_parent_input('baz', 'bat')
self.assertEqual(
store.get_parent_input('foo'),
'bar',
)
self.assertEqual(
store.get_parent_input('baz'),
'bat',
)
# test getting doesn't remove by getting twice
self.assertEqual(
store.get_parent_input('baz'),
'bat',
)
def test_put_and_get_task_output_with_artifact(self):
artifact = dsl.Artifact(
name='foo', uri='/my/uri', metadata={'foo': 'bar'})
store = io.IOStore()
store.put_task_output('my-task', 'foo', artifact)
self.assertEqual(
store.get_task_output('my-task', 'foo'),
artifact,
)
if __name__ == '__main__':
unittest.main()