mirror of https://github.com/kubeflow/examples.git
Add negative sampling to Transformer network (#167)
* Add negative sampling to Transformer network * Add generate data flag, can skip t2t-datagen step
This commit is contained in:
parent
b6a3c4c0ea
commit
c5f13464b4
|
|
@ -27,15 +27,20 @@ class SimilarityTransformer(t2t_model.T2TModel):
|
|||
with tf.variable_scope('code_embedding'):
|
||||
code_embedding = self.encode(features, 'targets')
|
||||
|
||||
cosine_dist = tf.losses.cosine_distance(
|
||||
tf.nn.l2_normalize(string_embedding, axis=1),
|
||||
tf.nn.l2_normalize(code_embedding, axis=1),
|
||||
axis=1, reduction=tf.losses.Reduction.NONE)
|
||||
string_embedding_norm = tf.nn.l2_normalize(string_embedding, axis=1)
|
||||
code_embedding_norm = tf.nn.l2_normalize(code_embedding, axis=1)
|
||||
|
||||
# TODO(sanyamkapoor): need negative sampling, won't be all ones anymore.
|
||||
labels = tf.one_hot(tf.ones(
|
||||
tf.shape(features['targets'])[0], tf.int32), 2)
|
||||
logits = tf.concat([cosine_dist, 1 - cosine_dist], axis=1)
|
||||
# All-vs-All cosine distance matrix, reshaped as row-major.
|
||||
cosine_dist = 1.0 - tf.matmul(string_embedding_norm, code_embedding_norm,
|
||||
transpose_b=True)
|
||||
cosine_dist_flat = tf.reshape(cosine_dist, [-1, 1])
|
||||
|
||||
# Positive samples on the diagonal, reshaped as row-major.
|
||||
label_matrix = tf.eye(tf.shape(cosine_dist)[0], dtype=tf.int32)
|
||||
label_matrix_flat = tf.reshape(label_matrix, [-1])
|
||||
|
||||
logits = tf.concat([1.0 - cosine_dist_flat, cosine_dist_flat], axis=1)
|
||||
labels = tf.one_hot(label_matrix_flat, 2)
|
||||
|
||||
loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels,
|
||||
logits=logits)
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ local baseParams = std.extVar("__ksonnet/params").components["t2t-job"];
|
|||
getTrainerCmd(params):: {
|
||||
local trainer = [
|
||||
"t2t-trainer",
|
||||
"--generate_data",
|
||||
"--problem=" + params.problem,
|
||||
"--data_dir=" + params.dataDir,
|
||||
"--output_dir=" + params.outputDir,
|
||||
|
|
|
|||
Loading…
Reference in New Issue