examples/github_issue_summarization/notebooks/server.py

55 lines
1.7 KiB
Python

from __future__ import print_function
import logging
import tornado.web
from tornado import gen
from tornado.options import define, options, parse_command_line
from keras.models import load_model
import dill as dpickle
from seq2seq_utils import Seq2Seq_Inference
define("port", default=8888, help="run on the given port", type=int)
define("instances_key", default='instances', help="requested instances json object key")
class PredictHandler(tornado.web.RequestHandler):
@gen.coroutine
def post(self):
request_key = self.settings['request_key']
request_data = tornado.escape.json_decode(self.request.body)
model = self.settings['model']
predictions = [model.generate_issue_title(body)[1] for body in request_data[request_key]]
self.write(dict(predictions=predictions))
class IndexHandler(tornado.web.RequestHandler):
def get(self):
self.write('Hello World')
def main():
parse_command_line()
with open('body_pp.dpkl', 'rb') as f:
body_pp = dpickle.load(f)
with open('title_pp.dpkl', 'rb') as f:
title_pp = dpickle.load(f)
model = Seq2Seq_Inference(encoder_preprocessor=body_pp,
decoder_preprocessor=title_pp,
seq2seq_model=load_model('seq2seq_model_tutorial.h5'))
app = tornado.web.Application(
[
(r"/predict", PredictHandler),
(r"/", IndexHandler),
],
xsrf_cookies=False,
request_key=options.instances_key,
model=model)
app.listen(options.port)
logging.info('running at http://localhost:%s' % options.port)
tornado.ioloop.IOLoop.current().start()
if __name__ == "__main__":
main()