217 lines
7.9 KiB
Python
217 lines
7.9 KiB
Python
import tensorflow as tf
|
|
from tensorflow.python.saved_model import signature_constants
|
|
from tensorflow.python.saved_model import tag_constants
|
|
from nets import nets_factory
|
|
from tensorflow.python.platform import gfile
|
|
import argparse
|
|
import validators
|
|
import os
|
|
import requests
|
|
import tarfile
|
|
from subprocess import Popen, PIPE
|
|
import shutil
|
|
import glob
|
|
import re
|
|
import json
|
|
from tensorflow.python.tools.freeze_graph import freeze_graph
|
|
from tensorflow.python.tools.saved_model_cli import _show_all
|
|
from urllib.parse import urlparse
|
|
from shutil import copyfile
|
|
from google.cloud import storage
|
|
|
|
|
|
def upload_to_gcs(src, dst):
|
|
parsed_path = urlparse(dst)
|
|
bucket_name = parsed_path.netloc
|
|
file_path = parsed_path.path[1:]
|
|
gs_client = storage.Client()
|
|
bucket = gs_client.get_bucket(bucket_name)
|
|
blob = bucket.blob(file_path)
|
|
blob.upload_from_filename(src)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description='Slim model generator')
|
|
parser.add_argument('--model_name', type=str,
|
|
help='')
|
|
parser.add_argument('--export_dir', type=str, default="/tmp/export_dir",
|
|
help='GCS or local path to save graph files')
|
|
parser.add_argument('--saved_model_dir', type=str,
|
|
help='GCS or local path to save the generated model')
|
|
parser.add_argument('--batch_size', type=str, default=1,
|
|
help='batch size to be used in the exported model')
|
|
parser.add_argument('--checkpoint_url', type=str,
|
|
help='URL to the pretrained compressed checkpoint')
|
|
parser.add_argument('--num_classes', type=int, default=1000,
|
|
help='number of model classes')
|
|
args = parser.parse_args()
|
|
|
|
MODEL = args.model_name
|
|
URL = args.checkpoint_url
|
|
if not validators.url(args.checkpoint_url):
|
|
print('use a valid URL parameter')
|
|
exit(1)
|
|
TMP_DIR = "/tmp/slim_tmp"
|
|
NUM_CLASSES = args.num_classes
|
|
BATCH_SIZE = args.batch_size
|
|
MODEL_FILE_NAME = URL.rsplit('/', 1)[-1]
|
|
EXPORT_DIR = args.export_dir
|
|
SAVED_MODEL_DIR = args.saved_model_dir
|
|
|
|
tmp_graph_file = os.path.join(TMP_DIR, MODEL + '_graph.pb')
|
|
export_graph_file = os.path.join(EXPORT_DIR, MODEL + '_graph.pb')
|
|
frozen_file = os.path.join(EXPORT_DIR, 'frozen_graph_' + MODEL + '.pb')
|
|
|
|
if not os.path.exists(TMP_DIR):
|
|
os.makedirs(TMP_DIR)
|
|
|
|
if not os.path.exists(TMP_DIR + '/' + MODEL_FILE_NAME):
|
|
print("Downloading and decompressing the model checkpoint...")
|
|
response = requests.get(URL, stream=True)
|
|
with open(os.path.join(TMP_DIR, MODEL_FILE_NAME), 'wb') as output:
|
|
output.write(response.content)
|
|
tar = tarfile.open(os.path.join(TMP_DIR, MODEL_FILE_NAME))
|
|
tar.extractall(path=TMP_DIR)
|
|
tar.close()
|
|
print("Model checkpoint downloaded and decompressed to:", TMP_DIR)
|
|
else:
|
|
print("Reusing existing model file ",
|
|
os.path.join(TMP_DIR, MODEL_FILE_NAME))
|
|
|
|
checkpoint = glob.glob(TMP_DIR + '/*.ckpt*')
|
|
print("checkpoint", checkpoint)
|
|
if len(checkpoint) > 0:
|
|
m = re.match(r"([\S]*.ckpt)", checkpoint[-1])
|
|
print("checkpoint match", m)
|
|
checkpoint = m[0]
|
|
print(checkpoint)
|
|
else:
|
|
print("checkpoint file not detected in " + URL)
|
|
exit(1)
|
|
|
|
print("Saving graph def file")
|
|
with tf.Graph().as_default() as graph:
|
|
|
|
network_fn = nets_factory.get_network_fn(MODEL,
|
|
num_classes=NUM_CLASSES,
|
|
is_training=False)
|
|
image_size = network_fn.default_image_size
|
|
if BATCH_SIZE == "None" or BATCH_SIZE == "-1":
|
|
batchsize = None
|
|
else:
|
|
batchsize = BATCH_SIZE
|
|
placeholder = tf.placeholder(name='input', dtype=tf.float32,
|
|
shape=[batchsize, image_size,
|
|
image_size, 3])
|
|
network_fn(placeholder)
|
|
graph_def = graph.as_graph_def()
|
|
|
|
with gfile.GFile(tmp_graph_file, 'wb') as f:
|
|
f.write(graph_def.SerializeToString())
|
|
if urlparse(EXPORT_DIR).scheme == 'gs':
|
|
upload_to_gcs(tmp_graph_file, export_graph_file)
|
|
elif urlparse(EXPORT_DIR).scheme == '':
|
|
if not os.path.exists(EXPORT_DIR):
|
|
os.makedirs(EXPORT_DIR)
|
|
copyfile(tmp_graph_file, export_graph_file)
|
|
else:
|
|
print("Invalid format of model export path")
|
|
print("Graph file saved to ",
|
|
os.path.join(EXPORT_DIR, MODEL + '_graph.pb'))
|
|
|
|
print("Analysing graph")
|
|
p = Popen("./summarize_graph --in_graph=" + tmp_graph_file +
|
|
" --print_structure=false", shell=True, stdout=PIPE, stderr=PIPE)
|
|
summary, err = p.communicate()
|
|
inputs = []
|
|
outputs = []
|
|
for line in summary.split(b'\n'):
|
|
line_str = line.decode()
|
|
if re.match(r"Found [\d]* possible inputs", line_str) is not None:
|
|
print("in", line)
|
|
m = re.findall(r'name=[\S]*,', line.decode())
|
|
for match in m:
|
|
print("match", match)
|
|
input = match[5:-1]
|
|
inputs.append(input)
|
|
print("inputs", inputs)
|
|
|
|
if re.match(r"Found [\d]* possible outputs", line_str) is not None:
|
|
print("out", line)
|
|
m = re.findall(r'name=[\S]*,', line_str)
|
|
for match in m:
|
|
print("match", match)
|
|
output = match[5:-1]
|
|
outputs.append(output)
|
|
print("outputs", outputs)
|
|
|
|
output_node_names = ",".join(outputs)
|
|
print("Creating freezed graph based on pretrained checkpoint")
|
|
freeze_graph(input_graph=tmp_graph_file,
|
|
input_checkpoint=checkpoint,
|
|
input_binary=True,
|
|
clear_devices=True,
|
|
input_saver='',
|
|
output_node_names=output_node_names,
|
|
restore_op_name="save/restore_all",
|
|
filename_tensor_name="save/Const:0",
|
|
output_graph=frozen_file,
|
|
initializer_nodes="")
|
|
if urlparse(SAVED_MODEL_DIR).scheme == '' and \
|
|
os.path.exists(SAVED_MODEL_DIR):
|
|
shutil.rmtree(SAVED_MODEL_DIR)
|
|
|
|
builder = tf.saved_model.builder.SavedModelBuilder(SAVED_MODEL_DIR)
|
|
|
|
with tf.gfile.GFile(frozen_file, "rb") as f:
|
|
graph_def = tf.GraphDef()
|
|
graph_def.ParseFromString(f.read())
|
|
|
|
sigs = {}
|
|
|
|
with tf.Session(graph=tf.Graph()) as sess:
|
|
tf.import_graph_def(graph_def, name="")
|
|
g = tf.get_default_graph()
|
|
inp_dic = {}
|
|
for inp in inputs:
|
|
inp_t = g.get_tensor_by_name(inp+":0")
|
|
inp_dic[inp] = inp_t
|
|
out_dic = {}
|
|
for out in outputs:
|
|
out_t = g.get_tensor_by_name(out+":0")
|
|
out_dic[out] = out_t
|
|
|
|
sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
|
|
tf.saved_model.signature_def_utils.predict_signature_def(
|
|
inp_dic, out_dic)
|
|
|
|
builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING],
|
|
signature_def_map=sigs)
|
|
print("Exporting saved model to:", SAVED_MODEL_DIR + ' ...')
|
|
builder.save()
|
|
|
|
print("Saved model exported to:", SAVED_MODEL_DIR)
|
|
_show_all(SAVED_MODEL_DIR)
|
|
pb_visual_writer = tf.summary.FileWriter(SAVED_MODEL_DIR)
|
|
pb_visual_writer.add_graph(sess.graph)
|
|
print("Visualize the model by running: "
|
|
"tensorboard --logdir={}".format(EXPORT_DIR))
|
|
with open('/tmp/saved_model_dir.txt', 'w') as f:
|
|
f.write(SAVED_MODEL_DIR)
|
|
with open('/tmp/export_dir.txt', 'w') as f:
|
|
f.write(EXPORT_DIR)
|
|
|
|
artifacts = {"version": 1,"outputs": [
|
|
{
|
|
"type": "tensorboard",
|
|
"source": SAVED_MODEL_DIR
|
|
}
|
|
]
|
|
}
|
|
with open('/mlpipeline-ui-metadata.json', 'w') as f:
|
|
json.dump(artifacts, f)
|
|
|
|
if __name__ == "__main__":
|
|
main()
|