pipelines/components/aws/sagemaker/train/component.yaml

132 lines
5.6 KiB
YAML

name: 'Sagemaker - Training Job'
description: |
Train Machine Learning and Deep Learning Models using SageMaker
inputs:
- name: region
description: 'The region where the training job launches.'
- name: job_name
description: 'The name of the batch training job.'
default: ''
- name: role
description: 'The Amazon Resource Name (ARN) that Amazon SageMaker assumes to perform tasks on your behalf.'
- name: image
description: 'The registry path of the Docker image that contains the training algorithm.'
default: ''
- name: algorithm_name
description: 'The name of the algorithm resource to use for the training job. Do not specify a value for this if using training image.'
default: ''
- name: metric_definitions
description: 'The dictionary of name-regex pairs specify the metrics that the algorithm emits.'
default: '{}'
- name: training_input_mode
description: 'The input mode that the algorithm supports. File or Pipe.'
default: 'File'
- name: hyperparameters
description: 'Dictionary of hyperparameters for the the algorithm.'
default: '{}'
- name: channels
description: 'A list of dicts specifying the input channels. Must have at least one.'
- name: data_location_1
description: 'The S3 URI of the input data source for channel 1.'
default: ''
- name: data_location_2
description: 'The S3 URI of the input data source for channel 2.'
default: ''
- name: data_location_3
description: 'The S3 URI of the input data source for channel 3.'
default: ''
- name: data_location_4
description: 'The S3 URI of the input data source for channel 4.'
default: ''
- name: data_location_5
description: 'The S3 URI of the input data source for channel 5.'
default: ''
- name: data_location_6
description: 'The S3 URI of the input data source for channel 6.'
default: ''
- name: data_location_7
description: 'The S3 URI of the input data source for channel 7.'
default: ''
- name: data_location_8
description: 'The S3 URI of the input data source for channel 8.'
default: ''
- name: instance_type
description: 'The ML compute instance type.'
default: 'ml.m4.xlarge'
- name: instance_count
description: 'The number of ML compute instances to use in each training job.'
default: '1'
- name: volume_size
description: 'The size of the ML storage volume that you want to provision.'
default: '1'
- name: resource_encryption_key
description: 'The AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s).'
default: ''
- name: max_run_time
description: 'The maximum run time in seconds for the training job.'
default: '86400'
- name: model_artifact_path
description: 'Identifies the S3 path where you want Amazon SageMaker to store the model artifacts.'
- name: output_encryption_key
description: 'The AWS KMS key that Amazon SageMaker uses to encrypt the model artifacts.'
default: ''
- name: vpc_security_group_ids
description: 'The VPC security group IDs, in the form sg-xxxxxxxx.'
default: ''
- name: vpc_subnets
description: 'The ID of the subnets in the VPC to which you want to connect your hpo job.'
default: ''
- name: network_isolation
description: 'Isolates the training container.'
default: 'True'
- name: traffic_encryption
description: 'Encrypts all communications between ML compute instances in distributed training.'
default: 'False'
- name: tags
description: 'Key-value pairs, to categorize AWS resources.'
default: '{}'
outputs:
- {name: model_artifact_url, description: 'Model artifacts url'}
- {name: job_name, description: 'Training job name'}
- {name: training_image, description: 'The registry path of the Docker image that contains the training algorithm'}
implementation:
container:
image: carowang/kubeflow-pipeline-aws-sm:20190809-02
command: ['python']
args: [
train.py,
--region, {inputValue: region},
--job_name, {inputValue: job_name},
--role, {inputValue: role},
--image, {inputValue: image},
--algorithm_name, {inputValue: algorithm_name},
--metric_definitions, {inputValue: metric_definitions},
--training_input_mode, {inputValue: training_input_mode},
--hyperparameters, {inputValue: hyperparameters},
--channels, {inputValue: channels},
--data_location_1, {inputValue: data_location_1},
--data_location_2, {inputValue: data_location_2},
--data_location_3, {inputValue: data_location_3},
--data_location_4, {inputValue: data_location_4},
--data_location_5, {inputValue: data_location_5},
--data_location_6, {inputValue: data_location_6},
--data_location_7, {inputValue: data_location_7},
--data_location_8, {inputValue: data_location_8},
--instance_type, {inputValue: instance_type},
--instance_count, {inputValue: instance_count},
--volume_size, {inputValue: volume_size},
--resource_encryption_key, {inputValue: resource_encryption_key},
--max_run_time, {inputValue: max_run_time},
--model_artifact_path, {inputValue: model_artifact_path},
--output_encryption_key, {inputValue: output_encryption_key},
--vpc_security_group_ids, {inputValue: vpc_security_group_ids},
--vpc_subnets, {inputValue: vpc_subnets},
--network_isolation, {inputValue: network_isolation},
--traffic_encryption, {inputValue: traffic_encryption},
--tags, {inputValue: tags}
]
fileOutputs:
model_artifact_url: /tmp/model_artifact_url.txt
job_name: /tmp/job_name.txt
training_image: /tmp/training_image.txt