32 lines
2.6 KiB
YAML
32 lines
2.6 KiB
YAML
name: Keras - Train classifier
|
|
description: Trains classifier using Keras sequential model
|
|
inputs:
|
|
- {name: training_set_features_path, type: {GcsPath: {data_type: TSV}}, description: 'Local or GCS path to the training set features table.'}
|
|
- {name: training_set_labels_path, type: {GcsPath: {data_type: TSV}}, description: 'Local or GCS path to the training set labels (each label is a class index from 0 to num-classes - 1).'}
|
|
- {name: output_model_uri, type: {GcsPath: {data_type: Keras model}}, description: 'Local or GCS path specifying where to save the trained model. The model (topology + weights + optimizer state) is saved in HDF5 format and can be loaded back by calling keras.models.load_model'} #Remove GcsUri and move to outputs once artifact passing support is checked in.
|
|
- {name: model_config, type: {GcsPath: {data_type: Keras model config json}}, description: 'JSON string containing the serialized model structure. Can be obtained by calling model.to_json() on a Keras model.'}
|
|
- {name: number_of_classes, type: Integer, description: 'Number of classifier classes.'}
|
|
- {name: number_of_epochs, type: Integer, default: '100', description: 'Number of epochs to train the model. An epoch is an iteration over the entire `x` and `y` data provided.'}
|
|
- {name: batch_size, type: Integer, default: '32', description: 'Number of samples per gradient update.'}
|
|
outputs:
|
|
- {name: output_model_uri, type: {GcsPath: {data_type: Keras model}}, description: 'GCS path where the trained model has been saved. The model (topology + weights + optimizer state) is saved in HDF5 format and can be loaded back by calling keras.models.load_model'} #Remove GcsUri and make it a proper output once artifact passing support is checked in.
|
|
metadata:
|
|
annotations:
|
|
author: Alexey Volkov <alexey.volkov@ark-kun.com>
|
|
canonical_location: 'https://raw.githubusercontent.com/Ark-kun/pipeline_components/master/components/sample/keras/train_classifier/component.yaml'
|
|
implementation:
|
|
container:
|
|
image: gcr.io/ml-pipeline/sample/keras/train_classifier
|
|
command: [python3, /pipelines/component/src/train.py]
|
|
args: [
|
|
--training-set-features-path, {inputValue: training_set_features_path},
|
|
--training-set-labels-path, {inputValue: training_set_labels_path},
|
|
--output-model-path, {inputValue: output_model_uri},
|
|
--model-config-json, {inputValue: model_config},
|
|
--num-classes, {inputValue: number_of_classes},
|
|
--num-epochs, {inputValue: number_of_epochs},
|
|
--batch-size, {inputValue: batch_size},
|
|
|
|
--output-model-path-file, {outputPath: output_model_uri},
|
|
]
|