fix(sdk): refresh access token only when it expires. Fixes #6883 (#6941)

* refresh access token whenerver it expires

* tight the condition when refreshing the access token
This commit is contained in:
hieuhc 2021-11-23 20:35:07 +01:00 committed by GitHub
parent 8064383cf1
commit 6d55e262b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 14 deletions

View File

@ -33,13 +33,8 @@ from kfp.compiler import compiler
from kfp.compiler._k8s_helper import sanitize_k8s_name
from kfp._auth import get_auth_token, get_gcp_access_token
from kfp_server_api import ApiException
# TTL of the access token associated with the client. This is needed because
# `gcloud auth print-access-token` generates a token with TTL=1 hour, after
# which the authentication expires. This TTL is needed for kfp.Client()
# initialized with host=<inverse proxy endpoint>.
# Set to 55 mins to provide some safe margin.
_GCP_ACCESS_TOKEN_TIMEOUT = datetime.timedelta(minutes=55)
# Operators on scalar values. Only applies to one of |int_value|,
# |long_value|, |string_value| or |timestamp_value|.
_FILTER_OPERATIONS = {
@ -1222,18 +1217,23 @@ class Client(object):
"""
status = 'Running:'
start_time = datetime.datetime.now()
last_token_refresh_time = datetime.datetime.now()
if isinstance(timeout, datetime.timedelta):
timeout = timeout.total_seconds()
is_valid_token = False
while (status is None or status.lower()
not in ['succeeded', 'failed', 'skipped', 'error']):
# Refreshes the access token before it hits the TTL.
if (datetime.datetime.now() - last_token_refresh_time >
_GCP_ACCESS_TOKEN_TIMEOUT):
self._refresh_api_client_token()
last_token_refresh_time = datetime.datetime.now()
get_run_response = self._run_api.get_run(run_id=run_id)
try:
get_run_response = self._run_api.get_run(run_id=run_id)
is_valid_token = True
except ApiException as api_ex:
# if the token is valid but receiving 401 Unauthorized error
# then refresh the token
if is_valid_token and api_ex.status == 401:
logging.info('Access token has expired !!! Refreshing ...')
self._refresh_api_client_token()
continue
else:
raise api_ex
status = get_run_response.run.status
elapsed_time = (datetime.datetime.now() -
start_time).total_seconds()