pipelines/components/PyTorch/Create_fully_connected_network/component.py

45 lines
1.7 KiB
Python

from kfp.components import create_component_from_func, OutputPath
def create_fully_connected_pytorch_network(
layer_sizes: list,
network_path: OutputPath('PyTorchScriptModule'),
activation_name: str = 'relu',
random_seed: int = 0,
):
'''Creates fully-connected network in PyTorch ScriptModule format'''
import torch
torch.manual_seed(random_seed)
activation = getattr(torch, activation_name, None) or getattr(torch.nn.functional, activation_name, None)
if not activation:
raise ValueError(f'Activation "{activation_name}" was not found.')
class ActivationLayer(torch.nn.Module):
def forward(self, input):
return activation(input)
layers = []
for layer_idx in range(len(layer_sizes) - 1):
layer = torch.nn.Linear(layer_sizes[layer_idx], layer_sizes[layer_idx + 1])
layers.append(layer)
if layer_idx < len(layer_sizes) - 2:
layers.append(ActivationLayer())
network = torch.nn.Sequential(*layers)
script_module = torch.jit.script(network)
print(script_module)
script_module.save(network_path)
if __name__ == '__main__':
create_fully_connected_pytorch_network_op = create_component_from_func(
create_fully_connected_pytorch_network,
output_component_file='component.yaml',
base_image='pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime',
packages_to_install=[],
annotations={
"author": "Alexey Volkov <alexey.volkov@ark-kun.com>",
"canonical_location": "https://raw.githubusercontent.com/Ark-kun/pipeline_components/master/components/PyTorch/Create_fully_connected_network/component.yaml",
},
)