96 lines
3.0 KiB
Python
96 lines
3.0 KiB
Python
# Copyright 2024 The Kubeflow Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Tests for io.py."""
|
|
|
|
import unittest
|
|
|
|
from kfp import dsl
|
|
from kfp.local import io
|
|
|
|
|
|
class IOStoreTest(unittest.TestCase):
|
|
|
|
def test_task_not_found(self):
|
|
store = io.IOStore()
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"Tried to get output 'foo' from task 'my-task', but task 'my-task' not found\."
|
|
):
|
|
store.get_task_output('my-task', 'foo')
|
|
|
|
def test_output_not_found(self):
|
|
store = io.IOStore()
|
|
store.put_task_output('my-task', 'bar', 'baz')
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"Tried to get output 'foo' from task 'my-task', but task 'my-task' has no output named 'foo'\."
|
|
):
|
|
store.get_task_output('my-task', 'foo')
|
|
|
|
def test_parent_input_not_found(self):
|
|
store = io.IOStore()
|
|
with self.assertRaisesRegex(
|
|
ValueError, r"Parent pipeline input argument 'foo' not found."):
|
|
store.get_parent_input('foo')
|
|
|
|
def test_put_and_get_task_output(self):
|
|
store = io.IOStore()
|
|
store.put_task_output('my-task', 'foo', 'bar')
|
|
store.put_task_output('my-task', 'baz', 'bat')
|
|
self.assertEqual(
|
|
store.get_task_output('my-task', 'foo'),
|
|
'bar',
|
|
)
|
|
self.assertEqual(
|
|
store.get_task_output('my-task', 'baz'),
|
|
'bat',
|
|
)
|
|
# test getting doesn't remove by getting twice
|
|
self.assertEqual(
|
|
store.get_task_output('my-task', 'baz'),
|
|
'bat',
|
|
)
|
|
|
|
def test_put_and_get_parent_input(self):
|
|
store = io.IOStore()
|
|
store.put_parent_input('foo', 'bar')
|
|
store.put_parent_input('baz', 'bat')
|
|
self.assertEqual(
|
|
store.get_parent_input('foo'),
|
|
'bar',
|
|
)
|
|
self.assertEqual(
|
|
store.get_parent_input('baz'),
|
|
'bat',
|
|
)
|
|
# test getting doesn't remove by getting twice
|
|
self.assertEqual(
|
|
store.get_parent_input('baz'),
|
|
'bat',
|
|
)
|
|
|
|
def test_put_and_get_task_output_with_artifact(self):
|
|
artifact = dsl.Artifact(
|
|
name='foo', uri='/my/uri', metadata={'foo': 'bar'})
|
|
store = io.IOStore()
|
|
store.put_task_output('my-task', 'foo', artifact)
|
|
self.assertEqual(
|
|
store.get_task_output('my-task', 'foo'),
|
|
artifact,
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|