mirror of https://github.com/kubeflow/examples.git
55 lines
1.7 KiB
Python
55 lines
1.7 KiB
Python
from __future__ import print_function
|
|
|
|
import logging
|
|
|
|
import tornado.web
|
|
from tornado import gen
|
|
from tornado.options import define, options, parse_command_line
|
|
from keras.models import load_model
|
|
import dill as dpickle
|
|
from seq2seq_utils import Seq2Seq_Inference
|
|
|
|
define("port", default=8888, help="run on the given port", type=int)
|
|
define("instances_key", default='instances', help="requested instances json object key")
|
|
|
|
|
|
class PredictHandler(tornado.web.RequestHandler):
|
|
@gen.coroutine
|
|
def post(self):
|
|
request_key = self.settings['request_key']
|
|
request_data = tornado.escape.json_decode(self.request.body)
|
|
model = self.settings['model']
|
|
predictions = [model.generate_issue_title(body)[1] for body in request_data[request_key]]
|
|
self.write(dict(predictions=predictions))
|
|
|
|
|
|
class IndexHandler(tornado.web.RequestHandler):
|
|
def get(self):
|
|
self.write('Hello World')
|
|
|
|
|
|
def main():
|
|
parse_command_line()
|
|
with open('body_pp.dpkl', 'rb') as f:
|
|
body_pp = dpickle.load(f)
|
|
with open('title_pp.dpkl', 'rb') as f:
|
|
title_pp = dpickle.load(f)
|
|
model = Seq2Seq_Inference(encoder_preprocessor=body_pp,
|
|
decoder_preprocessor=title_pp,
|
|
seq2seq_model=load_model('seq2seq_model_tutorial.h5'))
|
|
app = tornado.web.Application(
|
|
[
|
|
(r"/predict", PredictHandler),
|
|
(r"/", IndexHandler),
|
|
],
|
|
xsrf_cookies=False,
|
|
request_key=options.instances_key,
|
|
model=model)
|
|
app.listen(options.port)
|
|
logging.info('running at http://localhost:%s' % options.port)
|
|
tornado.ioloop.IOLoop.current().start()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|