mirror of https://github.com/kubeflow/examples.git
Cherry pick changes to PredictionDoFn (#226)
* Cherry pick changes to PredictionDoFn * Disable lint checks for cherry picked file * Update TODO and notebook install instructions * Restore CUSTOM_COMMANDS todo
This commit is contained in:
parent
18829159b0
commit
4e015e76a3
|
|
@ -30,9 +30,7 @@
|
|||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# FIXME(sanyamkapoor): The Kubeflow Batch Prediction dependency is installed from a fork for reasons in\n",
|
||||
"# kubeflow/batch-predict#9 and corresponding issue kubeflow/batch-predict#10\n",
|
||||
"! pip2 install https://github.com/activatedgeek/batch-predict/tarball/fix-value-provider\n",
|
||||
"! pip2 install https://github.com/kubeflow/batch-predict/tarball/master\n",
|
||||
"\n",
|
||||
"! pip2 install -r src/requirements.txt"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,215 @@
|
|||
# pylint: skip-file
|
||||
|
||||
"""A cherry-pick of PredictionDoFn from Kubeflow Batch Predict.
|
||||
|
||||
TODO: This file should be retired once kubeflow/batch-predict#10 is resolved.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import traceback
|
||||
|
||||
|
||||
import apache_beam as beam
|
||||
from apache_beam.options.value_provider import ValueProvider
|
||||
from apache_beam.utils.windowed_value import WindowedValue
|
||||
|
||||
from kubeflow_batch_predict import prediction as mlprediction
|
||||
from kubeflow_batch_predict.dataflow import _aggregators as aggregators
|
||||
|
||||
|
||||
from kubeflow_batch_predict.dataflow import _error_filter as error_filter
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
|
||||
DEFAULT_BATCH_SIZE = 1000 # 1K instances per batch when evaluating models.
|
||||
LOG_SIZE_LIMIT = 1000 # 1K bytes for the input field in log entries.
|
||||
LOG_NAME = "worker"
|
||||
_METRICS_NAMESPACE = "cloud_ml_batch_predict"
|
||||
|
||||
|
||||
class PredictionDoFn(beam.DoFn):
|
||||
"""A DoFn class loading the model to create session and performing prediction.
|
||||
|
||||
The input PCollection consists of a list of strings from the input files.
|
||||
|
||||
The DoFn first loads model from a given path where meta graph data and
|
||||
checkpoint data are exported to. Then if the there is only one string input
|
||||
tensor or the model needs to preprocess the input, it directly passes the
|
||||
data to prediction. Otherwise, it tries to load the data into JSON.
|
||||
|
||||
Then it batches the inputs of each instance into one feed_dict. After that, it
|
||||
runs session and predicts the interesting values for all the instances.
|
||||
Finally it emits the prediction result for each individual instance.
|
||||
"""
|
||||
|
||||
class _ModelState(object):
|
||||
"""Atomic representation of the in-memory state of the model."""
|
||||
|
||||
def __init__(self,
|
||||
model_dir,
|
||||
tags,
|
||||
framework=mlprediction.TENSORFLOW_FRAMEWORK_NAME):
|
||||
self.model_dir = model_dir
|
||||
client = mlprediction.create_client(framework, model_dir, tags)
|
||||
self.model = mlprediction.create_model(client, model_dir, framework)
|
||||
|
||||
_thread_local = threading.local()
|
||||
|
||||
def __init__(self,
|
||||
aggregator_dict=None,
|
||||
user_project_id="",
|
||||
user_job_id="",
|
||||
tags=tag_constants.SERVING,
|
||||
signature_name="",
|
||||
skip_preprocessing=False,
|
||||
target="",
|
||||
config=None,
|
||||
instances_key='instances',
|
||||
predictions_key='predictions',
|
||||
framework=mlprediction.TENSORFLOW_FRAMEWORK_NAME):
|
||||
"""Constructor of Prediction beam.DoFn class.
|
||||
|
||||
Args:
|
||||
aggregator_dict: A dict of aggregators containing maps from counter name
|
||||
to the aggregator.
|
||||
user_project_id: A string. The project to which the logs will be sent.
|
||||
user_job_id: A string. The job to which the logs will be sent.
|
||||
tags: A comma-separated string that contains a list of tags for serving
|
||||
graph.
|
||||
signature_name: A string to map into the signature map to get the serving
|
||||
signature.
|
||||
skip_preprocessing: bool whether to skip preprocessing even when
|
||||
the metadata.yaml/metadata.json file exists.
|
||||
target: The execution engine to connect to. See target in tf.Session(). In
|
||||
most cases, users should not set the target.
|
||||
config: A ConfigProto proto with configuration options. See config in
|
||||
tf.Session()
|
||||
framework: The framework used to train this model. Available frameworks:
|
||||
"TENSORFLOW", "SCIKIT_LEARN", and "XGBOOST".
|
||||
|
||||
Side Inputs:
|
||||
model_dir: The directory containing the model to load and the
|
||||
checkpoint files to restore the session.
|
||||
"""
|
||||
self._target = target
|
||||
self._user_project_id = user_project_id
|
||||
self._user_job_id = user_job_id
|
||||
self._tags = tags
|
||||
self._signature_name = signature_name
|
||||
self._skip_preprocessing = skip_preprocessing
|
||||
self._config = config
|
||||
self._aggregator_dict = aggregator_dict
|
||||
self._model_state = None
|
||||
self._instances_key = instances_key
|
||||
self._predictions_key = predictions_key
|
||||
|
||||
|
||||
self._tag_list = []
|
||||
self._framework = framework
|
||||
# Metrics.
|
||||
self._model_load_seconds_distribution = beam.metrics.Metrics.distribution(
|
||||
_METRICS_NAMESPACE, "model_load_seconds")
|
||||
self._batch_process_ms_distribution = beam.metrics.Metrics.distribution(
|
||||
_METRICS_NAMESPACE, "batch_process_milliseconds")
|
||||
|
||||
def start_bundle(self):
|
||||
if isinstance(self._signature_name, ValueProvider):
|
||||
self._signature_name = self._signature_name.get()
|
||||
|
||||
if isinstance(self._tags, ValueProvider):
|
||||
self._tags = self._tags.get()
|
||||
|
||||
self._tag_list = self._tags.split(",")
|
||||
|
||||
def process(self, element, model_dir):
|
||||
try:
|
||||
if isinstance(model_dir, ValueProvider):
|
||||
model_dir = model_dir.get()
|
||||
framework = self._framework.get() if isinstance(self._framework, ValueProvider) else self._framework
|
||||
if self._model_state is None:
|
||||
if (getattr(self._thread_local, "model_state", None) is None or
|
||||
self._thread_local.model_state.model_dir != model_dir):
|
||||
start = datetime.datetime.now()
|
||||
self._thread_local.model_state = self._ModelState(
|
||||
model_dir, self._tag_list, framework)
|
||||
self._model_load_seconds_distribution.update(
|
||||
int((datetime.datetime.now() - start).total_seconds()))
|
||||
self._model_state = self._thread_local.model_state
|
||||
else:
|
||||
assert self._model_state.model_dir == model_dir
|
||||
|
||||
# Measure the processing time.
|
||||
start = datetime.datetime.now()
|
||||
# Try to load it.
|
||||
if framework == mlprediction.TENSORFLOW_FRAMEWORK_NAME:
|
||||
# Even though predict() checks the signature in TensorFlowModel,
|
||||
# we need to duplicate this check here to determine the single string
|
||||
# input case.
|
||||
self._signature_name, signature = self._model_state.model.get_signature(
|
||||
self._signature_name)
|
||||
if self._model_state.model.is_single_string_input(signature):
|
||||
loaded_data = element
|
||||
else:
|
||||
loaded_data = [json.loads(d) for d in element]
|
||||
else:
|
||||
loaded_data = [json.loads(d) for d in element]
|
||||
loaded_data = mlprediction.decode_base64(loaded_data)
|
||||
instances = loaded_data[self._instances_key]
|
||||
|
||||
# Actual prediction occurs.
|
||||
kwargs = {}
|
||||
if self._signature_name:
|
||||
kwargs = {"signature_name": self._signature_name}
|
||||
inputs, predictions = self._model_state.model.predict(instances, **kwargs)
|
||||
|
||||
predictions = list(predictions)
|
||||
|
||||
if self._aggregator_dict:
|
||||
self._aggregator_dict[aggregators.AggregatorName.ML_PREDICTIONS].inc(
|
||||
len(predictions))
|
||||
|
||||
# For successful processing, record the time.
|
||||
td = datetime.datetime.now() - start
|
||||
time_delta_in_ms = int(
|
||||
td.microseconds / 10**3 + (td.seconds + td.days * 24 * 3600) * 10**3)
|
||||
self._batch_process_ms_distribution.update(time_delta_in_ms)
|
||||
|
||||
loaded_data[self._predictions_key] = predictions
|
||||
yield loaded_data
|
||||
|
||||
except mlprediction.PredictionError as e:
|
||||
logging.error("Got a known exception: [%s]\n%s", str(e),
|
||||
traceback.format_exc())
|
||||
clean_error_detail = error_filter.filter_tensorflow_error(e.error_detail)
|
||||
|
||||
|
||||
# Track in the counter.
|
||||
if self._aggregator_dict:
|
||||
counter_name = aggregators.AggregatorName.ML_FAILED_PREDICTIONS
|
||||
self._aggregator_dict[counter_name].inc(len(element))
|
||||
|
||||
# reraise failure to load model as permanent exception to end dataflow job
|
||||
if e.error_code == mlprediction.PredictionError.FAILED_TO_LOAD_MODEL:
|
||||
raise beam.utils.retry.PermanentException(clean_error_detail)
|
||||
try:
|
||||
yield beam.pvalue.TaggedOutput("errors", (clean_error_detail,
|
||||
element))
|
||||
except AttributeError:
|
||||
yield beam.pvalue.SideOutputValue("errors", (clean_error_detail,
|
||||
element))
|
||||
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
logging.error("Got an unknown exception: [%s].", traceback.format_exc())
|
||||
|
||||
|
||||
# Track in the counter.
|
||||
if self._aggregator_dict:
|
||||
counter_name = aggregators.AggregatorName.ML_FAILED_PREDICTIONS
|
||||
self._aggregator_dict[counter_name].inc(len(element))
|
||||
|
||||
try:
|
||||
yield beam.pvalue.TaggedOutput("errors", (str(e), element))
|
||||
except AttributeError:
|
||||
yield beam.pvalue.SideOutputValue("errors", (str(e), element))
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import apache_beam as beam
|
||||
import kubeflow_batch_predict.dataflow.batch_prediction as batch_prediction
|
||||
|
||||
import code_search.dataflow.do_fns.prediction_do_fn as pred
|
||||
import code_search.dataflow.do_fns.function_embeddings as func_embeddings
|
||||
import code_search.dataflow.transforms.github_bigquery as github_bigquery
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ class FunctionEmbeddings(beam.PTransform):
|
|||
batch_predict = (input_or_inputs
|
||||
| "Encoded Function Tokens" >> beam.ParDo(func_embeddings.EncodeFunctionTokens(
|
||||
self.problem, self.data_dir))
|
||||
| "Compute Function Embeddings" >> beam.ParDo(batch_prediction.PredictionDoFn(),
|
||||
| "Compute Function Embeddings" >> beam.ParDo(pred.PredictionDoFn(),
|
||||
self.saved_model_dir).with_outputs('err',
|
||||
main='main')
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,11 +10,8 @@ with open('requirements.txt', 'r') as f:
|
|||
|
||||
CUSTOM_COMMANDS = [
|
||||
['python', '-m', 'spacy', 'download', 'en'],
|
||||
##
|
||||
# TODO(sanyamkapoor): This isn't ideal but no other way for a seamless install right now.
|
||||
# This currently uses a fork due to API limitations (See kubeflow/batch-predict#10). The
|
||||
# API limitations have a workaround via kubeflow/batch-predict#9.
|
||||
['pip', 'install', 'https://github.com/activatedgeek/batch-predict/tarball/fix-value-provider']
|
||||
['pip', 'install', 'https://github.com/kubeflow/batch-predict/tarball/master']
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue