132 lines
5.6 KiB
YAML
132 lines
5.6 KiB
YAML
name: 'Sagemaker - Training Job'
|
|
description: |
|
|
Train Machine Learning and Deep Learning Models using SageMaker
|
|
inputs:
|
|
- name: region
|
|
description: 'The region where the training job launches.'
|
|
- name: job_name
|
|
description: 'The name of the batch training job.'
|
|
default: ''
|
|
- name: role
|
|
description: 'The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf.'
|
|
- name: image
|
|
description: 'The registry path of the Docker image that contains the training algorithm.'
|
|
default: ''
|
|
- name: algorithm_name
|
|
description: 'The name of the algorithm resource to use for the training job. Do not specify a value for this if using training image.'
|
|
default: ''
|
|
- name: metric_definitions
|
|
description: 'The dictionary of name-regex pairs specify the metrics that the algorithm emits.'
|
|
default: '{}'
|
|
- name: training_input_mode
|
|
description: 'The input mode that the algorithm supports. File or Pipe.'
|
|
default: 'File'
|
|
- name: hyperparameters
|
|
description: 'Dictionary of hyperparameters for the the algorithm.'
|
|
default: '{}'
|
|
- name: channels
|
|
description: 'A list of dicts specifying the input channels. Must have at least one.'
|
|
- name: data_location_1
|
|
description: 'The S3 URI of the input data source for channel 1.'
|
|
default: ''
|
|
- name: data_location_2
|
|
description: 'The S3 URI of the input data source for channel 2.'
|
|
default: ''
|
|
- name: data_location_3
|
|
description: 'The S3 URI of the input data source for channel 3.'
|
|
default: ''
|
|
- name: data_location_4
|
|
description: 'The S3 URI of the input data source for channel 4.'
|
|
default: ''
|
|
- name: data_location_5
|
|
description: 'The S3 URI of the input data source for channel 5.'
|
|
default: ''
|
|
- name: data_location_6
|
|
description: 'The S3 URI of the input data source for channel 6.'
|
|
default: ''
|
|
- name: data_location_7
|
|
description: 'The S3 URI of the input data source for channel 7.'
|
|
default: ''
|
|
- name: data_location_8
|
|
description: 'The S3 URI of the input data source for channel 8.'
|
|
default: ''
|
|
- name: instance_type
|
|
description: 'The ML compute instance type.'
|
|
default: 'ml.m4.xlarge'
|
|
- name: instance_count
|
|
description: 'The number of ML compute instances to use in each training job.'
|
|
default: '1'
|
|
- name: volume_size
|
|
description: 'The size of the ML storage volume that you want to provision.'
|
|
default: '1'
|
|
- name: resource_encryption_key
|
|
description: 'The AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s).'
|
|
default: ''
|
|
- name: max_run_time
|
|
description: 'The maximum run time in seconds for the training job.'
|
|
default: '86400'
|
|
- name: model_artifact_path
|
|
description: 'Identifies the S3 path where you want Amazon SageMaker to store the model artifacts.'
|
|
- name: output_encryption_key
|
|
description: 'The AWS KMS key that Amazon SageMaker uses to encrypt the model artifacts.'
|
|
default: ''
|
|
- name: vpc_security_group_ids
|
|
description: 'The VPC security group IDs, in the form sg-xxxxxxxx.'
|
|
default: ''
|
|
- name: vpc_subnets
|
|
description: 'The ID of the subnets in the VPC to which you want to connect your hpo job.'
|
|
default: ''
|
|
- name: network_isolation
|
|
description: 'Isolates the training container.'
|
|
default: 'True'
|
|
- name: traffic_encryption
|
|
description: 'Encrypts all communications between ML compute instances in distributed training.'
|
|
default: 'False'
|
|
- name: tags
|
|
description: 'Key-value pairs, to categorize AWS resources.'
|
|
default: '{}'
|
|
outputs:
|
|
- {name: model_artifact_url, description: 'Model artifacts url'}
|
|
- {name: job_name, description: 'Training job name'}
|
|
- {name: training_image, description: 'The registry path of the Docker image that contains the training algorithm'}
|
|
implementation:
|
|
container:
|
|
image: carowang/kubeflow-pipeline-aws-sm:20190809-02
|
|
command: ['python']
|
|
args: [
|
|
train.py,
|
|
--region, {inputValue: region},
|
|
--job_name, {inputValue: job_name},
|
|
--role, {inputValue: role},
|
|
--image, {inputValue: image},
|
|
--algorithm_name, {inputValue: algorithm_name},
|
|
--metric_definitions, {inputValue: metric_definitions},
|
|
--training_input_mode, {inputValue: training_input_mode},
|
|
--hyperparameters, {inputValue: hyperparameters},
|
|
--channels, {inputValue: channels},
|
|
--data_location_1, {inputValue: data_location_1},
|
|
--data_location_2, {inputValue: data_location_2},
|
|
--data_location_3, {inputValue: data_location_3},
|
|
--data_location_4, {inputValue: data_location_4},
|
|
--data_location_5, {inputValue: data_location_5},
|
|
--data_location_6, {inputValue: data_location_6},
|
|
--data_location_7, {inputValue: data_location_7},
|
|
--data_location_8, {inputValue: data_location_8},
|
|
--instance_type, {inputValue: instance_type},
|
|
--instance_count, {inputValue: instance_count},
|
|
--volume_size, {inputValue: volume_size},
|
|
--resource_encryption_key, {inputValue: resource_encryption_key},
|
|
--max_run_time, {inputValue: max_run_time},
|
|
--model_artifact_path, {inputValue: model_artifact_path},
|
|
--output_encryption_key, {inputValue: output_encryption_key},
|
|
--vpc_security_group_ids, {inputValue: vpc_security_group_ids},
|
|
--vpc_subnets, {inputValue: vpc_subnets},
|
|
--network_isolation, {inputValue: network_isolation},
|
|
--traffic_encryption, {inputValue: traffic_encryption},
|
|
--tags, {inputValue: tags}
|
|
]
|
|
fileOutputs:
|
|
model_artifact_url: /tmp/model_artifact_url.txt
|
|
job_name: /tmp/job_name.txt
|
|
training_image: /tmp/training_image.txt
|