* refresh access token whenerver it expires * tight the condition when refreshing the access token
This commit is contained in:
parent
8064383cf1
commit
6d55e262b4
|
@ -33,13 +33,8 @@ from kfp.compiler import compiler
|
||||||
from kfp.compiler._k8s_helper import sanitize_k8s_name
|
from kfp.compiler._k8s_helper import sanitize_k8s_name
|
||||||
|
|
||||||
from kfp._auth import get_auth_token, get_gcp_access_token
|
from kfp._auth import get_auth_token, get_gcp_access_token
|
||||||
|
from kfp_server_api import ApiException
|
||||||
|
|
||||||
# TTL of the access token associated with the client. This is needed because
|
|
||||||
# `gcloud auth print-access-token` generates a token with TTL=1 hour, after
|
|
||||||
# which the authentication expires. This TTL is needed for kfp.Client()
|
|
||||||
# initialized with host=<inverse proxy endpoint>.
|
|
||||||
# Set to 55 mins to provide some safe margin.
|
|
||||||
_GCP_ACCESS_TOKEN_TIMEOUT = datetime.timedelta(minutes=55)
|
|
||||||
# Operators on scalar values. Only applies to one of |int_value|,
|
# Operators on scalar values. Only applies to one of |int_value|,
|
||||||
# |long_value|, |string_value| or |timestamp_value|.
|
# |long_value|, |string_value| or |timestamp_value|.
|
||||||
_FILTER_OPERATIONS = {
|
_FILTER_OPERATIONS = {
|
||||||
|
@ -1222,18 +1217,23 @@ class Client(object):
|
||||||
"""
|
"""
|
||||||
status = 'Running:'
|
status = 'Running:'
|
||||||
start_time = datetime.datetime.now()
|
start_time = datetime.datetime.now()
|
||||||
last_token_refresh_time = datetime.datetime.now()
|
|
||||||
if isinstance(timeout, datetime.timedelta):
|
if isinstance(timeout, datetime.timedelta):
|
||||||
timeout = timeout.total_seconds()
|
timeout = timeout.total_seconds()
|
||||||
|
is_valid_token = False
|
||||||
while (status is None or status.lower()
|
while (status is None or status.lower()
|
||||||
not in ['succeeded', 'failed', 'skipped', 'error']):
|
not in ['succeeded', 'failed', 'skipped', 'error']):
|
||||||
# Refreshes the access token before it hits the TTL.
|
try:
|
||||||
if (datetime.datetime.now() - last_token_refresh_time >
|
get_run_response = self._run_api.get_run(run_id=run_id)
|
||||||
_GCP_ACCESS_TOKEN_TIMEOUT):
|
is_valid_token = True
|
||||||
self._refresh_api_client_token()
|
except ApiException as api_ex:
|
||||||
last_token_refresh_time = datetime.datetime.now()
|
# if the token is valid but receiving 401 Unauthorized error
|
||||||
|
# then refresh the token
|
||||||
get_run_response = self._run_api.get_run(run_id=run_id)
|
if is_valid_token and api_ex.status == 401:
|
||||||
|
logging.info('Access token has expired !!! Refreshing ...')
|
||||||
|
self._refresh_api_client_token()
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise api_ex
|
||||||
status = get_run_response.run.status
|
status = get_run_response.run.status
|
||||||
elapsed_time = (datetime.datetime.now() -
|
elapsed_time = (datetime.datetime.now() -
|
||||||
start_time).total_seconds()
|
start_time).total_seconds()
|
||||||
|
|
Loading…
Reference in New Issue