84 lines
3.3 KiB
YAML
84 lines
3.3 KiB
YAML
name: Create fully connected pytorch network
|
|
description: Creates fully-connected network in PyTorch ScriptModule format
|
|
metadata:
|
|
annotations:
|
|
author: Alexey Volkov <alexey.volkov@ark-kun.com>
|
|
canonical_location: 'https://raw.githubusercontent.com/Ark-kun/pipeline_components/master/components/PyTorch/Create_fully_connected_network/component.yaml'
|
|
inputs:
|
|
- {name: layer_sizes, type: JsonArray}
|
|
- {name: activation_name, type: String, default: relu, optional: true}
|
|
- {name: random_seed, type: Integer, default: '0', optional: true}
|
|
outputs:
|
|
- {name: network, type: PyTorchScriptModule}
|
|
implementation:
|
|
container:
|
|
image: pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime
|
|
command:
|
|
- sh
|
|
- -ec
|
|
- |
|
|
program_path=$(mktemp)
|
|
printf "%s" "$0" > "$program_path"
|
|
python3 -u "$program_path" "$@"
|
|
- |
|
|
def _make_parent_dirs_and_return_path(file_path: str):
|
|
import os
|
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
return file_path
|
|
|
|
def create_fully_connected_pytorch_network(
|
|
layer_sizes,
|
|
network_path,
|
|
activation_name = 'relu',
|
|
random_seed = 0,
|
|
):
|
|
'''Creates fully-connected network in PyTorch ScriptModule format'''
|
|
import torch
|
|
torch.manual_seed(random_seed)
|
|
|
|
activation = getattr(torch, activation_name, None) or getattr(torch.nn.functional, activation_name, None)
|
|
if not activation:
|
|
raise ValueError(f'Activation "{activation_name}" was not found.')
|
|
|
|
class ActivationLayer(torch.nn.Module):
|
|
def forward(self, input):
|
|
return activation(input)
|
|
|
|
layers = []
|
|
for layer_idx in range(len(layer_sizes) - 1):
|
|
layer = torch.nn.Linear(layer_sizes[layer_idx], layer_sizes[layer_idx + 1])
|
|
layers.append(layer)
|
|
if layer_idx < len(layer_sizes) - 2:
|
|
layers.append(ActivationLayer())
|
|
|
|
network = torch.nn.Sequential(*layers)
|
|
script_module = torch.jit.script(network)
|
|
print(script_module)
|
|
script_module.save(network_path)
|
|
|
|
import json
|
|
import argparse
|
|
_parser = argparse.ArgumentParser(prog='Create fully connected pytorch network', description='Creates fully-connected network in PyTorch ScriptModule format')
|
|
_parser.add_argument("--layer-sizes", dest="layer_sizes", type=json.loads, required=True, default=argparse.SUPPRESS)
|
|
_parser.add_argument("--activation-name", dest="activation_name", type=str, required=False, default=argparse.SUPPRESS)
|
|
_parser.add_argument("--random-seed", dest="random_seed", type=int, required=False, default=argparse.SUPPRESS)
|
|
_parser.add_argument("--network", dest="network_path", type=_make_parent_dirs_and_return_path, required=True, default=argparse.SUPPRESS)
|
|
_parsed_args = vars(_parser.parse_args())
|
|
|
|
_outputs = create_fully_connected_pytorch_network(**_parsed_args)
|
|
args:
|
|
- --layer-sizes
|
|
- {inputValue: layer_sizes}
|
|
- if:
|
|
cond: {isPresent: activation_name}
|
|
then:
|
|
- --activation-name
|
|
- {inputValue: activation_name}
|
|
- if:
|
|
cond: {isPresent: random_seed}
|
|
then:
|
|
- --random-seed
|
|
- {inputValue: random_seed}
|
|
- --network
|
|
- {outputPath: network}
|