mirror of https://github.com/kubeflow/examples.git
Fixed distributed training for LINEAR model (#130)
* Fixed distributed training for LINEAR model * Make line shorter & remove pylint disable unused argument
This commit is contained in:
parent
3bff3339f7
commit
ce2f1db11e
|
|
@ -122,7 +122,7 @@ def linear_serving_input_receiver_fn():
|
|||
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
|
||||
|
||||
|
||||
def main(unused_args): # pylint: disable=unused-argument
|
||||
def main(_):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
# Download and load MNIST dataset.
|
||||
|
|
@ -139,40 +139,38 @@ def main(unused_args): # pylint: disable=unused-argument
|
|||
num_epochs=1,
|
||||
shuffle=False)
|
||||
|
||||
training_config = tf.estimator.RunConfig(
|
||||
model_dir=TF_MODEL_DIR, save_summary_steps=100, save_checkpoints_steps=1000)
|
||||
|
||||
if TF_MODEL_TYPE == "LINEAR":
|
||||
# Linear classifier.
|
||||
feature_columns = [
|
||||
tf.feature_column.numeric_column(
|
||||
X_FEATURE, shape=mnist.train.images.shape[1:])]
|
||||
|
||||
classifier = tf.estimator.LinearClassifier(
|
||||
feature_columns=feature_columns, n_classes=N_DIGITS, model_dir=TF_MODEL_DIR)
|
||||
classifier.train(input_fn=train_input_fn, steps=TF_TRAIN_STEPS)
|
||||
scores = classifier.evaluate(input_fn=test_input_fn)
|
||||
print('Accuracy (LinearClassifier): {0:f}'.format(scores['accuracy']))
|
||||
# FIXME This doesn't seem to work. sticking to CNN for the example for now.
|
||||
classifier.export_savedmodel(
|
||||
TF_EXPORT_DIR, linear_serving_input_receiver_fn)
|
||||
feature_columns=feature_columns, n_classes=N_DIGITS,
|
||||
model_dir=TF_MODEL_DIR, config=training_config)
|
||||
export_final = tf.estimator.FinalExporter(
|
||||
TF_EXPORT_DIR, serving_input_receiver_fn=cnn_serving_input_receiver_fn)
|
||||
|
||||
elif TF_MODEL_TYPE == "CNN":
|
||||
# Convolutional network
|
||||
training_config = tf.estimator.RunConfig(
|
||||
model_dir=TF_MODEL_DIR, save_summary_steps=100, save_checkpoints_steps=1000)
|
||||
classifier = tf.estimator.Estimator(
|
||||
model_fn=conv_model, model_dir=TF_MODEL_DIR, config=training_config)
|
||||
export_final = tf.estimator.FinalExporter(
|
||||
TF_EXPORT_DIR, serving_input_receiver_fn=cnn_serving_input_receiver_fn)
|
||||
train_spec = tf.estimator.TrainSpec(
|
||||
input_fn=train_input_fn, max_steps=TF_TRAIN_STEPS)
|
||||
eval_spec = tf.estimator.EvalSpec(input_fn=test_input_fn,
|
||||
steps=1,
|
||||
exporters=export_final,
|
||||
throttle_secs=1,
|
||||
start_delay_secs=1)
|
||||
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
|
||||
else:
|
||||
print("No such model type: %s" % TF_MODEL_TYPE)
|
||||
sys.exit(1)
|
||||
|
||||
train_spec = tf.estimator.TrainSpec(
|
||||
input_fn=train_input_fn, max_steps=TF_TRAIN_STEPS)
|
||||
eval_spec = tf.estimator.EvalSpec(input_fn=test_input_fn,
|
||||
steps=1,
|
||||
exporters=export_final,
|
||||
throttle_secs=1,
|
||||
start_delay_secs=1)
|
||||
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.app.run()
|
||||
|
|
|
|||
Loading…
Reference in New Issue