feat(components): AWS SageMaker - Support for assuming a role (#4212)

* Add client assume role functionality

* Add assume_role to component.yaml files

* Update image to personal

* Update input to force NoneType on empty

* Update integration test setup with assumed role

* Add assume role integration test

* Update boto session to use refreshing credentials

* Update assume role relax trust relationship

* Add check for defined assumed role name

* Add processing assume integ test

* Add assume role unit test for main methods

* Add assume_role to all READMEs

* Update session to use AssumeRoleProvider

* Remove region from child calls to session

* Fix extra region_name in test

* Update assume role processing integ test name

* Add processing integ test to list

* Update assumed role to remain if not generated

* Update license version

* Update image tag to new version

* Add new version to Changelog
This commit is contained in:
Nicholas Thomson 2020-08-03 10:53:43 -07:00 committed by GitHub
parent 704c8c7660
commit 8014a44229
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
47 changed files with 480 additions and 57 deletions

View File

@ -4,6 +4,12 @@ The version of the AWS SageMaker Components is determined by the docker image ta
Repository: https://hub.docker.com/repository/docker/amazon/aws-sagemaker-kfp-components
---------------------------------------------
**Change log for version 0.7.0**
- Add functionality to assume role when sending SageMaker requests
> Pull requests : [#4212](https://github.com/kubeflow/pipelines/pull/4212)
**Change log for version 0.6.0**
- Add functionality to stop SageMaker jobs on run termination

View File

@ -1,4 +1,4 @@
** Amazon SageMaker Components for Kubeflow Pipelines; version 0.6.0 --
** Amazon SageMaker Components for Kubeflow Pipelines; version 0.7.0 --
https://github.com/kubeflow/pipelines/tree/master/components/aws/sagemaker
Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
** boto3; version 1.12.33 -- https://github.com/boto/boto3/

View File

@ -17,7 +17,8 @@ Create a transform job in AWS SageMaker.
Argument | Description | Optional (in pipeline definition) | Optional (in UI) | Data type | Accepted values | Default |
:--- | :---------- | :---------- | :---------- | :----------| :---------- | :----------|
region | The region where the endpoint is created | No | No | String | | |
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | Yes | String | | |
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
job_name | The name of the transform job. The name must be unique within an AWS Region in an AWS account | Yes | Yes | String | | is a generated name (combination of model_name and 'BatchTransform' string)|
model_name | The name of the model that you want to use for the transform job. Model name must be the name of an existing Amazon SageMaker model within an AWS Region in an AWS account | No | No | String | | |
max_concurrent | The maximum number of parallel requests that can be sent to each instance in a transform job | Yes | Yes | Integer | | 0 |

View File

@ -90,6 +90,10 @@ inputs:
description: 'The endpoint URL for the private link VPC endpoint.'
default: ''
type: String
- name: assume_role
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
default: ''
type: String
- name: tags
description: 'Key-value pairs to categorize AWS resources.'
default: '{}'
@ -98,12 +102,13 @@ outputs:
- {name: output_location, description: 'S3 URI of the transform job results.'}
implementation:
container:
image: amazon/aws-sagemaker-kfp-components:0.6.0
image: amazon/aws-sagemaker-kfp-components:0.7.0
command: ['python3']
args: [
batch_transform.py,
--region, {inputValue: region},
--endpoint_url, {inputValue: endpoint_url},
--assume_role, {inputValue: assume_role},
--job_name, {inputValue: job_name},
--model_name, {inputValue: model_name},
--max_concurrent, {inputValue: max_concurrent},

View File

@ -52,7 +52,7 @@ def main(argv=None):
args = parser.parse_args(argv)
logging.getLogger().setLevel(logging.INFO)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
logging.info('Submitting Batch Transformation request to SageMaker...')
batch_job_name = _utils.create_transform_job(client, vars(args))
@ -68,7 +68,7 @@ def main(argv=None):
except:
raise
finally:
cw_client = _utils.get_cloudwatch_client(args.region)
cw_client = _utils.get_cloudwatch_client(args.region, assume_role_arn=args.assume_role)
_utils.print_logs_for_job(cw_client, '/aws/sagemaker/TransformJobs', batch_job_name)
_utils.write_output(args.output_location_output_path, args.output_location)

View File

@ -24,8 +24,18 @@ import json
from pathlib2 import Path
import boto3
import botocore
from boto3.session import Session
from botocore.config import Config
from botocore.credentials import (
AssumeRoleCredentialFetcher,
CredentialResolver,
DeferredRefreshableCredentials,
JSONFileCache
)
from botocore.exceptions import ClientError
from botocore.session import Session as BotocoreSession
from sagemaker.amazon.amazon_estimator import get_image_uri
import logging
@ -71,6 +81,7 @@ def nullable_string_argument(value):
def add_default_client_arguments(parser):
parser.add_argument('--region', type=str, required=True, help='The region where the training job launches.')
parser.add_argument('--endpoint_url', type=nullable_string_argument, required=False, help='The URL to use when communicating with the SageMaker service.')
parser.add_argument('--assume_role', type=nullable_string_argument, required=False, help='The ARN of an IAM role to assume when connecting to SageMaker.')
def get_component_version():
@ -113,17 +124,59 @@ def print_logs_for_job(cw_client, log_grp, job_name):
logging.error(e)
def get_sagemaker_client(region, endpoint_url=None):
class AssumeRoleProvider(object):
METHOD = 'assume-role'
def __init__(self, fetcher):
self._fetcher = fetcher
def load(self):
return DeferredRefreshableCredentials(
self._fetcher.fetch_credentials,
self.METHOD
)
def get_boto3_session(region, role_arn=None):
"""Creates a boto3 session, optionally assuming a role"""
# By default return a basic session
if role_arn is None:
return Session(region_name=region)
# The following assume role example was taken from
# https://github.com/boto/botocore/issues/761#issuecomment-426037853
# Create a session used to assume role
assume_session = BotocoreSession()
fetcher = AssumeRoleCredentialFetcher(
assume_session.create_client,
assume_session.get_credentials(),
role_arn,
extra_args={
'DurationSeconds': 3600, # 1 hour assume assume by default
},
cache=JSONFileCache()
)
role_session = BotocoreSession()
role_session.register_component(
'credential_provider',
CredentialResolver([AssumeRoleProvider(fetcher)])
)
return Session(region_name=region, botocore_session=role_session)
def get_sagemaker_client(region, endpoint_url=None, assume_role_arn=None):
"""Builds a client to the AWS SageMaker API."""
session_config = botocore.config.Config(
session = get_boto3_session(region, assume_role_arn)
session_config = Config(
user_agent='sagemaker-on-kubeflow-pipelines-v{}'.format(get_component_version())
)
client = boto3.client('sagemaker', region_name=region, endpoint_url=endpoint_url, config=session_config)
client = session.client('sagemaker', endpoint_url=endpoint_url, config=session_config)
return client
def get_cloudwatch_client(region):
client = boto3.client('logs', region_name=region)
def get_cloudwatch_client(region, assume_role_arn=None):
session = get_boto3_session(region, assume_role_arn)
client = session.client('logs')
return client

View File

@ -19,7 +19,8 @@ Create an endpoint in AWS SageMaker Hosting Service for model deployment.
Argument | Description | Optional (in pipeline definition) | Optional (in UI) | Data type | Accepted values | Default |
:--- | :---------- | :---------- | :---------- | :----------| :---------- | :----------|
region | The region where the endpoint is created | No | No | String | | |
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | Yes | String | | |
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
endpoint_config_name | The name of the endpoint configuration | Yes | Yes | String | | |
endpoint_config_tags | Key-value pairs to tag endpoint configurations in AWS | Yes | Yes | Dict | | {} |
endpoint_tags | Key-value pairs to tag the Hosting endpoint in AWS | Yes | Yes | Dict | | {} |

View File

@ -88,6 +88,10 @@ inputs:
description: 'The endpoint URL for the private link VPC endpoint.'
default: ''
type: String
- name: assume_role
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
default: ''
type: String
- name: endpoint_config_tags
description: 'Key-value pairs to categorize AWS resources.'
default: '{}'
@ -104,12 +108,13 @@ outputs:
- {name: endpoint_name, description: 'Endpoint name'}
implementation:
container:
image: amazon/aws-sagemaker-kfp-components:0.6.0
image: amazon/aws-sagemaker-kfp-components:0.7.0
command: ['python3']
args: [
deploy.py,
--region, {inputValue: region},
--endpoint_url, {inputValue: endpoint_url},
--assume_role, {inputValue: assume_role},
--endpoint_config_name, {inputValue: endpoint_config_name},
--variant_name_1,{inputValue: variant_name_1},
--model_name_1, {inputValue: model_name_1},

View File

@ -52,7 +52,7 @@ def main(argv=None):
args = parser.parse_args(argv)
logging.getLogger().setLevel(logging.INFO)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
logging.info('Submitting Endpoint request to SageMaker...')
endpoint_name = _utils.deploy_model(client, vars(args))
logging.info('Endpoint creation request submitted. Waiting for completion...')

View File

@ -11,6 +11,8 @@ For Ground Truth jobs using AWS SageMaker.
Argument | Description | Optional | Data type | Accepted values | Default |
:--- | :---------- | :----------| :----------| :---------- | :----------|
region | The region where the cluster launches | No | String | | |
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
role | The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf | No | String | | |
job_name | The name of the Ground Truth job. Must be unique within the same AWS account and AWS region | Yes | String | | LabelingJob-[datetime]-[random id]|
label_attribute_name | The attribute name to use for the label in the output manifest file | Yes | String | | job_name |
@ -39,7 +41,6 @@ time_limit | The maximum run time in seconds per training job | No | Int | [30,
task_availibility | The length of time that a task remains available for labeling by human workers | Yes | Int | Public workforce: [1, 43200], other: [1, 864000] | |
max_concurrent_tasks | The maximum number of data objects that can be labeled by human workers at the same time | Yes | Int | [1, 1000] | |
workforce_task_price | The price that you pay for each task performed by a public worker in USD; Specify to the tenth fractions of a cent; Format as "0.000" | Yes | Float | 0.000 |
endpoint_url | The endpoint URL for the private link VPC endpoint. | Yes | String | | |
tags | Key-value pairs to categorize AWS resources | Yes | Dict | | {} |
## Outputs

View File

@ -110,6 +110,10 @@ inputs:
description: 'The endpoint URL for the private link VPC endpoint.'
default: ''
type: String
- name: assume_role
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
default: ''
type: String
- name: tags
description: 'Key-value pairs to categorize AWS resources.'
default: '{}'
@ -119,12 +123,13 @@ outputs:
- {name: active_learning_model_arn, description: 'The ARN for the most recent Amazon SageMaker model trained as part of automated data labeling.'}
implementation:
container:
image: amazon/aws-sagemaker-kfp-components:0.6.0
image: amazon/aws-sagemaker-kfp-components:0.7.0
command: ['python3']
args: [
ground_truth.py,
--region, {inputValue: region},
--endpoint_url, {inputValue: endpoint_url},
--assume_role, {inputValue: assume_role},
--role, {inputValue: role},
--job_name, {inputValue: job_name},
--label_attribute_name, {inputValue: label_attribute_name},

View File

@ -60,7 +60,7 @@ def main(argv=None):
args = parser.parse_args(argv)
logging.getLogger().setLevel(logging.INFO)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
logging.info('Submitting Ground Truth Job request to SageMaker...')
_utils.create_labeling_job(client, vars(args))

View File

@ -11,6 +11,8 @@ For hyperparameter tuning jobs using AWS SageMaker.
Argument | Description | Optional (in pipeline definition) | Optional (in UI) | Data type | Accepted values | Default |
:--- | :---------- | :---------- | :---------- | :----------| :---------- | :----------|
region | The region where the cluster launches | No | No | String | | |
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
job_name | The name of the tuning job. Must be unique within the same AWS account and AWS region | Yes | Yes | String | | HPOJob-[datetime]-[random id] |
role | The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf | No | No | String | | |
image | The registry path of the Docker image that contains the training algorithm | Yes | Yes | String | | |
@ -44,7 +46,6 @@ max_wait_time | The maximum time in seconds you are willing to wait for a manage
checkpoint_config | Dictionary of information about the output location for managed spot training checkpoint data | Yes | Yes | Dict | | {} |
warm_start_type | Specifies the type of warm start used | Yes | No | String | IdenticalDataAndAlgorithm, TransferLearning | |
parent_hpo_jobs | List of previously completed or stopped hyperparameter tuning jobs to be used as a starting point | Yes | Yes | String | Yes | | |
endpoint_url | The endpoint URL for the private link VPC endpoint. | Yes | Yes | String | | |
tags | Key-value pairs to categorize AWS resources | Yes | Yes | Dict | | {} |
Notes:

View File

@ -133,6 +133,10 @@ inputs:
description: 'The endpoint URL for the private link VPC endpoint.'
default: ''
type: String
- name: assume_role
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
default: ''
type: String
- name: tags
description: 'Key-value pairs, to categorize AWS resources.'
default: '{}'
@ -150,12 +154,13 @@ outputs:
description: 'The registry path of the Docker image that contains the training algorithm'
implementation:
container:
image: amazon/aws-sagemaker-kfp-components:0.6.0
image: amazon/aws-sagemaker-kfp-components:0.7.0
command: ['python3']
args: [
hyperparameter_tuning.py,
--region, {inputValue: region},
--endpoint_url, {inputValue: endpoint_url},
--assume_role, {inputValue: assume_role},
--job_name, {inputValue: job_name},
--role, {inputValue: role},
--image, {inputValue: image},

View File

@ -76,7 +76,7 @@ def main(argv=None):
args = parser.parse_args(argv)
logging.getLogger().setLevel(logging.INFO)
client = _utils.get_sagemaker_client(args.region)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
logging.info('Submitting HyperParameter Tuning Job request to SageMaker...')
hpo_job_name = _utils.create_hyperparameter_tuning_job(client, vars(args))

View File

@ -19,7 +19,8 @@ Create a model in Amazon SageMaker to be used for [creating an endpoint](https:/
Argument | Description | Optional (in pipeline definition) | Optional (in UI) | Data type | Accepted values | Default |
:--- | :---------- | :---------- | :---------- | :----------| :---------- | :----------|
region | The region where the model is created | No | No | String | | |
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | Yes | String | | |
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
tags | Key-value pairs to tag the model created in AWS | Yes | Yes | Dict | | {} |
role | The ARN of the IAM role that Amazon SageMaker can assume to access model artifacts and docker image for deployment on ML compute instances or for batch transform jobs | No | No | String | | |
network_isolation | Isolates the model container. No inbound or outbound network calls can be made to or from the model container | Yes | Yes | Boolean | | True |

View File

@ -51,6 +51,10 @@ inputs:
description: 'The endpoint URL for the private link VPC endpoint.'
default: ''
type: String
- name: assume_role
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
default: ''
type: String
- name: tags
description: 'Key-value pairs to categorize AWS resources.'
default: '{}'
@ -59,12 +63,13 @@ outputs:
- {name: model_name, description: 'The model name SageMaker created'}
implementation:
container:
image: amazon/aws-sagemaker-kfp-components:0.6.0
image: amazon/aws-sagemaker-kfp-components:0.7.0
command: ['python3']
args: [
create_model.py,
--region, {inputValue: region},
--endpoint_url, {inputValue: endpoint_url},
--assume_role, {inputValue: assume_role},
--model_name, {inputValue: model_name},
--role, {inputValue: role},
--container_host_name, {inputValue: container_host_name},

View File

@ -41,7 +41,7 @@ def main(argv=None):
args = parser.parse_args(argv)
logging.getLogger().setLevel(logging.INFO)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
logging.info('Submitting model creation request to SageMaker...')
_utils.create_model(client, vars(args))

View File

@ -11,7 +11,8 @@ For running your data processing workloads, such as feature engineering, data va
Argument | Description | Optional | Data type | Accepted values | Default |
:--- | :---------- | :----------| :----------| :---------- | :----------|
region | The region where the cluster launches | No | String | | |
endpoint_url | The endpoint URL for the private link VPC endpoint. | Yes | String | | |
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
job_name | The name of the Processing job. Must be unique within the same AWS account and AWS region | Yes | String | | ProcessingJob-[datetime]-[random id]|
role | The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf | No | String | | |
image | The registry path of the Docker image that contains the processing script | Yes | String | | |

View File

@ -80,6 +80,10 @@ inputs:
description: 'The endpoint URL for the private link VPC endpoint.'
default: ''
type: String
- name: assume_role
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
default: ''
type: String
- name: tags
description: 'Key-value pairs, to categorize AWS resources.'
default: '{}'
@ -89,12 +93,13 @@ outputs:
- {name: output_artifacts, description: 'A dictionary containing the output S3 artifacts'}
implementation:
container:
image: amazon/aws-sagemaker-kfp-components:0.6.0
image: amazon/aws-sagemaker-kfp-components:0.7.0
command: ['python3']
args: [
process.py,
--region, {inputValue: region},
--endpoint_url, {inputValue: endpoint_url},
--assume_role, {inputValue: assume_role},
--job_name, {inputValue: job_name},
--role, {inputValue: role},
--image, {inputValue: image},

View File

@ -50,7 +50,7 @@ def main(argv=None):
args = parser.parse_args(argv)
logging.getLogger().setLevel(logging.INFO)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
logging.info('Submitting Processing Job to SageMaker...')
job_name = _utils.create_processing_job(client, vars(args))
@ -68,7 +68,7 @@ def main(argv=None):
except:
raise
finally:
cw_client = _utils.get_cloudwatch_client(args.region)
cw_client = _utils.get_cloudwatch_client(args.region, assume_role_arn=args.assume_role)
_utils.print_logs_for_job(cw_client, '/aws/sagemaker/ProcessingJobs', job_name)
outputs = _utils.get_processing_job_outputs(client, job_name)

View File

@ -12,4 +12,7 @@ S3_DATA_BUCKET=my-data-bucket
# EKS_EXISTING_CLUSTER=my-eks-cluster
# If you would like to skip the FSx set-up and tests
# SKIP_FSX_TESTS=true
# SKIP_FSX_TESTS=true
# If you have an IAM role that the EKS cluster should assume for the "assume role" tests
# ASSUMED_ROLE_NAME=my-assumed-role

View File

@ -14,7 +14,8 @@ from utils import argo_utils
pytest.param(
"resources/config/kmeans-algo-mnist-processing",
marks=pytest.mark.canary_test,
)
),
"resources/config/assume-role-processing"
],
)
def test_processingjob(

View File

@ -15,6 +15,7 @@ from utils import argo_utils
),
pytest.param("resources/config/fsx-mnist-training", marks=pytest.mark.fsx_test),
"resources/config/spot-sample-pipeline-training",
"resources/config/assume-role-training",
],
)
def test_trainingjob(
@ -74,8 +75,9 @@ def test_trainingjob(
else:
assert f"dkr.ecr.{region}.amazonaws.com" in training_image
assert not argo_utils.error_in_cw_logs(workflow_json["metadata"]["name"]), \
('Found the CloudWatch error message in the log output. Check SageMaker to see if the job has failed.')
assert not argo_utils.error_in_cw_logs(
workflow_json["metadata"]["name"]
), "Found the CloudWatch error message in the log output. Check SageMaker to see if the job has failed."
utils.remove_dir(download_dir)

View File

@ -18,6 +18,11 @@ def pytest_addoption(parser):
parser.addoption(
"--role-arn", required=True, help="SageMaker execution IAM role ARN",
)
parser.addoption(
"--assume-role-arn",
required=True,
help="The ARN of a role which the assume role tests will assume to access SageMaker.",
)
parser.addoption(
"--s3-data-bucket",
required=True,
@ -61,6 +66,12 @@ def region(request):
return request.config.getoption("--region")
@pytest.fixture(scope="session", autouse=True)
def assume_role_arn(request):
os.environ["ASSUME_ROLE_ARN"] = request.config.getoption("--assume-role-arn")
return request.config.getoption("--assume-role-arn")
@pytest.fixture(scope="session", autouse=True)
def role_arn(request):
os.environ["ROLE_ARN"] = request.config.getoption("--role-arn")

View File

@ -0,0 +1,47 @@
PipelineDefinition: resources/definition/processing_pipeline.py
TestName: assume-role-processing
Timeout: 1800
Arguments:
region: ((REGION))
assume_role: ((ASSUME_ROLE_ARN))
job_name: process-mnist-for-kmeans
image: 763104351884.dkr.ecr.((REGION)).amazonaws.com/pytorch-training:1.5.0-cpu-py36-ubuntu16.04
container_entrypoint:
- python
- /opt/ml/processing/code/kmeans_preprocessing.py
input_config:
- InputName: mnist_tar
S3Input:
S3Uri: s3://sagemaker-sample-data-((REGION))/algorithms/kmeans/mnist/mnist.pkl.gz
LocalPath: /opt/ml/processing/input
S3DataType: S3Prefix
S3InputMode: File
S3CompressionType: None
- InputName: source_code
S3Input:
S3Uri: s3://((DATA_BUCKET))/mnist_kmeans_example/processing_code/kmeans_preprocessing.py
LocalPath: /opt/ml/processing/code
S3DataType: S3Prefix
S3InputMode: File
S3CompressionType: None
output_config:
- OutputName: train_data
S3Output:
S3Uri: s3://((DATA_BUCKET))/mnist_kmeans_example/output/
LocalPath: /opt/ml/processing/output_train/
S3UploadMode: EndOfJob
- OutputName: test_data
S3Output:
S3Uri: s3://((DATA_BUCKET))/mnist_kmeans_example/output/
LocalPath: /opt/ml/processing/output_test/
S3UploadMode: EndOfJob
- OutputName: valid_data
S3Output:
S3Uri: s3://((DATA_BUCKET))/mnist_kmeans_example/output/
LocalPath: /opt/ml/processing/output_valid/
S3UploadMode: EndOfJob
instance_type: ml.m5.xlarge
instance_count: 1
volume_size: 50
max_run_time: 1800
role: ((ROLE_ARN))

View File

@ -0,0 +1,33 @@
PipelineDefinition: resources/definition/training_pipeline.py
TestName: assume-role-training
Timeout: 3600
ExpectedTrainingImage: ((KMEANS_REGISTRY)).dkr.ecr.((REGION)).amazonaws.com/kmeans:1
Arguments:
region: ((REGION))
image: ((KMEANS_REGISTRY)).dkr.ecr.((REGION)).amazonaws.com/kmeans:1
assume_role: ((ASSUME_ROLE_ARN))
training_input_mode: File
hyperparameters:
k: "10"
feature_dim: "784"
channels:
- ChannelName: train
DataSource:
S3DataSource:
S3Uri: s3://((DATA_BUCKET))/mnist_kmeans_example/data
S3DataType: S3Prefix
S3DataDistributionType: FullyReplicated
CompressionType: None
RecordWrapperType: None
InputMode: File
instance_type: ml.m5.xlarge
instance_count: 1
volume_size: 50
max_run_time: 3600
model_artifact_path: s3://((DATA_BUCKET))/mnist_kmeans_example/output
network_isolation: "True"
traffic_encryption: "False"
spot_instance: "False"
max_wait_time: 3600
checkpoint_config: "{}"
role: ((ROLE_ARN))

View File

@ -23,6 +23,7 @@ def processing_pipeline(
output_config={},
network_isolation=False,
role="",
assume_role="",
):
sagemaker_process_op(
region=region,
@ -39,6 +40,7 @@ def processing_pipeline(
output_config=output_config,
network_isolation=network_isolation,
role=role,
assume_role=assume_role,
)

View File

@ -27,6 +27,7 @@ def training_pipeline(
checkpoint_config="{}",
vpc_security_group_ids="",
vpc_subnets="",
assume_role="",
role="",
):
sagemaker_train_op(
@ -50,6 +51,7 @@ def training_pipeline(
checkpoint_config=checkpoint_config,
vpc_security_group_ids=vpc_security_group_ids,
vpc_subnets=vpc_subnets,
assume_role=assume_role,
role=role,
)

View File

@ -12,7 +12,8 @@ CLUSTER_REGION="${3:-us-east-1}"
SERVICE_NAMESPACE="${4:-kubeflow}"
SERVICE_ACCOUNT="${5:-pipeline-runner}"
aws_account=$(aws sts get-caller-identity --query Account --output text)
trustfile="trust.json"
trust_file="trust.json"
assume_role_file="assume-role.json"
cwd=$(dirname $(realpath $0))
@ -36,15 +37,27 @@ function get_oidc_id {
# Parameter:
# $1: Name of the trust file to generate.
function create_namespaced_iam_role {
local trustfile="${1}"
local trust_file_path="${1}"
# Check if role already exists
aws iam get-role --role-name ${ROLE_NAME}
if [[ $? -eq 0 ]]; then
echo "A role for this cluster and namespace already exists in this account, assuming sagemaker access and proceeding."
echo "A role for this cluster and namespace already exists in this account, assuming SageMaker and AssumeRole access and proceeding."
else
echo "IAM Role does not exist, creating a new Role for the cluster"
aws iam create-role --role-name ${ROLE_NAME} --assume-role-policy-document file://${trustfile} --output=text --query "Role.Arn"
aws iam attach-role-policy --role-name ${ROLE_NAME} --policy-arn arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
aws iam create-role --role-name ${ROLE_NAME} --assume-role-policy-document file://${trust_file_path} --output=text --query "Role.Arn"
aws iam attach-role-policy --role-name ${ROLE_NAME} --policy-arn arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
printf '{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Action": "sts:AssumeRole",
"Resource": "arn:aws:iam::'"${aws_account}"':role/*"
}
]
}' > ${assume_role_file}
aws iam put-role-policy --role-name ${ROLE_NAME} --policy-name AllowAssumeRole --policy-document file://${assume_role_file}
fi
}
@ -58,11 +71,11 @@ function delete_generated_file {
echo "Get the OIDC ID for the cluster"
get_oidc_id
echo "Delete the trust json file if it already exists"
delete_generated_file "${trustfile}"
delete_generated_file "${trust_file}"
echo "Generate a trust json"
"$cwd"/generate_trust_policy ${CLUSTER_REGION} ${aws_account} ${oidc_id} ${SERVICE_NAMESPACE} ${SERVICE_ACCOUNT} > "${trustfile}"
"$cwd"/generate_trust_policy ${CLUSTER_REGION} ${aws_account} ${oidc_id} ${SERVICE_NAMESPACE} ${SERVICE_ACCOUNT} > "${trust_file}"
echo "Create the IAM Role using these values"
create_namespaced_iam_role "${trustfile}"
create_namespaced_iam_role "${trust_file}"
echo "Cleanup for the next run"
delete_generated_file "${trustfile}"
delete_generated_file "${trust_file}"

View File

@ -26,10 +26,12 @@ EKS_PRIVATE_SUBNETS=${EKS_PRIVATE_SUBNETS:-""}
MINIO_LOCAL_PORT=${MINIO_LOCAL_PORT:-9000}
KFP_NAMESPACE=${KFP_NAMESPACE:-"kubeflow"}
KFP_SERVICE_ACCOUNT=${KFP_SERVICE_ACCOUNT:-"pipeline-runner"}
AWS_ACCOUNT_ID=${AWS_ACCOUNT_ID:-"$(aws sts get-caller-identity --query=Account --output=text)"}
PYTEST_MARKER=${PYTEST_MARKER:-""}
S3_DATA_BUCKET=${S3_DATA_BUCKET:-""}
SAGEMAKER_EXECUTION_ROLE_ARN=${SAGEMAKER_EXECUTION_ROLE_ARN:-""}
ASSUMED_ROLE_NAME=${ASSUMED_ROLE_NAME:-""}
SKIP_FSX_TESTS=${SKIP_FSX_TESTS:-"false"}
@ -78,7 +80,8 @@ function cleanup() {
set +e
cleanup_kfp
delete_generated_role
delete_assumed_role
delete_oidc_role
if [[ "${SKIP_FSX_TESTS}" == "false" ]]; then
delete_fsx_instance
@ -141,21 +144,60 @@ function install_kfp() {
echo "[Installing KFP] Pipeline pods are ready"
}
function generate_iam_role_name() {
function generate_oidc_role_name() {
OIDC_ROLE_NAME="$(echo "${DEPLOY_NAME}-kubeflow-role" | cut -c1-64)"
OIDC_ROLE_ARN="arn:aws:iam::$(aws sts get-caller-identity --query=Account --output=text):role/${OIDC_ROLE_NAME}"
OIDC_ROLE_ARN="arn:aws:iam::${AWS_ACCOUNT_ID}:role/${OIDC_ROLE_NAME}"
}
function install_generated_role() {
function install_oidc_role() {
kubectl patch serviceaccount -n ${KFP_NAMESPACE} ${KFP_SERVICE_ACCOUNT} --patch '{"metadata": {"annotations": {"eks.amazonaws.com/role-arn": "'"${OIDC_ROLE_ARN}"'"}}}'
}
function delete_generated_role() {
function delete_oidc_role() {
# Delete the role associated with the cluster thats being deleted
aws iam detach-role-policy --role-name "${OIDC_ROLE_NAME}" --policy-arn arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
aws iam delete-role-policy --role-name "${OIDC_ROLE_NAME}" --policy-name AllowAssumeRole
aws iam delete-role --role-name "${OIDC_ROLE_NAME}"
}
function generate_assumed_role() {
# If not defined in the env file
if [[ -z "${ASSUMED_ROLE_NAME}" ]]; then
ASSUMED_ROLE_NAME="${DEPLOY_NAME}-assumed-role"
CREATED_ASSUMED_ROLE="true"
# Create a trust file that allows the OIDC role to authenticate
local assumed_trust_file="assumed-role-trust.json"
printf '{
"Version": "2012-10-17",
"Statement": [
{
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::'"${AWS_ACCOUNT_ID}"':root"
},
"Action": "sts:AssumeRole",
"Condition": {}
}
]
}' > "${assumed_trust_file}"
aws iam create-role --role-name "${ASSUMED_ROLE_NAME}" --assume-role-policy-document file://${assumed_trust_file} --output=text --query "Role.Arn"
aws iam attach-role-policy --role-name ${ASSUMED_ROLE_NAME} --policy-arn arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
fi
# Generate the ARN using the role name
ASSUMED_ROLE_ARN="arn:aws:iam::${AWS_ACCOUNT_ID}:role/${ASSUMED_ROLE_NAME}"
}
function delete_assumed_role() {
# Ensure that the automated script created the assumed role
if [[ ! -z "${ASSUMED_ROLE_NAME}" && "${CREATED_ASSUMED_ROLE:-false}" == "true" ]]; then
# Delete the role associated with the cluster thats being deleted
aws iam detach-role-policy --role-name "${ASSUMED_ROLE_NAME}" --policy-arn arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
aws iam delete-role --role-name "${ASSUMED_ROLE_NAME}"
fi
}
function cleanup_kfp() {
# Clean up Minio
if [[ ! -z "${MINIO_PID:-}" ]]; then
@ -187,12 +229,16 @@ else
wait
fi
generate_iam_role_name
generate_oidc_role_name
"$cwd"/generate_iam_role ${EKS_CLUSTER_NAME} ${OIDC_ROLE_NAME} ${REGION} ${KFP_NAMESPACE} ${KFP_SERVICE_ACCOUNT}
install_kfp
install_generated_role
install_oidc_role
pytest_args=( --region "${REGION}" --role-arn "${SAGEMAKER_EXECUTION_ROLE_ARN}" --s3-data-bucket "${S3_DATA_BUCKET}" --minio-service-port "${MINIO_LOCAL_PORT}" --kfp-namespace "${KFP_NAMESPACE}" )
generate_assumed_role
pytest_args=( --region "${REGION}" --role-arn "${SAGEMAKER_EXECUTION_ROLE_ARN}" \
--s3-data-bucket "${S3_DATA_BUCKET}" --kfp-namespace "${KFP_NAMESPACE}" \
--minio-service-port "${MINIO_LOCAL_PORT}" --assume-role-arn "${ASSUMED_ROLE_ARN}")
if [[ "${SKIP_FSX_TESTS}" == "true" ]]; then
pytest_args+=( -m "not fsx_test" )

View File

@ -42,6 +42,10 @@ def get_fsx_id():
return os.environ.get("FSX_ID")
def get_assume_role_arn():
return os.environ.get("ASSUME_ROLE_ARN")
def get_algorithm_image_registry(region, algorithm):
return get_image_uri(region, algorithm).split(".")[0]
@ -84,6 +88,7 @@ def replace_placeholders(input_filename, output_filename):
"((FSX_ID))": get_fsx_id(),
"((FSX_SUBNET))": get_fsx_subnet(),
"((FSX_SECURITY_GROUP))": get_fsx_security_group(),
"((ASSUME_ROLE_ARN))": get_assume_role_arn()
}
filedata = ""

View File

@ -53,6 +53,20 @@ class BatchTransformTestCase(unittest.TestCase):
call('/tmp/output', 's3://fake-bucket/output')
])
def test_main_assumes_role(self):
# Mock out all of utils except parser
batch_transform._utils = MagicMock()
batch_transform._utils.add_default_client_arguments = _utils.add_default_client_arguments
# Set some static returns
batch_transform._utils.create_transform_job.return_value = 'test-batch-job'
assume_role_args = required_args + ['--assume_role', 'my-role']
batch_transform.main(assume_role_args)
batch_transform._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
def test_batch_transform(self):
mock_client = MagicMock()

View File

@ -41,6 +41,20 @@ class DeployTestCase(unittest.TestCase):
call('/tmp/output', 'test-endpoint-name')
])
def test_main_assumes_role(self):
# Mock out all of utils except parser
deploy._utils = MagicMock()
deploy._utils.add_default_client_arguments = _utils.add_default_client_arguments
# Set some static returns
deploy._utils.deploy_model.return_value = 'test-endpoint-name'
assume_role_args = required_args + ['--assume_role', 'my-role']
deploy.main(assume_role_args)
deploy._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
def test_deploy_model(self):
mock_client = MagicMock()
mock_args = self.parser.parse_args(required_args + ['--endpoint_name', 'test-endpoint-name', '--endpoint_config_name', 'test-endpoint-config-name'])

View File

@ -56,6 +56,20 @@ class GroundTruthTestCase(unittest.TestCase):
call('/tmp/model-output', 'arn:aws:sagemaker:us-east-1:999999999999:labeling-job')
])
def test_main_assumes_role(self):
# Mock out all of utils except parser
ground_truth._utils = MagicMock()
ground_truth._utils.add_default_client_arguments = _utils.add_default_client_arguments
# Set some static returns
ground_truth._utils.get_labeling_job_outputs.return_value = ('s3://fake-bucket/output', 'arn:aws:sagemaker:us-east-1:999999999999:labeling-job')
assume_role_args = required_args + ['--assume_role', 'my-role']
ground_truth.main(assume_role_args)
ground_truth._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
def test_ground_truth(self):
mock_client = MagicMock()
mock_args = self.parser.parse_args(required_args)

View File

@ -87,6 +87,21 @@ class HyperparameterTestCase(unittest.TestCase):
call('/tmp/best_hyperparameters_output_path', {"key_1": "best_hp_1"}, json_encode=True),
call('/tmp/training_image_output_path', 'training-image')
])
def test_main_assumes_role(self):
# Mock out all of utils except parser
hpo._utils = MagicMock()
hpo._utils.add_default_client_arguments = _utils.add_default_client_arguments
# Set some static returns
hpo._utils.create_hyperparameter_tuning_job.return_value = 'job-name'
hpo._utils.get_best_training_job_and_hyperparameters.return_value = 'best_job', {"key_1": "best_hp_1"}
assume_role_args = required_args + ['--assume_role', 'my-role']
hpo.main(assume_role_args)
hpo._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
def test_create_hyperparameter_tuning_job(self):
mock_client = MagicMock()
@ -125,7 +140,7 @@ class HyperparameterTestCase(unittest.TestCase):
def test_main_stop_hyperparameter_tuning_job(self):
hpo._utils = MagicMock()
hpo._utils.create_processing_job.return_value = 'job-name'
hpo._utils.create_hyperparameter_tuning_job.return_value = 'job-name'
try:
os.kill(os.getpid(), signal.SIGTERM)

View File

@ -43,6 +43,20 @@ class ModelTestCase(unittest.TestCase):
call('/tmp/output', 'model_test')
])
def test_main_assumes_role(self):
# Mock out all of utils except parser
create_model._utils = MagicMock()
create_model._utils.add_default_client_arguments = _utils.add_default_client_arguments
# Set some static returns
create_model._utils.create_model.return_value = 'model_test'
assume_role_args = required_args + ['--assume_role', 'my-role']
create_model.main(assume_role_args)
create_model._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
def test_create_model(self):
mock_client = MagicMock()
mock_args = self.parser.parse_args(required_args)

View File

@ -68,6 +68,21 @@ class ProcessTestCase(unittest.TestCase):
call('/tmp/output_artifacts_output_path', mock_outputs, json_encode=True)
])
def test_main_assumes_role(self):
# Mock out all of utils except parser
process._utils = MagicMock()
process._utils.add_default_client_arguments = _utils.add_default_client_arguments
# Set some static returns
process._utils.create_processing_job.return_value = 'job-name'
process._utils.get_processing_job_outputs.return_value = mock_outputs = {'val1': 's3://1', 'val2': 's3://2'}
assume_role_args = required_args + ['--assume_role', 'my-role']
process.main(assume_role_args)
process._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
def test_create_processing_job(self):
mock_client = MagicMock()
mock_args = self.parser.parse_args(required_args + ['--job_name', 'test-job'])

