mirror of https://github.com/kubeflow/examples.git
				
				
				
			
		
			
				
	
	
		
			84 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
			
		
		
	
	
			84 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Python
		
	
	
	
| '''
 | |
| Copyright 2018 Google LLC
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     https://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License.
 | |
| '''
 | |
| 
 | |
| import os
 | |
| import uuid
 | |
| import logging
 | |
| from threading import Timer
 | |
| 
 | |
| from flask import Flask, render_template, request
 | |
| from mnist_client import get_prediction, random_mnist
 | |
| 
 | |
| app = Flask(__name__)
 | |
| 
 | |
| 
 | |
| # handle requests to the server
 | |
| @app.route("/")
 | |
| def main():
 | |
|   # get url parameters for HTML template
 | |
|   name_arg = request.args.get('name', 'mnist-classifier')
 | |
|   addr_arg = request.args.get('addr', 'ambassador')
 | |
|   port_arg = request.args.get('port', '80')
 | |
|   args = {"name": name_arg, "addr": addr_arg, "port": port_arg}
 | |
|   logging.info(args)
 | |
| 
 | |
|   output = None
 | |
|   connection = {"text": "", "success": False}
 | |
|   img_id = str(uuid.uuid4())
 | |
|   img_path = "static/tmp/" + img_id + ".png"
 | |
|   try:
 | |
|     # get a random test MNIST image
 | |
|     x, y, _ = random_mnist(img_path)
 | |
|     # get prediction from TensorFlow server
 | |
|     pred = get_prediction(x, server_host=addr_arg, server_port=int(port_arg),
 | |
|                           deployment_name=name_arg)
 | |
|     # if no exceptions thrown, server connection was a success
 | |
|     connection["text"] = "Connected to Seldon GRPC model serving service"
 | |
|     connection["success"] = True
 | |
|     # parse class confidence scores from server prediction
 | |
|     scores_dict = []
 | |
|     for i in range(0, 10):
 | |
|       scores_dict += [{"index": str(i), "val": pred[i]}]
 | |
|     output = {"truth": y,
 | |
|               "img_path": img_path, "scores": scores_dict}
 | |
|   except IOError as e:
 | |
|     # server connection failed
 | |
|     connection["text"] = "Could Not Connect to Server: " + str(e)
 | |
|   # after 10 seconds, delete cached image file from server
 | |
|   t = Timer(10.0, remove_resource, [img_path])
 | |
|   t.start()
 | |
|   # render results using HTML template
 | |
|   return render_template('index.html', output=output,
 | |
|                          connection=connection, args=args)
 | |
| 
 | |
| 
 | |
| def remove_resource(path):
 | |
|   """
 | |
|   attempt to delete file from path. Used to clean up MNIST testing images
 | |
| 
 | |
|   :param path: the path of the file to delete
 | |
|   """
 | |
|   try:
 | |
|     os.remove(path)
 | |
|     logging.info("removed %s", path)
 | |
|   except OSError:
 | |
|     logging.error("no file at %s", path)
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|   logging.getLogger().setLevel(logging.INFO)
 | |
|   app.run(debug=True, host='0.0.0.0')
 |