feat(components): PyTorch - Convert to ONNX from PyTorch ScriptModule (#5207)
This commit is contained in:
parent
dfa756317e
commit
fc7afdd0a0
|
|
@ -0,0 +1,36 @@
|
|||
from kfp.components import create_component_from_func, InputPath, OutputPath
|
||||
|
||||
|
||||
def convert_to_onnx_from_pytorch_script_module(
|
||||
model_path: InputPath('PyTorchScriptModule'),
|
||||
converted_model_path: OutputPath('OnnxModel'),
|
||||
list_of_input_shapes: list,
|
||||
):
|
||||
'''Creates fully-connected network in PyTorch ScriptModule format'''
|
||||
import torch
|
||||
model = torch.jit.load(model_path)
|
||||
example_inputs = [
|
||||
torch.ones(*input_shape)
|
||||
for input_shape in list_of_input_shapes
|
||||
]
|
||||
example_outputs = model.forward(*example_inputs)
|
||||
torch.onnx.export(
|
||||
model=model,
|
||||
args=example_inputs,
|
||||
f=converted_model_path,
|
||||
verbose=True,
|
||||
training=torch.onnx.TrainingMode.EVAL,
|
||||
example_outputs=example_outputs,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
convert_to_onnx_from_pytorch_script_module_op = create_component_from_func(
|
||||
convert_to_onnx_from_pytorch_script_module,
|
||||
output_component_file='component.yaml',
|
||||
base_image='pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime',
|
||||
packages_to_install=[],
|
||||
annotations={
|
||||
"author": "Alexey Volkov <alexey.volkov@ark-kun.com>",
|
||||
},
|
||||
)
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
name: Convert to onnx from pytorch script module
|
||||
description: Creates fully-connected network in PyTorch ScriptModule format
|
||||
metadata:
|
||||
annotations: {author: Alexey Volkov <alexey.volkov@ark-kun.com>}
|
||||
inputs:
|
||||
- {name: model, type: PyTorchScriptModule}
|
||||
- {name: list_of_input_shapes, type: JsonArray}
|
||||
outputs:
|
||||
- {name: converted_model, type: OnnxModel}
|
||||
implementation:
|
||||
container:
|
||||
image: pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime
|
||||
command:
|
||||
- sh
|
||||
- -ec
|
||||
- |
|
||||
program_path=$(mktemp)
|
||||
printf "%s" "$0" > "$program_path"
|
||||
python3 -u "$program_path" "$@"
|
||||
- |
|
||||
def _make_parent_dirs_and_return_path(file_path: str):
|
||||
import os
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
return file_path
|
||||
|
||||
def convert_to_onnx_from_pytorch_script_module(
|
||||
model_path,
|
||||
converted_model_path,
|
||||
list_of_input_shapes,
|
||||
):
|
||||
'''Creates fully-connected network in PyTorch ScriptModule format'''
|
||||
import torch
|
||||
model = torch.jit.load(model_path)
|
||||
example_inputs = [
|
||||
torch.ones(*input_shape)
|
||||
for input_shape in list_of_input_shapes
|
||||
]
|
||||
example_outputs = model.forward(*example_inputs)
|
||||
torch.onnx.export(
|
||||
model=model,
|
||||
args=example_inputs,
|
||||
f=converted_model_path,
|
||||
verbose=True,
|
||||
training=torch.onnx.TrainingMode.EVAL,
|
||||
example_outputs=example_outputs,
|
||||
)
|
||||
|
||||
import json
|
||||
import argparse
|
||||
_parser = argparse.ArgumentParser(prog='Convert to onnx from pytorch script module', description='Creates fully-connected network in PyTorch ScriptModule format')
|
||||
_parser.add_argument("--model", dest="model_path", type=str, required=True, default=argparse.SUPPRESS)
|
||||
_parser.add_argument("--list-of-input-shapes", dest="list_of_input_shapes", type=json.loads, required=True, default=argparse.SUPPRESS)
|
||||
_parser.add_argument("--converted-model", dest="converted_model_path", type=_make_parent_dirs_and_return_path, required=True, default=argparse.SUPPRESS)
|
||||
_parsed_args = vars(_parser.parse_args())
|
||||
|
||||
_outputs = convert_to_onnx_from_pytorch_script_module(**_parsed_args)
|
||||
args:
|
||||
- --model
|
||||
- {inputPath: model}
|
||||
- --list-of-input-shapes
|
||||
- {inputValue: list_of_input_shapes}
|
||||
- --converted-model
|
||||
- {outputPath: converted_model}
|
||||
Loading…
Reference in New Issue