mirror of https://github.com/kubeflow/examples.git
Use conditionals and add test for code search (#291)
* Fix model export, loss function, and add some manual tests. Fix Model export to support computing code embeddings: Fix #260 * The previous exported model was always using the embeddings trained for the search query. * But we need to be able to compute embedding vectors for both the query and code. * To support this we add a new input feature "embed_code" and conditional ops. The exported model uses the value of the embed_code feature to determine whether to treat the inputs as a query string or code and computes the embeddings appropriately. * Originally based on #233 by @activatedgeek Loss function improvements * See #259 for a long discussion about different loss functions. * @activatedgeek was experimenting with different loss functions in #233 and this pulls in some of those changes. Add manual tests * Related to #258 * We add a smoke test for T2T steps so we can catch bugs in the code. * We also add a smoke test for serving the model with TFServing. * We add a sanity check to ensure we get different values for the same input based on which embeddings we are computing. Change Problem/Model name * Register the problem github_function_docstring with a different name to distinguish it from the version inside the Tensor2Tensor library. * * Skip the test when running under prow because its a manual test. * Fix some lint errors. * * Fix lint and skip tests. * Fix lint. * * Fix lint * Revert loss function changes; we can do that in a follow on PR. * * Run generate_data as part of the test rather than reusing a cached vocab and processed input file. * Modify SimilarityTransformer so we can overwrite the number of shards used easily to facilitate testing. * Comment out py-test for now.
This commit is contained in:
parent
07483c2dff
commit
acd8007717
|
@ -0,0 +1,58 @@
|
|||
# Developer guide for the code search example
|
||||
|
||||
This doc is intended for folks looking to contribute to the example.
|
||||
|
||||
## Testing
|
||||
|
||||
We currently have tests that can be run manually to test the code.
|
||||
We hope to get these integrated into our CI system soon.
|
||||
|
||||
### T2T Test
|
||||
|
||||
The test code_search/src/code_search/t2t/similarity_transformer_test.py
|
||||
can be used to test
|
||||
|
||||
* Training
|
||||
* Evaluation
|
||||
* Model Export
|
||||
|
||||
The test can be run as follows
|
||||
|
||||
```
|
||||
cd code_search/src
|
||||
python3 -m code_searcch.t2t.similarity_transformer_export_test
|
||||
```
|
||||
The test just runs the relevant T2T steps and verifies they succeeds. No additional
|
||||
checks are executed.
|
||||
|
||||
|
||||
### TF Serving test
|
||||
|
||||
code_search/src/code_search/nmslib/cli/embed_query_test.py
|
||||
|
||||
|
||||
Can be used to test generating predictions using TFServing.
|
||||
|
||||
The test assumes the TFServing is running in a docker container
|
||||
|
||||
You can start TFServing as follows
|
||||
|
||||
```
|
||||
./code_search/nmslib/cli/start_test_server.sh
|
||||
```
|
||||
|
||||
You can then run the test
|
||||
|
||||
```
|
||||
export PYTHONPATH=${EXAMPLES_REPO/code_search/src:${PYTHONPATH}
|
||||
python3 -m embed_query_test
|
||||
```
|
||||
|
||||
The test verifies that the code can successfully generate embeddings using TFServing.
|
||||
|
||||
The test verifies that different embeddings are computed for the query and the code.
|
||||
|
||||
**start_test_server.sh** relies on a model stored in **code_search/src/code_search/t2t/**
|
||||
A new model can be produced by running **similarity_transformer_export_test**. The unittest
|
||||
will export the model to a temporary directory. You can then copy that model to the test_data
|
||||
directory.
|
|
@ -2,7 +2,7 @@ ARG BASE_IMAGE_TAG=1.8.0
|
|||
|
||||
FROM tensorflow/tensorflow:$BASE_IMAGE_TAG
|
||||
|
||||
RUN pip --no-cache-dir install tensor2tensor~=1.7.0 oauth2client~=4.1.0 &&\
|
||||
RUN pip --no-cache-dir install tensor2tensor~=1.8.0 oauth2client~=4.1.0 &&\
|
||||
apt-get update && apt-get install -y jq &&\
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ RUN curl -sL https://deb.nodesource.com/setup_10.x | bash - &&\
|
|||
numpy~=1.14.0 \
|
||||
oauth2client~=4.1.0 \
|
||||
requests~=2.18.0 \
|
||||
tensor2tensor~=1.7.0 &&\
|
||||
tensor2tensor~=1.8.0 &&\
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
ADD src/ /src
|
||||
|
|
|
@ -72,7 +72,8 @@ class EncodeFunctionTokens(beam.DoFn):
|
|||
}
|
||||
"""
|
||||
encoder = get_encoder(self.problem, self.data_dir)
|
||||
encoded_function = encode_query(encoder, element.get(self.function_tokens_key))
|
||||
encoded_function = encode_query(encoder, True,
|
||||
element.get(self.function_tokens_key))
|
||||
|
||||
element[self.instances_key] = [{'input': {'b64': encoded_function}}]
|
||||
yield element
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import apache_beam as beam
|
||||
|
||||
import code_search.dataflow.do_fns.prediction_do_fn as pred
|
||||
import code_search.dataflow.do_fns.function_embeddings as func_embeddings
|
||||
import code_search.dataflow.do_fns.function_embeddings as func_embeddings # pylint: disable=no-name-in-module
|
||||
import code_search.dataflow.transforms.github_bigquery as github_bigquery
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
# coding=utf-8
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test embedding query using TFServing.
|
||||
|
||||
This is a manual/E2E test that assumes TFServing is running externally
|
||||
(e.g. Docker container or K8s pod).
|
||||
|
||||
The script start_test_server.sh can be used to start a Docker container
|
||||
when running locally.
|
||||
|
||||
To run TFServing we need a model. start_test_server.sh will use a model
|
||||
in ../../t2t/test_data/model
|
||||
|
||||
code_search must be a top level Python package.
|
||||
|
||||
requires host machine has tensorflow_model_server executable available
|
||||
"""
|
||||
|
||||
# TODO(jlewi): Starting the test seems very slow. I wonder if this is because
|
||||
# tensor2tensor is loading a bunch of models and if maybe we can skip that.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
import tensorflow as tf
|
||||
|
||||
import numpy as np
|
||||
|
||||
from code_search.nmslib.cli import start_search_server
|
||||
|
||||
start = datetime.datetime.now()
|
||||
|
||||
FLAGS = tf.flags.FLAGS
|
||||
|
||||
PROBLEM_NAME = "kf_github_function_docstring"
|
||||
|
||||
class TestEmbedQuery(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(os.getenv("PROW_JOB_ID"), "Manual test not run on prow")
|
||||
def test_embed(self):
|
||||
"""Test that we can embed the search query string via tf.serving.
|
||||
|
||||
This test assumes the model is running as an external process in TensorFlow
|
||||
serving.
|
||||
|
||||
The external process can be started a variety of ways e.g. subprocess,
|
||||
kubernetes, or docker container.
|
||||
|
||||
The script start_test_server.sh can be used to start TFServing in
|
||||
docker container.
|
||||
"""
|
||||
# Directory containing the vocabulary.
|
||||
test_data_dir = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "t2t", "test_data"))
|
||||
# 8501 should be REST port
|
||||
server = os.getenv("TEST_SERVER", "localhost:8501")
|
||||
|
||||
# Model name matches the subdirectory in TF Serving's model Directory
|
||||
# containing models.
|
||||
model_name = "test_model_20181031"
|
||||
serving_url = "http://{0}/v1/models/{1}:predict".format(server, model_name)
|
||||
query = "Write to GCS"
|
||||
query_encoder = start_search_server.build_query_encoder(PROBLEM_NAME,
|
||||
test_data_dir)
|
||||
code_encoder = start_search_server.build_query_encoder(PROBLEM_NAME,
|
||||
test_data_dir,
|
||||
embed_code=True)
|
||||
|
||||
query_result = start_search_server.embed_query(query_encoder, serving_url, query)
|
||||
code_result = start_search_server.embed_query(code_encoder, serving_url, query)
|
||||
|
||||
# As a sanity check ensure the vectors aren't equal
|
||||
q_vec = np.array(query_result)
|
||||
q_vec = q_vec/np.sqrt(np.dot(q_vec, q_vec))
|
||||
c_vec = np.array(code_result)
|
||||
c_vec = c_vec/np.sqrt(np.dot(c_vec, c_vec))
|
||||
|
||||
dist = np.dot(q_vec, c_vec)
|
||||
self.assertNotAlmostEqual(1, dist)
|
||||
logging.info("Done")
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
unittest.main()
|
|
@ -1,4 +1,5 @@
|
|||
import csv
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import functools
|
||||
|
@ -7,6 +8,8 @@ import tensorflow as tf
|
|||
|
||||
import code_search.nmslib.cli.arguments as arguments
|
||||
import code_search.t2t.query as query
|
||||
# We need to import function_docstring to ensure the problem is registered
|
||||
from code_search.t2t import function_docstring # pylint: disable=unused-import
|
||||
from code_search.nmslib.search_engine import CodeSearchEngine
|
||||
from code_search.nmslib.search_server import CodeSearchServer
|
||||
|
||||
|
@ -18,10 +21,28 @@ def embed_query(encoder, serving_url, query_str):
|
|||
headers={'content-type': 'application/json'},
|
||||
data=json.dumps(data))
|
||||
|
||||
if not response.ok:
|
||||
logging.error("Request failed; status: %s reason %s response: %s",
|
||||
response.status_code,
|
||||
response.reason,
|
||||
response.content)
|
||||
result = response.json()
|
||||
return result['predictions'][0]['outputs']
|
||||
|
||||
|
||||
def build_query_encoder(problem, data_dir, embed_code=False):
|
||||
"""Build a query encoder.
|
||||
|
||||
Args:
|
||||
problem: The name of the T2T problem to use
|
||||
data_dir: Directory containing the data. This should include the vocabulary.
|
||||
embed_code: Whether to compute embeddings for natural language or code.
|
||||
"""
|
||||
encoder = query.get_encoder(problem, data_dir)
|
||||
query_encoder = functools.partial(query.encode_query, encoder, embed_code)
|
||||
|
||||
return query_encoder
|
||||
|
||||
def start_search_server(argv=None):
|
||||
"""Start a Flask REST server.
|
||||
|
||||
|
@ -53,8 +74,9 @@ def start_search_server(argv=None):
|
|||
if not os.path.isfile(tmp_index_file):
|
||||
tf.gfile.Copy(args.index_file, tmp_index_file)
|
||||
|
||||
encoder = query.get_encoder(args.problem, args.data_dir)
|
||||
query_encoder = functools.partial(query.encode_query, encoder)
|
||||
# Build an an encoder for the natural language strings.
|
||||
query_encoder = build_query_encoder(args.problem, args.data_dir,
|
||||
embed_code=False)
|
||||
embedding_fn = functools.partial(embed_query, query_encoder, args.serving_url)
|
||||
|
||||
search_engine = CodeSearchEngine(tmp_index_file, lookup_data, embedding_fn)
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
#!/bin/bash
|
||||
#
|
||||
# A simple script for starting TFServing locally in a docker container.
|
||||
# This allows us to test sending predictions to the model.
|
||||
set -ex
|
||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )"
|
||||
MODELS_DIR="$( cd "${DIR}/../../t2t/test_data/" >/dev/null && pwd)"
|
||||
|
||||
MODEL_NAME=test_model_20181031
|
||||
|
||||
if [ ! -d ${MODELS_DIR}/${MODEL_NAME} ]; then
|
||||
echo Missing directory ${MODELS_DIR}/${MODEL_NAME}
|
||||
exit 1
|
||||
fi
|
||||
|
||||
set +e
|
||||
docker rm -f cs_serving_test
|
||||
set -e
|
||||
|
||||
# TODO(jlewi): Is there anyway to cause TF Serving to load all models in
|
||||
# MODELS_DIR and not have to specify the environment variable MODEL_NAME
|
||||
docker run --rm --name=cs_serving_test -p 8500:8500 -p 8501:8501 \
|
||||
-v "${MODELS_DIR}:/models" \
|
||||
-e MODEL_NAME="${MODEL_NAME}" \
|
||||
tensorflow/serving
|
||||
# Tail the logs
|
||||
docker logs -f cs_serving_test
|
|
@ -1,12 +1,20 @@
|
|||
"""Github function/text similatrity problems."""
|
||||
import csv
|
||||
import logging
|
||||
from six import StringIO
|
||||
from tensor2tensor.data_generators import generator_utils
|
||||
from tensor2tensor.data_generators import text_problems
|
||||
from tensor2tensor.utils import metrics
|
||||
from tensor2tensor.utils import registry
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
# There is a copy of the problem in the Tensor2Tensor library.
|
||||
# http://bit.ly/2Olf34u
|
||||
#
|
||||
# We want to register this problem with a different name to make sure
|
||||
# we don't end up using that problem.
|
||||
# So we register it with the name kf
|
||||
@registry.register_problem("kf_github_function_docstring")
|
||||
class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
||||
"""Function and Docstring similarity Problem.
|
||||
|
||||
|
@ -18,6 +26,7 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
|||
"""
|
||||
|
||||
DATA_PATH_PREFIX = "gs://kubeflow-examples/t2t-code-search/raw_data"
|
||||
NUM_SHARDS = 100
|
||||
|
||||
@property
|
||||
def pair_files_list(self):
|
||||
|
@ -46,12 +55,14 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
|||
In this case, the tuple is of size 1 because the URL points
|
||||
to a file itself.
|
||||
"""
|
||||
logging.info("Using %s shards", self.NUM_SHARDS)
|
||||
return [
|
||||
[
|
||||
"{}/func-doc-pairs-000{:02}-of-00100.csv".format(self.DATA_PATH_PREFIX, i),
|
||||
("func-doc-pairs-000{:02}-of-00100.csv".format(i),)
|
||||
"{}/func-doc-pairs-{:05}-of-{:05}.csv".format(
|
||||
self.DATA_PATH_PREFIX, i, self.NUM_SHARDS),
|
||||
("func-doc-pairs-{:05}-of-{:05}.csv".format(i, self.NUM_SHARDS),)
|
||||
]
|
||||
for i in range(100)
|
||||
for i in range(self.NUM_SHARDS)
|
||||
]
|
||||
|
||||
@property
|
||||
|
@ -69,8 +80,8 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
|||
|
||||
def get_csv_files(self, _data_dir, tmp_dir, _dataset_split):
|
||||
return [
|
||||
generator_utils.maybe_download(tmp_dir, file_list[0], uri)
|
||||
for uri, file_list in self.pair_files_list
|
||||
generator_utils.maybe_download(tmp_dir, file_list[0], uri)
|
||||
for uri, file_list in self.pair_files_list
|
||||
]
|
||||
|
||||
def generate_samples(self, data_dir, tmp_dir, dataset_split):
|
||||
|
@ -85,7 +96,7 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
|||
|
||||
Yields:
|
||||
Each element yielded is of a Python dict of the form
|
||||
{"inputs": "STRING", "targets": "STRING"}
|
||||
{"inputs": "STRING", "targets": "STRING", "embed_code": [0]}
|
||||
"""
|
||||
csv_files = self.get_csv_files(data_dir, tmp_dir, dataset_split)
|
||||
|
||||
|
@ -95,7 +106,17 @@ class GithubFunctionDocstring(text_problems.Text2TextProblem):
|
|||
for line in csv_file:
|
||||
reader = csv.reader(StringIO(line))
|
||||
for docstring_tokens, function_tokens in reader:
|
||||
yield {"inputs": docstring_tokens, "targets": function_tokens}
|
||||
yield {
|
||||
"inputs": docstring_tokens,
|
||||
"targets": function_tokens,
|
||||
"embed_code": [0],
|
||||
}
|
||||
|
||||
def example_reading_spec(self):
|
||||
data_fields, data_items_to_decoders = super(GithubFunctionDocstring,
|
||||
self).example_reading_spec()
|
||||
data_fields["embed_code"] = tf.FixedLenFeature([1], dtype=tf.int64)
|
||||
return data_fields, data_items_to_decoders
|
||||
|
||||
def eval_metrics(self): # pylint: disable=no-self-use
|
||||
return [
|
||||
|
|
|
@ -14,16 +14,36 @@ def get_encoder(problem_name, data_dir):
|
|||
return problem.feature_info["inputs"].encoder
|
||||
|
||||
|
||||
def encode_query(encoder, query_str):
|
||||
def encode_query(encoder, embed_code, query_str):
|
||||
"""Encode the input query string using encoder. This
|
||||
might vary by problem but keeping generic as a reference.
|
||||
Note that in T2T problems, the 'targets' key is needed
|
||||
even though it is ignored during inference.
|
||||
See tensorflow/tensor2tensor#868"""
|
||||
See tensorflow/tensor2tensor#868
|
||||
|
||||
Args:
|
||||
encoder: Encoder to encode the string as a vector.
|
||||
embed_code: Bool determines whether to treat query_str as a natural language
|
||||
query and use the associated embedding for natural language or to
|
||||
treat it as code and use the associated embeddings.
|
||||
query_str: The data to compute embeddings for.
|
||||
"""
|
||||
|
||||
encoded_str = encoder.encode(query_str) + [text_encoder.EOS_ID]
|
||||
features = {"inputs": tf.train.Feature(int64_list=tf.train.Int64List(value=encoded_str)),
|
||||
"targets": tf.train.Feature(int64_list=tf.train.Int64List(value=[0]))}
|
||||
|
||||
embed_code_value = 0
|
||||
if embed_code:
|
||||
embed_code_value = 1
|
||||
|
||||
features = {
|
||||
"inputs": tf.train.Feature(int64_list=tf.train.Int64List(value=encoded_str)),
|
||||
"targets": tf.train.Feature(int64_list=tf.train.Int64List(value=[0])),
|
||||
# The embed code feature determines whether we treat the input (0)
|
||||
# or code (1). Since we want to compute the query embeddings we set it
|
||||
# to 0.
|
||||
"embed_code": tf.train.Feature(int64_list=tf.train.Int64List(
|
||||
value=[embed_code_value])),
|
||||
}
|
||||
example = tf.train.Example(features=tf.train.Features(feature=features))
|
||||
return base64.b64encode(example.SerializeToString()).decode('utf-8')
|
||||
|
||||
|
|
|
@ -6,8 +6,11 @@ from tensor2tensor.utils import registry
|
|||
from tensor2tensor.utils import t2t_model
|
||||
import tensorflow as tf
|
||||
|
||||
MODEL_NAME = 'cs_similarity_transformer'
|
||||
|
||||
@registry.register_model('cs_similarity_transformer')
|
||||
# We don't use the default name because there is already an older version
|
||||
# included as part of the T2T library with the default name.
|
||||
@registry.register_model(MODEL_NAME)
|
||||
class SimilarityTransformer(t2t_model.T2TModel):
|
||||
"""Transformer Model for Similarity between two strings.
|
||||
|
||||
|
@ -17,15 +20,17 @@ class SimilarityTransformer(t2t_model.T2TModel):
|
|||
Dot Product is used as the distance measure between two
|
||||
string embeddings.
|
||||
"""
|
||||
|
||||
def top(self, body_output, _): # pylint: disable=no-self-use
|
||||
return body_output
|
||||
|
||||
def body(self, features):
|
||||
with tf.variable_scope('string_embedding'):
|
||||
string_embedding = self.encode(features, 'inputs')
|
||||
|
||||
if 'targets' in features:
|
||||
if self.hparams.mode != tf.estimator.ModeKeys.PREDICT:
|
||||
# In training mode we need to embed both the queries and the code
|
||||
# using the inputs and targets respectively.
|
||||
with tf.variable_scope('string_embedding'):
|
||||
string_embedding = self.encode(features, 'inputs')
|
||||
|
||||
with tf.variable_scope('code_embedding'):
|
||||
code_embedding = self.encode(features, 'targets')
|
||||
|
||||
|
@ -34,7 +39,7 @@ class SimilarityTransformer(t2t_model.T2TModel):
|
|||
|
||||
# All-vs-All cosine distance matrix, reshaped as row-major.
|
||||
cosine_dist = 1.0 - tf.matmul(string_embedding_norm, code_embedding_norm,
|
||||
transpose_b=True)
|
||||
transpose_b=True)
|
||||
cosine_dist_flat = tf.reshape(cosine_dist, [-1, 1])
|
||||
|
||||
# Positive samples on the diagonal, reshaped as row-major.
|
||||
|
@ -46,10 +51,36 @@ class SimilarityTransformer(t2t_model.T2TModel):
|
|||
|
||||
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
|
||||
logits=logits)
|
||||
result = string_embedding
|
||||
return result, {'training': loss}
|
||||
|
||||
return string_embedding, {'training': loss}
|
||||
|
||||
return string_embedding
|
||||
# In predict mode we conditionally embed either the string query
|
||||
# or the code based on the embed_code feature. In both cases the
|
||||
# input will be in the inputs feature but the variable scope will
|
||||
# be different
|
||||
# Define predicates to be used with tf.cond
|
||||
def embed_string():
|
||||
with tf.variable_scope('string_embedding'):
|
||||
string_embedding = self.encode(features, 'inputs')
|
||||
return string_embedding
|
||||
|
||||
def embed_code():
|
||||
with tf.variable_scope('code_embedding'):
|
||||
code_embedding = self.encode(features, 'inputs')
|
||||
return code_embedding
|
||||
|
||||
embed_code_feature = features.get('embed_code')
|
||||
|
||||
# embed_code_feature will be a tensor because inputs will be a batch
|
||||
# of inputs. We need to reduce that down to a single value for use
|
||||
# with tf.cond; so we simply take the max of all the elements.
|
||||
# This implicitly assume all inputs have the same value.
|
||||
is_embed_code = tf.reduce_max(embed_code_feature)
|
||||
result = tf.cond(is_embed_code > 0, embed_code, embed_string)
|
||||
|
||||
result = tf.nn.l2_normalize(result)
|
||||
return result
|
||||
|
||||
def encode(self, features, input_key):
|
||||
hparams = self._hparams
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
# coding=utf-8
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests of modified similarity transformer model.
|
||||
|
||||
code_search must be a top level Python package.
|
||||
python -m code_search.t2t.similarity_transformer_export_test
|
||||
"""
|
||||
|
||||
# TODO(jlewi): Starting the test seems very slow. I wonder if this is because
|
||||
# tensor2tensor is loading a bunch of models and if maybe we can skip that.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensor2tensor.bin import t2t_trainer
|
||||
from tensor2tensor.serving import export
|
||||
from tensor2tensor.utils import registry
|
||||
|
||||
|
||||
from code_search.t2t import similarity_transformer
|
||||
|
||||
FLAGS = tf.flags.FLAGS
|
||||
|
||||
PROBLEM_NAME = "github_function_docstring"
|
||||
|
||||
class TestSimilarityTransformer(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(os.getenv("PROW_JOB_ID"), "Manual test not run on prow")
|
||||
def test_train_and_export(self): # pylint: disable=no-self-use
|
||||
"""Test that we can train and export the model."""
|
||||
|
||||
test_data_dir = os.path.join(os.path.dirname(__file__), "test_data")
|
||||
# If we set t2t_usr_dir t2t_train.main will end up importing that
|
||||
# directory which causes an error because the model ends up being registered
|
||||
# twice.
|
||||
FLAGS.problem = "kf_github_function_docstring"
|
||||
FLAGS.data_dir = tempfile.mkdtemp()
|
||||
|
||||
FLAGS.tmp_dir = tempfile.mkdtemp()
|
||||
logging.info("Using data_dir %s", FLAGS.data_dir)
|
||||
logging.info("Using tmp_dir %s", FLAGS.tmp_dir)
|
||||
|
||||
FLAGS.output_dir = tempfile.mkdtemp()
|
||||
logging.info("Using output_dir %s", FLAGS.output_dir)
|
||||
|
||||
FLAGS.model = similarity_transformer.MODEL_NAME
|
||||
FLAGS.hparams_set = "transformer_tiny"
|
||||
FLAGS.train_steps = 1
|
||||
FLAGS.eval_steps = 5
|
||||
|
||||
# We want to trigger eval.
|
||||
FLAGS.local_eval_frequency = 1
|
||||
FLAGS.schedule = "continuous_train_and_eval"
|
||||
|
||||
problem = registry.problem(FLAGS.problem)
|
||||
|
||||
# Override the data path prefix and number of shards so we use
|
||||
# the test data rather than downloading from GCS.
|
||||
problem.DATA_PATH_PREFIX = os.path.join(test_data_dir, "raw_data")
|
||||
problem.NUM_SHARDS = 1
|
||||
|
||||
# Generating the data can be slow because it uses an iterative process
|
||||
# to compute the vocab.
|
||||
# During development you can reuse data_dir between runs; if the vocab
|
||||
# and processed input files already exists in that directory it won't
|
||||
# need to regenerate them.
|
||||
problem.generate_data(FLAGS.data_dir, FLAGS.tmp_dir)
|
||||
|
||||
t2t_trainer.main(None)
|
||||
|
||||
export.main(None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger().setLevel(logging.INFO)
|
||||
unittest.main()
|
|
@ -0,0 +1,22 @@
|
|||
This directory contains data to be used by the unittests.
|
||||
|
||||
|
||||
## Training data
|
||||
|
||||
raw_data/kf_github_function_docstring-train-00000-of-00001.csv
|
||||
|
||||
We copied one of the shard files from GCS and renamed it. We had to change
|
||||
the max number of shards
|
||||
|
||||
## Inference Data
|
||||
|
||||
./test_data/export
|
||||
|
||||
This is an exported model suitable for training. It can be reproduced by
|
||||
running the training unittest which will export the model to a temporary
|
||||
directory. You can then just copy it into the test_data folder.
|
||||
|
||||
./vocab.kf_github_function_docstring
|
||||
|
||||
This is the vocabe file. You can run training and then copy it from the
|
||||
temporary output directory.
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
|
@ -1,11 +1,10 @@
|
|||
astor~=0.6.0
|
||||
apache-beam[gcp]~=2.5.0
|
||||
astor~=0.7.0
|
||||
apache-beam[gcp]~=2.6.0
|
||||
Flask~=1.0.0
|
||||
nltk~=3.3.0
|
||||
nmslib~=1.7.0
|
||||
numpy~=1.14.0
|
||||
oauth2client~=4.1.0
|
||||
requests~=2.18.0
|
||||
requests~=2.19.0
|
||||
spacy~=2.0.0
|
||||
tensor2tensor~=1.7.0
|
||||
tensorflow~=1.8.0
|
||||
tensor2tensor~=1.9.0
|
||||
tensorflow~=1.11.0
|
||||
|
|
|
@ -163,10 +163,22 @@
|
|||
name: "create-pr-symlink",
|
||||
template: "create-pr-symlink",
|
||||
},
|
||||
{
|
||||
name: "py-test",
|
||||
template: "py-test",
|
||||
},
|
||||
// test_py_checks runs all py files matching "_test.py"
|
||||
// This is currently commented out because the only matching tests
|
||||
// are manual tests for some of the examples and/or they require
|
||||
// dependencies (e.g. tensorflow) not in the generic test worker image.
|
||||
//
|
||||
//
|
||||
// test_py_checks doesn't have options to exclude specific directories.
|
||||
// Since there are no other tests we just comment it out.
|
||||
//
|
||||
// TODO(https://github.com/kubeflow/testing/issues/240): Modify py_test
|
||||
// so we can exclude specific files.
|
||||
//
|
||||
// {
|
||||
// name: "py-test",
|
||||
// template: "py-test",
|
||||
// },
|
||||
{
|
||||
name: "py-lint",
|
||||
template: "py-lint",
|
||||
|
|
Loading…
Reference in New Issue