* refresh access token whenerver it expires * tight the condition when refreshing the access token
This commit is contained in:
parent
8064383cf1
commit
6d55e262b4
|
@ -33,13 +33,8 @@ from kfp.compiler import compiler
|
|||
from kfp.compiler._k8s_helper import sanitize_k8s_name
|
||||
|
||||
from kfp._auth import get_auth_token, get_gcp_access_token
|
||||
from kfp_server_api import ApiException
|
||||
|
||||
# TTL of the access token associated with the client. This is needed because
|
||||
# `gcloud auth print-access-token` generates a token with TTL=1 hour, after
|
||||
# which the authentication expires. This TTL is needed for kfp.Client()
|
||||
# initialized with host=<inverse proxy endpoint>.
|
||||
# Set to 55 mins to provide some safe margin.
|
||||
_GCP_ACCESS_TOKEN_TIMEOUT = datetime.timedelta(minutes=55)
|
||||
# Operators on scalar values. Only applies to one of |int_value|,
|
||||
# |long_value|, |string_value| or |timestamp_value|.
|
||||
_FILTER_OPERATIONS = {
|
||||
|
@ -1222,18 +1217,23 @@ class Client(object):
|
|||
"""
|
||||
status = 'Running:'
|
||||
start_time = datetime.datetime.now()
|
||||
last_token_refresh_time = datetime.datetime.now()
|
||||
if isinstance(timeout, datetime.timedelta):
|
||||
timeout = timeout.total_seconds()
|
||||
is_valid_token = False
|
||||
while (status is None or status.lower()
|
||||
not in ['succeeded', 'failed', 'skipped', 'error']):
|
||||
# Refreshes the access token before it hits the TTL.
|
||||
if (datetime.datetime.now() - last_token_refresh_time >
|
||||
_GCP_ACCESS_TOKEN_TIMEOUT):
|
||||
self._refresh_api_client_token()
|
||||
last_token_refresh_time = datetime.datetime.now()
|
||||
|
||||
get_run_response = self._run_api.get_run(run_id=run_id)
|
||||
try:
|
||||
get_run_response = self._run_api.get_run(run_id=run_id)
|
||||
is_valid_token = True
|
||||
except ApiException as api_ex:
|
||||
# if the token is valid but receiving 401 Unauthorized error
|
||||
# then refresh the token
|
||||
if is_valid_token and api_ex.status == 401:
|
||||
logging.info('Access token has expired !!! Refreshing ...')
|
||||
self._refresh_api_client_token()
|
||||
continue
|
||||
else:
|
||||
raise api_ex
|
||||
status = get_run_response.run.status
|
||||
elapsed_time = (datetime.datetime.now() -
|
||||
start_time).total_seconds()
|
||||
|
|
Loading…
Reference in New Issue