View File

@ -57,6 +57,22 @@ class TrainTestCase(unittest.TestCase):
call('/tmp/training_image_output_path', 'training-image')
])
def test_main_assumes_role(self):
# Mock out all of utils except parser
train._utils = MagicMock()
train._utils.add_default_client_arguments = _utils.add_default_client_arguments
# Set some static returns
train._utils.create_training_job.return_value = 'job-name'
train._utils.get_image_from_job.return_value = 'training-image'
train._utils.get_model_artifacts_from_job.return_value = 'model-artifacts'
assume_role_args = required_args + ['--assume_role', 'my-role']
train.main(assume_role_args)
train._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
def test_create_training_job(self):
mock_client = MagicMock()
mock_args = self.parser.parse_args(required_args + ['--job_name', 'test-job'])

View File

@ -1,7 +1,8 @@
import unittest
import json
from unittest.mock import patch, call, Mock, MagicMock, mock_open
from unittest.mock import patch, call, Mock, MagicMock, mock_open, ANY
from boto3.session import Session
from botocore.exceptions import ClientError
from common import _utils
@ -66,3 +67,37 @@ class UtilsTestCase(unittest.TestCase):
mock_path("/tmp/test-output").write_text.assert_called_with(
json.dumps(case)
)
def test_assume_default_boto3_session(self):
with patch("common._utils.boto3", MagicMock()) as mock_boto3:
returned_session = _utils.get_boto3_session("us-east-1")
assert isinstance(returned_session, Session)
assert returned_session.region_name == "us-east-1"
mock_boto3.assert_not_called()
@patch("common._utils.DeferredRefreshableCredentials", MagicMock())
@patch("common._utils.AssumeRoleCredentialFetcher", MagicMock())
def test_assume_role_boto3_session(self):
returned_session = _utils.get_boto3_session("us-east-1", role_arn="abc123")
assert isinstance(returned_session, Session)
assert returned_session.region_name == "us-east-1"
# Bury into the internals to ensure our provider was registered correctly
our_provider = returned_session._session._components.get_component('credential_provider').providers[0]
assert isinstance(our_provider, _utils.AssumeRoleProvider)
def test_assumed_sagemaker_client(self):
_utils.get_boto3_session = MagicMock()
mock_sm_client = MagicMock()
# Mock the client("SageMaker", ...) return value
_utils.get_boto3_session.return_value.client.return_value = mock_sm_client
client = _utils.get_sagemaker_client("us-east-1", assume_role_arn="abc123")
assert client == mock_sm_client
_utils.get_boto3_session.assert_called_once_with("us-east-1", "abc123")
_utils.get_boto3_session.return_value.client.assert_called_once_with("sagemaker", endpoint_url=None, config=ANY)

