pipelines/sdk/python/kfp/client/client_test.py

420 lines
17 KiB
Python

# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
import textwrap
import unittest
from unittest.mock import Mock
from unittest.mock import patch
from absl.testing import parameterized
from google.protobuf import json_format
from kfp.client import auth
from kfp.client import client
from kfp.compiler import Compiler
from kfp.dsl import component
from kfp.dsl import pipeline
from kfp.pipeline_spec import pipeline_spec_pb2
import kfp_server_api
import yaml
class TestValidatePipelineName(parameterized.TestCase):
@parameterized.parameters([
'pipeline',
'my-pipeline',
'my-pipeline-1',
'1pipeline',
'pipeline1',
'my_pipeline',
"person's-pipeline",
'my pipeline',
'pipeline.yaml',
])
def test_valid(self, name: str):
client.validate_pipeline_display_name(name)
@parameterized.parameters(['', ' ', '\t'])
def test_invalid(self, name: str):
with self.assertRaisesRegex(
ValueError,
'Invalid pipeline name. Pipeline name cannot be empty or contain only whitespace.'
):
client.validate_pipeline_display_name(name)
class TestOverrideCachingOptions(parameterized.TestCase):
def test_override_caching_of_multiple_components(self):
@component
def hello_word(text: str) -> str:
return text
@component
def to_lower(text: str) -> str:
return text.lower()
@pipeline(
name='sample two-step pipeline',
description='a minimal two-step pipeline')
def pipeline_with_two_component(text: str = 'hi there'):
component_1 = hello_word(text=text).set_caching_options(True)
component_2 = to_lower(
text=component_1.output).set_caching_options(False)
with tempfile.TemporaryDirectory() as tempdir:
temp_filepath = os.path.join(tempdir, 'hello_world_pipeline.yaml')
Compiler().compile(
pipeline_func=pipeline_with_two_component,
package_path=temp_filepath)
with open(temp_filepath, 'r') as f:
pipeline_obj = yaml.safe_load(f)
pipeline_spec = json_format.ParseDict(
pipeline_obj, pipeline_spec_pb2.PipelineSpec())
client._override_caching_options(pipeline_spec, True)
pipeline_obj = json_format.MessageToDict(pipeline_spec)
self.assertTrue(pipeline_obj['root']['dag']['tasks']
['hello-word']['cachingOptions']['enableCache'])
self.assertTrue(pipeline_obj['root']['dag']['tasks']['to-lower']
['cachingOptions']['enableCache'])
class TestExtractPipelineYAML(parameterized.TestCase):
def test_extract_pipeline_yaml_single_doc(self):
with tempfile.TemporaryDirectory() as tempdir:
temp_filepath = os.path.join(tempdir, 'single_doc_pipeline.yaml')
with open(temp_filepath, 'w') as f:
f.write(
textwrap.dedent('''
components:
comp-foo:
executorLabel: exec-foo
deploymentSpec:
executors:
exec-foo:
container:
command:
- sh
- -c
- cat /data/file.txt
image: alpine
pipelineInfo:
name: my-pipeline
root:
dag:
tasks:
foo:
componentRef:
name: comp-foo
taskInfo:
name: foo
schemaVersion: 2.1.0
sdkVersion: kfp-2.0.0-beta.13
'''))
pipeline_dict = client._extract_pipeline_yaml(
temp_filepath).to_dict()
self.assertEqual('my-pipeline',
pipeline_dict['pipelineInfo']['name'])
def test_extract_pipeline_yaml_multiple_docs(self):
with tempfile.TemporaryDirectory() as tempdir:
temp_filepath = os.path.join(tempdir, 'multi_docs_pipeline.yaml')
with open(temp_filepath, 'w') as f:
f.write(
textwrap.dedent('''
components:
comp-foo:
executorLabel: exec-foo
deploymentSpec:
executors:
exec-foo:
container:
command:
- sh
- -c
- cat /data/file.txt
image: alpine
pipelineInfo:
name: my-pipeline
root:
dag:
tasks:
foo:
componentRef:
name: comp-foo
taskInfo:
name: foo
schemaVersion: 2.1.0
sdkVersion: kfp-2.0.0-beta.13
---
platforms:
kubernetes:
deploymentSpec:
executors:
exec-foo:
pvcMount:
- mountPath: /data
constant: my-pvc
'''))
pipeline_dict = client._extract_pipeline_yaml(
temp_filepath).to_dict()
self.assertEqual(
'my-pipeline',
pipeline_dict['pipeline_spec']['pipelineInfo']['name'])
self.assertEqual(
'my-pvc', pipeline_dict['platform_spec']['platforms']
['kubernetes']['deploymentSpec']['executors']['exec-foo']
['pvcMount'][0]['constant'])
class TestClient(parameterized.TestCase):
def setUp(self):
self.client = client.Client(namespace='ns1')
def test_wait_for_run_completion_invalid_token_should_raise_error(self):
with self.assertRaises(kfp_server_api.ApiException):
with patch.object(
self.client._run_api,
'get_run',
side_effect=kfp_server_api.ApiException) as mock_get_run:
self.client.wait_for_run_completion(
run_id='foo', timeout=1, sleep_duration=0)
mock_get_run.assert_called_once()
def test_wait_for_run_completion_expired_access_token(self):
with patch.object(self.client._run_api, 'get_run') as mock_get_run:
# We need to iterate through multiple side effects in order to test this logic.
mock_get_run.side_effect = [
Mock(state='unknown state'),
kfp_server_api.ApiException(status=401),
Mock(state='succeeded'),
]
with patch.object(self.client, '_refresh_api_client_token'
) as mock_refresh_api_client_token:
self.client.wait_for_run_completion(
run_id='foo', timeout=1, sleep_duration=0)
mock_get_run.assert_called_with(run_id='foo')
mock_refresh_api_client_token.assert_called_once()
def test_wait_for_run_completion_valid_token(self):
with patch.object(self.client._run_api, 'get_run') as mock_get_run:
mock_get_run.return_value = Mock(state='succeeded')
response = self.client.wait_for_run_completion(
run_id='foo', timeout=1, sleep_duration=0)
mock_get_run.assert_called_once_with(run_id='foo')
assert response == mock_get_run.return_value
def test_wait_for_run_completion_run_timeout_should_raise_error(self):
with self.assertRaises(TimeoutError):
with patch.object(self.client._run_api, 'get_run') as mock_get_run:
mock_get_run.return_value = Mock(run=Mock(status='foo'))
self.client.wait_for_run_completion(
run_id='foo', timeout=1, sleep_duration=0)
mock_get_run.assert_called_once_with(run_id='foo')
@patch('kfp.Client.get_experiment', side_effect=ValueError)
def test_create_experiment_no_experiment_should_raise_error(
self, mock_get_experiment):
with self.assertRaises(ValueError):
self.client.create_experiment(name='foo', namespace='ns1')
mock_get_experiment.assert_called_once_with(
name='foo', namespace='ns1')
@patch('kfp.Client.get_experiment', return_value=Mock(id='foo'))
@patch('kfp.Client._get_url_prefix', return_value='/pipeline')
def test_create_experiment_existing_experiment(self, mock_get_url_prefix,
mock_get_experiment):
self.client.create_experiment(name='foo')
mock_get_experiment.assert_called_once_with(
experiment_name='foo', namespace='ns1')
mock_get_url_prefix.assert_called_once()
@patch('kfp_server_api.V2beta1Experiment')
@patch(
'kfp.Client.get_experiment',
side_effect=ValueError('No experiment is found with name'))
@patch('kfp.Client._get_url_prefix', return_value='/pipeline')
def test__create_experiment_name_not_found(self, mock_get_url_prefix,
mock_get_experiment,
mock_api_experiment):
# experiment with the specified name is not found, so a new experiment
# is created.
with patch.object(
self.client._experiment_api,
'create_experiment',
return_value=Mock(
experiment_id='foo')) as mock_create_experiment:
self.client.create_experiment(name='foo')
mock_get_experiment.assert_called_once_with(
experiment_name='foo', namespace='ns1')
mock_api_experiment.assert_called_once()
mock_create_experiment.assert_called_once()
mock_get_url_prefix.assert_called_once()
def test_get_experiment_no_experiment_id_or_name_should_raise_error(self):
with self.assertRaises(ValueError):
self.client.get_experiment()
@patch('kfp.Client.get_user_namespace', return_value=None)
def test_get_experiment_does_not_exist_should_raise_error(
self, mock_get_user_namespace):
with self.assertRaises(ValueError):
with patch.object(
self.client._experiment_api,
'list_experiments',
return_value=Mock(
experiments=None)) as mock_list_experiments:
self.client.get_experiment(experiment_name='foo')
mock_list_experiments.assert_called_once()
mock_get_user_namespace.assert_called_once()
@patch('kfp.Client.get_user_namespace', return_value=None)
def test_get_experiment_multiple_experiments_with_name_should_raise_error(
self, mock_get_user_namespace):
with self.assertRaises(ValueError):
with patch.object(
self.client._experiment_api,
'list_experiments',
return_value=Mock(
experiments=['foo', 'foo'])) as mock_list_experiments:
self.client.get_experiment(experiment_name='foo')
mock_list_experiments.assert_called_once()
mock_get_user_namespace.assert_called_once()
def test_get_experiment_with_experiment_id(self):
with patch.object(self.client._experiment_api,
'get_experiment') as mock_get_experiment:
self.client.get_experiment(experiment_id='foo')
mock_get_experiment.assert_called_once_with(experiment_id='foo')
def test_get_experiment_with_experiment_name_and_namespace(self):
with patch.object(self.client._experiment_api,
'list_experiments') as mock_list_experiments:
self.client.get_experiment(experiment_name='foo', namespace='ns1')
mock_list_experiments.assert_called_once()
@patch('kfp.Client.get_user_namespace', return_value=None)
def test_get_experiment_with_experiment_name_and_no_namespace(
self, mock_get_user_namespace):
with patch.object(self.client._experiment_api,
'list_experiments') as mock_list_experiments:
self.client.get_experiment(experiment_name='foo')
mock_list_experiments.assert_called_once()
mock_get_user_namespace.assert_called_once()
@patch('kfp_server_api.HealthzServiceApi.get_healthz')
def test_get_kfp_healthz(self, mock_get_kfp_healthz):
mock_get_kfp_healthz.return_value = json.dumps([{'foo': 'bar'}])
response = self.client.get_kfp_healthz()
mock_get_kfp_healthz.assert_called_once()
assert (response == mock_get_kfp_healthz.return_value)
@patch(
'kfp_server_api.HealthzServiceApi.get_healthz',
side_effect=kfp_server_api.ApiException)
def test_get_kfp_healthz_should_raise_error(self, mock_get_kfp_healthz):
with self.assertRaises(TimeoutError):
self.client.get_kfp_healthz(sleep_duration=0)
mock_get_kfp_healthz.assert_called()
def test_upload_pipeline_without_name(self):
@component
def return_bool(boolean: bool) -> bool:
return boolean
@pipeline(name='test-upload-without-name', description='description')
def pipeline_test_upload_without_name(boolean: bool = True):
return_bool(boolean=boolean)
with patch.object(self.client._upload_api,
'upload_pipeline') as mock_upload_pipeline:
with patch.object(auth, 'is_ipython', return_value=False):
with tempfile.TemporaryDirectory() as tmp_path:
pipeline_test_path = os.path.join(tmp_path, 'test.yaml')
Compiler().compile(
pipeline_func=pipeline_test_upload_without_name,
package_path=pipeline_test_path)
self.client.upload_pipeline(
pipeline_package_path=pipeline_test_path,
description='description',
namespace='ns1')
mock_upload_pipeline.assert_called_once_with(
pipeline_test_path,
name='test-upload-without-name',
description='description',
namespace='ns1')
@parameterized.parameters([
'pipeline',
'my-pipeline',
'my-pipeline-1',
'1pipeline',
'pipeline1',
'my_pipeline',
"person's-pipeline",
'my pipeline',
'pipeline.yaml',
])
def test_upload_pipeline_with_name(self, pipeline_name):
with patch.object(self.client._upload_api,
'upload_pipeline') as mock_upload_pipeline:
with patch.object(auth, 'is_ipython', return_value=False):
self.client.upload_pipeline(
pipeline_package_path='fake.yaml',
pipeline_name=pipeline_name,
description='description',
namespace='ns1')
mock_upload_pipeline.assert_called_once_with(
'fake.yaml',
name=pipeline_name,
description='description',
namespace='ns1')
@parameterized.parameters([
'',
' ',
'\t',
])
def test_upload_pipeline_with_name_invalid(self, pipeline_name):
with patch.object(self.client._upload_api,
'upload_pipeline') as mock_upload_pipeline:
with patch.object(auth, 'is_ipython', return_value=False):
with self.assertRaisesRegex(
ValueError,
'Invalid pipeline name. Pipeline name cannot be empty or contain only whitespace.'
):
self.client.upload_pipeline(
pipeline_package_path='fake.yaml',
pipeline_name=pipeline_name,
description='description',
namespace='ns1')
if __name__ == '__main__':
unittest.main()