feat(components): AWS SageMaker - Support for assuming a role (#4212)
* Add client assume role functionality * Add assume_role to component.yaml files * Update image to personal * Update input to force NoneType on empty * Update integration test setup with assumed role * Add assume role integration test * Update boto session to use refreshing credentials * Update assume role relax trust relationship * Add check for defined assumed role name * Add processing assume integ test * Add assume role unit test for main methods * Add assume_role to all READMEs * Update session to use AssumeRoleProvider * Remove region from child calls to session * Fix extra region_name in test * Update assume role processing integ test name * Add processing integ test to list * Update assumed role to remain if not generated * Update license version * Update image tag to new version * Add new version to Changelog
This commit is contained in:
parent
704c8c7660
commit
8014a44229
|
|
@ -4,6 +4,12 @@ The version of the AWS SageMaker Components is determined by the docker image ta
|
|||
Repository: https://hub.docker.com/repository/docker/amazon/aws-sagemaker-kfp-components
|
||||
|
||||
---------------------------------------------
|
||||
**Change log for version 0.7.0**
|
||||
- Add functionality to assume role when sending SageMaker requests
|
||||
|
||||
> Pull requests : [#4212](https://github.com/kubeflow/pipelines/pull/4212)
|
||||
|
||||
|
||||
**Change log for version 0.6.0**
|
||||
- Add functionality to stop SageMaker jobs on run termination
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
** Amazon SageMaker Components for Kubeflow Pipelines; version 0.6.0 --
|
||||
** Amazon SageMaker Components for Kubeflow Pipelines; version 0.7.0 --
|
||||
https://github.com/kubeflow/pipelines/tree/master/components/aws/sagemaker
|
||||
Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
** boto3; version 1.12.33 -- https://github.com/boto/boto3/
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@ Create a transform job in AWS SageMaker.
|
|||
Argument | Description | Optional (in pipeline definition) | Optional (in UI) | Data type | Accepted values | Default |
|
||||
:--- | :---------- | :---------- | :---------- | :----------| :---------- | :----------|
|
||||
region | The region where the endpoint is created | No | No | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | Yes | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
|
||||
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
|
||||
job_name | The name of the transform job. The name must be unique within an AWS Region in an AWS account | Yes | Yes | String | | is a generated name (combination of model_name and 'BatchTransform' string)|
|
||||
model_name | The name of the model that you want to use for the transform job. Model name must be the name of an existing Amazon SageMaker model within an AWS Region in an AWS account | No | No | String | | |
|
||||
max_concurrent | The maximum number of parallel requests that can be sent to each instance in a transform job | Yes | Yes | Integer | | 0 |
|
||||
|
|
|
|||
|
|
@ -90,6 +90,10 @@ inputs:
|
|||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: assume_role
|
||||
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: tags
|
||||
description: 'Key-value pairs to categorize AWS resources.'
|
||||
default: '{}'
|
||||
|
|
@ -98,12 +102,13 @@ outputs:
|
|||
- {name: output_location, description: 'S3 URI of the transform job results.'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.6.0
|
||||
image: amazon/aws-sagemaker-kfp-components:0.7.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
batch_transform.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--assume_role, {inputValue: assume_role},
|
||||
--job_name, {inputValue: job_name},
|
||||
--model_name, {inputValue: model_name},
|
||||
--max_concurrent, {inputValue: max_concurrent},
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ def main(argv=None):
|
|||
args = parser.parse_args(argv)
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
|
||||
logging.info('Submitting Batch Transformation request to SageMaker...')
|
||||
batch_job_name = _utils.create_transform_job(client, vars(args))
|
||||
|
||||
|
|
@ -68,7 +68,7 @@ def main(argv=None):
|
|||
except:
|
||||
raise
|
||||
finally:
|
||||
cw_client = _utils.get_cloudwatch_client(args.region)
|
||||
cw_client = _utils.get_cloudwatch_client(args.region, assume_role_arn=args.assume_role)
|
||||
_utils.print_logs_for_job(cw_client, '/aws/sagemaker/TransformJobs', batch_job_name)
|
||||
|
||||
_utils.write_output(args.output_location_output_path, args.output_location)
|
||||
|
|
|
|||
|
|
@ -24,8 +24,18 @@ import json
|
|||
from pathlib2 import Path
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
from boto3.session import Session
|
||||
|
||||
from botocore.config import Config
|
||||
from botocore.credentials import (
|
||||
AssumeRoleCredentialFetcher,
|
||||
CredentialResolver,
|
||||
DeferredRefreshableCredentials,
|
||||
JSONFileCache
|
||||
)
|
||||
from botocore.exceptions import ClientError
|
||||
from botocore.session import Session as BotocoreSession
|
||||
|
||||
from sagemaker.amazon.amazon_estimator import get_image_uri
|
||||
|
||||
import logging
|
||||
|
|
@ -71,6 +81,7 @@ def nullable_string_argument(value):
|
|||
def add_default_client_arguments(parser):
|
||||
parser.add_argument('--region', type=str, required=True, help='The region where the training job launches.')
|
||||
parser.add_argument('--endpoint_url', type=nullable_string_argument, required=False, help='The URL to use when communicating with the SageMaker service.')
|
||||
parser.add_argument('--assume_role', type=nullable_string_argument, required=False, help='The ARN of an IAM role to assume when connecting to SageMaker.')
|
||||
|
||||
|
||||
def get_component_version():
|
||||
|
|
@ -113,17 +124,59 @@ def print_logs_for_job(cw_client, log_grp, job_name):
|
|||
logging.error(e)
|
||||
|
||||
|
||||
def get_sagemaker_client(region, endpoint_url=None):
|
||||
class AssumeRoleProvider(object):
|
||||
METHOD = 'assume-role'
|
||||
|
||||
def __init__(self, fetcher):
|
||||
self._fetcher = fetcher
|
||||
|
||||
def load(self):
|
||||
return DeferredRefreshableCredentials(
|
||||
self._fetcher.fetch_credentials,
|
||||
self.METHOD
|
||||
)
|
||||
|
||||
def get_boto3_session(region, role_arn=None):
|
||||
"""Creates a boto3 session, optionally assuming a role"""
|
||||
|
||||
# By default return a basic session
|
||||
if role_arn is None:
|
||||
return Session(region_name=region)
|
||||
|
||||
# The following assume role example was taken from
|
||||
# https://github.com/boto/botocore/issues/761#issuecomment-426037853
|
||||
|
||||
# Create a session used to assume role
|
||||
assume_session = BotocoreSession()
|
||||
fetcher = AssumeRoleCredentialFetcher(
|
||||
assume_session.create_client,
|
||||
assume_session.get_credentials(),
|
||||
role_arn,
|
||||
extra_args={
|
||||
'DurationSeconds': 3600, # 1 hour assume assume by default
|
||||
},
|
||||
cache=JSONFileCache()
|
||||
)
|
||||
role_session = BotocoreSession()
|
||||
role_session.register_component(
|
||||
'credential_provider',
|
||||
CredentialResolver([AssumeRoleProvider(fetcher)])
|
||||
)
|
||||
return Session(region_name=region, botocore_session=role_session)
|
||||
|
||||
def get_sagemaker_client(region, endpoint_url=None, assume_role_arn=None):
|
||||
"""Builds a client to the AWS SageMaker API."""
|
||||
session_config = botocore.config.Config(
|
||||
session = get_boto3_session(region, assume_role_arn)
|
||||
session_config = Config(
|
||||
user_agent='sagemaker-on-kubeflow-pipelines-v{}'.format(get_component_version())
|
||||
)
|
||||
client = boto3.client('sagemaker', region_name=region, endpoint_url=endpoint_url, config=session_config)
|
||||
client = session.client('sagemaker', endpoint_url=endpoint_url, config=session_config)
|
||||
return client
|
||||
|
||||
|
||||
def get_cloudwatch_client(region):
|
||||
client = boto3.client('logs', region_name=region)
|
||||
def get_cloudwatch_client(region, assume_role_arn=None):
|
||||
session = get_boto3_session(region, assume_role_arn)
|
||||
client = session.client('logs')
|
||||
return client
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,8 @@ Create an endpoint in AWS SageMaker Hosting Service for model deployment.
|
|||
Argument | Description | Optional (in pipeline definition) | Optional (in UI) | Data type | Accepted values | Default |
|
||||
:--- | :---------- | :---------- | :---------- | :----------| :---------- | :----------|
|
||||
region | The region where the endpoint is created | No | No | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | Yes | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
|
||||
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
|
||||
endpoint_config_name | The name of the endpoint configuration | Yes | Yes | String | | |
|
||||
endpoint_config_tags | Key-value pairs to tag endpoint configurations in AWS | Yes | Yes | Dict | | {} |
|
||||
endpoint_tags | Key-value pairs to tag the Hosting endpoint in AWS | Yes | Yes | Dict | | {} |
|
||||
|
|
|
|||
|
|
@ -88,6 +88,10 @@ inputs:
|
|||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: assume_role
|
||||
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: endpoint_config_tags
|
||||
description: 'Key-value pairs to categorize AWS resources.'
|
||||
default: '{}'
|
||||
|
|
@ -104,12 +108,13 @@ outputs:
|
|||
- {name: endpoint_name, description: 'Endpoint name'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.6.0
|
||||
image: amazon/aws-sagemaker-kfp-components:0.7.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
deploy.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--assume_role, {inputValue: assume_role},
|
||||
--endpoint_config_name, {inputValue: endpoint_config_name},
|
||||
--variant_name_1,{inputValue: variant_name_1},
|
||||
--model_name_1, {inputValue: model_name_1},
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ def main(argv=None):
|
|||
args = parser.parse_args(argv)
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
|
||||
logging.info('Submitting Endpoint request to SageMaker...')
|
||||
endpoint_name = _utils.deploy_model(client, vars(args))
|
||||
logging.info('Endpoint creation request submitted. Waiting for completion...')
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ For Ground Truth jobs using AWS SageMaker.
|
|||
Argument | Description | Optional | Data type | Accepted values | Default |
|
||||
:--- | :---------- | :----------| :----------| :---------- | :----------|
|
||||
region | The region where the cluster launches | No | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
|
||||
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
|
||||
role | The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf | No | String | | |
|
||||
job_name | The name of the Ground Truth job. Must be unique within the same AWS account and AWS region | Yes | String | | LabelingJob-[datetime]-[random id]|
|
||||
label_attribute_name | The attribute name to use for the label in the output manifest file | Yes | String | | job_name |
|
||||
|
|
@ -39,7 +41,6 @@ time_limit | The maximum run time in seconds per training job | No | Int | [30,
|
|||
task_availibility | The length of time that a task remains available for labeling by human workers | Yes | Int | Public workforce: [1, 43200], other: [1, 864000] | |
|
||||
max_concurrent_tasks | The maximum number of data objects that can be labeled by human workers at the same time | Yes | Int | [1, 1000] | |
|
||||
workforce_task_price | The price that you pay for each task performed by a public worker in USD; Specify to the tenth fractions of a cent; Format as "0.000" | Yes | Float | 0.000 |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint. | Yes | String | | |
|
||||
tags | Key-value pairs to categorize AWS resources | Yes | Dict | | {} |
|
||||
|
||||
## Outputs
|
||||
|
|
|
|||
|
|
@ -110,6 +110,10 @@ inputs:
|
|||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: assume_role
|
||||
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: tags
|
||||
description: 'Key-value pairs to categorize AWS resources.'
|
||||
default: '{}'
|
||||
|
|
@ -119,12 +123,13 @@ outputs:
|
|||
- {name: active_learning_model_arn, description: 'The ARN for the most recent Amazon SageMaker model trained as part of automated data labeling.'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.6.0
|
||||
image: amazon/aws-sagemaker-kfp-components:0.7.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
ground_truth.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--assume_role, {inputValue: assume_role},
|
||||
--role, {inputValue: role},
|
||||
--job_name, {inputValue: job_name},
|
||||
--label_attribute_name, {inputValue: label_attribute_name},
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ def main(argv=None):
|
|||
args = parser.parse_args(argv)
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
|
||||
logging.info('Submitting Ground Truth Job request to SageMaker...')
|
||||
_utils.create_labeling_job(client, vars(args))
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ For hyperparameter tuning jobs using AWS SageMaker.
|
|||
Argument | Description | Optional (in pipeline definition) | Optional (in UI) | Data type | Accepted values | Default |
|
||||
:--- | :---------- | :---------- | :---------- | :----------| :---------- | :----------|
|
||||
region | The region where the cluster launches | No | No | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
|
||||
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
|
||||
job_name | The name of the tuning job. Must be unique within the same AWS account and AWS region | Yes | Yes | String | | HPOJob-[datetime]-[random id] |
|
||||
role | The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf | No | No | String | | |
|
||||
image | The registry path of the Docker image that contains the training algorithm | Yes | Yes | String | | |
|
||||
|
|
@ -44,7 +46,6 @@ max_wait_time | The maximum time in seconds you are willing to wait for a manage
|
|||
checkpoint_config | Dictionary of information about the output location for managed spot training checkpoint data | Yes | Yes | Dict | | {} |
|
||||
warm_start_type | Specifies the type of warm start used | Yes | No | String | IdenticalDataAndAlgorithm, TransferLearning | |
|
||||
parent_hpo_jobs | List of previously completed or stopped hyperparameter tuning jobs to be used as a starting point | Yes | Yes | String | Yes | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint. | Yes | Yes | String | | |
|
||||
tags | Key-value pairs to categorize AWS resources | Yes | Yes | Dict | | {} |
|
||||
|
||||
Notes:
|
||||
|
|
|
|||
|
|
@ -133,6 +133,10 @@ inputs:
|
|||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: assume_role
|
||||
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: tags
|
||||
description: 'Key-value pairs, to categorize AWS resources.'
|
||||
default: '{}'
|
||||
|
|
@ -150,12 +154,13 @@ outputs:
|
|||
description: 'The registry path of the Docker image that contains the training algorithm'
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.6.0
|
||||
image: amazon/aws-sagemaker-kfp-components:0.7.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
hyperparameter_tuning.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--assume_role, {inputValue: assume_role},
|
||||
--job_name, {inputValue: job_name},
|
||||
--role, {inputValue: role},
|
||||
--image, {inputValue: image},
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ def main(argv=None):
|
|||
args = parser.parse_args(argv)
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_sagemaker_client(args.region)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
|
||||
logging.info('Submitting HyperParameter Tuning Job request to SageMaker...')
|
||||
hpo_job_name = _utils.create_hyperparameter_tuning_job(client, vars(args))
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,8 @@ Create a model in Amazon SageMaker to be used for [creating an endpoint](https:/
|
|||
Argument | Description | Optional (in pipeline definition) | Optional (in UI) | Data type | Accepted values | Default |
|
||||
:--- | :---------- | :---------- | :---------- | :----------| :---------- | :----------|
|
||||
region | The region where the model is created | No | No | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | Yes | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
|
||||
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
|
||||
tags | Key-value pairs to tag the model created in AWS | Yes | Yes | Dict | | {} |
|
||||
role | The ARN of the IAM role that Amazon SageMaker can assume to access model artifacts and docker image for deployment on ML compute instances or for batch transform jobs | No | No | String | | |
|
||||
network_isolation | Isolates the model container. No inbound or outbound network calls can be made to or from the model container | Yes | Yes | Boolean | | True |
|
||||
|
|
|
|||
|
|
@ -51,6 +51,10 @@ inputs:
|
|||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: assume_role
|
||||
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: tags
|
||||
description: 'Key-value pairs to categorize AWS resources.'
|
||||
default: '{}'
|
||||
|
|
@ -59,12 +63,13 @@ outputs:
|
|||
- {name: model_name, description: 'The model name SageMaker created'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.6.0
|
||||
image: amazon/aws-sagemaker-kfp-components:0.7.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
create_model.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--assume_role, {inputValue: assume_role},
|
||||
--model_name, {inputValue: model_name},
|
||||
--role, {inputValue: role},
|
||||
--container_host_name, {inputValue: container_host_name},
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ def main(argv=None):
|
|||
args = parser.parse_args(argv)
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
|
||||
|
||||
logging.info('Submitting model creation request to SageMaker...')
|
||||
_utils.create_model(client, vars(args))
|
||||
|
|
|
|||
|
|
@ -11,7 +11,8 @@ For running your data processing workloads, such as feature engineering, data va
|
|||
Argument | Description | Optional | Data type | Accepted values | Default |
|
||||
:--- | :---------- | :----------| :----------| :---------- | :----------|
|
||||
region | The region where the cluster launches | No | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint. | Yes | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
|
||||
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
|
||||
job_name | The name of the Processing job. Must be unique within the same AWS account and AWS region | Yes | String | | ProcessingJob-[datetime]-[random id]|
|
||||
role | The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf | No | String | | |
|
||||
image | The registry path of the Docker image that contains the processing script | Yes | String | | |
|
||||
|
|
|
|||
|
|
@ -80,6 +80,10 @@ inputs:
|
|||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: assume_role
|
||||
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: tags
|
||||
description: 'Key-value pairs, to categorize AWS resources.'
|
||||
default: '{}'
|
||||
|
|
@ -89,12 +93,13 @@ outputs:
|
|||
- {name: output_artifacts, description: 'A dictionary containing the output S3 artifacts'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.6.0
|
||||
image: amazon/aws-sagemaker-kfp-components:0.7.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
process.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--assume_role, {inputValue: assume_role},
|
||||
--job_name, {inputValue: job_name},
|
||||
--role, {inputValue: role},
|
||||
--image, {inputValue: image},
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ def main(argv=None):
|
|||
args = parser.parse_args(argv)
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
|
||||
|
||||
logging.info('Submitting Processing Job to SageMaker...')
|
||||
job_name = _utils.create_processing_job(client, vars(args))
|
||||
|
|
@ -68,7 +68,7 @@ def main(argv=None):
|
|||
except:
|
||||
raise
|
||||
finally:
|
||||
cw_client = _utils.get_cloudwatch_client(args.region)
|
||||
cw_client = _utils.get_cloudwatch_client(args.region, assume_role_arn=args.assume_role)
|
||||
_utils.print_logs_for_job(cw_client, '/aws/sagemaker/ProcessingJobs', job_name)
|
||||
|
||||
outputs = _utils.get_processing_job_outputs(client, job_name)
|
||||
|
|
|
|||
|
|
@ -12,4 +12,7 @@ S3_DATA_BUCKET=my-data-bucket
|
|||
# EKS_EXISTING_CLUSTER=my-eks-cluster
|
||||
|
||||
# If you would like to skip the FSx set-up and tests
|
||||
# SKIP_FSX_TESTS=true
|
||||
# SKIP_FSX_TESTS=true
|
||||
|
||||
# If you have an IAM role that the EKS cluster should assume for the "assume role" tests
|
||||
# ASSUMED_ROLE_NAME=my-assumed-role
|
||||
|
|
@ -14,7 +14,8 @@ from utils import argo_utils
|
|||
pytest.param(
|
||||
"resources/config/kmeans-algo-mnist-processing",
|
||||
marks=pytest.mark.canary_test,
|
||||
)
|
||||
),
|
||||
"resources/config/assume-role-processing"
|
||||
],
|
||||
)
|
||||
def test_processingjob(
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from utils import argo_utils
|
|||
),
|
||||
pytest.param("resources/config/fsx-mnist-training", marks=pytest.mark.fsx_test),
|
||||
"resources/config/spot-sample-pipeline-training",
|
||||
"resources/config/assume-role-training",
|
||||
],
|
||||
)
|
||||
def test_trainingjob(
|
||||
|
|
@ -74,8 +75,9 @@ def test_trainingjob(
|
|||
else:
|
||||
assert f"dkr.ecr.{region}.amazonaws.com" in training_image
|
||||
|
||||
assert not argo_utils.error_in_cw_logs(workflow_json["metadata"]["name"]), \
|
||||
('Found the CloudWatch error message in the log output. Check SageMaker to see if the job has failed.')
|
||||
assert not argo_utils.error_in_cw_logs(
|
||||
workflow_json["metadata"]["name"]
|
||||
), "Found the CloudWatch error message in the log output. Check SageMaker to see if the job has failed."
|
||||
|
||||
utils.remove_dir(download_dir)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,11 @@ def pytest_addoption(parser):
|
|||
parser.addoption(
|
||||
"--role-arn", required=True, help="SageMaker execution IAM role ARN",
|
||||
)
|
||||
parser.addoption(
|
||||
"--assume-role-arn",
|
||||
required=True,
|
||||
help="The ARN of a role which the assume role tests will assume to access SageMaker.",
|
||||
)
|
||||
parser.addoption(
|
||||
"--s3-data-bucket",
|
||||
required=True,
|
||||
|
|
@ -61,6 +66,12 @@ def region(request):
|
|||
return request.config.getoption("--region")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def assume_role_arn(request):
|
||||
os.environ["ASSUME_ROLE_ARN"] = request.config.getoption("--assume-role-arn")
|
||||
return request.config.getoption("--assume-role-arn")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def role_arn(request):
|
||||
os.environ["ROLE_ARN"] = request.config.getoption("--role-arn")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,47 @@
|
|||
PipelineDefinition: resources/definition/processing_pipeline.py
|
||||
TestName: assume-role-processing
|
||||
Timeout: 1800
|
||||
Arguments:
|
||||
region: ((REGION))
|
||||
assume_role: ((ASSUME_ROLE_ARN))
|
||||
job_name: process-mnist-for-kmeans
|
||||
image: 763104351884.dkr.ecr.((REGION)).amazonaws.com/pytorch-training:1.5.0-cpu-py36-ubuntu16.04
|
||||
container_entrypoint:
|
||||
- python
|
||||
- /opt/ml/processing/code/kmeans_preprocessing.py
|
||||
input_config:
|
||||
- InputName: mnist_tar
|
||||
S3Input:
|
||||
S3Uri: s3://sagemaker-sample-data-((REGION))/algorithms/kmeans/mnist/mnist.pkl.gz
|
||||
LocalPath: /opt/ml/processing/input
|
||||
S3DataType: S3Prefix
|
||||
S3InputMode: File
|
||||
S3CompressionType: None
|
||||
- InputName: source_code
|
||||
S3Input:
|
||||
S3Uri: s3://((DATA_BUCKET))/mnist_kmeans_example/processing_code/kmeans_preprocessing.py
|
||||
LocalPath: /opt/ml/processing/code
|
||||
S3DataType: S3Prefix
|
||||
S3InputMode: File
|
||||
S3CompressionType: None
|
||||
output_config:
|
||||
- OutputName: train_data
|
||||
S3Output:
|
||||
S3Uri: s3://((DATA_BUCKET))/mnist_kmeans_example/output/
|
||||
LocalPath: /opt/ml/processing/output_train/
|
||||
S3UploadMode: EndOfJob
|
||||
- OutputName: test_data
|
||||
S3Output:
|
||||
S3Uri: s3://((DATA_BUCKET))/mnist_kmeans_example/output/
|
||||
LocalPath: /opt/ml/processing/output_test/
|
||||
S3UploadMode: EndOfJob
|
||||
- OutputName: valid_data
|
||||
S3Output:
|
||||
S3Uri: s3://((DATA_BUCKET))/mnist_kmeans_example/output/
|
||||
LocalPath: /opt/ml/processing/output_valid/
|
||||
S3UploadMode: EndOfJob
|
||||
instance_type: ml.m5.xlarge
|
||||
instance_count: 1
|
||||
volume_size: 50
|
||||
max_run_time: 1800
|
||||
role: ((ROLE_ARN))
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
PipelineDefinition: resources/definition/training_pipeline.py
|
||||
TestName: assume-role-training
|
||||
Timeout: 3600
|
||||
ExpectedTrainingImage: ((KMEANS_REGISTRY)).dkr.ecr.((REGION)).amazonaws.com/kmeans:1
|
||||
Arguments:
|
||||
region: ((REGION))
|
||||
image: ((KMEANS_REGISTRY)).dkr.ecr.((REGION)).amazonaws.com/kmeans:1
|
||||
assume_role: ((ASSUME_ROLE_ARN))
|
||||
training_input_mode: File
|
||||
hyperparameters:
|
||||
k: "10"
|
||||
feature_dim: "784"
|
||||
channels:
|
||||
- ChannelName: train
|
||||
DataSource:
|
||||
S3DataSource:
|
||||
S3Uri: s3://((DATA_BUCKET))/mnist_kmeans_example/data
|
||||
S3DataType: S3Prefix
|
||||
S3DataDistributionType: FullyReplicated
|
||||
CompressionType: None
|
||||
RecordWrapperType: None
|
||||
InputMode: File
|
||||
instance_type: ml.m5.xlarge
|
||||
instance_count: 1
|
||||
volume_size: 50
|
||||
max_run_time: 3600
|
||||
model_artifact_path: s3://((DATA_BUCKET))/mnist_kmeans_example/output
|
||||
network_isolation: "True"
|
||||
traffic_encryption: "False"
|
||||
spot_instance: "False"
|
||||
max_wait_time: 3600
|
||||
checkpoint_config: "{}"
|
||||
role: ((ROLE_ARN))
|
||||
|
|
@ -23,6 +23,7 @@ def processing_pipeline(
|
|||
output_config={},
|
||||
network_isolation=False,
|
||||
role="",
|
||||
assume_role="",
|
||||
):
|
||||
sagemaker_process_op(
|
||||
region=region,
|
||||
|
|
@ -39,6 +40,7 @@ def processing_pipeline(
|
|||
output_config=output_config,
|
||||
network_isolation=network_isolation,
|
||||
role=role,
|
||||
assume_role=assume_role,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ def training_pipeline(
|
|||
checkpoint_config="{}",
|
||||
vpc_security_group_ids="",
|
||||
vpc_subnets="",
|
||||
assume_role="",
|
||||
role="",
|
||||
):
|
||||
sagemaker_train_op(
|
||||
|
|
@ -50,6 +51,7 @@ def training_pipeline(
|
|||
checkpoint_config=checkpoint_config,
|
||||
vpc_security_group_ids=vpc_security_group_ids,
|
||||
vpc_subnets=vpc_subnets,
|
||||
assume_role=assume_role,
|
||||
role=role,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ CLUSTER_REGION="${3:-us-east-1}"
|
|||
SERVICE_NAMESPACE="${4:-kubeflow}"
|
||||
SERVICE_ACCOUNT="${5:-pipeline-runner}"
|
||||
aws_account=$(aws sts get-caller-identity --query Account --output text)
|
||||
trustfile="trust.json"
|
||||
trust_file="trust.json"
|
||||
assume_role_file="assume-role.json"
|
||||
|
||||
cwd=$(dirname $(realpath $0))
|
||||
|
||||
|
|
@ -36,15 +37,27 @@ function get_oidc_id {
|
|||
# Parameter:
|
||||
# $1: Name of the trust file to generate.
|
||||
function create_namespaced_iam_role {
|
||||
local trustfile="${1}"
|
||||
local trust_file_path="${1}"
|
||||
# Check if role already exists
|
||||
aws iam get-role --role-name ${ROLE_NAME}
|
||||
if [[ $? -eq 0 ]]; then
|
||||
echo "A role for this cluster and namespace already exists in this account, assuming sagemaker access and proceeding."
|
||||
echo "A role for this cluster and namespace already exists in this account, assuming SageMaker and AssumeRole access and proceeding."
|
||||
else
|
||||
echo "IAM Role does not exist, creating a new Role for the cluster"
|
||||
aws iam create-role --role-name ${ROLE_NAME} --assume-role-policy-document file://${trustfile} --output=text --query "Role.Arn"
|
||||
aws iam attach-role-policy --role-name ${ROLE_NAME} --policy-arn arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
|
||||
aws iam create-role --role-name ${ROLE_NAME} --assume-role-policy-document file://${trust_file_path} --output=text --query "Role.Arn"
|
||||
aws iam attach-role-policy --role-name ${ROLE_NAME} --policy-arn arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
|
||||
|
||||
printf '{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Action": "sts:AssumeRole",
|
||||
"Resource": "arn:aws:iam::'"${aws_account}"':role/*"
|
||||
}
|
||||
]
|
||||
}' > ${assume_role_file}
|
||||
aws iam put-role-policy --role-name ${ROLE_NAME} --policy-name AllowAssumeRole --policy-document file://${assume_role_file}
|
||||
fi
|
||||
}
|
||||
|
||||
|
|
@ -58,11 +71,11 @@ function delete_generated_file {
|
|||
echo "Get the OIDC ID for the cluster"
|
||||
get_oidc_id
|
||||
echo "Delete the trust json file if it already exists"
|
||||
delete_generated_file "${trustfile}"
|
||||
delete_generated_file "${trust_file}"
|
||||
echo "Generate a trust json"
|
||||
"$cwd"/generate_trust_policy ${CLUSTER_REGION} ${aws_account} ${oidc_id} ${SERVICE_NAMESPACE} ${SERVICE_ACCOUNT} > "${trustfile}"
|
||||
"$cwd"/generate_trust_policy ${CLUSTER_REGION} ${aws_account} ${oidc_id} ${SERVICE_NAMESPACE} ${SERVICE_ACCOUNT} > "${trust_file}"
|
||||
echo "Create the IAM Role using these values"
|
||||
create_namespaced_iam_role "${trustfile}"
|
||||
create_namespaced_iam_role "${trust_file}"
|
||||
echo "Cleanup for the next run"
|
||||
delete_generated_file "${trustfile}"
|
||||
delete_generated_file "${trust_file}"
|
||||
|
||||
|
|
|
|||
|
|
@ -26,10 +26,12 @@ EKS_PRIVATE_SUBNETS=${EKS_PRIVATE_SUBNETS:-""}
|
|||
MINIO_LOCAL_PORT=${MINIO_LOCAL_PORT:-9000}
|
||||
KFP_NAMESPACE=${KFP_NAMESPACE:-"kubeflow"}
|
||||
KFP_SERVICE_ACCOUNT=${KFP_SERVICE_ACCOUNT:-"pipeline-runner"}
|
||||
AWS_ACCOUNT_ID=${AWS_ACCOUNT_ID:-"$(aws sts get-caller-identity --query=Account --output=text)"}
|
||||
|
||||
PYTEST_MARKER=${PYTEST_MARKER:-""}
|
||||
S3_DATA_BUCKET=${S3_DATA_BUCKET:-""}
|
||||
SAGEMAKER_EXECUTION_ROLE_ARN=${SAGEMAKER_EXECUTION_ROLE_ARN:-""}
|
||||
ASSUMED_ROLE_NAME=${ASSUMED_ROLE_NAME:-""}
|
||||
|
||||
SKIP_FSX_TESTS=${SKIP_FSX_TESTS:-"false"}
|
||||
|
||||
|
|
@ -78,7 +80,8 @@ function cleanup() {
|
|||
set +e
|
||||
|
||||
cleanup_kfp
|
||||
delete_generated_role
|
||||
delete_assumed_role
|
||||
delete_oidc_role
|
||||
|
||||
if [[ "${SKIP_FSX_TESTS}" == "false" ]]; then
|
||||
delete_fsx_instance
|
||||
|
|
@ -141,21 +144,60 @@ function install_kfp() {
|
|||
echo "[Installing KFP] Pipeline pods are ready"
|
||||
}
|
||||
|
||||
function generate_iam_role_name() {
|
||||
function generate_oidc_role_name() {
|
||||
OIDC_ROLE_NAME="$(echo "${DEPLOY_NAME}-kubeflow-role" | cut -c1-64)"
|
||||
OIDC_ROLE_ARN="arn:aws:iam::$(aws sts get-caller-identity --query=Account --output=text):role/${OIDC_ROLE_NAME}"
|
||||
OIDC_ROLE_ARN="arn:aws:iam::${AWS_ACCOUNT_ID}:role/${OIDC_ROLE_NAME}"
|
||||
}
|
||||
|
||||
function install_generated_role() {
|
||||
function install_oidc_role() {
|
||||
kubectl patch serviceaccount -n ${KFP_NAMESPACE} ${KFP_SERVICE_ACCOUNT} --patch '{"metadata": {"annotations": {"eks.amazonaws.com/role-arn": "'"${OIDC_ROLE_ARN}"'"}}}'
|
||||
}
|
||||
|
||||
function delete_generated_role() {
|
||||
function delete_oidc_role() {
|
||||
# Delete the role associated with the cluster thats being deleted
|
||||
aws iam detach-role-policy --role-name "${OIDC_ROLE_NAME}" --policy-arn arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
|
||||
aws iam delete-role-policy --role-name "${OIDC_ROLE_NAME}" --policy-name AllowAssumeRole
|
||||
aws iam delete-role --role-name "${OIDC_ROLE_NAME}"
|
||||
}
|
||||
|
||||
function generate_assumed_role() {
|
||||
# If not defined in the env file
|
||||
if [[ -z "${ASSUMED_ROLE_NAME}" ]]; then
|
||||
ASSUMED_ROLE_NAME="${DEPLOY_NAME}-assumed-role"
|
||||
CREATED_ASSUMED_ROLE="true"
|
||||
|
||||
# Create a trust file that allows the OIDC role to authenticate
|
||||
local assumed_trust_file="assumed-role-trust.json"
|
||||
printf '{
|
||||
"Version": "2012-10-17",
|
||||
"Statement": [
|
||||
{
|
||||
"Effect": "Allow",
|
||||
"Principal": {
|
||||
"AWS": "arn:aws:iam::'"${AWS_ACCOUNT_ID}"':root"
|
||||
},
|
||||
"Action": "sts:AssumeRole",
|
||||
"Condition": {}
|
||||
}
|
||||
]
|
||||
}' > "${assumed_trust_file}"
|
||||
aws iam create-role --role-name "${ASSUMED_ROLE_NAME}" --assume-role-policy-document file://${assumed_trust_file} --output=text --query "Role.Arn"
|
||||
aws iam attach-role-policy --role-name ${ASSUMED_ROLE_NAME} --policy-arn arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
|
||||
fi
|
||||
|
||||
# Generate the ARN using the role name
|
||||
ASSUMED_ROLE_ARN="arn:aws:iam::${AWS_ACCOUNT_ID}:role/${ASSUMED_ROLE_NAME}"
|
||||
}
|
||||
|
||||
function delete_assumed_role() {
|
||||
# Ensure that the automated script created the assumed role
|
||||
if [[ ! -z "${ASSUMED_ROLE_NAME}" && "${CREATED_ASSUMED_ROLE:-false}" == "true" ]]; then
|
||||
# Delete the role associated with the cluster thats being deleted
|
||||
aws iam detach-role-policy --role-name "${ASSUMED_ROLE_NAME}" --policy-arn arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
|
||||
aws iam delete-role --role-name "${ASSUMED_ROLE_NAME}"
|
||||
fi
|
||||
}
|
||||
|
||||
function cleanup_kfp() {
|
||||
# Clean up Minio
|
||||
if [[ ! -z "${MINIO_PID:-}" ]]; then
|
||||
|
|
@ -187,12 +229,16 @@ else
|
|||
wait
|
||||
fi
|
||||
|
||||
generate_iam_role_name
|
||||
generate_oidc_role_name
|
||||
"$cwd"/generate_iam_role ${EKS_CLUSTER_NAME} ${OIDC_ROLE_NAME} ${REGION} ${KFP_NAMESPACE} ${KFP_SERVICE_ACCOUNT}
|
||||
install_kfp
|
||||
install_generated_role
|
||||
install_oidc_role
|
||||
|
||||
pytest_args=( --region "${REGION}" --role-arn "${SAGEMAKER_EXECUTION_ROLE_ARN}" --s3-data-bucket "${S3_DATA_BUCKET}" --minio-service-port "${MINIO_LOCAL_PORT}" --kfp-namespace "${KFP_NAMESPACE}" )
|
||||
generate_assumed_role
|
||||
|
||||
pytest_args=( --region "${REGION}" --role-arn "${SAGEMAKER_EXECUTION_ROLE_ARN}" \
|
||||
--s3-data-bucket "${S3_DATA_BUCKET}" --kfp-namespace "${KFP_NAMESPACE}" \
|
||||
--minio-service-port "${MINIO_LOCAL_PORT}" --assume-role-arn "${ASSUMED_ROLE_ARN}")
|
||||
|
||||
if [[ "${SKIP_FSX_TESTS}" == "true" ]]; then
|
||||
pytest_args+=( -m "not fsx_test" )
|
||||
|
|
|
|||
|
|
@ -42,6 +42,10 @@ def get_fsx_id():
|
|||
return os.environ.get("FSX_ID")
|
||||
|
||||
|
||||
def get_assume_role_arn():
|
||||
return os.environ.get("ASSUME_ROLE_ARN")
|
||||
|
||||
|
||||
def get_algorithm_image_registry(region, algorithm):
|
||||
return get_image_uri(region, algorithm).split(".")[0]
|
||||
|
||||
|
|
@ -84,6 +88,7 @@ def replace_placeholders(input_filename, output_filename):
|
|||
"((FSX_ID))": get_fsx_id(),
|
||||
"((FSX_SUBNET))": get_fsx_subnet(),
|
||||
"((FSX_SECURITY_GROUP))": get_fsx_security_group(),
|
||||
"((ASSUME_ROLE_ARN))": get_assume_role_arn()
|
||||
}
|
||||
|
||||
filedata = ""
|
||||
|
|
|
|||
|
|
@ -53,6 +53,20 @@ class BatchTransformTestCase(unittest.TestCase):
|
|||
call('/tmp/output', 's3://fake-bucket/output')
|
||||
])
|
||||
|
||||
def test_main_assumes_role(self):
|
||||
# Mock out all of utils except parser
|
||||
batch_transform._utils = MagicMock()
|
||||
batch_transform._utils.add_default_client_arguments = _utils.add_default_client_arguments
|
||||
|
||||
# Set some static returns
|
||||
batch_transform._utils.create_transform_job.return_value = 'test-batch-job'
|
||||
|
||||
assume_role_args = required_args + ['--assume_role', 'my-role']
|
||||
|
||||
batch_transform.main(assume_role_args)
|
||||
|
||||
batch_transform._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
|
||||
|
||||
|
||||
def test_batch_transform(self):
|
||||
mock_client = MagicMock()
|
||||
|
|
|
|||
|
|
@ -41,6 +41,20 @@ class DeployTestCase(unittest.TestCase):
|
|||
call('/tmp/output', 'test-endpoint-name')
|
||||
])
|
||||
|
||||
def test_main_assumes_role(self):
|
||||
# Mock out all of utils except parser
|
||||
deploy._utils = MagicMock()
|
||||
deploy._utils.add_default_client_arguments = _utils.add_default_client_arguments
|
||||
|
||||
# Set some static returns
|
||||
deploy._utils.deploy_model.return_value = 'test-endpoint-name'
|
||||
|
||||
assume_role_args = required_args + ['--assume_role', 'my-role']
|
||||
|
||||
deploy.main(assume_role_args)
|
||||
|
||||
deploy._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
|
||||
|
||||
def test_deploy_model(self):
|
||||
mock_client = MagicMock()
|
||||
mock_args = self.parser.parse_args(required_args + ['--endpoint_name', 'test-endpoint-name', '--endpoint_config_name', 'test-endpoint-config-name'])
|
||||
|
|
|
|||
|
|
@ -56,6 +56,20 @@ class GroundTruthTestCase(unittest.TestCase):
|
|||
call('/tmp/model-output', 'arn:aws:sagemaker:us-east-1:999999999999:labeling-job')
|
||||
])
|
||||
|
||||
def test_main_assumes_role(self):
|
||||
# Mock out all of utils except parser
|
||||
ground_truth._utils = MagicMock()
|
||||
ground_truth._utils.add_default_client_arguments = _utils.add_default_client_arguments
|
||||
|
||||
# Set some static returns
|
||||
ground_truth._utils.get_labeling_job_outputs.return_value = ('s3://fake-bucket/output', 'arn:aws:sagemaker:us-east-1:999999999999:labeling-job')
|
||||
|
||||
assume_role_args = required_args + ['--assume_role', 'my-role']
|
||||
|
||||
ground_truth.main(assume_role_args)
|
||||
|
||||
ground_truth._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
|
||||
|
||||
def test_ground_truth(self):
|
||||
mock_client = MagicMock()
|
||||
mock_args = self.parser.parse_args(required_args)
|
||||
|
|
|
|||
|
|
@ -87,6 +87,21 @@ class HyperparameterTestCase(unittest.TestCase):
|
|||
call('/tmp/best_hyperparameters_output_path', {"key_1": "best_hp_1"}, json_encode=True),
|
||||
call('/tmp/training_image_output_path', 'training-image')
|
||||
])
|
||||
|
||||
def test_main_assumes_role(self):
|
||||
# Mock out all of utils except parser
|
||||
hpo._utils = MagicMock()
|
||||
hpo._utils.add_default_client_arguments = _utils.add_default_client_arguments
|
||||
|
||||
# Set some static returns
|
||||
hpo._utils.create_hyperparameter_tuning_job.return_value = 'job-name'
|
||||
hpo._utils.get_best_training_job_and_hyperparameters.return_value = 'best_job', {"key_1": "best_hp_1"}
|
||||
|
||||
assume_role_args = required_args + ['--assume_role', 'my-role']
|
||||
|
||||
hpo.main(assume_role_args)
|
||||
|
||||
hpo._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
|
||||
|
||||
def test_create_hyperparameter_tuning_job(self):
|
||||
mock_client = MagicMock()
|
||||
|
|
@ -125,7 +140,7 @@ class HyperparameterTestCase(unittest.TestCase):
|
|||
|
||||
def test_main_stop_hyperparameter_tuning_job(self):
|
||||
hpo._utils = MagicMock()
|
||||
hpo._utils.create_processing_job.return_value = 'job-name'
|
||||
hpo._utils.create_hyperparameter_tuning_job.return_value = 'job-name'
|
||||
|
||||
try:
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
|
|
|
|||
|
|
@ -43,6 +43,20 @@ class ModelTestCase(unittest.TestCase):
|
|||
call('/tmp/output', 'model_test')
|
||||
])
|
||||
|
||||
def test_main_assumes_role(self):
|
||||
# Mock out all of utils except parser
|
||||
create_model._utils = MagicMock()
|
||||
create_model._utils.add_default_client_arguments = _utils.add_default_client_arguments
|
||||
|
||||
# Set some static returns
|
||||
create_model._utils.create_model.return_value = 'model_test'
|
||||
|
||||
assume_role_args = required_args + ['--assume_role', 'my-role']
|
||||
|
||||
create_model.main(assume_role_args)
|
||||
|
||||
create_model._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
|
||||
|
||||
def test_create_model(self):
|
||||
mock_client = MagicMock()
|
||||
mock_args = self.parser.parse_args(required_args)
|
||||
|
|
|
|||
|
|
@ -68,6 +68,21 @@ class ProcessTestCase(unittest.TestCase):
|
|||
call('/tmp/output_artifacts_output_path', mock_outputs, json_encode=True)
|
||||
])
|
||||
|
||||
def test_main_assumes_role(self):
|
||||
# Mock out all of utils except parser
|
||||
process._utils = MagicMock()
|
||||
process._utils.add_default_client_arguments = _utils.add_default_client_arguments
|
||||
|
||||
# Set some static returns
|
||||
process._utils.create_processing_job.return_value = 'job-name'
|
||||
process._utils.get_processing_job_outputs.return_value = mock_outputs = {'val1': 's3://1', 'val2': 's3://2'}
|
||||
|
||||
assume_role_args = required_args + ['--assume_role', 'my-role']
|
||||
|
||||
process.main(assume_role_args)
|
||||
|
||||
process._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
|
||||
|
||||
def test_create_processing_job(self):
|
||||
mock_client = MagicMock()
|
||||
mock_args = self.parser.parse_args(required_args + ['--job_name', 'test-job'])
|
||||
|
|
|
|||
|
|
@ -57,6 +57,22 @@ class TrainTestCase(unittest.TestCase):
|
|||
call('/tmp/training_image_output_path', 'training-image')
|
||||
])
|
||||
|
||||
def test_main_assumes_role(self):
|
||||
# Mock out all of utils except parser
|
||||
train._utils = MagicMock()
|
||||
train._utils.add_default_client_arguments = _utils.add_default_client_arguments
|
||||
|
||||
# Set some static returns
|
||||
train._utils.create_training_job.return_value = 'job-name'
|
||||
train._utils.get_image_from_job.return_value = 'training-image'
|
||||
train._utils.get_model_artifacts_from_job.return_value = 'model-artifacts'
|
||||
|
||||
assume_role_args = required_args + ['--assume_role', 'my-role']
|
||||
|
||||
train.main(assume_role_args)
|
||||
|
||||
train._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
|
||||
|
||||
def test_create_training_job(self):
|
||||
mock_client = MagicMock()
|
||||
mock_args = self.parser.parse_args(required_args + ['--job_name', 'test-job'])
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import unittest
|
||||
import json
|
||||
|
||||
from unittest.mock import patch, call, Mock, MagicMock, mock_open
|
||||
from unittest.mock import patch, call, Mock, MagicMock, mock_open, ANY
|
||||
from boto3.session import Session
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from common import _utils
|
||||
|
|
@ -66,3 +67,37 @@ class UtilsTestCase(unittest.TestCase):
|
|||
mock_path("/tmp/test-output").write_text.assert_called_with(
|
||||
json.dumps(case)
|
||||
)
|
||||
|
||||
def test_assume_default_boto3_session(self):
|
||||
with patch("common._utils.boto3", MagicMock()) as mock_boto3:
|
||||
returned_session = _utils.get_boto3_session("us-east-1")
|
||||
|
||||
assert isinstance(returned_session, Session)
|
||||
assert returned_session.region_name == "us-east-1"
|
||||
mock_boto3.assert_not_called()
|
||||
|
||||
@patch("common._utils.DeferredRefreshableCredentials", MagicMock())
|
||||
@patch("common._utils.AssumeRoleCredentialFetcher", MagicMock())
|
||||
def test_assume_role_boto3_session(self):
|
||||
returned_session = _utils.get_boto3_session("us-east-1", role_arn="abc123")
|
||||
|
||||
assert isinstance(returned_session, Session)
|
||||
assert returned_session.region_name == "us-east-1"
|
||||
|
||||
# Bury into the internals to ensure our provider was registered correctly
|
||||
our_provider = returned_session._session._components.get_component('credential_provider').providers[0]
|
||||
assert isinstance(our_provider, _utils.AssumeRoleProvider)
|
||||
|
||||
def test_assumed_sagemaker_client(self):
|
||||
_utils.get_boto3_session = MagicMock()
|
||||
|
||||
mock_sm_client = MagicMock()
|
||||
# Mock the client("SageMaker", ...) return value
|
||||
_utils.get_boto3_session.return_value.client.return_value = mock_sm_client
|
||||
|
||||
client = _utils.get_sagemaker_client("us-east-1", assume_role_arn="abc123")
|
||||
assert client == mock_sm_client
|
||||
|
||||
_utils.get_boto3_session.assert_called_once_with("us-east-1", "abc123")
|
||||
_utils.get_boto3_session.return_value.client.assert_called_once_with("sagemaker", endpoint_url=None, config=ANY)
|
||||
|
||||
|
|
@ -23,6 +23,20 @@ class WorkTeamTestCase(unittest.TestCase):
|
|||
def test_create_parser(self):
|
||||
self.assertIsNotNone(self.parser)
|
||||
|
||||
def test_main(self):
|
||||
# Mock out all of utils except parser
|
||||
workteam._utils = MagicMock()
|
||||
workteam._utils.add_default_client_arguments = _utils.add_default_client_arguments
|
||||
|
||||
# Set some static returns
|
||||
workteam._utils.create_workteam.return_value = 'arn:aws:sagemaker:us-east-1:999999999999:work-team'
|
||||
|
||||
assume_role_args = required_args + ['--assume_role', 'my-role']
|
||||
|
||||
workteam.main(assume_role_args)
|
||||
|
||||
workteam._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
|
||||
|
||||
def test_main(self):
|
||||
# Mock out all of utils except parser
|
||||
workteam._utils = MagicMock()
|
||||
|
|
|
|||
|
|
@ -12,7 +12,8 @@ For model training using AWS SageMaker.
|
|||
Argument | Description | Optional | Data type | Accepted values | Default |
|
||||
:--- | :---------- | :----------| :----------| :---------- | :----------|
|
||||
region | The region where the cluster launches | No | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint. | Yes | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
|
||||
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
|
||||
job_name | The name of the Ground Truth job. Must be unique within the same AWS account and AWS region | Yes | String | | LabelingJob-[datetime]-[random id]|
|
||||
role | The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf | No | String | | |
|
||||
image | The registry path of the Docker image that contains the training algorithm | Yes | String | | |
|
||||
|
|
|
|||
|
|
@ -94,6 +94,10 @@ inputs:
|
|||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: assume_role
|
||||
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: tags
|
||||
description: 'Key-value pairs, to categorize AWS resources.'
|
||||
default: '{}'
|
||||
|
|
@ -104,12 +108,13 @@ outputs:
|
|||
- {name: training_image, description: 'The registry path of the Docker image that contains the training algorithm'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.6.0
|
||||
image: amazon/aws-sagemaker-kfp-components:0.7.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
train.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--assume_role, {inputValue: assume_role},
|
||||
--job_name, {inputValue: job_name},
|
||||
--role, {inputValue: role},
|
||||
--image, {inputValue: image},
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ def main(argv=None):
|
|||
args = parser.parse_args(argv)
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
|
||||
|
||||
logging.info('Submitting Training Job to SageMaker...')
|
||||
job_name = _utils.create_training_job(client, vars(args))
|
||||
|
|
@ -78,7 +78,7 @@ def main(argv=None):
|
|||
except:
|
||||
raise
|
||||
finally:
|
||||
cw_client = _utils.get_cloudwatch_client(args.region)
|
||||
cw_client = _utils.get_cloudwatch_client(args.region, assume_role_arn=args.assume_role)
|
||||
_utils.print_logs_for_job(cw_client, '/aws/sagemaker/TrainingJobs', job_name)
|
||||
|
||||
image = _utils.get_image_from_job(client, job_name)
|
||||
|
|
|
|||
|
|
@ -11,13 +11,14 @@ For creating a private workteam from pre-existing Amazon Cognito user groups usi
|
|||
Argument | Description | Optional | Data type | Accepted values | Default |
|
||||
:--- | :---------- | :----------| :----------| :---------- | :----------|
|
||||
region | The region where the cluster launches | No | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
|
||||
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
|
||||
team_name | The name of your work team | No | String | | |
|
||||
description | A description of the work team | No | String | | |
|
||||
user_pool | An identifier for a user pool, which must be in the same region as the service that you are calling | No | String | | |
|
||||
user_groups | An identifier for user groups separated by commas | No | String | | |
|
||||
client_id | An identifier for an application client, which you must create using Amazon Cognito | No | String | | |
|
||||
sns_topic | The AWS KMS key that Amazon SageMaker uses to encrypt the model artifacts | Yes | String | | |
|
||||
endpoint_url | The endpoint URL for the private link VPC endpoint. | Yes | Yes | String | | |
|
||||
tags | Key-value pairs to categorize AWS resources | Yes | Dict | | {} |
|
||||
|
||||
Notes:
|
||||
|
|
|
|||
|
|
@ -28,6 +28,10 @@ inputs:
|
|||
description: 'The endpoint URL for the private link VPC endpoint.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: assume_role
|
||||
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
|
||||
default: ''
|
||||
type: String
|
||||
- name: tags
|
||||
description: 'Key-value pairs to categorize AWS resources.'
|
||||
default: '{}'
|
||||
|
|
@ -36,12 +40,13 @@ outputs:
|
|||
- {name: workteam_arn, description: 'The ARN of the workteam.'}
|
||||
implementation:
|
||||
container:
|
||||
image: amazon/aws-sagemaker-kfp-components:0.5.3
|
||||
image: amazon/aws-sagemaker-kfp-components:0.7.0
|
||||
command: ['python3']
|
||||
args: [
|
||||
workteam.py,
|
||||
--region, {inputValue: region},
|
||||
--endpoint_url, {inputValue: endpoint_url},
|
||||
--assume_role, {inputValue: assume_role},
|
||||
--team_name, {inputValue: team_name},
|
||||
--description, {inputValue: description},
|
||||
--user_pool, {inputValue: user_pool},
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def main(argv=None):
|
|||
args = parser.parse_args(argv)
|
||||
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
|
||||
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
|
||||
logging.info('Submitting a create workteam request to SageMaker...')
|
||||
workteam_arn = _utils.create_workteam(client, vars(args))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue