Support to send default service account jwt token for pipeline client. (#779)
* Support to send default service account jwt token for pipeline client. * Configure auth for both kfp_run and kfp_experiment APIs
This commit is contained in:
parent
e9bd7c6a4d
commit
69b7fd31de
|
|
@ -23,7 +23,7 @@
|
||||||
# Setup:
|
# Setup:
|
||||||
# apt-get update -y
|
# apt-get update -y
|
||||||
# apt-get install --no-install-recommends -y -q default-jdk
|
# apt-get install --no-install-recommends -y -q default-jdk
|
||||||
# wget http://central.maven.org/maven2/io/swagger/swagger-codegen-cli/2.3.1/swagger-codegen-cli-2.3.1.jar -O /tmp/swagger-codegen-cli.jar
|
# wget http://central.maven.org/maven2/io/swagger/swagger-codegen-cli/2.4.1/swagger-codegen-cli-2.4.1.jar -O /tmp/swagger-codegen-cli.jar
|
||||||
|
|
||||||
get_abs_filename() {
|
get_abs_filename() {
|
||||||
# $1 : relative filename
|
# $1 : relative filename
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,110 @@
|
||||||
|
# Copyright 2018 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import google.auth
|
||||||
|
import google.auth.app_engine
|
||||||
|
import google.auth.compute_engine.credentials
|
||||||
|
import google.auth.iam
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
import google.oauth2.credentials
|
||||||
|
import google.oauth2.service_account
|
||||||
|
import requests_toolbelt.adapters.appengine
|
||||||
|
|
||||||
|
IAM_SCOPE = 'https://www.googleapis.com/auth/iam'
|
||||||
|
OAUTH_TOKEN_URI = 'https://www.googleapis.com/oauth2/v4/token'
|
||||||
|
|
||||||
|
def get_auth_token(client_id):
|
||||||
|
"""Gets auth token from default service account.
|
||||||
|
|
||||||
|
If no service account credential is found, returns None.
|
||||||
|
"""
|
||||||
|
service_account_credentials = get_service_account_credentials(client_id)
|
||||||
|
if service_account_credentials:
|
||||||
|
return get_google_open_id_connect_token(service_account_credentials)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_service_account_credentials(client_id):
|
||||||
|
# Figure out what environment we're running in and get some preliminary
|
||||||
|
# information about the service account.
|
||||||
|
bootstrap_credentials, _ = google.auth.default(
|
||||||
|
scopes=[IAM_SCOPE])
|
||||||
|
if isinstance(bootstrap_credentials,
|
||||||
|
google.oauth2.credentials.Credentials):
|
||||||
|
logging.info('Found OAuth2 credentials and skip SA auth.')
|
||||||
|
return None
|
||||||
|
elif isinstance(bootstrap_credentials,
|
||||||
|
google.auth.app_engine.Credentials):
|
||||||
|
requests_toolbelt.adapters.appengine.monkeypatch()
|
||||||
|
|
||||||
|
# For service account's using the Compute Engine metadata service,
|
||||||
|
# service_account_email isn't available until refresh is called.
|
||||||
|
bootstrap_credentials.refresh(Request())
|
||||||
|
signer_email = bootstrap_credentials.service_account_email
|
||||||
|
if isinstance(bootstrap_credentials,
|
||||||
|
google.auth.compute_engine.credentials.Credentials):
|
||||||
|
# Since the Compute Engine metadata service doesn't expose the service
|
||||||
|
# account key, we use the IAM signBlob API to sign instead.
|
||||||
|
# In order for this to work:
|
||||||
|
#
|
||||||
|
# 1. Your VM needs the https://www.googleapis.com/auth/iam scope.
|
||||||
|
# You can specify this specific scope when creating a VM
|
||||||
|
# through the API or gcloud. When using Cloud Console,
|
||||||
|
# you'll need to specify the "full access to all Cloud APIs"
|
||||||
|
# scope. A VM's scopes can only be specified at creation time.
|
||||||
|
#
|
||||||
|
# 2. The VM's default service account needs the "Service Account Actor"
|
||||||
|
# role. This can be found under the "Project" category in Cloud
|
||||||
|
# Console, or roles/iam.serviceAccountActor in gcloud.
|
||||||
|
signer = google.auth.iam.Signer(
|
||||||
|
Request(), bootstrap_credentials, signer_email)
|
||||||
|
else:
|
||||||
|
# A Signer object can sign a JWT using the service account's key.
|
||||||
|
signer = bootstrap_credentials.signer
|
||||||
|
|
||||||
|
# Construct OAuth 2.0 service account credentials using the signer
|
||||||
|
# and email acquired from the bootstrap credentials.
|
||||||
|
return google.oauth2.service_account.Credentials(
|
||||||
|
signer, signer_email, token_uri=OAUTH_TOKEN_URI, additional_claims={
|
||||||
|
'target_audience': client_id
|
||||||
|
})
|
||||||
|
|
||||||
|
def get_google_open_id_connect_token(service_account_credentials):
|
||||||
|
"""Get an OpenID Connect token issued by Google for the service account.
|
||||||
|
This function:
|
||||||
|
1. Generates a JWT signed with the service account's private key
|
||||||
|
containing a special "target_audience" claim.
|
||||||
|
2. Sends it to the OAUTH_TOKEN_URI endpoint. Because the JWT in #1
|
||||||
|
has a target_audience claim, that endpoint will respond with
|
||||||
|
an OpenID Connect token for the service account -- in other words,
|
||||||
|
a JWT signed by *Google*. The aud claim in this JWT will be
|
||||||
|
set to the value from the target_audience claim in #1.
|
||||||
|
For more information, see
|
||||||
|
https://developers.google.com/identity/protocols/OAuth2ServiceAccount .
|
||||||
|
The HTTP/REST example on that page describes the JWT structure and
|
||||||
|
demonstrates how to call the token endpoint. (The example on that page
|
||||||
|
shows how to get an OAuth2 access token; this code is using a
|
||||||
|
modified version of it to get an OpenID Connect token.)
|
||||||
|
"""
|
||||||
|
|
||||||
|
service_account_jwt = (
|
||||||
|
service_account_credentials._make_authorization_grant_assertion())
|
||||||
|
request = google.auth.transport.requests.Request()
|
||||||
|
body = {
|
||||||
|
'assertion': service_account_jwt,
|
||||||
|
'grant_type': google.oauth2._client._JWT_GRANT_TYPE,
|
||||||
|
}
|
||||||
|
token_response = google.oauth2._client._token_endpoint_request(
|
||||||
|
request, OAUTH_TOKEN_URI, body)
|
||||||
|
return token_response['id_token']
|
||||||
|
|
@ -24,6 +24,7 @@ from datetime import datetime
|
||||||
from .compiler import compiler
|
from .compiler import compiler
|
||||||
from .compiler import _k8s_helper
|
from .compiler import _k8s_helper
|
||||||
|
|
||||||
|
from ._auth import get_auth_token
|
||||||
|
|
||||||
class Client(object):
|
class Client(object):
|
||||||
""" API Client for KubeFlow Pipeline.
|
""" API Client for KubeFlow Pipeline.
|
||||||
|
|
@ -32,7 +33,7 @@ class Client(object):
|
||||||
# in-cluster DNS name of the pipeline service
|
# in-cluster DNS name of the pipeline service
|
||||||
IN_CLUSTER_DNS_NAME = 'ml-pipeline.kubeflow.svc.cluster.local:8888'
|
IN_CLUSTER_DNS_NAME = 'ml-pipeline.kubeflow.svc.cluster.local:8888'
|
||||||
|
|
||||||
def __init__(self, host=None):
|
def __init__(self, host=None, client_id=None):
|
||||||
"""Create a new instance of kfp client.
|
"""Create a new instance of kfp client.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -41,6 +42,7 @@ class Client(object):
|
||||||
in the same cluster (such as a Jupyter instance spawned by Kubeflow's
|
in the same cluster (such as a Jupyter instance spawned by Kubeflow's
|
||||||
JupyterHub). If you have a different connection to cluster, such as a kubectl
|
JupyterHub). If you have a different connection to cluster, such as a kubectl
|
||||||
proxy connection, then set it to something like "127.0.0.1:8080/pipeline".
|
proxy connection, then set it to something like "127.0.0.1:8080/pipeline".
|
||||||
|
client_id: The client ID used by Identity-Aware Proxy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -55,17 +57,28 @@ class Client(object):
|
||||||
|
|
||||||
self._host = host
|
self._host = host
|
||||||
|
|
||||||
|
token = None
|
||||||
|
if host and client_id:
|
||||||
|
token = get_auth_token(client_id)
|
||||||
|
|
||||||
config = kfp_run.configuration.Configuration()
|
config = kfp_run.configuration.Configuration()
|
||||||
config.host = host if host else Client.IN_CLUSTER_DNS_NAME
|
config.host = host if host else Client.IN_CLUSTER_DNS_NAME
|
||||||
|
self._configure_auth(config, token)
|
||||||
api_client = kfp_run.api_client.ApiClient(config)
|
api_client = kfp_run.api_client.ApiClient(config)
|
||||||
self._run_api = kfp_run.api.run_service_api.RunServiceApi(api_client)
|
self._run_api = kfp_run.api.run_service_api.RunServiceApi(api_client)
|
||||||
|
|
||||||
config = kfp_experiment.configuration.Configuration()
|
config = kfp_experiment.configuration.Configuration()
|
||||||
config.host = host if host else Client.IN_CLUSTER_DNS_NAME
|
config.host = host if host else Client.IN_CLUSTER_DNS_NAME
|
||||||
|
self._configure_auth(config, token)
|
||||||
api_client = kfp_experiment.api_client.ApiClient(config)
|
api_client = kfp_experiment.api_client.ApiClient(config)
|
||||||
self._experiment_api = \
|
self._experiment_api = \
|
||||||
kfp_experiment.api.experiment_service_api.ExperimentServiceApi(api_client)
|
kfp_experiment.api.experiment_service_api.ExperimentServiceApi(api_client)
|
||||||
|
|
||||||
|
def _configure_auth(self, config, token):
|
||||||
|
if token:
|
||||||
|
config.api_key['authorization'] = token
|
||||||
|
config.api_key_prefix['authorization'] = 'Bearer'
|
||||||
|
|
||||||
def _is_ipython(self):
|
def _is_ipython(self):
|
||||||
"""Returns whether we are running in notebook."""
|
"""Returns whether we are running in notebook."""
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,8 @@ NAME = 'kfp'
|
||||||
VERSION = '0.1'
|
VERSION = '0.1'
|
||||||
|
|
||||||
REQUIRES = ['urllib3 >= 1.15', 'six >= 1.10', 'certifi', 'python-dateutil', 'PyYAML',
|
REQUIRES = ['urllib3 >= 1.15', 'six >= 1.10', 'certifi', 'python-dateutil', 'PyYAML',
|
||||||
'google-cloud-storage == 1.13.0', 'kubernetes == 8.0.0']
|
'google-cloud-storage == 1.13.0', 'kubernetes == 8.0.0', 'PyJWT==1.6.4',
|
||||||
|
'cryptography==2.4.2', 'google-auth==1.6.1', 'requests_toolbelt==0.8.0']
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name=NAME,
|
name=NAME,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue