mirror of https://github.com/kubeflow/examples.git
91 lines
2.8 KiB
Python
91 lines
2.8 KiB
Python
#!/usr/bin/env python2.7
|
|
'''
|
|
Copyright 2018 Google LLC
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
'''
|
|
|
|
from __future__ import print_function
|
|
|
|
import logging
|
|
import grpc
|
|
import numpy as np
|
|
from PIL import Image
|
|
from tensorflow.examples.tutorials.mnist import input_data
|
|
from proto import prediction_pb2
|
|
from proto import prediction_pb2_grpc
|
|
|
|
|
|
def get_prediction(image, server_host='127.0.0.1', server_port=8080,
|
|
deployment_name="server"):
|
|
"""
|
|
Retrieve a prediction from a TensorFlow model server
|
|
|
|
:param image: a MNIST image represented as a 1x784 array
|
|
:param server_host: the address of the Seldon server
|
|
:param server_port: the port used by the server
|
|
:param deployment_name: the name of the deployment
|
|
:return 0: the integer predicted in the MNIST image
|
|
"""
|
|
|
|
try:
|
|
# build request
|
|
chosen = 0
|
|
data = image[chosen].reshape(784)
|
|
datadef = prediction_pb2.DefaultData(
|
|
tensor=prediction_pb2.Tensor(
|
|
shape=data.shape,
|
|
values=data
|
|
)
|
|
)
|
|
|
|
# retrieve results
|
|
request = prediction_pb2.SeldonMessage(data=datadef)
|
|
logging.info("connecting to:%s:%i", server_host, server_port)
|
|
channel = grpc.insecure_channel(server_host + ":" + str(server_port))
|
|
stub = prediction_pb2_grpc.SeldonStub(channel)
|
|
metadata = [('seldon', deployment_name)]
|
|
response = stub.Predict(request=request, metadata=metadata)
|
|
except IOError as e:
|
|
# server connection failed
|
|
logging.error("Could Not Connect to Server: %s", str(e))
|
|
return response.data.tensor.values
|
|
|
|
|
|
def random_mnist(save_path=None):
|
|
"""
|
|
Pull a random image out of the MNIST test dataset
|
|
Optionally save the selected image as a file to disk
|
|
|
|
:param save_path: the path to save the file to. If None, file is not saved
|
|
:return 0: a 1x784 representation of the MNIST image
|
|
:return 1: the ground truth label associated with the image
|
|
:return 2: a bool representing whether the image file was saved to disk
|
|
"""
|
|
|
|
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
|
|
batch_size = 1
|
|
batch_x, batch_y = mnist.test.next_batch(batch_size)
|
|
|
|
saved = False
|
|
if save_path is not None:
|
|
# save image file to disk
|
|
try:
|
|
data = (batch_x * 255).astype(np.uint8).reshape(28, 28)
|
|
img = Image.fromarray(data, 'L')
|
|
img.save(save_path)
|
|
saved = True
|
|
except IOError:
|
|
pass
|
|
return batch_x, np.argmax(batch_y), saved
|