View File

@ -23,6 +23,20 @@ class WorkTeamTestCase(unittest.TestCase):
def test_create_parser(self):
self.assertIsNotNone(self.parser)
def test_main(self):
# Mock out all of utils except parser
workteam._utils = MagicMock()
workteam._utils.add_default_client_arguments = _utils.add_default_client_arguments
# Set some static returns
workteam._utils.create_workteam.return_value = 'arn:aws:sagemaker:us-east-1:999999999999:work-team'
assume_role_args = required_args + ['--assume_role', 'my-role']
workteam.main(assume_role_args)
workteam._utils.get_sagemaker_client.assert_called_once_with('us-west-2', None, assume_role_arn='my-role')
def test_main(self):
# Mock out all of utils except parser
workteam._utils = MagicMock()

View File

@ -12,7 +12,8 @@ For model training using AWS SageMaker.
Argument | Description | Optional | Data type | Accepted values | Default |
:--- | :---------- | :----------| :----------| :---------- | :----------|
region | The region where the cluster launches | No | String | | |
endpoint_url | The endpoint URL for the private link VPC endpoint. | Yes | String | | |
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
job_name | The name of the Ground Truth job. Must be unique within the same AWS account and AWS region | Yes | String | | LabelingJob-[datetime]-[random id]|
role | The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf | No | String | | |
image | The registry path of the Docker image that contains the training algorithm | Yes | String | | |

