feat(components): PyTorch - Added the Create PyTorch Model Archive component (#5630)
* Components - PyTorch - Added the Create PyTorch Model Archive component * Added component to the sample pipeline
This commit is contained in:
parent
e20a1b8bda
commit
b08b29f46e
|
|
@ -0,0 +1,39 @@
|
|||
name: Create PyTorch Model Archive
|
||||
inputs:
|
||||
- {name: Model, type: PyTorchScriptModule}
|
||||
- {name: Model name, type: String, default: model}
|
||||
- {name: Model version, type: String, default: "1.0"}
|
||||
- {name: Handler, type: PythonCode, description: "See https://github.com/pytorch/serve/blob/master/docs/custom_service.md"}
|
||||
outputs:
|
||||
- {name: Model archive, type: PyTorchModelArchive}
|
||||
metadata:
|
||||
annotations:
|
||||
author: Alexey Volkov <alexey.volkov@ark-kun.com>
|
||||
implementation:
|
||||
container:
|
||||
image: pytorch/torchserve:0.3.0-cpu
|
||||
command:
|
||||
- bash
|
||||
- -exc
|
||||
- |
|
||||
model_path=$0
|
||||
handler_path=$1
|
||||
model_name=$2
|
||||
model_version=$3
|
||||
output_model_archive_path=$4
|
||||
|
||||
mkdir -p "$(dirname "$output_model_archive_path")"
|
||||
|
||||
# torch-model-archiver needs the handler to have .py extension
|
||||
cp "$handler_path" handler.py
|
||||
torch-model-archiver --model-name "$model_name" --version "$model_version" --serialized-file "$model_path" --handler handler.py
|
||||
|
||||
# torch-model-archiver does not allow specifying the output path, but always writes to "${model_name}.<format>"
|
||||
expected_model_archive_path="${model_name}.mar"
|
||||
mv "$expected_model_archive_path" "$output_model_archive_path"
|
||||
|
||||
- {inputPath: Model}
|
||||
- {inputPath: Handler}
|
||||
- {inputValue: Model name}
|
||||
- {inputValue: Model version}
|
||||
- {outputPath: Model archive}
|
||||
|
|
@ -3,10 +3,12 @@ from kfp import components
|
|||
|
||||
chicago_taxi_dataset_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/e3337b8bdcd63636934954e592d4b32c95b49129/components/datasets/Chicago%20Taxi/component.yaml')
|
||||
pandas_transform_csv_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/6162d55998b176b50267d351241100bb0ee715bc/components/pandas/Transform_DataFrame/in_CSV_format/component.yaml')
|
||||
download_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/240543e483076ae718f82c6f280441daa2f041fd/components/web/Download/component.yaml')
|
||||
|
||||
create_fully_connected_pytorch_network_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/4e1facea1a270535b515a9e8cc59422d1ad76a9e/components/PyTorch/Create_fully_connected_network/component.yaml')
|
||||
train_pytorch_model_from_csv_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/603342c4b88fe2d69ff07682f702cd3601e883bb/components/PyTorch/Train_PyTorch_model/from_CSV/component.yaml')
|
||||
convert_to_onnx_from_pytorch_script_module_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/e011e4affa85542ef2b24d63fdac27f8d939bbee/components/PyTorch/Convert_to_OnnxModel_from_PyTorchScriptModule/component.yaml')
|
||||
create_pytorch_model_archive_op = components.load_component_from_url('https://raw.githubusercontent.com/kubeflow/pipelines/abc180be2b2b5538d19eb87124684629ec45e620/components/PyTorch/Create_PyTorch_Model_Archive/component.yaml')
|
||||
|
||||
|
||||
def pytorch_pipeline():
|
||||
|
|
@ -47,6 +49,16 @@ def pytorch_pipeline():
|
|||
list_of_input_shapes=[[len(feature_columns)]],
|
||||
)
|
||||
|
||||
# TODO: Use a real working regression handler here. See https://github.com/pytorch/serve/issues/987
|
||||
serving_handler = download_op('https://raw.githubusercontent.com/pytorch/serve/5c03e711a401387a1d42fc01072fcc38b4995b66/ts/torch_handler/base_handler.py').output
|
||||
|
||||
model_archive = create_pytorch_model_archive_op(
|
||||
model=trained_model,
|
||||
handler=serving_handler,
|
||||
# model_name="model", # Optional
|
||||
# model_version="1.0", # Optional
|
||||
).output
|
||||
|
||||
if __name__ == '__main__':
|
||||
import kfp
|
||||
kfp_endpoint=None
|
||||
|
|
|
|||
Loading…
Reference in New Issue