121 lines
3.9 KiB
Python
121 lines
3.9 KiB
Python
#!/usr/bin/env/python3
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""Generate component.yaml from templates"""
|
|
import json
|
|
import os
|
|
import shutil
|
|
import sys
|
|
|
|
import yaml
|
|
|
|
CURRENT_FOLDER = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
PIPELINES_HOME = os.path.join(CURRENT_FOLDER.split("pipelines")[0], "pipelines")
|
|
|
|
TEMPLATE_PATH = os.path.join(
|
|
PIPELINES_HOME, "components/PyTorch/pytorch-kfp-components/templates"
|
|
)
|
|
|
|
OUTPUT_YAML_FOLDER = "yaml"
|
|
|
|
|
|
def create_output_folder():
|
|
"""Removes the `yaml` folder and recreates it"""
|
|
if os.path.exists(OUTPUT_YAML_FOLDER):
|
|
shutil.rmtree(OUTPUT_YAML_FOLDER)
|
|
|
|
os.mkdir(OUTPUT_YAML_FOLDER)
|
|
|
|
|
|
def get_templates_list():
|
|
"""Get the list of template files from `templates` directory"""
|
|
assert os.path.exists(TEMPLATE_PATH)
|
|
templates_list = os.listdir(TEMPLATE_PATH)
|
|
return templates_list
|
|
|
|
|
|
def read_template(template_path: str):
|
|
"""Read the `componanent.yaml` template"""
|
|
with open(template_path, "r") as stream:
|
|
try:
|
|
template_dict = yaml.safe_load(stream)
|
|
except yaml.YAMLError as exc:
|
|
print(exc)
|
|
|
|
return template_dict
|
|
|
|
|
|
def replace_keys_in_template(template_dict: dict, mapping: dict):
|
|
"""Replace the keys, values in `component.yaml` based on `mapping` dict"""
|
|
|
|
# Sample mapping will be as below
|
|
# { "implementation.container.image" : "image_name" }
|
|
for nested_key, value in mapping.items():
|
|
|
|
# parse through each nested key
|
|
|
|
keys = nested_key.split(".")
|
|
accessable = template_dict
|
|
for k in keys[:-1]:
|
|
accessable = accessable[k]
|
|
accessable[keys[-1]] = value
|
|
|
|
return template_dict
|
|
|
|
|
|
def write_to_yaml_file(template_dict: dict, yaml_path: str):
|
|
"""Write yaml output into file"""
|
|
with open(yaml_path, "w") as pointer:
|
|
yaml.dump(template_dict, pointer)
|
|
|
|
|
|
def generate_component_yaml(mapping_template_path: str):
|
|
"""Method to generate component.yaml based on the template"""
|
|
mapping: dict = {}
|
|
if os.path.exists(mapping_template_path):
|
|
with open(mapping_template_path) as pointer:
|
|
mapping = json.load(pointer)
|
|
create_output_folder()
|
|
template_list = get_templates_list()
|
|
|
|
for template_name in template_list:
|
|
print("Processing {}".format(template_name))
|
|
|
|
# if the template name is not present in the mapping dictionary
|
|
# There is no change in the template, we can simply copy the template
|
|
# into the output yaml folder.
|
|
src = os.path.join(TEMPLATE_PATH, template_name)
|
|
dest = os.path.join(OUTPUT_YAML_FOLDER, template_name)
|
|
if not mapping or template_name not in mapping:
|
|
shutil.copy(src, dest)
|
|
else:
|
|
# if the mapping is specified, replace the key value pairs
|
|
# and then save the file
|
|
template_dict = read_template(template_path=src)
|
|
template_dict = replace_keys_in_template(
|
|
template_dict=template_dict, mapping=mapping[template_name]
|
|
)
|
|
write_to_yaml_file(template_dict=template_dict, yaml_path=dest)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if len(sys.argv) != 2:
|
|
raise Exception(
|
|
"\n\nUsage: "
|
|
"python utils/generate_templates.py "
|
|
"cifar10/template_mapping.json\n\n"
|
|
)
|
|
input_template_path = sys.argv[1]
|
|
generate_component_yaml(mapping_template_path=input_template_path)
|