pipelines/components/contrib/sample/keras/train_classifier/component.yaml

32 lines
2.6 KiB
YAML

name: Keras - Train classifier
description: Trains classifier using Keras sequential model
inputs:
- {name: training_set_features_path, type: {GcsPath: {data_type: TSV}}, description: 'Local or GCS path to the training set features table.'}
- {name: training_set_labels_path, type: {GcsPath: {data_type: TSV}}, description: 'Local or GCS path to the training set labels (each label is a class index from 0 to num-classes - 1).'}
- {name: output_model_uri, type: {GcsPath: {data_type: Keras model}}, description: 'Local or GCS path specifying where to save the trained model. The model (topology + weights + optimizer state) is saved in HDF5 format and can be loaded back by calling keras.models.load_model'} #Remove GcsUri and move to outputs once artifact passing support is checked in.
- {name: model_config, type: {GcsPath: {data_type: Keras model config json}}, description: 'JSON string containing the serialized model structure. Can be obtained by calling model.to_json() on a Keras model.'}
- {name: number_of_classes, type: Integer, description: 'Number of classifier classes.'}
- {name: number_of_epochs, type: Integer, default: '100', description: 'Number of epochs to train the model. An epoch is an iteration over the entire `x` and `y` data provided.'}
- {name: batch_size, type: Integer, default: '32', description: 'Number of samples per gradient update.'}
outputs:
- {name: output_model_uri, type: {GcsPath: {data_type: Keras model}}, description: 'GCS path where the trained model has been saved. The model (topology + weights + optimizer state) is saved in HDF5 format and can be loaded back by calling keras.models.load_model'} #Remove GcsUri and make it a proper output once artifact passing support is checked in.
metadata:
annotations:
author: Alexey Volkov <alexey.volkov@ark-kun.com>
canonical_location: 'https://raw.githubusercontent.com/Ark-kun/pipeline_components/master/components/sample/keras/train_classifier/component.yaml'
implementation:
container:
image: gcr.io/ml-pipeline/sample/keras/train_classifier
command: [python3, /pipelines/component/src/train.py]
args: [
--training-set-features-path, {inputValue: training_set_features_path},
--training-set-labels-path, {inputValue: training_set_labels_path},
--output-model-path, {inputValue: output_model_uri},
--model-config-json, {inputValue: model_config},
--num-classes, {inputValue: number_of_classes},
--num-epochs, {inputValue: number_of_epochs},
--batch-size, {inputValue: batch_size},
--output-model-path-file, {outputPath: output_model_uri},
]