hub/examples/text_embeddings/export_test.py

128 lines
4.7 KiB
Python

# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for text embedding exporting tool."""
import logging
import os
from distutils.version import LooseVersion
import numpy as np
import tensorflow.compat.v1 as tf
import tensorflow_hub as hub
from examples.text_embeddings import export
_MOCK_EMBEDDING = "\n".join(
["cat 1.11 2.56 3.45", "dog 1 2 3", "mouse 0.5 0.1 0.6"])
class ExportTokenEmbeddingTest(tf.test.TestCase):
def setUp(self):
self._embedding_file_path = os.path.join(self.get_temp_dir(),
"mock_embedding_file.txt")
with tf.gfile.GFile(self._embedding_file_path, mode="w") as f:
f.write(_MOCK_EMBEDDING)
def testEmbeddingLoaded(self):
vocabulary, embeddings = export.load(self._embedding_file_path,
export.parse_line)
self.assertEqual((3,), np.shape(vocabulary))
self.assertEqual((3, 3), np.shape(embeddings))
def testExportTokenEmbeddingModule(self):
export.export_module_from_file(
embedding_file=self._embedding_file_path,
export_path=self.get_temp_dir(),
parse_line_fn=export.parse_line,
num_oov_buckets=1,
preprocess_text=False)
with tf.Graph().as_default():
hub_module = hub.Module(self.get_temp_dir())
tokens = tf.constant(["cat", "lizard", "dog"])
embeddings = hub_module(tokens)
with tf.Session() as session:
session.run(tf.tables_initializer())
session.run(tf.global_variables_initializer())
self.assertAllClose(
session.run(embeddings),
[[1.11, 2.56, 3.45], [0.0, 0.0, 0.0], [1.0, 2.0, 3.0]])
def testExportFulltextEmbeddingModule(self):
export.export_module_from_file(
embedding_file=self._embedding_file_path,
export_path=self.get_temp_dir(),
parse_line_fn=export.parse_line,
num_oov_buckets=1,
preprocess_text=True)
with tf.Graph().as_default():
hub_module = hub.Module(self.get_temp_dir())
tokens = tf.constant(["cat", "cat cat", "lizard. dog", "cat? dog", ""])
embeddings = hub_module(tokens)
with tf.Session() as session:
session.run(tf.tables_initializer())
session.run(tf.global_variables_initializer())
self.assertAllClose(
session.run(embeddings),
[[1.11, 2.56, 3.45], [1.57, 3.62, 4.88], [0.70, 1.41, 2.12],
[1.49, 3.22, 4.56], [0.0, 0.0, 0.0]],
rtol=0.02)
def testEmptyInput(self):
export.export_module_from_file(
embedding_file=self._embedding_file_path,
export_path=self.get_temp_dir(),
parse_line_fn=export.parse_line,
num_oov_buckets=1,
preprocess_text=True)
with tf.Graph().as_default():
hub_module = hub.Module(self.get_temp_dir())
tokens = tf.constant(["", "", ""])
embeddings = hub_module(tokens)
with tf.Session() as session:
session.run(tf.tables_initializer())
session.run(tf.global_variables_initializer())
self.assertAllClose(
session.run(embeddings),
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
rtol=0.02)
def testEmptyLeading(self):
export.export_module_from_file(
embedding_file=self._embedding_file_path,
export_path=self.get_temp_dir(),
parse_line_fn=export.parse_line,
num_oov_buckets=1,
preprocess_text=True)
with tf.Graph().as_default():
hub_module = hub.Module(self.get_temp_dir())
tokens = tf.constant(["", "cat dog"])
embeddings = hub_module(tokens)
with tf.Session() as session:
session.run(tf.tables_initializer())
session.run(tf.global_variables_initializer())
self.assertAllClose(
session.run(embeddings),
[[0.0, 0.0, 0.0], [1.49, 3.22, 4.56]],
rtol=0.02)
if __name__ == "__main__":
# This test is only supported in graph mode.
if tf.executing_eagerly():
logging.warning("Skipping running tests for TF Version: %s running eagerly.",
tf.__version__)
else:
tf.test.main()