diff --git a/sdk/python/kfp/_client.py b/sdk/python/kfp/_client.py index d2c19788f7..7012e82aa0 100644 --- a/sdk/python/kfp/_client.py +++ b/sdk/python/kfp/_client.py @@ -33,13 +33,8 @@ from kfp.compiler import compiler from kfp.compiler._k8s_helper import sanitize_k8s_name from kfp._auth import get_auth_token, get_gcp_access_token +from kfp_server_api import ApiException -# TTL of the access token associated with the client. This is needed because -# `gcloud auth print-access-token` generates a token with TTL=1 hour, after -# which the authentication expires. This TTL is needed for kfp.Client() -# initialized with host=. -# Set to 55 mins to provide some safe margin. -_GCP_ACCESS_TOKEN_TIMEOUT = datetime.timedelta(minutes=55) # Operators on scalar values. Only applies to one of |int_value|, # |long_value|, |string_value| or |timestamp_value|. _FILTER_OPERATIONS = { @@ -1222,18 +1217,23 @@ class Client(object): """ status = 'Running:' start_time = datetime.datetime.now() - last_token_refresh_time = datetime.datetime.now() if isinstance(timeout, datetime.timedelta): timeout = timeout.total_seconds() + is_valid_token = False while (status is None or status.lower() not in ['succeeded', 'failed', 'skipped', 'error']): - # Refreshes the access token before it hits the TTL. - if (datetime.datetime.now() - last_token_refresh_time > - _GCP_ACCESS_TOKEN_TIMEOUT): - self._refresh_api_client_token() - last_token_refresh_time = datetime.datetime.now() - - get_run_response = self._run_api.get_run(run_id=run_id) + try: + get_run_response = self._run_api.get_run(run_id=run_id) + is_valid_token = True + except ApiException as api_ex: + # if the token is valid but receiving 401 Unauthorized error + # then refresh the token + if is_valid_token and api_ex.status == 401: + logging.info('Access token has expired !!! Refreshing ...') + self._refresh_api_client_token() + continue + else: + raise api_ex status = get_run_response.run.status elapsed_time = (datetime.datetime.now() - start_time).total_seconds()