test(sdk): implement small tests for Client class (#8517)
Co-authored-by: droctothorpe <mythicalsunlight@gmail.com> Co-authored-by: droctothorpe <mythicalsunlight@gmail.com>
This commit is contained in:
parent
931c14a742
commit
b26dd100e1
|
@ -398,9 +398,14 @@ class Client:
|
|||
with open(Client._LOCAL_KFP_CONTEXT, 'w') as f:
|
||||
json.dump(self._context_setting, f)
|
||||
|
||||
def get_kfp_healthz(self) -> kfp_server_api.ApiGetHealthzResponse:
|
||||
def get_kfp_healthz(
|
||||
self,
|
||||
sleep_duration: int = 5) -> kfp_server_api.ApiGetHealthzResponse:
|
||||
"""Gets healthz info for KFP deployment.
|
||||
|
||||
Args:
|
||||
sleep_duration: Time in seconds between retries.
|
||||
|
||||
Returns:
|
||||
JSON response from the healtz endpoint.
|
||||
"""
|
||||
|
@ -422,8 +427,9 @@ class Client:
|
|||
except kfp_server_api.ApiException:
|
||||
# logging.exception also logs detailed info about the ApiException
|
||||
logging.exception(
|
||||
f'Failed to get healthz info attempt {count} of 5.')
|
||||
time.sleep(5)
|
||||
f'Failed to get healthz info attempt {count} of {sleep_duration}.'
|
||||
)
|
||||
time.sleep(sleep_duration)
|
||||
|
||||
def get_user_namespace(self) -> str:
|
||||
"""Gets user namespace in context config.
|
||||
|
@ -463,7 +469,7 @@ class Client:
|
|||
logging.info(f'Creating experiment {name}.')
|
||||
|
||||
resource_references = []
|
||||
if namespace:
|
||||
if namespace is not None:
|
||||
key = kfp_server_api.ApiResourceKey(
|
||||
id=namespace, type=kfp_server_api.ApiResourceType.NAMESPACE)
|
||||
reference = kfp_server_api.ApiResourceReference(
|
||||
|
@ -584,7 +590,7 @@ class Client:
|
|||
'stringValue': experiment_name,
|
||||
}]
|
||||
})
|
||||
if namespace:
|
||||
if namespace is not None:
|
||||
result = self._experiment_api.list_experiment(
|
||||
filter=experiment_filter,
|
||||
resource_reference_key_type=kfp_server_api.ApiResourceType
|
||||
|
@ -1213,7 +1219,7 @@ class Client:
|
|||
.EXPERIMENT,
|
||||
resource_reference_key_id=experiment_id,
|
||||
filter=filter)
|
||||
elif namespace:
|
||||
elif namespace is not None:
|
||||
response = self._run_api.list_runs(
|
||||
page_token=page_token,
|
||||
page_size=page_size,
|
||||
|
@ -1299,13 +1305,17 @@ class Client:
|
|||
"""
|
||||
return self._run_api.get_run(run_id=run_id)
|
||||
|
||||
def wait_for_run_completion(self, run_id: str,
|
||||
timeout: int) -> kfp_server_api.ApiRun:
|
||||
def wait_for_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
timeout: int,
|
||||
sleep_duration: int = 5) -> kfp_server_api.ApiRun:
|
||||
"""Waits for a run to complete.
|
||||
|
||||
Args:
|
||||
run_id: ID of the run.
|
||||
timeout: Timeout after which the client should stop waiting for run completion (seconds).
|
||||
sleep_duration: Time in seconds between retries.
|
||||
|
||||
Returns:
|
||||
``ApiRun`` object.
|
||||
|
@ -1335,7 +1345,7 @@ class Client:
|
|||
logging.info('Waiting for the job to complete...')
|
||||
if elapsed_time > timeout:
|
||||
raise TimeoutError('Run timeout')
|
||||
time.sleep(5)
|
||||
time.sleep(sleep_duration)
|
||||
return get_run_response
|
||||
|
||||
def _get_workflow_json(self, run_id: str) -> dict:
|
||||
|
|
|
@ -12,15 +12,20 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import patch
|
||||
|
||||
from absl.testing import parameterized
|
||||
from kfp.client import client
|
||||
from kfp.compiler import Compiler
|
||||
from kfp.dsl import component
|
||||
from kfp.dsl import pipeline
|
||||
import kfp_server_api
|
||||
import yaml
|
||||
|
||||
|
||||
|
@ -69,7 +74,7 @@ class TestOverrideCachingOptions(parameterized.TestCase):
|
|||
|
||||
with open(temp_filepath, 'r') as f:
|
||||
pipeline_obj = yaml.safe_load(f)
|
||||
test_client = client.Client(namespace='dummy_namespace')
|
||||
test_client = client.Client(namespace='foo')
|
||||
test_client._override_caching_options(pipeline_obj, False)
|
||||
for _, task in pipeline_obj['root']['dag']['tasks'].items():
|
||||
self.assertFalse(task['cachingOptions']['enableCache'])
|
||||
|
@ -101,7 +106,7 @@ class TestOverrideCachingOptions(parameterized.TestCase):
|
|||
|
||||
with open(temp_filepath, 'r') as f:
|
||||
pipeline_obj = yaml.safe_load(f)
|
||||
test_client = client.Client(namespace='dummy_namespace')
|
||||
test_client = client.Client(namespace='foo')
|
||||
test_client._override_caching_options(pipeline_obj, False)
|
||||
self.assertFalse(
|
||||
pipeline_obj['root']['dag']['tasks']['hello-word']
|
||||
|
@ -110,5 +115,175 @@ class TestOverrideCachingOptions(parameterized.TestCase):
|
|||
['to-lower']['cachingOptions']['enableCache'])
|
||||
|
||||
|
||||
class TestClient(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.client = client.Client(namespace='foo')
|
||||
|
||||
def test__is_ipython_return_false(self):
|
||||
mock = MagicMock()
|
||||
with patch.dict('sys.modules', IPython=mock):
|
||||
mock.get_ipython.return_value = None
|
||||
self.assertFalse(self.client._is_ipython())
|
||||
|
||||
def test__is_ipython_return_true(self):
|
||||
mock = MagicMock()
|
||||
with patch.dict('sys.modules', IPython=mock):
|
||||
mock.get_ipython.return_value = 'Something'
|
||||
self.assertTrue(self.client._is_ipython())
|
||||
|
||||
def test__is_ipython_should_raise_error(self):
|
||||
mock = MagicMock()
|
||||
with patch.dict('sys.modules', mock):
|
||||
mock.side_effect = ImportError
|
||||
self.assertFalse(self.client._is_ipython())
|
||||
|
||||
def test_wait_for_run_completion_invalid_token_should_raise_error(self):
|
||||
with self.assertRaises(kfp_server_api.ApiException):
|
||||
with patch.object(
|
||||
self.client._run_api,
|
||||
'get_run',
|
||||
side_effect=kfp_server_api.ApiException) as mock_get_run:
|
||||
self.client.wait_for_run_completion(
|
||||
run_id='foo', timeout=1, sleep_duration=0)
|
||||
mock_get_run.assert_called_once()
|
||||
|
||||
def test_wait_for_run_completion_expired_access_token(self):
|
||||
with patch.object(self.client._run_api, 'get_run') as mock_get_run:
|
||||
# We need to iterate through multiple side effects in order to test this logic.
|
||||
mock_get_run.side_effect = [
|
||||
Mock(run=Mock(status='foo')),
|
||||
kfp_server_api.ApiException(status=401),
|
||||
Mock(run=Mock(status='succeeded')),
|
||||
]
|
||||
|
||||
with patch.object(self.client, '_refresh_api_client_token'
|
||||
) as mock_refresh_api_client_token:
|
||||
self.client.wait_for_run_completion(
|
||||
run_id='foo', timeout=1, sleep_duration=0)
|
||||
mock_get_run.assert_called_with(run_id='foo')
|
||||
mock_refresh_api_client_token.assert_called_once()
|
||||
|
||||
def test_wait_for_run_completion_valid_token(self):
|
||||
with patch.object(self.client._run_api, 'get_run') as mock_get_run:
|
||||
mock_get_run.return_value = Mock(run=Mock(status='succeeded'))
|
||||
response = self.client.wait_for_run_completion(
|
||||
run_id='foo', timeout=1, sleep_duration=0)
|
||||
mock_get_run.assert_called_once_with(run_id='foo')
|
||||
assert response == mock_get_run.return_value
|
||||
|
||||
def test_wait_for_run_completion_run_timeout_should_raise_error(self):
|
||||
with self.assertRaises(TimeoutError):
|
||||
with patch.object(self.client._run_api, 'get_run') as mock_get_run:
|
||||
mock_get_run.return_value = Mock(run=Mock(status='foo'))
|
||||
self.client.wait_for_run_completion(
|
||||
run_id='foo', timeout=1, sleep_duration=0)
|
||||
mock_get_run.assert_called_once_with(run_id='foo')
|
||||
|
||||
@patch('kfp.Client.get_experiment', side_effect=ValueError)
|
||||
def test_create_experiment_no_experiment_should_raise_error(
|
||||
self, mock_get_experiment):
|
||||
with self.assertRaises(ValueError):
|
||||
self.client.create_experiment(name='foo', namespace='foo')
|
||||
mock_get_experiment.assert_called_once_with(
|
||||
name='foo', namespace='foo')
|
||||
|
||||
@patch('kfp.Client.get_experiment', return_value=Mock(id='foo'))
|
||||
@patch('kfp.Client._get_url_prefix', return_value='/pipeline')
|
||||
def test_create_experiment_existing_experiment(self, mock_get_url_prefix,
|
||||
mock_get_experiment):
|
||||
self.client.create_experiment(name='foo')
|
||||
mock_get_experiment.assert_called_once_with(
|
||||
experiment_name='foo', namespace='foo')
|
||||
mock_get_url_prefix.assert_called_once()
|
||||
|
||||
@patch('kfp_server_api.ApiExperiment')
|
||||
@patch(
|
||||
'kfp.Client.get_experiment',
|
||||
side_effect=ValueError('No experiment is found with name'))
|
||||
@patch('kfp.Client._get_url_prefix', return_value='/pipeline')
|
||||
def test__create_experiment_name_not_found(self, mock_get_url_prefix,
|
||||
mock_get_experiment,
|
||||
mock_api_experiment):
|
||||
# experiment with the specified name is not found, so a new experiment
|
||||
# is created.
|
||||
with patch.object(
|
||||
self.client._experiment_api,
|
||||
'create_experiment',
|
||||
return_value=Mock(id='foo')) as mock_create_experiment:
|
||||
self.client.create_experiment(name='foo')
|
||||
mock_get_experiment.assert_called_once_with(
|
||||
experiment_name='foo', namespace='foo')
|
||||
mock_api_experiment.assert_called_once()
|
||||
mock_create_experiment.assert_called_once()
|
||||
mock_get_url_prefix.assert_called_once()
|
||||
|
||||
def test_get_experiment_no_experiment_id_or_name_should_raise_error(self):
|
||||
with self.assertRaises(ValueError):
|
||||
self.client.get_experiment()
|
||||
|
||||
@patch('kfp.Client.get_user_namespace', return_value=None)
|
||||
def test_get_experiment_does_not_exist_should_raise_error(
|
||||
self, mock_get_user_namespace):
|
||||
with self.assertRaises(ValueError):
|
||||
with patch.object(
|
||||
self.client._experiment_api,
|
||||
'list_experiment',
|
||||
return_value=Mock(
|
||||
experiments=None)) as mock_list_experiment:
|
||||
self.client.get_experiment(experiment_name='foo')
|
||||
mock_list_experiment.assert_called_once()
|
||||
mock_get_user_namespace.assert_called_once()
|
||||
|
||||
@patch('kfp.Client.get_user_namespace', return_value=None)
|
||||
def test_get_experiment_multiple_experiments_with_name_should_raise_error(
|
||||
self, mock_get_user_namespace):
|
||||
with self.assertRaises(ValueError):
|
||||
with patch.object(
|
||||
self.client._experiment_api,
|
||||
'list_experiment',
|
||||
return_value=Mock(
|
||||
experiments=['foo', 'foo'])) as mock_list_experiment:
|
||||
self.client.get_experiment(experiment_name='foo')
|
||||
mock_list_experiment.assert_called_once()
|
||||
mock_get_user_namespace.assert_called_once()
|
||||
|
||||
def test_get_experiment_with_experiment_id(self):
|
||||
with patch.object(self.client._experiment_api,
|
||||
'get_experiment') as mock_get_experiment:
|
||||
self.client.get_experiment(experiment_id='foo')
|
||||
mock_get_experiment.assert_called_once_with(id='foo')
|
||||
|
||||
def test_get_experiment_with_experiment_name_and_namespace(self):
|
||||
with patch.object(self.client._experiment_api,
|
||||
'list_experiment') as mock_list_experiment:
|
||||
self.client.get_experiment(experiment_name='foo', namespace='foo')
|
||||
mock_list_experiment.assert_called_once()
|
||||
|
||||
@patch('kfp.Client.get_user_namespace', return_value=None)
|
||||
def test_get_experiment_with_experiment_name_and_no_namespace(
|
||||
self, mock_get_user_namespace):
|
||||
with patch.object(self.client._experiment_api,
|
||||
'list_experiment') as mock_list_experiment:
|
||||
self.client.get_experiment(experiment_name='foo')
|
||||
mock_list_experiment.assert_called_once()
|
||||
mock_get_user_namespace.assert_called_once()
|
||||
|
||||
@patch('kfp_server_api.HealthzServiceApi.get_healthz')
|
||||
def test_get_kfp_healthz(self, mock_get_kfp_healthz):
|
||||
mock_get_kfp_healthz.return_value = json.dumps([{'foo': 'bar'}])
|
||||
response = self.client.get_kfp_healthz()
|
||||
mock_get_kfp_healthz.assert_called_once()
|
||||
assert (response == mock_get_kfp_healthz.return_value)
|
||||
|
||||
@patch(
|
||||
'kfp_server_api.HealthzServiceApi.get_healthz',
|
||||
side_effect=kfp_server_api.ApiException)
|
||||
def test_get_kfp_healthz_should_raise_error(self, mock_get_kfp_healthz):
|
||||
with self.assertRaises(TimeoutError):
|
||||
self.client.get_kfp_healthz(sleep_duration=0)
|
||||
mock_get_kfp_healthz.assert_called()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue