tfjs-examples/simple-object-detection/train.js

258 lines
9.1 KiB
JavaScript

/**
* @license
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
const fs = require('fs');
const path = require('path');
const argparse = require('argparse');
const canvas = require('canvas');
const tf = require('@tensorflow/tfjs');
const synthesizer = require('./synthetic_images');
const CANVAS_SIZE = 224; // Matches the input size of MobileNet.
// Name prefixes of layers that will be unfrozen during fine-tuning.
const topLayerGroupNames = ['conv_pw_9', 'conv_pw_10', 'conv_pw_11'];
// Name of the layer that will become the top layer of the truncated base.
const topLayerName =
`${topLayerGroupNames[topLayerGroupNames.length - 1]}_relu`;
// Used to scale the first column (0-1 shape indicator) of `yTrue`
// in order to ensure balanced contributions to the final loss value
// from shape and bounding-box predictions.
const LABEL_MULTIPLIER = [CANVAS_SIZE, 1, 1, 1, 1];
/**
* Custom loss function for object detection.
*
* The loss function is a sum of two losses
* - shape-class loss, computed as binaryCrossentropy and scaled by
* `classLossMultiplier` to match the scale of the bounding-box loss
* approximatey.
* - bounding-box loss, computed as the meanSquaredError between the
* true and predicted bounding boxes.
* @param {tf.Tensor} yTrue True labels. Shape: [batchSize, 5].
* The first column is a 0-1 indicator for whether the shape is a triangle
* (0) or a rectangle (1). The remaining for columns are the bounding boxes
* for the target shape: [left, right, top, bottom], in unit of pixels.
* The bounding box values are in the range [0, CANVAS_SIZE).
* @param {tf.Tensor} yPred Predicted labels. Shape: the same as `yTrue`.
* @return {tf.Tensor} Loss scalar.
*/
function customLossFunction(yTrue, yPred) {
return tf.tidy(() => {
// Scale the the first column (0-1 shape indicator) of `yTrue` in order
// to ensure balanced contributions to the final loss value
// from shape and bounding-box predictions.
return tf.metrics.meanSquaredError(yTrue.mul(LABEL_MULTIPLIER), yPred);
});
}
/**
* Loads MobileNet, removes the top part, and freeze all the layers.
*
* The top removal and layer freezing are preparation for transfer learning.
*
* Also gets handles to the layers that will be unfrozen during the fine-tuning
* phase of the training.
*
* @return {tf.Model} The truncated MobileNet, with all layers frozen.
*/
async function loadTruncatedBase() {
// TODO(cais): Add unit test.
const mobilenet = await tf.loadLayersModel(
'https://storage.googleapis.com/tfjs-models/tfjs/mobilenet_v1_0.25_224/model.json');
// Return a model that outputs an internal activation.
const fineTuningLayers = [];
const layer = mobilenet.getLayer(topLayerName);
const truncatedBase =
tf.model({inputs: mobilenet.inputs, outputs: layer.output});
// Freeze the model's layers.
for (const layer of truncatedBase.layers) {
layer.trainable = false;
for (const groupName of topLayerGroupNames) {
if (layer.name.indexOf(groupName) === 0) {
fineTuningLayers.push(layer);
break;
}
}
}
tf.util.assert(
fineTuningLayers.length > 1,
`Did not find any layers that match the prefixes ${topLayerGroupNames}`);
return {truncatedBase, fineTuningLayers};
}
/**
* Build a new head (i.e., output sub-model) that will be connected to
* the top of the truncated base for object detection.
*
* @param {tf.Shape} inputShape Input shape of the new model.
* @returns {tf.Model} The new head model.
*/
function buildNewHead(inputShape) {
const newHead = tf.sequential();
newHead.add(tf.layers.flatten({inputShape}));
newHead.add(tf.layers.dense({units: 200, activation: 'relu'}));
// Five output units:
// - The first is a shape indictor: predicts whether the target
// shape is a triangle or a rectangle.
// - The remaining four units are for bounding-box prediction:
// [left, right, top, bottom] in the unit of pixels.
newHead.add(tf.layers.dense({units: 5}));
return newHead;
}
/**
* Builds object-detection model from MobileNet.
*
* @returns {[tf.Model, tf.layers.Layer[]]}
* 1. The newly-built model for simple object detection.
* 2. The layers that can be unfrozen during fine-tuning.
*/
async function buildObjectDetectionModel() {
const {truncatedBase, fineTuningLayers} = await loadTruncatedBase();
// Build the new head model.
const newHead = buildNewHead(truncatedBase.outputs[0].shape.slice(1));
const newOutput = newHead.apply(truncatedBase.outputs[0]);
const model = tf.model({inputs: truncatedBase.inputs, outputs: newOutput});
return {model, fineTuningLayers};
}
(async function main() {
// Data-related settings.
const numCircles = 10;
const numLines = 10;
const parser = new argparse.ArgumentParser();
parser.addArgument('--gpu', {
action: 'storeTrue',
help: 'Use tfjs-node-gpu for training (required CUDA and CuDNN)'
});
parser.addArgument(
'--numExamples',
{type: 'int', defaultValue: 2000, help: 'Number of training exapmles'});
parser.addArgument('--validationSplit', {
type: 'float',
defaultValue: 0.15,
help: 'Validation split to be used during training'
});
parser.addArgument('--batchSize', {
type: 'int',
defaultValue: 128,
help: 'Batch size to be used during training'
});
parser.addArgument('--initialTransferEpochs', {
type: 'int',
defaultValue: 100,
help: 'Number of training epochs in the initial transfer ' +
'learning (i.e., 1st) phase'
});
parser.addArgument('--fineTuningEpochs', {
type: 'int',
defaultValue: 100,
help: 'Number of training epochs in the fine-tuning (i.e., 2nd) phase'
});
parser.addArgument('--logDir', {
type: 'string',
help: 'Optional tensorboard log directory, to which the loss ' +
'values will be logged during model training.'
});
parser.addArgument('--logUpdateFreq', {
type: 'string',
defaultValue: 'batch',
optionStrings: ['batch', 'epoch'],
help: 'Frequency at which the loss will be logged to tensorboard.'
});
const args = parser.parseArgs();
let tfn;
if (args.gpu) {
console.log('Training using GPU.');
tfn = require('@tensorflow/tfjs-node-gpu');
} else {
console.log('Training using CPU.');
tfn = require('@tensorflow/tfjs-node');
}
const modelSaveURL = 'file://./dist/object_detection_model';
const tBegin = tf.util.now();
console.log(`Generating ${args.numExamples} training examples...`);
const synthDataCanvas = canvas.createCanvas(CANVAS_SIZE, CANVAS_SIZE);
const synth =
new synthesizer.ObjectDetectionImageSynthesizer(synthDataCanvas, tf);
const {images, targets} =
await synth.generateExampleBatch(args.numExamples, numCircles, numLines);
const {model, fineTuningLayers} = await buildObjectDetectionModel();
model.compile({loss: customLossFunction, optimizer: tf.train.rmsprop(5e-3)});
model.summary();
// Initial phase of transfer learning.
console.log('Phase 1 of 2: initial transfer learning');
await model.fit(images, targets, {
epochs: args.initialTransferEpochs,
batchSize: args.batchSize,
validationSplit: args.validationSplit,
callbacks: args.logDir == null ? null : tfn.node.tensorBoard(args.logDir, {
updateFreq: args.logUpdateFreq
})
});
// Fine-tuning phase of transfer learning.
// Unfreeze layers for fine-tuning.
for (const layer of fineTuningLayers) {
layer.trainable = true;
}
model.compile({loss: customLossFunction, optimizer: tf.train.rmsprop(2e-3)});
model.summary();
// Do fine-tuning.
// The batch size is reduced to avoid CPU/GPU OOM. This has
// to do with the unfreezing of the fine-tuning layers above,
// which leads to higher memory consumption during backpropagation.
console.log('Phase 2 of 2: fine-tuning phase');
await model.fit(images, targets, {
epochs: args.fineTuningEpochs,
batchSize: args.batchSize / 2,
validationSplit: args.validationSplit,
callbacks: args.logDir == null ? null : tfn.node.tensorBoard(args.logDir, {
updateFreq: args.logUpdateFreq
})
});
// Save model.
// First make sure that the base directory dists.
const modelSavePath = modelSaveURL.replace('file://', '');
const dirName = path.dirname(modelSavePath);
if (!fs.existsSync(dirName)) {
fs.mkdirSync(dirName);
}
await model.save(modelSaveURL);
console.log(`Model training took ${(tf.util.now() - tBegin) / 1e3} s`);
console.log(`Trained model is saved to ${modelSaveURL}`);
console.log(
`\nNext, run the following command to test the model in the browser:`);
console.log(`\n yarn watch`);
})();