test(sdk): implement small tests for Client class (#8517)

Co-authored-by: droctothorpe <mythicalsunlight@gmail.com>

Co-authored-by: droctothorpe <mythicalsunlight@gmail.com>
This commit is contained in:
andreafehrman 2022-12-01 12:39:45 -07:00 committed by GitHub
parent 931c14a742
commit b26dd100e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 196 additions and 11 deletions

View File

@ -398,9 +398,14 @@ class Client:
with open(Client._LOCAL_KFP_CONTEXT, 'w') as f:
json.dump(self._context_setting, f)
def get_kfp_healthz(self) -> kfp_server_api.ApiGetHealthzResponse:
def get_kfp_healthz(
self,
sleep_duration: int = 5) -> kfp_server_api.ApiGetHealthzResponse:
"""Gets healthz info for KFP deployment.
Args:
sleep_duration: Time in seconds between retries.
Returns:
JSON response from the healtz endpoint.
"""
@ -422,8 +427,9 @@ class Client:
except kfp_server_api.ApiException:
# logging.exception also logs detailed info about the ApiException
logging.exception(
f'Failed to get healthz info attempt {count} of 5.')
time.sleep(5)
f'Failed to get healthz info attempt {count} of {sleep_duration}.'
)
time.sleep(sleep_duration)
def get_user_namespace(self) -> str:
"""Gets user namespace in context config.
@ -463,7 +469,7 @@ class Client:
logging.info(f'Creating experiment {name}.')
resource_references = []
if namespace:
if namespace is not None:
key = kfp_server_api.ApiResourceKey(
id=namespace, type=kfp_server_api.ApiResourceType.NAMESPACE)
reference = kfp_server_api.ApiResourceReference(
@ -584,7 +590,7 @@ class Client:
'stringValue': experiment_name,
}]
})
if namespace:
if namespace is not None:
result = self._experiment_api.list_experiment(
filter=experiment_filter,
resource_reference_key_type=kfp_server_api.ApiResourceType
@ -1213,7 +1219,7 @@ class Client:
.EXPERIMENT,
resource_reference_key_id=experiment_id,
filter=filter)
elif namespace:
elif namespace is not None:
response = self._run_api.list_runs(
page_token=page_token,
page_size=page_size,
@ -1299,13 +1305,17 @@ class Client:
"""
return self._run_api.get_run(run_id=run_id)
def wait_for_run_completion(self, run_id: str,
timeout: int) -> kfp_server_api.ApiRun:
def wait_for_run_completion(
self,
run_id: str,
timeout: int,
sleep_duration: int = 5) -> kfp_server_api.ApiRun:
"""Waits for a run to complete.
Args:
run_id: ID of the run.
timeout: Timeout after which the client should stop waiting for run completion (seconds).
sleep_duration: Time in seconds between retries.
Returns:
``ApiRun`` object.
@ -1335,7 +1345,7 @@ class Client:
logging.info('Waiting for the job to complete...')
if elapsed_time > timeout:
raise TimeoutError('Run timeout')
time.sleep(5)
time.sleep(sleep_duration)
return get_run_response
def _get_workflow_json(self, run_id: str) -> dict:

View File

@ -12,15 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
import unittest
from unittest.mock import MagicMock
from unittest.mock import Mock
from unittest.mock import patch
from absl.testing import parameterized
from kfp.client import client
from kfp.compiler import Compiler
from kfp.dsl import component
from kfp.dsl import pipeline
import kfp_server_api
import yaml
@ -69,7 +74,7 @@ class TestOverrideCachingOptions(parameterized.TestCase):
with open(temp_filepath, 'r') as f:
pipeline_obj = yaml.safe_load(f)
test_client = client.Client(namespace='dummy_namespace')
test_client = client.Client(namespace='foo')
test_client._override_caching_options(pipeline_obj, False)
for _, task in pipeline_obj['root']['dag']['tasks'].items():
self.assertFalse(task['cachingOptions']['enableCache'])
@ -101,7 +106,7 @@ class TestOverrideCachingOptions(parameterized.TestCase):
with open(temp_filepath, 'r') as f:
pipeline_obj = yaml.safe_load(f)
test_client = client.Client(namespace='dummy_namespace')
test_client = client.Client(namespace='foo')
test_client._override_caching_options(pipeline_obj, False)
self.assertFalse(
pipeline_obj['root']['dag']['tasks']['hello-word']
@ -110,5 +115,175 @@ class TestOverrideCachingOptions(parameterized.TestCase):
['to-lower']['cachingOptions']['enableCache'])
class TestClient(unittest.TestCase):
def setUp(self):
self.client = client.Client(namespace='foo')
def test__is_ipython_return_false(self):
mock = MagicMock()
with patch.dict('sys.modules', IPython=mock):
mock.get_ipython.return_value = None
self.assertFalse(self.client._is_ipython())
def test__is_ipython_return_true(self):
mock = MagicMock()
with patch.dict('sys.modules', IPython=mock):
mock.get_ipython.return_value = 'Something'
self.assertTrue(self.client._is_ipython())
def test__is_ipython_should_raise_error(self):
mock = MagicMock()
with patch.dict('sys.modules', mock):
mock.side_effect = ImportError
self.assertFalse(self.client._is_ipython())
def test_wait_for_run_completion_invalid_token_should_raise_error(self):
with self.assertRaises(kfp_server_api.ApiException):
with patch.object(
self.client._run_api,
'get_run',
side_effect=kfp_server_api.ApiException) as mock_get_run:
self.client.wait_for_run_completion(
run_id='foo', timeout=1, sleep_duration=0)
mock_get_run.assert_called_once()
def test_wait_for_run_completion_expired_access_token(self):
with patch.object(self.client._run_api, 'get_run') as mock_get_run:
# We need to iterate through multiple side effects in order to test this logic.
mock_get_run.side_effect = [
Mock(run=Mock(status='foo')),
kfp_server_api.ApiException(status=401),
Mock(run=Mock(status='succeeded')),
]
with patch.object(self.client, '_refresh_api_client_token'
) as mock_refresh_api_client_token:
self.client.wait_for_run_completion(
run_id='foo', timeout=1, sleep_duration=0)
mock_get_run.assert_called_with(run_id='foo')
mock_refresh_api_client_token.assert_called_once()
def test_wait_for_run_completion_valid_token(self):
with patch.object(self.client._run_api, 'get_run') as mock_get_run:
mock_get_run.return_value = Mock(run=Mock(status='succeeded'))
response = self.client.wait_for_run_completion(
run_id='foo', timeout=1, sleep_duration=0)
mock_get_run.assert_called_once_with(run_id='foo')
assert response == mock_get_run.return_value
def test_wait_for_run_completion_run_timeout_should_raise_error(self):
with self.assertRaises(TimeoutError):
with patch.object(self.client._run_api, 'get_run') as mock_get_run:
mock_get_run.return_value = Mock(run=Mock(status='foo'))
self.client.wait_for_run_completion(
run_id='foo', timeout=1, sleep_duration=0)
mock_get_run.assert_called_once_with(run_id='foo')
@patch('kfp.Client.get_experiment', side_effect=ValueError)
def test_create_experiment_no_experiment_should_raise_error(
self, mock_get_experiment):
with self.assertRaises(ValueError):
self.client.create_experiment(name='foo', namespace='foo')
mock_get_experiment.assert_called_once_with(
name='foo', namespace='foo')
@patch('kfp.Client.get_experiment', return_value=Mock(id='foo'))
@patch('kfp.Client._get_url_prefix', return_value='/pipeline')
def test_create_experiment_existing_experiment(self, mock_get_url_prefix,
mock_get_experiment):
self.client.create_experiment(name='foo')
mock_get_experiment.assert_called_once_with(
experiment_name='foo', namespace='foo')
mock_get_url_prefix.assert_called_once()
@patch('kfp_server_api.ApiExperiment')
@patch(
'kfp.Client.get_experiment',
side_effect=ValueError('No experiment is found with name'))
@patch('kfp.Client._get_url_prefix', return_value='/pipeline')
def test__create_experiment_name_not_found(self, mock_get_url_prefix,
mock_get_experiment,
mock_api_experiment):
# experiment with the specified name is not found, so a new experiment
# is created.
with patch.object(
self.client._experiment_api,
'create_experiment',
return_value=Mock(id='foo')) as mock_create_experiment:
self.client.create_experiment(name='foo')
mock_get_experiment.assert_called_once_with(
experiment_name='foo', namespace='foo')
mock_api_experiment.assert_called_once()
mock_create_experiment.assert_called_once()
mock_get_url_prefix.assert_called_once()
def test_get_experiment_no_experiment_id_or_name_should_raise_error(self):
with self.assertRaises(ValueError):
self.client.get_experiment()
@patch('kfp.Client.get_user_namespace', return_value=None)
def test_get_experiment_does_not_exist_should_raise_error(
self, mock_get_user_namespace):
with self.assertRaises(ValueError):
with patch.object(
self.client._experiment_api,
'list_experiment',
return_value=Mock(
experiments=None)) as mock_list_experiment:
self.client.get_experiment(experiment_name='foo')
mock_list_experiment.assert_called_once()
mock_get_user_namespace.assert_called_once()
@patch('kfp.Client.get_user_namespace', return_value=None)
def test_get_experiment_multiple_experiments_with_name_should_raise_error(
self, mock_get_user_namespace):
with self.assertRaises(ValueError):
with patch.object(
self.client._experiment_api,
'list_experiment',
return_value=Mock(
experiments=['foo', 'foo'])) as mock_list_experiment:
self.client.get_experiment(experiment_name='foo')
mock_list_experiment.assert_called_once()
mock_get_user_namespace.assert_called_once()
def test_get_experiment_with_experiment_id(self):
with patch.object(self.client._experiment_api,
'get_experiment') as mock_get_experiment:
self.client.get_experiment(experiment_id='foo')
mock_get_experiment.assert_called_once_with(id='foo')
def test_get_experiment_with_experiment_name_and_namespace(self):
with patch.object(self.client._experiment_api,
'list_experiment') as mock_list_experiment:
self.client.get_experiment(experiment_name='foo', namespace='foo')
mock_list_experiment.assert_called_once()
@patch('kfp.Client.get_user_namespace', return_value=None)
def test_get_experiment_with_experiment_name_and_no_namespace(
self, mock_get_user_namespace):
with patch.object(self.client._experiment_api,
'list_experiment') as mock_list_experiment:
self.client.get_experiment(experiment_name='foo')
mock_list_experiment.assert_called_once()
mock_get_user_namespace.assert_called_once()
@patch('kfp_server_api.HealthzServiceApi.get_healthz')
def test_get_kfp_healthz(self, mock_get_kfp_healthz):
mock_get_kfp_healthz.return_value = json.dumps([{'foo': 'bar'}])
response = self.client.get_kfp_healthz()
mock_get_kfp_healthz.assert_called_once()
assert (response == mock_get_kfp_healthz.return_value)
@patch(
'kfp_server_api.HealthzServiceApi.get_healthz',
side_effect=kfp_server_api.ApiException)
def test_get_kfp_healthz_should_raise_error(self, mock_get_kfp_healthz):
with self.assertRaises(TimeoutError):
self.client.get_kfp_healthz(sleep_duration=0)
mock_get_kfp_healthz.assert_called()
if __name__ == '__main__':
unittest.main()