mirror of https://github.com/tensorflow/hub.git
128 lines
4.7 KiB
Python
128 lines
4.7 KiB
Python
# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for text embedding exporting tool."""
|
|
|
|
import logging
|
|
import os
|
|
from distutils.version import LooseVersion
|
|
import numpy as np
|
|
import tensorflow.compat.v1 as tf
|
|
import tensorflow_hub as hub
|
|
|
|
from examples.text_embeddings import export
|
|
|
|
_MOCK_EMBEDDING = "\n".join(
|
|
["cat 1.11 2.56 3.45", "dog 1 2 3", "mouse 0.5 0.1 0.6"])
|
|
|
|
|
|
class ExportTokenEmbeddingTest(tf.test.TestCase):
|
|
|
|
def setUp(self):
|
|
self._embedding_file_path = os.path.join(self.get_temp_dir(),
|
|
"mock_embedding_file.txt")
|
|
with tf.gfile.GFile(self._embedding_file_path, mode="w") as f:
|
|
f.write(_MOCK_EMBEDDING)
|
|
|
|
def testEmbeddingLoaded(self):
|
|
vocabulary, embeddings = export.load(self._embedding_file_path,
|
|
export.parse_line)
|
|
self.assertEqual((3,), np.shape(vocabulary))
|
|
self.assertEqual((3, 3), np.shape(embeddings))
|
|
|
|
def testExportTokenEmbeddingModule(self):
|
|
export.export_module_from_file(
|
|
embedding_file=self._embedding_file_path,
|
|
export_path=self.get_temp_dir(),
|
|
parse_line_fn=export.parse_line,
|
|
num_oov_buckets=1,
|
|
preprocess_text=False)
|
|
with tf.Graph().as_default():
|
|
hub_module = hub.Module(self.get_temp_dir())
|
|
tokens = tf.constant(["cat", "lizard", "dog"])
|
|
embeddings = hub_module(tokens)
|
|
with tf.Session() as session:
|
|
session.run(tf.tables_initializer())
|
|
session.run(tf.global_variables_initializer())
|
|
self.assertAllClose(
|
|
session.run(embeddings),
|
|
[[1.11, 2.56, 3.45], [0.0, 0.0, 0.0], [1.0, 2.0, 3.0]])
|
|
|
|
def testExportFulltextEmbeddingModule(self):
|
|
export.export_module_from_file(
|
|
embedding_file=self._embedding_file_path,
|
|
export_path=self.get_temp_dir(),
|
|
parse_line_fn=export.parse_line,
|
|
num_oov_buckets=1,
|
|
preprocess_text=True)
|
|
with tf.Graph().as_default():
|
|
hub_module = hub.Module(self.get_temp_dir())
|
|
tokens = tf.constant(["cat", "cat cat", "lizard. dog", "cat? dog", ""])
|
|
embeddings = hub_module(tokens)
|
|
with tf.Session() as session:
|
|
session.run(tf.tables_initializer())
|
|
session.run(tf.global_variables_initializer())
|
|
self.assertAllClose(
|
|
session.run(embeddings),
|
|
[[1.11, 2.56, 3.45], [1.57, 3.62, 4.88], [0.70, 1.41, 2.12],
|
|
[1.49, 3.22, 4.56], [0.0, 0.0, 0.0]],
|
|
rtol=0.02)
|
|
|
|
def testEmptyInput(self):
|
|
export.export_module_from_file(
|
|
embedding_file=self._embedding_file_path,
|
|
export_path=self.get_temp_dir(),
|
|
parse_line_fn=export.parse_line,
|
|
num_oov_buckets=1,
|
|
preprocess_text=True)
|
|
with tf.Graph().as_default():
|
|
hub_module = hub.Module(self.get_temp_dir())
|
|
tokens = tf.constant(["", "", ""])
|
|
embeddings = hub_module(tokens)
|
|
with tf.Session() as session:
|
|
session.run(tf.tables_initializer())
|
|
session.run(tf.global_variables_initializer())
|
|
self.assertAllClose(
|
|
session.run(embeddings),
|
|
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
|
rtol=0.02)
|
|
|
|
def testEmptyLeading(self):
|
|
export.export_module_from_file(
|
|
embedding_file=self._embedding_file_path,
|
|
export_path=self.get_temp_dir(),
|
|
parse_line_fn=export.parse_line,
|
|
num_oov_buckets=1,
|
|
preprocess_text=True)
|
|
with tf.Graph().as_default():
|
|
hub_module = hub.Module(self.get_temp_dir())
|
|
tokens = tf.constant(["", "cat dog"])
|
|
embeddings = hub_module(tokens)
|
|
with tf.Session() as session:
|
|
session.run(tf.tables_initializer())
|
|
session.run(tf.global_variables_initializer())
|
|
self.assertAllClose(
|
|
session.run(embeddings),
|
|
[[0.0, 0.0, 0.0], [1.49, 3.22, 4.56]],
|
|
rtol=0.02)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# This test is only supported in graph mode.
|
|
if tf.executing_eagerly():
|
|
logging.warning("Skipping running tests for TF Version: %s running eagerly.",
|
|
tf.__version__)
|
|
else:
|
|
tf.test.main()
|