Add negative sampling to Transformer network (#167)

* Add negative sampling to Transformer network

* Add generate data flag, can skip t2t-datagen step
This commit is contained in:
Sanyam Kapoor 2018-07-04 20:14:22 -07:00 committed by k8s-ci-robot
parent b6a3c4c0ea
commit c5f13464b4
2 changed files with 14 additions and 8 deletions

View File

@ -27,15 +27,20 @@ class SimilarityTransformer(t2t_model.T2TModel):
with tf.variable_scope('code_embedding'):
code_embedding = self.encode(features, 'targets')
cosine_dist = tf.losses.cosine_distance(
tf.nn.l2_normalize(string_embedding, axis=1),
tf.nn.l2_normalize(code_embedding, axis=1),
axis=1, reduction=tf.losses.Reduction.NONE)
string_embedding_norm = tf.nn.l2_normalize(string_embedding, axis=1)
code_embedding_norm = tf.nn.l2_normalize(code_embedding, axis=1)
# TODO(sanyamkapoor): need negative sampling, won't be all ones anymore.
labels = tf.one_hot(tf.ones(
tf.shape(features['targets'])[0], tf.int32), 2)
logits = tf.concat([cosine_dist, 1 - cosine_dist], axis=1)
# All-vs-All cosine distance matrix, reshaped as row-major.
cosine_dist = 1.0 - tf.matmul(string_embedding_norm, code_embedding_norm,
transpose_b=True)
cosine_dist_flat = tf.reshape(cosine_dist, [-1, 1])
# Positive samples on the diagonal, reshaped as row-major.
label_matrix = tf.eye(tf.shape(cosine_dist)[0], dtype=tf.int32)
label_matrix_flat = tf.reshape(label_matrix, [-1])
logits = tf.concat([1.0 - cosine_dist_flat, cosine_dist_flat], axis=1)
labels = tf.one_hot(label_matrix_flat, 2)
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
logits=logits)

View File

@ -49,6 +49,7 @@ local baseParams = std.extVar("__ksonnet/params").components["t2t-job"];
getTrainerCmd(params):: {
local trainer = [
"t2t-trainer",
"--generate_data",
"--problem=" + params.problem,
"--data_dir=" + params.dataDir,
"--output_dir=" + params.outputDir,