Fixed distributed training for LINEAR model (#130)

* Fixed distributed training for LINEAR model

* Make line shorter & remove pylint disable unused argument
This commit is contained in:
Maerville 2018-06-13 11:57:28 -07:00 committed by k8s-ci-robot
parent 3bff3339f7
commit ce2f1db11e
1 changed files with 17 additions and 19 deletions

View File

@ -122,7 +122,7 @@ def linear_serving_input_receiver_fn():
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
def main(unused_args): # pylint: disable=unused-argument
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
# Download and load MNIST dataset.
@ -139,40 +139,38 @@ def main(unused_args): # pylint: disable=unused-argument
num_epochs=1,
shuffle=False)
training_config = tf.estimator.RunConfig(
model_dir=TF_MODEL_DIR, save_summary_steps=100, save_checkpoints_steps=1000)
if TF_MODEL_TYPE == "LINEAR":
# Linear classifier.
feature_columns = [
tf.feature_column.numeric_column(
X_FEATURE, shape=mnist.train.images.shape[1:])]
classifier = tf.estimator.LinearClassifier(
feature_columns=feature_columns, n_classes=N_DIGITS, model_dir=TF_MODEL_DIR)
classifier.train(input_fn=train_input_fn, steps=TF_TRAIN_STEPS)
scores = classifier.evaluate(input_fn=test_input_fn)
print('Accuracy (LinearClassifier): {0:f}'.format(scores['accuracy']))
# FIXME This doesn't seem to work. sticking to CNN for the example for now.
classifier.export_savedmodel(
TF_EXPORT_DIR, linear_serving_input_receiver_fn)
feature_columns=feature_columns, n_classes=N_DIGITS,
model_dir=TF_MODEL_DIR, config=training_config)
export_final = tf.estimator.FinalExporter(
TF_EXPORT_DIR, serving_input_receiver_fn=cnn_serving_input_receiver_fn)
elif TF_MODEL_TYPE == "CNN":
# Convolutional network
training_config = tf.estimator.RunConfig(
model_dir=TF_MODEL_DIR, save_summary_steps=100, save_checkpoints_steps=1000)
classifier = tf.estimator.Estimator(
model_fn=conv_model, model_dir=TF_MODEL_DIR, config=training_config)
export_final = tf.estimator.FinalExporter(
TF_EXPORT_DIR, serving_input_receiver_fn=cnn_serving_input_receiver_fn)
train_spec = tf.estimator.TrainSpec(
input_fn=train_input_fn, max_steps=TF_TRAIN_STEPS)
eval_spec = tf.estimator.EvalSpec(input_fn=test_input_fn,
steps=1,
exporters=export_final,
throttle_secs=1,
start_delay_secs=1)
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
else:
print("No such model type: %s" % TF_MODEL_TYPE)
sys.exit(1)
train_spec = tf.estimator.TrainSpec(
input_fn=train_input_fn, max_steps=TF_TRAIN_STEPS)
eval_spec = tf.estimator.EvalSpec(input_fn=test_input_fn,
steps=1,
exporters=export_final,
throttle_secs=1,
start_delay_secs=1)
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
if __name__ == '__main__':
tf.app.run()