532 lines
21 KiB
Python
532 lines
21 KiB
Python
# Copyright 2022 The Kubeflow Authors
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from contextlib import contextmanager
|
|
import json
|
|
import logging
|
|
import os
|
|
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple
|
|
from urllib.parse import parse_qs
|
|
from urllib.parse import urlparse
|
|
from webbrowser import open_new_tab
|
|
import wsgiref.simple_server
|
|
import wsgiref.util
|
|
|
|
import google.auth
|
|
import google.auth.app_engine
|
|
import google.auth.compute_engine.credentials
|
|
import google.auth.iam
|
|
from google.auth.transport.requests import Request
|
|
import google.oauth2.credentials
|
|
import google.oauth2.service_account
|
|
import requests
|
|
|
|
IAM_SCOPE = 'https://www.googleapis.com/auth/iam'
|
|
OAUTH_TOKEN_URI = 'https://www.googleapis.com/oauth2/v4/token'
|
|
LOCAL_KFP_CREDENTIAL = os.path.expanduser('~/.config/kfp/credentials.json')
|
|
|
|
|
|
def get_gcp_access_token() -> Optional[str]:
|
|
"""Gets GCP access token for the current Application Default Credentials.
|
|
|
|
Returns:
|
|
GCP access token or None, if it is not set. For more information, see
|
|
https://cloud.google.com/sdk/gcloud/reference/auth/application-default/print-access-token.
|
|
"""
|
|
token = None
|
|
try:
|
|
creds, _ = google.auth.default(
|
|
scopes=['https://www.googleapis.com/auth/cloud-platform'])
|
|
if not creds.valid:
|
|
auth_req = Request()
|
|
creds.refresh(auth_req)
|
|
if creds.valid:
|
|
token = creds.token
|
|
except Exception as e:
|
|
logging.warning('Failed to get GCP access token: %s', e)
|
|
return token
|
|
|
|
|
|
def get_auth_token(client_id: str, other_client_id: str,
|
|
other_client_secret: str) -> Tuple[Optional[str], bool]:
|
|
"""Gets auth token from default service account or user account.
|
|
|
|
Returns:
|
|
Tuple of (ID token or None, if not found, and a boolean
|
|
indicating whether a refresh token has been saved locally)
|
|
"""
|
|
is_refresh_token = True
|
|
if os.path.exists(LOCAL_KFP_CREDENTIAL):
|
|
# fetch IAP auth token using the locally stored credentials.
|
|
with open(LOCAL_KFP_CREDENTIAL, 'r') as f:
|
|
credentials = json.load(f)
|
|
if client_id in credentials:
|
|
saved_refresh_token = credentials[client_id].get('refresh_token')
|
|
saved_other_client_id = credentials[client_id].get(
|
|
'other_client_id')
|
|
saved_other_client_secret = credentials[client_id].get(
|
|
'other_client_secret')
|
|
if None not in {
|
|
saved_refresh_token, saved_other_client_id,
|
|
saved_other_client_secret
|
|
}:
|
|
return id_token_from_refresh_token(saved_other_client_id,
|
|
saved_other_client_secret,
|
|
saved_refresh_token,
|
|
client_id), is_refresh_token
|
|
if other_client_id is None or other_client_secret is None:
|
|
# fetch IAP auth token: service accounts
|
|
token = get_auth_token_from_sa(client_id)
|
|
is_refresh_token = False
|
|
else:
|
|
# fetch IAP auth token: user account
|
|
# Obtain the ID token for provided Client ID with user accounts.
|
|
# Flow: get authorization code -> exchange for refresh token -> obtain
|
|
# and return ID token
|
|
refresh_token, is_refresh_token = get_refresh_token_from_client_id(
|
|
other_client_id, other_client_secret)
|
|
credentials = {}
|
|
if os.path.exists(LOCAL_KFP_CREDENTIAL):
|
|
with open(LOCAL_KFP_CREDENTIAL, 'r') as f:
|
|
credentials = json.load(f)
|
|
credentials[client_id] = {
|
|
'other_client_id': other_client_id,
|
|
'other_client_secret': other_client_secret
|
|
}
|
|
|
|
if is_refresh_token:
|
|
credentials[client_id]['refresh_token'] = refresh_token
|
|
else:
|
|
credentials[client_id]['access_token'] = refresh_token
|
|
# TODO: handle the case when the refresh_token expires, which only
|
|
# happens if the refresh_token is not used once for six months.
|
|
if not os.path.exists(os.path.dirname(LOCAL_KFP_CREDENTIAL)):
|
|
os.makedirs(os.path.dirname(LOCAL_KFP_CREDENTIAL))
|
|
with open(LOCAL_KFP_CREDENTIAL, 'w') as f:
|
|
json.dump(credentials, f)
|
|
token = id_token_from_refresh_token(other_client_id,
|
|
other_client_secret, refresh_token,
|
|
client_id)
|
|
return token, is_refresh_token
|
|
|
|
|
|
def get_auth_token_from_sa(client_id: str) -> Optional[str]:
|
|
"""Gets auth token from default service account.
|
|
|
|
Returns:
|
|
Authorization token or None, if not found.
|
|
"""
|
|
service_account_credentials = get_service_account_credentials(client_id)
|
|
if service_account_credentials:
|
|
return get_google_open_id_connect_token(service_account_credentials)
|
|
return None
|
|
|
|
|
|
def get_service_account_credentials(
|
|
client_id: str) -> Optional[google.oauth2.service_account.Credentials]:
|
|
"""Figure out what environment we're running in and get some preliminary
|
|
information about the service account.
|
|
|
|
Args:
|
|
client_id: OAuth client ID.
|
|
|
|
Returns:
|
|
OAuth2 credentials or None.
|
|
"""
|
|
bootstrap_credentials, _ = google.auth.default(scopes=[IAM_SCOPE])
|
|
if isinstance(bootstrap_credentials, google.oauth2.credentials.Credentials):
|
|
logging.info('Found OAuth2 credentials and skip SA auth.')
|
|
return None
|
|
if isinstance(bootstrap_credentials, google.auth.app_engine.Credentials):
|
|
# import requests_toolbelt.adapters.appengine here for those who run KFP
|
|
# in an environment where urllib3<2.0.0 (https://github.com/kubeflow/pipelines/blob/9f278f3682662b24b46be2d9ef4a783bcc1f9b0c/sdk/python/requirements.in#L25C14-L25C14)
|
|
# is not available, preventing breaks due to https://github.com/kubeflow/pipelines/issues/9326#issuecomment-1535491761
|
|
# whenever the user runs `import kfp`.
|
|
# by putting the import statement here, only those invoking the KFP SDK client
|
|
# from within App Engine are strictly required to have urllib3<2.0.0.
|
|
import requests_toolbelt.adapters.appengine
|
|
requests_toolbelt.adapters.appengine.monkeypatch()
|
|
# For service account's using the Compute Engine metadata service,
|
|
# service_account_email isn't available until refresh is called.
|
|
bootstrap_credentials.refresh(Request())
|
|
signer_email = bootstrap_credentials.service_account_email
|
|
if isinstance(bootstrap_credentials,
|
|
google.auth.compute_engine.credentials.Credentials):
|
|
# Since the Compute Engine metadata service doesn't expose the service
|
|
# account key, we use the IAM signBlob API to sign instead.
|
|
# In order for this to work:
|
|
#
|
|
# 1. Your VM needs the https://www.googleapis.com/auth/iam scope.
|
|
# You can specify this specific scope when creating a VM
|
|
# through the API or gcloud. When using Cloud Console,
|
|
# you'll need to specify the "full access to all Cloud APIs"
|
|
# scope. A VM's scopes can only be specified at creation time.
|
|
#
|
|
# 2. The VM's default service account needs the "Service Account Actor"
|
|
# role. This can be found under the "Project" category in Cloud
|
|
# Console, or roles/iam.serviceAccountActor in gcloud.
|
|
signer = google.auth.iam.Signer(Request(), bootstrap_credentials,
|
|
signer_email)
|
|
else:
|
|
# A Signer object can sign a JWT using the service account's key.
|
|
signer = bootstrap_credentials.signer
|
|
# Construct OAuth 2.0 service account credentials using the signer
|
|
# and email acquired from the bootstrap credentials.
|
|
return google.oauth2.service_account.Credentials(
|
|
signer,
|
|
signer_email,
|
|
token_uri=OAUTH_TOKEN_URI,
|
|
additional_claims={'target_audience': client_id})
|
|
|
|
|
|
def get_google_open_id_connect_token(
|
|
service_account_credentials: google.oauth2.service_account.Credentials
|
|
) -> str:
|
|
"""Gets an OpenID Connect token issued by Google for the service account.
|
|
|
|
This function:
|
|
1. Generates a JWT signed with the service account's private key
|
|
containing a special "target_audience" claim.
|
|
2. Sends it to the OAUTH_TOKEN_URI endpoint. Because the JWT in #1
|
|
has a target_audience claim, that endpoint will respond with
|
|
an OpenID Connect token for the service account -- in other words,
|
|
a JWT signed by *Google*. The aud claim in this JWT will be
|
|
set to the value from the target_audience claim in #1.
|
|
For more information, see
|
|
https://developers.google.com/identity/protocols/OAuth2ServiceAccount
|
|
The HTTP/REST example on that page describes the JWT structure and
|
|
demonstrates how to call the token endpoint. (The example on that page
|
|
shows how to get an OAuth2 access token; this code is using a
|
|
modified version of it to get an OpenID Connect token).
|
|
|
|
Returns:
|
|
OAuth ID token.
|
|
"""
|
|
service_account_jwt = (
|
|
service_account_credentials._make_authorization_grant_assertion())
|
|
request = google.auth.transport.requests.Request()
|
|
body = {
|
|
'assertion': service_account_jwt,
|
|
'grant_type': google.oauth2._client._JWT_GRANT_TYPE,
|
|
}
|
|
token_response = google.oauth2._client._token_endpoint_request(
|
|
request, OAUTH_TOKEN_URI, body)
|
|
return token_response['id_token']
|
|
|
|
|
|
def get_refresh_token_from_client_id(client_id: str,
|
|
client_secret: str) -> Tuple[str, bool]:
|
|
"""Obtains the ID token for provided Client ID with user accounts.
|
|
Flow: get authorization code -> exchange for refresh token -> obtain and
|
|
return ID token.
|
|
|
|
Args:
|
|
client_id: OAuth client ID.
|
|
client_secret: OAuth client secret.
|
|
See https://console.cloud.google.com/apis/credentials.
|
|
|
|
Returns:
|
|
Tuple of (OAuth short-lived access token or long-lived refresh token,
|
|
and a boolean indicating whether the returned token is refresh_token)
|
|
"""
|
|
auth_code, redirect_uri = get_auth_code(client_id)
|
|
token, is_refresh_token = get_refresh_token_from_code(
|
|
auth_code, client_id, client_secret, redirect_uri)
|
|
return token, is_refresh_token
|
|
|
|
|
|
def get_auth_code(client_id: str) -> Tuple[str, str]:
|
|
"""Retrieves authorization token using Loopback flow.
|
|
|
|
Args:
|
|
client_id: OAuth client ID. To retrieve it, visit
|
|
https://console.cloud.google.com/apis/credentials.
|
|
|
|
Returns:
|
|
A tuple of (authorization token, redirect_uri parameter).
|
|
|
|
Raises:
|
|
ValueError: If the provided authorization_response is empty.
|
|
"""
|
|
host = 'localhost'
|
|
port = 9901
|
|
redirect_uri = f'http://{host}:{port}'
|
|
auth_url = ('https://accounts.google.com/o/oauth2/v2/auth?'
|
|
f'client_id={client_id}&response_type=code&'
|
|
'scope=openid%20email&access_type=offline&'
|
|
f'redirect_uri={redirect_uri}')
|
|
authorization_response = None
|
|
if ('SSH_CONNECTION' in os.environ) or ('SSH_CLIENT'
|
|
in os.environ) or is_ipython():
|
|
try:
|
|
print(('SSH connection or IPython shell detected. Please follow the'
|
|
' instructions below. Otherwise, press CTRL+C if you are not'
|
|
' connected via SSH and not using IPython (e.g. Jupyter'
|
|
' Notebook).'))
|
|
authorization_response = get_auth_response_ssh(host, port, auth_url)
|
|
except KeyboardInterrupt:
|
|
authorization_response = None
|
|
logging.warning('User pressed CTRL+C. Trying to open browser...')
|
|
if authorization_response is None:
|
|
try:
|
|
print(('Using a local web-server. Please follow the instructions '
|
|
'below. Otherwise, press CTRL+C to cancel and manually '
|
|
'copy-paste the response URL.'))
|
|
authorization_response = get_auth_response_local(
|
|
host, port, auth_url)
|
|
except KeyboardInterrupt:
|
|
logging.warning('User pressed CTRL+C. See instructions below.')
|
|
authorization_response = get_auth_response_ssh(host, port, auth_url)
|
|
except OSError as err:
|
|
logging.warning(
|
|
('%s.\n Error occurred while creating a local web-server. '
|
|
'Possibly http://%s:%i is allocated to '
|
|
'another process. Falling back to manual mode. '
|
|
'See instructions below.'), err, host, port)
|
|
authorization_response = get_auth_response_ssh(host, port, auth_url)
|
|
if authorization_response is None:
|
|
raise ValueError(
|
|
'Authorization response URL is empty. This may be caused by '
|
|
f'corrupted or expired credentials in {LOCAL_KFP_CREDENTIAL}'
|
|
'. Try renaming or moving them to another directory before '
|
|
'running again.')
|
|
token = fetch_auth_token_from_response(authorization_response)
|
|
return token, redirect_uri
|
|
|
|
|
|
def get_auth_response_ssh(host: str, port: int, auth_url: str) -> str:
|
|
"""Fetches OAuth authorization response URL for remote SSH connection.
|
|
|
|
Args:
|
|
host: Hostname in redirect_uri.
|
|
port: Port in redirect_uri.
|
|
auth_url: OAuth request URL.
|
|
|
|
Returns:
|
|
A URL containing authorization code.
|
|
"""
|
|
print(auth_url)
|
|
return input(
|
|
'Carefully follow these steps: (1) open the URL above in your'
|
|
' browser, (2) authenticate and copy a url of the response page'
|
|
f' that starts with http://{host}:{port}..., and (3) paste it'
|
|
' below:\n')
|
|
|
|
|
|
def get_auth_response_local(host: str, port: int,
|
|
auth_url: str) -> Optional[str]:
|
|
"""Fetches OAuth authorization response URL using a local web-server.
|
|
|
|
Args:
|
|
host: Hostname of the server.
|
|
port: Port of the server.
|
|
auth_url: OAuth request URL.
|
|
|
|
Returns:
|
|
A URL containing authorization code.
|
|
"""
|
|
with get_local_server_app(host, port) as (local_server, wsgi_app):
|
|
open_new_tab(auth_url)
|
|
print(f'Please visit this URL to authorize Kubeflow SDK: {auth_url}')
|
|
print((f'Make sure that http://{host}:{port} is added to Authorized'
|
|
' redirect URIs for your OAuth 2.0 Client ID. Check it here:'
|
|
' https://console.cloud.google.com/apis/credentials'))
|
|
local_server.handle_request()
|
|
return wsgi_app.last_request_uri
|
|
|
|
|
|
def get_refresh_token_from_code(auth_code: str, client_id: str,
|
|
client_secret: str,
|
|
redirect_uri: str) -> Tuple[str, bool]:
|
|
"""Returns refresh or access token from authorization code.
|
|
|
|
Args:
|
|
auth_code: OAuth authorization code.
|
|
client_id: OAuth client ID.
|
|
client_secret: OAuth client secret.
|
|
redirect_uri: Redirect uri used to obtain auth_code.
|
|
|
|
Returns:
|
|
Tuple of (OAuth short-lived access token or long-lived refresh token,
|
|
and a boolean indicating whether the returned token is refresh_token)
|
|
|
|
Raises:
|
|
ValueError: If HTTP request returns
|
|
a requests.exceptions.HTTPError.
|
|
"""
|
|
payload = {
|
|
'code': auth_code,
|
|
'client_id': client_id,
|
|
'client_secret': client_secret,
|
|
'redirect_uri': redirect_uri,
|
|
'grant_type': 'authorization_code'
|
|
}
|
|
res = requests.post(OAUTH_TOKEN_URI, data=payload)
|
|
try:
|
|
res.raise_for_status()
|
|
except requests.exceptions.HTTPError as err:
|
|
raise ValueError(
|
|
('Some HTTPErrors are caused by expired credentials in '
|
|
f'{LOCAL_KFP_CREDENTIAL} Try renaming or moving '
|
|
'them to another directory before running again.')) from err
|
|
parsed_res = json.loads(res.text)
|
|
token = parsed_res.get('refresh_token')
|
|
is_refresh_token = True
|
|
if token is None:
|
|
token = parsed_res.get('access_token')
|
|
is_refresh_token = False
|
|
return str(token), is_refresh_token
|
|
|
|
|
|
def id_token_from_refresh_token(client_id: str, client_secret: str,
|
|
refresh_token: str, audience: str) -> str:
|
|
"""Returns ID token from refresh token.
|
|
|
|
Args:
|
|
client_id: OAuth client ID.
|
|
client_secret: OAuth client secret.
|
|
refresh_token: OAuth refresh token.
|
|
audience: OAuth audience.
|
|
|
|
Returns:
|
|
OAuth ID token.
|
|
|
|
Raises:
|
|
ValueError: If HTTP request returns
|
|
a requests.exceptions.HTTPError.
|
|
"""
|
|
payload = {
|
|
'client_id': client_id,
|
|
'client_secret': client_secret,
|
|
'refresh_token': refresh_token,
|
|
'grant_type': 'refresh_token',
|
|
'audience': audience
|
|
}
|
|
res = requests.post(OAUTH_TOKEN_URI, data=payload)
|
|
try:
|
|
res.raise_for_status()
|
|
except requests.exceptions.HTTPError as err:
|
|
raise ValueError(
|
|
('Some HTTPErrors are caused by expired credentials in '
|
|
f'{LOCAL_KFP_CREDENTIAL}. Try renaming or moving '
|
|
'them to another directory before running again.')) from err
|
|
return str(json.loads(res.text).get('id_token'))
|
|
|
|
|
|
class RedirectWSGIApp:
|
|
"""WSGI app to handle the authorization redirect.
|
|
|
|
Stores the request URI and displays the given success message.
|
|
"""
|
|
|
|
def __init__(self, success_message: str) -> None:
|
|
"""
|
|
Args:
|
|
success_message: The message to display in the web browser
|
|
the authorization flow is complete.
|
|
"""
|
|
self.last_request_uri = None
|
|
self._success_message = success_message
|
|
|
|
def __call__(
|
|
self, environ: Dict[str, Any], start_response: Callable[[str, list],
|
|
Callable[...,
|
|
None]]
|
|
) -> Iterable[bytes]:
|
|
"""WSGI Callable. Updates environment dictionary with parameters
|
|
required for WSGI and returns it.
|
|
|
|
Args:
|
|
environ: The WSGI environment.
|
|
start_response: The WSGI start_response
|
|
callable.
|
|
|
|
Returns:
|
|
The response body.
|
|
"""
|
|
wsgiref.util.setup_testing_defaults(environ)
|
|
start_response('200 OK',
|
|
[('Content-type', 'text/plain; charset=utf-8')])
|
|
self.last_request_uri = wsgiref.util.request_uri(environ)
|
|
return [self._success_message.encode('utf-8')]
|
|
|
|
|
|
@contextmanager
|
|
def get_local_server_app(
|
|
host: str, port: int
|
|
) -> Generator[Tuple[wsgiref.simple_server.WSGIServer, RedirectWSGIApp], None,
|
|
None]:
|
|
"""Creates a local web-server for given host and port.
|
|
|
|
Args:
|
|
host: Hostname of the server.
|
|
port: Port of the server.
|
|
|
|
Returns:
|
|
Tuple of (a local server instance, WSGI app that handles
|
|
the authorization redirect).
|
|
"""
|
|
success_message = ('Kubeflow SDK authentication is completed.'
|
|
' You may close this window now.')
|
|
wsgi_app = RedirectWSGIApp(success_message)
|
|
wsgiref.simple_server.WSGIServer.allow_reuse_address = False
|
|
local_server = wsgiref.simple_server.make_server(
|
|
host,
|
|
port,
|
|
wsgi_app,
|
|
handler_class=wsgiref.simple_server.WSGIRequestHandler)
|
|
try:
|
|
yield local_server, wsgi_app
|
|
finally:
|
|
local_server.server_close()
|
|
del wsgi_app
|
|
|
|
|
|
def fetch_auth_token_from_response(url: str) -> str:
|
|
"""Fetches authorization code for OAuth2.0 Loopback flow.
|
|
|
|
Args:
|
|
url: A string containing the response URL.
|
|
|
|
Returns:
|
|
An access code.
|
|
|
|
Raises:
|
|
KeyError: If no authorization code is found in the provided url.
|
|
"""
|
|
parsed_url = urlparse(url)
|
|
parsed_query = parse_qs(parsed_url.query)
|
|
access_code = parsed_query.get('code')
|
|
if access_code is None:
|
|
raise KeyError((
|
|
'Authorization code is missing or empty in the response.'
|
|
' Please, try again or check '
|
|
'https://www.kubeflow.org/docs/distributions/gke/deploy/oauth-setup'
|
|
))
|
|
if isinstance(access_code, list):
|
|
access_code = str(access_code.pop(0))
|
|
return access_code
|
|
|
|
|
|
def is_ipython() -> bool:
|
|
"""Returns whether we are running in notebook."""
|
|
try:
|
|
import IPython
|
|
ipy = IPython.get_ipython()
|
|
if ipy is None:
|
|
return False
|
|
except ImportError:
|
|
return False
|
|
return True
|