tfjs-examples/date-conversion-attention/train.js

245 lines
8.5 KiB
JavaScript

/**
* @license
* Copyright 2019 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/**
* Training an attention LSTM sequence-to-sequence decoder to translate
* various date formats into the ISO date format.
*
* Inspired by and loosely based on
* https://github.com/wanasit/katakana/blob/master/notebooks/Attention-based%20Sequence-to-Sequence%20in%20Keras.ipynb
*/
import * as fs from 'fs';
import * as shelljs from 'shelljs';
import * as argparse from 'argparse';
import * as tf from '@tensorflow/tfjs';
import * as dateFormat from './date_format';
import {createModel, runSeq2SeqInference} from './model';
/**
* Generate sets of data for training.
*
* @param {number} trainSplit Trainining split. Must be >0 and <1.
* @param {number} valSplit Validatoin split. Must be >0 and <1.
* @return An `Object` consisting of
* - trainEncoderInput, as a `tf.Tensor` of shape
* `[numTrainExapmles, inputLength]`
* - trainDecoderInput, as a `tf.Tensor` of shape
* `[numTrainExapmles, outputLength]`. The first element of every
* example has been set as START_CODE (the sequence-start symbol).
* - trainDecoderOuptut, as a one-hot encoded `tf.Tensor` of shape
* `[numTrainExamples, outputLength, outputVocabSize]`.
* - valEncoderInput, same as trainEncoderInput, but for the validation set.
* - valDecoderInput, same as trainDecoderInput, but for the validation set.
* - valDecoderOutput, same as trainDecoderOuptut, but for the validation
* set.
* - testDateTuples, date tuples ([year, month, day]) for the test set.
*/
export function generateDataForTraining(trainSplit = 0.25, valSplit = 0.15) {
tf.util.assert(
trainSplit > 0 && valSplit > 0 && trainSplit + valSplit <= 1,
`Invalid trainSplit (${trainSplit}) and valSplit (${valSplit})`);
const dateTuples = [];
const MIN_YEAR = 1950;
const MAX_YEAR = 2050;
for (let date = new Date(MIN_YEAR,0,1);
date.getFullYear() < MAX_YEAR;
date.setDate(date.getDate() + 1)) {
dateTuples.push([date.getFullYear(), date.getMonth() + 1, date.getDate()]);
}
tf.util.shuffle(dateTuples);
const numTrain = Math.floor(dateTuples.length * trainSplit);
const numVal = Math.floor(dateTuples.length * valSplit);
console.log(`Number of dates used for training: ${numTrain}`);
console.log(`Number of dates used for validation: ${numVal}`);
console.log(
`Number of dates used for testing: ` +
`${dateTuples.length - numTrain - numVal}`);
function dateTuplesToTensor(dateTuples) {
return tf.tidy(() => {
const inputs =
dateFormat.INPUT_FNS.map(fn => dateTuples.map(tuple => fn(tuple)));
const inputStrings = [];
inputs.forEach(inputs => inputStrings.push(...inputs));
const encoderInput =
dateFormat.encodeInputDateStrings(inputStrings);
const trainTargetStrings = dateTuples.map(
tuple => dateFormat.dateTupleToYYYYDashMMDashDD(tuple));
let decoderInput =
dateFormat.encodeOutputDateStrings(trainTargetStrings)
.asType('float32');
// One-step time shift: The decoder input is shifted to the left by
// one time step with respect to the encoder input. This accounts for
// the step-by-step decoding that happens during inference time.
decoderInput = tf.concat([
tf.ones([decoderInput.shape[0], 1]).mul(dateFormat.START_CODE),
decoderInput.slice(
[0, 0], [decoderInput.shape[0], decoderInput.shape[1] - 1])
], 1).tile([dateFormat.INPUT_FNS.length, 1]);
const decoderOutput = tf.oneHot(
dateFormat.encodeOutputDateStrings(trainTargetStrings),
dateFormat.OUTPUT_VOCAB.length).tile(
[dateFormat.INPUT_FNS.length, 1, 1]);
return {encoderInput, decoderInput, decoderOutput};
});
}
const {
encoderInput: trainEncoderInput,
decoderInput: trainDecoderInput,
decoderOutput: trainDecoderOutput
} = dateTuplesToTensor(dateTuples.slice(0, numTrain));
const {
encoderInput: valEncoderInput,
decoderInput: valDecoderInput,
decoderOutput: valDecoderOutput
} = dateTuplesToTensor(dateTuples.slice(numTrain, numTrain + numVal));
const testDateTuples =
dateTuples.slice(numTrain + numVal, dateTuples.length);
return {
trainEncoderInput,
trainDecoderInput,
trainDecoderOutput,
valEncoderInput,
valDecoderInput,
valDecoderOutput,
testDateTuples
};
}
function parseArguments() {
const argParser = new argparse.ArgumentParser({
description:
'Train an attention-based date-conversion model in TensorFlow.js'
});
argParser.addArgument('--gpu', {
action: 'storeTrue',
help: 'Use tfjs-node-gpu to train the model. Requires CUDA/CuDNN.'
});
argParser.addArgument('--epochs', {
type: 'int',
defaultValue: 2,
help: 'Number of epochs to train the model for'
});
argParser.addArgument('--batchSize', {
type: 'int',
defaultValue: 128,
help: 'Batch size to be used during model training'
});
argParser.addArgument('--trainSplit ', {
type: 'float',
defaultValue: 0.25,
help: 'Fraction of all possible dates to use for training. Must be ' +
'> 0 and < 1. Its sum with valSplit must be <1.'
});
argParser.addArgument('--valSplit', {
type: 'float',
defaultValue: 0.15,
help: 'Fraction of all possible dates to use for training. Must be ' +
'> 0 and < 1. Its sum with trainSplit must be <1.'
});
argParser.addArgument('--savePath', {
type: 'string',
defaultValue: './dist/model',
});
argParser.addArgument('--logDir', {
type: 'string',
help: 'Optional tensorboard log directory, to which the loss and ' +
'accuracy will be logged during model training.'
});
argParser.addArgument('--logUpdateFreq', {
type: 'string',
defaultValue: 'batch',
optionStrings: ['batch', 'epoch'],
help: 'Frequency at which the loss and accuracy will be logged to ' +
'tensorboard.'
});
return argParser.parseArgs();
}
async function run() {
const args = parseArguments();
let tfn;
if (args.gpu) {
console.log('Using GPU');
tfn = require('@tensorflow/tfjs-node-gpu');
} else {
console.log('Using CPU');
tfn = require('@tensorflow/tfjs-node');
}
const model = createModel(
dateFormat.INPUT_VOCAB.length, dateFormat.OUTPUT_VOCAB.length,
dateFormat.INPUT_LENGTH, dateFormat.OUTPUT_LENGTH);
model.summary();
const {
trainEncoderInput,
trainDecoderInput,
trainDecoderOutput,
valEncoderInput,
valDecoderInput,
valDecoderOutput,
testDateTuples
} = generateDataForTraining(args.trainSplit, args.valSplit);
await model.fit(
[trainEncoderInput, trainDecoderInput], trainDecoderOutput, {
epochs: args.epochs,
batchSize: args.batchSize,
shuffle: true,
validationData: [[valEncoderInput, valDecoderInput], valDecoderOutput],
callbacks: args.logDir == null ? null :
tfn.node.tensorBoard(args.logDir, {updateFreq: args.logUpdateFreq})
});
// Save the model.
if (args.savePath != null && args.savePath.length) {
if (!fs.existsSync(args.savePath)) {
shelljs.mkdir('-p', args.savePath);
}
const saveURL = `file://${args.savePath}`
await model.save(saveURL);
console.log(`Saved model to ${saveURL}`);
}
// Run seq2seq inference tests and print the results to console.
const numTests = 10;
for (let n = 0; n < numTests; ++n) {
for (const testInputFn of dateFormat.INPUT_FNS) {
const inputStr = testInputFn(testDateTuples[n]);
console.log('\n-----------------------');
console.log(`Input string: ${inputStr}`);
const correctAnswer =
dateFormat.dateTupleToYYYYDashMMDashDD(testDateTuples[n]);
console.log(`Correct answer: ${correctAnswer}`);
const {outputStr} = await runSeq2SeqInference(model, inputStr);
const isCorrect = outputStr === correctAnswer;
console.log(
`Model output: ${outputStr} (${isCorrect ? 'OK' : 'WRONG'})` );
}
}
}
if (require.main === module) {
run();
}