Cherry pick changes to PredictionDoFn (#226)

* Cherry pick changes to PredictionDoFn

* Disable lint checks for cherry picked file

* Update TODO and notebook install instructions

* Restore CUSTOM_COMMANDS todo
This commit is contained in:
Sanyam Kapoor 2018-08-15 06:21:00 -07:00 committed by k8s-ci-robot
parent 18829159b0
commit 4e015e76a3
4 changed files with 219 additions and 9 deletions

View File

@ -30,9 +30,7 @@
"metadata": {},
"outputs": [],
"source": [
"# FIXME(sanyamkapoor): The Kubeflow Batch Prediction dependency is installed from a fork for reasons in\n",
"# kubeflow/batch-predict#9 and corresponding issue kubeflow/batch-predict#10\n",
"! pip2 install https://github.com/activatedgeek/batch-predict/tarball/fix-value-provider\n",
"! pip2 install https://github.com/kubeflow/batch-predict/tarball/master\n",
"\n",
"! pip2 install -r src/requirements.txt"
]

View File

@ -0,0 +1,215 @@
# pylint: skip-file
"""A cherry-pick of PredictionDoFn from Kubeflow Batch Predict.
TODO: This file should be retired once kubeflow/batch-predict#10 is resolved.
"""
import datetime
import json
import logging
import threading
import traceback
import apache_beam as beam
from apache_beam.options.value_provider import ValueProvider
from apache_beam.utils.windowed_value import WindowedValue
from kubeflow_batch_predict import prediction as mlprediction
from kubeflow_batch_predict.dataflow import _aggregators as aggregators
from kubeflow_batch_predict.dataflow import _error_filter as error_filter
from tensorflow.python.saved_model import tag_constants
DEFAULT_BATCH_SIZE = 1000 # 1K instances per batch when evaluating models.
LOG_SIZE_LIMIT = 1000 # 1K bytes for the input field in log entries.
LOG_NAME = "worker"
_METRICS_NAMESPACE = "cloud_ml_batch_predict"
class PredictionDoFn(beam.DoFn):
"""A DoFn class loading the model to create session and performing prediction.
The input PCollection consists of a list of strings from the input files.
The DoFn first loads model from a given path where meta graph data and
checkpoint data are exported to. Then if the there is only one string input
tensor or the model needs to preprocess the input, it directly passes the
data to prediction. Otherwise, it tries to load the data into JSON.
Then it batches the inputs of each instance into one feed_dict. After that, it
runs session and predicts the interesting values for all the instances.
Finally it emits the prediction result for each individual instance.
"""
class _ModelState(object):
"""Atomic representation of the in-memory state of the model."""
def __init__(self,
model_dir,
tags,
framework=mlprediction.TENSORFLOW_FRAMEWORK_NAME):
self.model_dir = model_dir
client = mlprediction.create_client(framework, model_dir, tags)
self.model = mlprediction.create_model(client, model_dir, framework)
_thread_local = threading.local()
def __init__(self,
aggregator_dict=None,
user_project_id="",
user_job_id="",
tags=tag_constants.SERVING,
signature_name="",
skip_preprocessing=False,
target="",
config=None,
instances_key='instances',
predictions_key='predictions',
framework=mlprediction.TENSORFLOW_FRAMEWORK_NAME):
"""Constructor of Prediction beam.DoFn class.
Args:
aggregator_dict: A dict of aggregators containing maps from counter name
to the aggregator.
user_project_id: A string. The project to which the logs will be sent.
user_job_id: A string. The job to which the logs will be sent.
tags: A comma-separated string that contains a list of tags for serving
graph.
signature_name: A string to map into the signature map to get the serving
signature.
skip_preprocessing: bool whether to skip preprocessing even when
the metadata.yaml/metadata.json file exists.
target: The execution engine to connect to. See target in tf.Session(). In
most cases, users should not set the target.
config: A ConfigProto proto with configuration options. See config in
tf.Session()
framework: The framework used to train this model. Available frameworks:
"TENSORFLOW", "SCIKIT_LEARN", and "XGBOOST".
Side Inputs:
model_dir: The directory containing the model to load and the
checkpoint files to restore the session.
"""
self._target = target
self._user_project_id = user_project_id
self._user_job_id = user_job_id
self._tags = tags
self._signature_name = signature_name
self._skip_preprocessing = skip_preprocessing
self._config = config
self._aggregator_dict = aggregator_dict
self._model_state = None
self._instances_key = instances_key
self._predictions_key = predictions_key
self._tag_list = []
self._framework = framework
# Metrics.
self._model_load_seconds_distribution = beam.metrics.Metrics.distribution(
_METRICS_NAMESPACE, "model_load_seconds")
self._batch_process_ms_distribution = beam.metrics.Metrics.distribution(
_METRICS_NAMESPACE, "batch_process_milliseconds")
def start_bundle(self):
if isinstance(self._signature_name, ValueProvider):
self._signature_name = self._signature_name.get()
if isinstance(self._tags, ValueProvider):
self._tags = self._tags.get()
self._tag_list = self._tags.split(",")
def process(self, element, model_dir):
try:
if isinstance(model_dir, ValueProvider):
model_dir = model_dir.get()
framework = self._framework.get() if isinstance(self._framework, ValueProvider) else self._framework
if self._model_state is None:
if (getattr(self._thread_local, "model_state", None) is None or
self._thread_local.model_state.model_dir != model_dir):
start = datetime.datetime.now()
self._thread_local.model_state = self._ModelState(
model_dir, self._tag_list, framework)
self._model_load_seconds_distribution.update(
int((datetime.datetime.now() - start).total_seconds()))
self._model_state = self._thread_local.model_state
else:
assert self._model_state.model_dir == model_dir
# Measure the processing time.
start = datetime.datetime.now()
# Try to load it.
if framework == mlprediction.TENSORFLOW_FRAMEWORK_NAME:
# Even though predict() checks the signature in TensorFlowModel,
# we need to duplicate this check here to determine the single string
# input case.
self._signature_name, signature = self._model_state.model.get_signature(
self._signature_name)
if self._model_state.model.is_single_string_input(signature):
loaded_data = element
else:
loaded_data = [json.loads(d) for d in element]
else:
loaded_data = [json.loads(d) for d in element]
loaded_data = mlprediction.decode_base64(loaded_data)
instances = loaded_data[self._instances_key]
# Actual prediction occurs.
kwargs = {}
if self._signature_name:
kwargs = {"signature_name": self._signature_name}
inputs, predictions = self._model_state.model.predict(instances, **kwargs)
predictions = list(predictions)
if self._aggregator_dict:
self._aggregator_dict[aggregators.AggregatorName.ML_PREDICTIONS].inc(
len(predictions))
# For successful processing, record the time.
td = datetime.datetime.now() - start
time_delta_in_ms = int(
td.microseconds / 10**3 + (td.seconds + td.days * 24 * 3600) * 10**3)
self._batch_process_ms_distribution.update(time_delta_in_ms)
loaded_data[self._predictions_key] = predictions
yield loaded_data
except mlprediction.PredictionError as e:
logging.error("Got a known exception: [%s]\n%s", str(e),
traceback.format_exc())
clean_error_detail = error_filter.filter_tensorflow_error(e.error_detail)
# Track in the counter.
if self._aggregator_dict:
counter_name = aggregators.AggregatorName.ML_FAILED_PREDICTIONS
self._aggregator_dict[counter_name].inc(len(element))
# reraise failure to load model as permanent exception to end dataflow job
if e.error_code == mlprediction.PredictionError.FAILED_TO_LOAD_MODEL:
raise beam.utils.retry.PermanentException(clean_error_detail)
try:
yield beam.pvalue.TaggedOutput("errors", (clean_error_detail,
element))
except AttributeError:
yield beam.pvalue.SideOutputValue("errors", (clean_error_detail,
element))
except Exception as e: # pylint: disable=broad-except
logging.error("Got an unknown exception: [%s].", traceback.format_exc())
# Track in the counter.
if self._aggregator_dict:
counter_name = aggregators.AggregatorName.ML_FAILED_PREDICTIONS
self._aggregator_dict[counter_name].inc(len(element))
try:
yield beam.pvalue.TaggedOutput("errors", (str(e), element))
except AttributeError:
yield beam.pvalue.SideOutputValue("errors", (str(e), element))

View File

@ -1,6 +1,6 @@
import apache_beam as beam
import kubeflow_batch_predict.dataflow.batch_prediction as batch_prediction
import code_search.dataflow.do_fns.prediction_do_fn as pred
import code_search.dataflow.do_fns.function_embeddings as func_embeddings
import code_search.dataflow.transforms.github_bigquery as github_bigquery
@ -27,7 +27,7 @@ class FunctionEmbeddings(beam.PTransform):
batch_predict = (input_or_inputs
| "Encoded Function Tokens" >> beam.ParDo(func_embeddings.EncodeFunctionTokens(
self.problem, self.data_dir))
| "Compute Function Embeddings" >> beam.ParDo(batch_prediction.PredictionDoFn(),
| "Compute Function Embeddings" >> beam.ParDo(pred.PredictionDoFn(),
self.saved_model_dir).with_outputs('err',
main='main')
)

View File

@ -10,11 +10,8 @@ with open('requirements.txt', 'r') as f:
CUSTOM_COMMANDS = [
['python', '-m', 'spacy', 'download', 'en'],
##
# TODO(sanyamkapoor): This isn't ideal but no other way for a seamless install right now.
# This currently uses a fork due to API limitations (See kubeflow/batch-predict#10). The
# API limitations have a workaround via kubeflow/batch-predict#9.
['pip', 'install', 'https://github.com/activatedgeek/batch-predict/tarball/fix-value-provider']
['pip', 'install', 'https://github.com/kubeflow/batch-predict/tarball/master']
]