View File

@ -94,6 +94,10 @@ inputs:
description: 'The endpoint URL for the private link VPC endpoint.'
default: ''
type: String
- name: assume_role
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
default: ''
type: String
- name: tags
description: 'Key-value pairs, to categorize AWS resources.'
default: '{}'
@ -104,12 +108,13 @@ outputs:
- {name: training_image, description: 'The registry path of the Docker image that contains the training algorithm'}
implementation:
container:
image: amazon/aws-sagemaker-kfp-components:0.6.0
image: amazon/aws-sagemaker-kfp-components:0.7.0
command: ['python3']
args: [
train.py,
--region, {inputValue: region},
--endpoint_url, {inputValue: endpoint_url},
--assume_role, {inputValue: assume_role},
--job_name, {inputValue: job_name},
--role, {inputValue: role},
--image, {inputValue: image},

View File

@ -62,7 +62,7 @@ def main(argv=None):
args = parser.parse_args(argv)
logging.getLogger().setLevel(logging.INFO)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
logging.info('Submitting Training Job to SageMaker...')
job_name = _utils.create_training_job(client, vars(args))
@ -78,7 +78,7 @@ def main(argv=None):
except:
raise
finally:
cw_client = _utils.get_cloudwatch_client(args.region)
cw_client = _utils.get_cloudwatch_client(args.region, assume_role_arn=args.assume_role)
_utils.print_logs_for_job(cw_client, '/aws/sagemaker/TrainingJobs', job_name)
image = _utils.get_image_from_job(client, job_name)

View File

@ -11,13 +11,14 @@ For creating a private workteam from pre-existing Amazon Cognito user groups usi
Argument | Description | Optional | Data type | Accepted values | Default |
:--- | :---------- | :----------| :----------| :---------- | :----------|
region | The region where the cluster launches | No | String | | |
endpoint_url | The endpoint URL for the private link VPC endpoint | Yes | String | | |
assume_role | The ARN of an IAM role to assume when connecting to SageMaker | Yes | String | | |
team_name | The name of your work team | No | String | | |
description | A description of the work team | No | String | | |
user_pool | An identifier for a user pool, which must be in the same region as the service that you are calling | No | String | | |
user_groups | An identifier for user groups separated by commas | No | String | | |
client_id | An identifier for an application client, which you must create using Amazon Cognito | No | String | | |
sns_topic | The AWS KMS key that Amazon SageMaker uses to encrypt the model artifacts | Yes | String | | |
endpoint_url | The endpoint URL for the private link VPC endpoint. | Yes | Yes | String | | |
tags | Key-value pairs to categorize AWS resources | Yes | Dict | | {} |
Notes:

View File

@ -28,6 +28,10 @@ inputs:
description: 'The endpoint URL for the private link VPC endpoint.'
default: ''
type: String
- name: assume_role
description: 'The ARN of an IAM role to assume when connecting to SageMaker.'
default: ''
type: String
- name: tags
description: 'Key-value pairs to categorize AWS resources.'
default: '{}'
@ -36,12 +40,13 @@ outputs:
- {name: workteam_arn, description: 'The ARN of the workteam.'}
implementation:
container:
image: amazon/aws-sagemaker-kfp-components:0.5.3
image: amazon/aws-sagemaker-kfp-components:0.7.0
command: ['python3']
args: [
workteam.py,
--region, {inputValue: region},
--endpoint_url, {inputValue: endpoint_url},
--assume_role, {inputValue: assume_role},
--team_name, {inputValue: team_name},
--description, {inputValue: description},
--user_pool, {inputValue: user_pool},

View File

@ -36,7 +36,7 @@ def main(argv=None):
args = parser.parse_args(argv)
logging.getLogger().setLevel(logging.INFO)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url)
client = _utils.get_sagemaker_client(args.region, args.endpoint_url, assume_role_arn=args.assume_role)
logging.info('Submitting a create workteam request to SageMaker...')
workteam_arn = _utils.create_workteam(client, vars(args))