# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import argparse import logging from common import _utils def create_parser(): parser = argparse.ArgumentParser(description='SageMaker Hyperparameter Tuning Job') _utils.add_default_client_arguments(parser) parser.add_argument('--team_name', type=str, required=True, help='The name of your work team.') parser.add_argument('--description', type=str, required=True, help='A description of the work team.') parser.add_argument('--user_pool', type=str, required=False, help='An identifier for a user pool. The user pool must be in the same region as the service that you are calling.', default='') parser.add_argument('--user_groups', type=str, required=False, help='A list of identifiers for user groups separated by commas.', default='') parser.add_argument('--client_id', type=str, required=False, help='An identifier for an application client. You must create the app client ID using Amazon Cognito.', default='') parser.add_argument('--sns_topic', type=str, required=False, help='The ARN for the SNS topic to which notifications should be published.', default='') parser.add_argument('--tags', type=_utils.yaml_or_json_str, required=False, help='An array of key-value pairs, to categorize AWS resources.', default={}) parser.add_argument('--workteam_arn_output_path', type=str, default='/tmp/workteam-arn', help='Local output path for the file containing the ARN of the workteam.') return parser def main(argv=None): parser = create_parser() args = parser.parse_args(argv) logging.getLogger().setLevel(logging.INFO) client = _utils.get_sagemaker_client(args.region, args.endpoint_url) logging.info('Submitting a create workteam request to SageMaker...') workteam_arn = _utils.create_workteam(client, vars(args)) logging.info('Workteam created.') _utils.write_output(args.workteam_arn_output_path, workteam_arn) if __name__== "__main__": main(sys.argv[1:])