pipelines/sdk/python/kfp/client/auth.py

532 lines
21 KiB
Python

# Copyright 2022 The Kubeflow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
import json
import logging
import os
from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple
from urllib.parse import parse_qs
from urllib.parse import urlparse
from webbrowser import open_new_tab
import wsgiref.simple_server
import wsgiref.util
import google.auth
import google.auth.app_engine
import google.auth.compute_engine.credentials
import google.auth.iam
from google.auth.transport.requests import Request
import google.oauth2.credentials
import google.oauth2.service_account
import requests
IAM_SCOPE = 'https://www.googleapis.com/auth/iam'
OAUTH_TOKEN_URI = 'https://www.googleapis.com/oauth2/v4/token'
LOCAL_KFP_CREDENTIAL = os.path.expanduser('~/.config/kfp/credentials.json')
def get_gcp_access_token() -> Optional[str]:
"""Gets GCP access token for the current Application Default Credentials.
Returns:
GCP access token or None, if it is not set. For more information, see
https://cloud.google.com/sdk/gcloud/reference/auth/application-default/print-access-token.
"""
token = None
try:
creds, _ = google.auth.default(
scopes=['https://www.googleapis.com/auth/cloud-platform'])
if not creds.valid:
auth_req = Request()
creds.refresh(auth_req)
if creds.valid:
token = creds.token
except Exception as e:
logging.warning('Failed to get GCP access token: %s', e)
return token
def get_auth_token(client_id: str, other_client_id: str,
other_client_secret: str) -> Tuple[Optional[str], bool]:
"""Gets auth token from default service account or user account.
Returns:
Tuple of (ID token or None, if not found, and a boolean
indicating whether a refresh token has been saved locally)
"""
is_refresh_token = True
if os.path.exists(LOCAL_KFP_CREDENTIAL):
# fetch IAP auth token using the locally stored credentials.
with open(LOCAL_KFP_CREDENTIAL, 'r') as f:
credentials = json.load(f)
if client_id in credentials:
saved_refresh_token = credentials[client_id].get('refresh_token')
saved_other_client_id = credentials[client_id].get(
'other_client_id')
saved_other_client_secret = credentials[client_id].get(
'other_client_secret')
if None not in {
saved_refresh_token, saved_other_client_id,
saved_other_client_secret
}:
return id_token_from_refresh_token(saved_other_client_id,
saved_other_client_secret,
saved_refresh_token,
client_id), is_refresh_token
if other_client_id is None or other_client_secret is None:
# fetch IAP auth token: service accounts
token = get_auth_token_from_sa(client_id)
is_refresh_token = False
else:
# fetch IAP auth token: user account
# Obtain the ID token for provided Client ID with user accounts.
# Flow: get authorization code -> exchange for refresh token -> obtain
# and return ID token
refresh_token, is_refresh_token = get_refresh_token_from_client_id(
other_client_id, other_client_secret)
credentials = {}
if os.path.exists(LOCAL_KFP_CREDENTIAL):
with open(LOCAL_KFP_CREDENTIAL, 'r') as f:
credentials = json.load(f)
credentials[client_id] = {
'other_client_id': other_client_id,
'other_client_secret': other_client_secret
}
if is_refresh_token:
credentials[client_id]['refresh_token'] = refresh_token
else:
credentials[client_id]['access_token'] = refresh_token
# TODO: handle the case when the refresh_token expires, which only
# happens if the refresh_token is not used once for six months.
if not os.path.exists(os.path.dirname(LOCAL_KFP_CREDENTIAL)):
os.makedirs(os.path.dirname(LOCAL_KFP_CREDENTIAL))
with open(LOCAL_KFP_CREDENTIAL, 'w') as f:
json.dump(credentials, f)
token = id_token_from_refresh_token(other_client_id,
other_client_secret, refresh_token,
client_id)
return token, is_refresh_token
def get_auth_token_from_sa(client_id: str) -> Optional[str]:
"""Gets auth token from default service account.
Returns:
Authorization token or None, if not found.
"""
service_account_credentials = get_service_account_credentials(client_id)
if service_account_credentials:
return get_google_open_id_connect_token(service_account_credentials)
return None
def get_service_account_credentials(
client_id: str) -> Optional[google.oauth2.service_account.Credentials]:
"""Figure out what environment we're running in and get some preliminary
information about the service account.
Args:
client_id: OAuth client ID.
Returns:
OAuth2 credentials or None.
"""
bootstrap_credentials, _ = google.auth.default(scopes=[IAM_SCOPE])
if isinstance(bootstrap_credentials, google.oauth2.credentials.Credentials):
logging.info('Found OAuth2 credentials and skip SA auth.')
return None
if isinstance(bootstrap_credentials, google.auth.app_engine.Credentials):
# import requests_toolbelt.adapters.appengine here for those who run KFP
# in an environment where urllib3<2.0.0 (https://github.com/kubeflow/pipelines/blob/9f278f3682662b24b46be2d9ef4a783bcc1f9b0c/sdk/python/requirements.in#L25C14-L25C14)
# is not available, preventing breaks due to https://github.com/kubeflow/pipelines/issues/9326#issuecomment-1535491761
# whenever the user runs `import kfp`.
# by putting the import statement here, only those invoking the KFP SDK client
# from within App Engine are strictly required to have urllib3<2.0.0.
import requests_toolbelt.adapters.appengine
requests_toolbelt.adapters.appengine.monkeypatch()
# For service account's using the Compute Engine metadata service,
# service_account_email isn't available until refresh is called.
bootstrap_credentials.refresh(Request())
signer_email = bootstrap_credentials.service_account_email
if isinstance(bootstrap_credentials,
google.auth.compute_engine.credentials.Credentials):
# Since the Compute Engine metadata service doesn't expose the service
# account key, we use the IAM signBlob API to sign instead.
# In order for this to work:
#
# 1. Your VM needs the https://www.googleapis.com/auth/iam scope.
# You can specify this specific scope when creating a VM
# through the API or gcloud. When using Cloud Console,
# you'll need to specify the "full access to all Cloud APIs"
# scope. A VM's scopes can only be specified at creation time.
#
# 2. The VM's default service account needs the "Service Account Actor"
# role. This can be found under the "Project" category in Cloud
# Console, or roles/iam.serviceAccountActor in gcloud.
signer = google.auth.iam.Signer(Request(), bootstrap_credentials,
signer_email)
else:
# A Signer object can sign a JWT using the service account's key.
signer = bootstrap_credentials.signer
# Construct OAuth 2.0 service account credentials using the signer
# and email acquired from the bootstrap credentials.
return google.oauth2.service_account.Credentials(
signer,
signer_email,
token_uri=OAUTH_TOKEN_URI,
additional_claims={'target_audience': client_id})
def get_google_open_id_connect_token(
service_account_credentials: google.oauth2.service_account.Credentials
) -> str:
"""Gets an OpenID Connect token issued by Google for the service account.
This function:
1. Generates a JWT signed with the service account's private key
containing a special "target_audience" claim.
2. Sends it to the OAUTH_TOKEN_URI endpoint. Because the JWT in #1
has a target_audience claim, that endpoint will respond with
an OpenID Connect token for the service account -- in other words,
a JWT signed by *Google*. The aud claim in this JWT will be
set to the value from the target_audience claim in #1.
For more information, see
https://developers.google.com/identity/protocols/OAuth2ServiceAccount
The HTTP/REST example on that page describes the JWT structure and
demonstrates how to call the token endpoint. (The example on that page
shows how to get an OAuth2 access token; this code is using a
modified version of it to get an OpenID Connect token).
Returns:
OAuth ID token.
"""
service_account_jwt = (
service_account_credentials._make_authorization_grant_assertion())
request = google.auth.transport.requests.Request()
body = {
'assertion': service_account_jwt,
'grant_type': google.oauth2._client._JWT_GRANT_TYPE,
}
token_response = google.oauth2._client._token_endpoint_request(
request, OAUTH_TOKEN_URI, body)
return token_response['id_token']
def get_refresh_token_from_client_id(client_id: str,
client_secret: str) -> Tuple[str, bool]:
"""Obtains the ID token for provided Client ID with user accounts.
Flow: get authorization code -> exchange for refresh token -> obtain and
return ID token.
Args:
client_id: OAuth client ID.
client_secret: OAuth client secret.
See https://console.cloud.google.com/apis/credentials.
Returns:
Tuple of (OAuth short-lived access token or long-lived refresh token,
and a boolean indicating whether the returned token is refresh_token)
"""
auth_code, redirect_uri = get_auth_code(client_id)
token, is_refresh_token = get_refresh_token_from_code(
auth_code, client_id, client_secret, redirect_uri)
return token, is_refresh_token
def get_auth_code(client_id: str) -> Tuple[str, str]:
"""Retrieves authorization token using Loopback flow.
Args:
client_id: OAuth client ID. To retrieve it, visit
https://console.cloud.google.com/apis/credentials.
Returns:
A tuple of (authorization token, redirect_uri parameter).
Raises:
ValueError: If the provided authorization_response is empty.
"""
host = 'localhost'
port = 9901
redirect_uri = f'http://{host}:{port}'
auth_url = ('https://accounts.google.com/o/oauth2/v2/auth?'
f'client_id={client_id}&response_type=code&'
'scope=openid%20email&access_type=offline&'
f'redirect_uri={redirect_uri}')
authorization_response = None
if ('SSH_CONNECTION' in os.environ) or ('SSH_CLIENT'
in os.environ) or is_ipython():
try:
print(('SSH connection or IPython shell detected. Please follow the'
' instructions below. Otherwise, press CTRL+C if you are not'
' connected via SSH and not using IPython (e.g. Jupyter'
' Notebook).'))
authorization_response = get_auth_response_ssh(host, port, auth_url)
except KeyboardInterrupt:
authorization_response = None
logging.warning('User pressed CTRL+C. Trying to open browser...')
if authorization_response is None:
try:
print(('Using a local web-server. Please follow the instructions '
'below. Otherwise, press CTRL+C to cancel and manually '
'copy-paste the response URL.'))
authorization_response = get_auth_response_local(
host, port, auth_url)
except KeyboardInterrupt:
logging.warning('User pressed CTRL+C. See instructions below.')
authorization_response = get_auth_response_ssh(host, port, auth_url)
except OSError as err:
logging.warning(
('%s.\n Error occurred while creating a local web-server. '
'Possibly http://%s:%i is allocated to '
'another process. Falling back to manual mode. '
'See instructions below.'), err, host, port)
authorization_response = get_auth_response_ssh(host, port, auth_url)
if authorization_response is None:
raise ValueError(
'Authorization response URL is empty. This may be caused by '
f'corrupted or expired credentials in {LOCAL_KFP_CREDENTIAL}'
'. Try renaming or moving them to another directory before '
'running again.')
token = fetch_auth_token_from_response(authorization_response)
return token, redirect_uri
def get_auth_response_ssh(host: str, port: int, auth_url: str) -> str:
"""Fetches OAuth authorization response URL for remote SSH connection.
Args:
host: Hostname in redirect_uri.
port: Port in redirect_uri.
auth_url: OAuth request URL.
Returns:
A URL containing authorization code.
"""
print(auth_url)
return input(
'Carefully follow these steps: (1) open the URL above in your'
' browser, (2) authenticate and copy a url of the response page'
f' that starts with http://{host}:{port}..., and (3) paste it'
' below:\n')
def get_auth_response_local(host: str, port: int,
auth_url: str) -> Optional[str]:
"""Fetches OAuth authorization response URL using a local web-server.
Args:
host: Hostname of the server.
port: Port of the server.
auth_url: OAuth request URL.
Returns:
A URL containing authorization code.
"""
with get_local_server_app(host, port) as (local_server, wsgi_app):
open_new_tab(auth_url)
print(f'Please visit this URL to authorize Kubeflow SDK: {auth_url}')
print((f'Make sure that http://{host}:{port} is added to Authorized'
' redirect URIs for your OAuth 2.0 Client ID. Check it here:'
' https://console.cloud.google.com/apis/credentials'))
local_server.handle_request()
return wsgi_app.last_request_uri
def get_refresh_token_from_code(auth_code: str, client_id: str,
client_secret: str,
redirect_uri: str) -> Tuple[str, bool]:
"""Returns refresh or access token from authorization code.
Args:
auth_code: OAuth authorization code.
client_id: OAuth client ID.
client_secret: OAuth client secret.
redirect_uri: Redirect uri used to obtain auth_code.
Returns:
Tuple of (OAuth short-lived access token or long-lived refresh token,
and a boolean indicating whether the returned token is refresh_token)
Raises:
ValueError: If HTTP request returns
a requests.exceptions.HTTPError.
"""
payload = {
'code': auth_code,
'client_id': client_id,
'client_secret': client_secret,
'redirect_uri': redirect_uri,
'grant_type': 'authorization_code'
}
res = requests.post(OAUTH_TOKEN_URI, data=payload)
try:
res.raise_for_status()
except requests.exceptions.HTTPError as err:
raise ValueError(
('Some HTTPErrors are caused by expired credentials in '
f'{LOCAL_KFP_CREDENTIAL} Try renaming or moving '
'them to another directory before running again.')) from err
parsed_res = json.loads(res.text)
token = parsed_res.get('refresh_token')
is_refresh_token = True
if token is None:
token = parsed_res.get('access_token')
is_refresh_token = False
return str(token), is_refresh_token
def id_token_from_refresh_token(client_id: str, client_secret: str,
refresh_token: str, audience: str) -> str:
"""Returns ID token from refresh token.
Args:
client_id: OAuth client ID.
client_secret: OAuth client secret.
refresh_token: OAuth refresh token.
audience: OAuth audience.
Returns:
OAuth ID token.
Raises:
ValueError: If HTTP request returns
a requests.exceptions.HTTPError.
"""
payload = {
'client_id': client_id,
'client_secret': client_secret,
'refresh_token': refresh_token,
'grant_type': 'refresh_token',
'audience': audience
}
res = requests.post(OAUTH_TOKEN_URI, data=payload)
try:
res.raise_for_status()
except requests.exceptions.HTTPError as err:
raise ValueError(
('Some HTTPErrors are caused by expired credentials in '
f'{LOCAL_KFP_CREDENTIAL}. Try renaming or moving '
'them to another directory before running again.')) from err
return str(json.loads(res.text).get('id_token'))
class RedirectWSGIApp:
"""WSGI app to handle the authorization redirect.
Stores the request URI and displays the given success message.
"""
def __init__(self, success_message: str) -> None:
"""
Args:
success_message: The message to display in the web browser
the authorization flow is complete.
"""
self.last_request_uri = None
self._success_message = success_message
def __call__(
self, environ: Dict[str, Any], start_response: Callable[[str, list],
Callable[...,
None]]
) -> Iterable[bytes]:
"""WSGI Callable. Updates environment dictionary with parameters
required for WSGI and returns it.
Args:
environ: The WSGI environment.
start_response: The WSGI start_response
callable.
Returns:
The response body.
"""
wsgiref.util.setup_testing_defaults(environ)
start_response('200 OK',
[('Content-type', 'text/plain; charset=utf-8')])
self.last_request_uri = wsgiref.util.request_uri(environ)
return [self._success_message.encode('utf-8')]
@contextmanager
def get_local_server_app(
host: str, port: int
) -> Generator[Tuple[wsgiref.simple_server.WSGIServer, RedirectWSGIApp], None,
None]:
"""Creates a local web-server for given host and port.
Args:
host: Hostname of the server.
port: Port of the server.
Returns:
Tuple of (a local server instance, WSGI app that handles
the authorization redirect).
"""
success_message = ('Kubeflow SDK authentication is completed.'
' You may close this window now.')
wsgi_app = RedirectWSGIApp(success_message)
wsgiref.simple_server.WSGIServer.allow_reuse_address = False
local_server = wsgiref.simple_server.make_server(
host,
port,
wsgi_app,
handler_class=wsgiref.simple_server.WSGIRequestHandler)
try:
yield local_server, wsgi_app
finally:
local_server.server_close()
del wsgi_app
def fetch_auth_token_from_response(url: str) -> str:
"""Fetches authorization code for OAuth2.0 Loopback flow.
Args:
url: A string containing the response URL.
Returns:
An access code.
Raises:
KeyError: If no authorization code is found in the provided url.
"""
parsed_url = urlparse(url)
parsed_query = parse_qs(parsed_url.query)
access_code = parsed_query.get('code')
if access_code is None:
raise KeyError((
'Authorization code is missing or empty in the response.'
' Please, try again or check '
'https://www.kubeflow.org/docs/distributions/gke/deploy/oauth-setup'
))
if isinstance(access_code, list):
access_code = str(access_code.pop(0))
return access_code
def is_ipython() -> bool:
"""Returns whether we are running in notebook."""
try:
import IPython
ipy = IPython.get_ipython()
if ipy is None:
return False
except ImportError:
return False
return True