Add averageTimeExclFirst to benchmark_util.timeInference (#7231)

FEATURE
This commit is contained in:
Linchenn 2023-01-04 13:33:52 -08:00 committed by GitHub
parent 4260128ed8
commit 09535adfe5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 104 additions and 93 deletions

View File

@ -109,8 +109,8 @@ function generateInputFromDef(inputDefs, isForGraphModel = false) {
generatedRaw.dispose();
} else {
throw new Error(
`The ${inputDef.dtype} dtype of '${inputDef.name}' input ` +
`at model.inputs[${inputDefIndex}] is not supported.`);
`The ${inputDef.dtype} dtype of '${inputDef.name}' input ` +
`at model.inputs[${inputDefIndex}] is not supported.`);
}
tensorArray.push(inputTensor);
});
@ -162,8 +162,8 @@ function getPredictFnForModel(model, input) {
predict = () => model.predict(input);
} else {
throw new Error(
'Predict function was not found. Please provide a tf.GraphModel or ' +
'tf.LayersModel');
'Predict function was not found. Please provide a tf.GraphModel or ' +
'tf.LayersModel');
}
return predict;
}
@ -175,6 +175,8 @@ function getPredictFnForModel(model, input) {
* about the model's inference time:
* - `times`: an array of inference time for each inference
* - `averageTime`: the average time of all inferences
* - `averageTimeExclFirst`: the average time of all inferences except the
* first.
* - `minTime`: the minimum time of all inferences
* - `maxTime`: the maximum time of all inferences
*
@ -211,6 +213,8 @@ async function timeModelInference(model, input, numRuns = 1) {
* time:
* - `times`: an array of inference time for each inference
* - `averageTime`: the average time of all inferences
* - `averageTimeExclFirst`: the average time of all inferences except the
* first.
* - `minTime`: the minimum time of all inferences
* - `maxTime`: the maximum time of all inferences
*
@ -237,8 +241,8 @@ async function timeModelInference(model, input, numRuns = 1) {
async function timeInference(predict, numRuns = 1) {
if (typeof predict !== 'function') {
throw new Error(
'The first parameter should be a function, while ' +
`a(n) ${typeof predict} is found.`);
'The first parameter should be a function, while ' +
`a(n) ${typeof predict} is found.`);
}
const times = [];
@ -254,11 +258,15 @@ async function timeInference(predict, numRuns = 1) {
}
const averageTime = times.reduce((acc, curr) => acc + curr, 0) / times.length;
const averageTimeExclFirst = times.length > 1 ?
times.slice(1).reduce((acc, curr) => acc + curr, 0) / (times.length - 1) :
'NA';
const minTime = Math.min(...times);
const maxTime = Math.max(...times);
const timeInfo = {
times,
averageTime,
averageTimeExclFirst,
minTime,
maxTime
@ -352,9 +360,9 @@ async function downloadValuesFromTensorContainer(tensorContainer) {
* @param numProfiles The number of rounds for profiling the inference process.
*/
async function profileModelInference(
model, input, isTflite = false, numProfiles = 1) {
model, input, isTflite = false, numProfiles = 1) {
const predict = isTflite ? () => tfliteModel.predict(input) :
getPredictFnForModel(model, input);
getPredictFnForModel(model, input);
return profileInference(predict, isTflite, numProfiles);
}
@ -392,8 +400,8 @@ async function profileModelInference(
async function profileInference(predict, isTflite = false, numProfiles = 1) {
if (typeof predict !== 'function') {
throw new Error(
'The first parameter should be a function, while ' +
`a(n) ${typeof predict} is found.`);
'The first parameter should be a function, while ' +
`a(n) ${typeof predict} is found.`);
}
let kernelInfo = {};
@ -431,7 +439,7 @@ async function profileInference(predict, isTflite = false, numProfiles = 1) {
kernelInfo.kernels[i].kernelTimeMs = totalTimeMs / kernelInfos.length;
}
kernelInfo.kernels =
kernelInfo.kernels.sort((a, b) => b.kernelTimeMs - a.kernelTimeMs);
kernelInfo.kernels.sort((a, b) => b.kernelTimeMs - a.kernelTimeMs);
kernelInfo.aggregatedKernels = aggregateKernelTime(kernelInfo.kernels);
return kernelInfo;
}
@ -451,13 +459,13 @@ function aggregateKernelTime(kernels) {
aggregatedKernelTime[kernel.name] = kernel.kernelTimeMs;
} else {
aggregatedKernelTime[kernel.name] =
oldAggregatedKernelTime + kernel.kernelTimeMs;
oldAggregatedKernelTime + kernel.kernelTimeMs;
}
});
return Object.entries(aggregatedKernelTime)
.map(([name, timeMs]) => ({name, timeMs}))
.sort((a, b) => b.timeMs - a.timeMs);
.map(([name, timeMs]) => ({ name, timeMs }))
.sort((a, b) => b.timeMs - a.timeMs);
}
/**
@ -512,7 +520,7 @@ async function setEnvFlags(flagConfig) {
return true;
} else if (typeof flagConfig !== 'object') {
throw new Error(
`An object is expected, while a(n) ${typeof flagConfig} is found.`);
`An object is expected, while a(n) ${typeof flagConfig} is found.`);
}
// Check the validation of flags and values.
@ -523,9 +531,8 @@ async function setEnvFlags(flagConfig) {
}
if (TUNABLE_FLAG_VALUE_RANGE_MAP[flag].indexOf(flagConfig[flag]) === -1) {
throw new Error(
`${flag} value is expected to be in the range [${
TUNABLE_FLAG_VALUE_RANGE_MAP[flag]}], while ${flagConfig[flag]}` +
' is found.');
`${flag} value is expected to be in the range [${TUNABLE_FLAG_VALUE_RANGE_MAP[flag]}], while ${flagConfig[flag]}` +
' is found.');
}
}

View File

@ -30,37 +30,41 @@ describe('test app.js cli', () => {
mockResults = {
'iPhone_XS_1': {
timeInfo: {
times: [216.00000000000045],
averageTime: 216.00000000000045,
times: [218.00000000000045, 216.00000000000045],
averageTime: 217.00000000000045,
averageTimeExclFirst: 216.00000000000045,
minTime: 216.00000000000045,
maxTime: 216.00000000000045
maxTime: 218.00000000000045
},
tabId: 'iPhone_XS_1'
},
'Samsung_Galaxy_S20_1': {
timeInfo: {
times: [428.89999999897555],
averageTime: 428.89999999897555,
times: [428.89999999897555, 430.89999999897555],
averageTime: 429.89999999897555,
averageTimeExclFirst: 430.89999999897555,
minTime: 428.89999999897555,
maxTime: 428.89999999897555
maxTime: 430.89999999897555
},
tabId: 'Samsung_Galaxy_S20_1'
},
'Windows_10_1': {
timeInfo: {
times: [395.8500000001095],
averageTime: 395.8500000001095,
times: [395.8500000001095, 397.8500000001095],
averageTime: 396.8500000001095,
averageTimeExclFirst: 397.8500000001095,
minTime: 395.8500000001095,
maxTime: 395.8500000001095
maxTime: 397.8500000001095
},
tabId: 'Windows_10_1'
},
'OS_X_Catalina_1': {
timeInfo: {
times: [176.19500000728294],
averageTime: 176.19500000728294,
times: [178.19500000728294, 176.19500000728294],
averageTime: 177.19500000728294,
averageTimeExclFirst: 176.19500000728294,
minTime: 176.19500000728294,
maxTime: 176.19500000728294
maxTime: 178.19500000728294
},
tabId: 'OS_X_Catalina_1'
}

View File

@ -49,7 +49,7 @@ async function getBenchmarkSummary(timeInfo, memoryInfo, modelName = 'model') {
const benchmarkSummary = `
benchmark the ${modelName} on ${envSummary}
1st inference time: ${printTime(timeInfo.times[0])}
Average inference time (${numRuns} runs): ${printTime(timeInfo.averageTime)}
Subsequent average inference time (${numRuns} runs): ${printTime(timeInfo.averageTimeExclFirst)}
Best inference time: ${printTime(timeInfo.minTime)}
Peak memory: ${printMemory(memoryInfo.peakBytes)}
`;
@ -85,7 +85,7 @@ async function benchmarkModel(benchmarkParameters) {
memoryInfo = await profileModelInference(model, input);
}
return {timeInfo, memoryInfo};
return { timeInfo, memoryInfo };
}
async function benchmarkCodeSnippet(benchmarkParameters) {
@ -97,7 +97,7 @@ async function benchmarkCodeSnippet(benchmarkParameters) {
if (predict == null) {
throw new Error(
'predict function is suppoed to be defined in codeSnippet.');
'predict function is suppoed to be defined in codeSnippet.');
}
// Warm up.
@ -107,7 +107,7 @@ async function benchmarkCodeSnippet(benchmarkParameters) {
timeInfo = await timeInference(predict, benchmarkParameters.numRuns);
memoryInfo = await profileInference(predict);
return {timeInfo, memoryInfo};
return { timeInfo, memoryInfo };
}
describe('BrowserStack benchmark', () => {
@ -135,11 +135,11 @@ describe('BrowserStack benchmark', () => {
// Get GPU hardware info.
resultObj.gpuInfo =
targetBackend === 'webgl' ? (await getRendererInfo()) : 'MISS';
targetBackend === 'webgl' ? (await getRendererInfo()) : 'MISS';
// Report results.
console.log(
`<tfjs_benchmark>${JSON.stringify(resultObj)}</tfjs_benchmark>`);
`<tfjs_benchmark>${JSON.stringify(resultObj)}</tfjs_benchmark>`);
} catch (error) {
console.log(`<tfjs_error>${error}</tfjs_error>`);
}

View File

@ -86,7 +86,7 @@ limitations under the License.
}
structuredBenchmarkResults[tableName][deviceName][benchmarkTargetName]
= benchmarkReocrd?.value?.timeInfo?.averageTime;
= benchmarkReocrd?.value?.timeInfo?.averageTimeExclFirst;
}
return structuredBenchmarkResults;
}

View File

@ -16,7 +16,7 @@
*/
const TUNABLE_BROWSER_FIELDS =
['os', 'os_version', 'browser', 'browser_version', 'device'];
['os', 'os_version', 'browser', 'browser_version', 'device'];
const WAITING_STATUS_COLOR = '#AAAAAA';
const COMPLETE_STATUS_COLOR = '#357edd';
const ERROR_STATUS_COLOR = '#e8564b';
@ -48,9 +48,9 @@ const state = {
numRuns: 10,
backend: 'webgl',
setupCodeSnippetEnv:
'const img = tf.randomUniform([1, 240, 240, 3], 0, 1000); const filter = tf.randomUniform([3, 3, 3, 3], 0, 1000);',
'const img = tf.randomUniform([1, 240, 240, 3], 0, 1000); const filter = tf.randomUniform([3, 3, 3, 3], 0, 1000);',
codeSnippet:
'predict = () => { return tf.conv2d(img, filter, 2, \'same\');};'
'predict = () => { return tf.conv2d(img, filter, 2, \'same\');};'
},
/**
@ -69,7 +69,7 @@ const state = {
addBrowser: () => {
// Add browser config to `state.browsers` array.
state.browsers.push({...state.browser});
state.browsers.push({ ...state.browser });
// Enable the benchmark button.
benchmarkButton.__li.style.pointerEvents = '';
@ -84,8 +84,8 @@ const state = {
removeBrowser: index => {
if (index >= state.browsers.length) {
throw new Error(
`Invalid index ${index}, while the state.browsers only ` +
`has ${state.browsers.length} items.`);
`Invalid index ${index}, while the state.browsers only ` +
`has ${state.browsers.length} items.`);
}
// Remove the browser from the `state.browsers` array.
@ -127,7 +127,7 @@ const state = {
browserTabIdConfigMap[tabId] = browser;
});
const benchmark = {...state.benchmark};
const benchmark = { ...state.benchmark };
if (state.benchmark.model !== 'custom') {
delete benchmark['modelUrl'];
}
@ -180,7 +180,7 @@ function constructBrowserTree(browsersArray) {
// Route through non-leaf nodes.
for (let fieldIndex = 0; fieldIndex <= TUNABLE_BROWSER_FIELDS.length - 2;
fieldIndex++) {
fieldIndex++) {
const fieldName = TUNABLE_BROWSER_FIELDS[fieldIndex];
if (currentNode[browser[fieldName]] == null) {
currentNode[browser[fieldName]] = {};
@ -190,14 +190,14 @@ function constructBrowserTree(browsersArray) {
// Set the full configuration as the leaf node.
const leafFieldName =
TUNABLE_BROWSER_FIELDS[TUNABLE_BROWSER_FIELDS.length - 1];
TUNABLE_BROWSER_FIELDS[TUNABLE_BROWSER_FIELDS.length - 1];
const leafFieldValue = browser[leafFieldName];
if (currentNode[leafFieldValue] == null) {
currentNode[leafFieldValue] = browser;
} else {
console.warn(
`The browser ${browser} shares the same ` +
'configuration with another browser.');
`The browser ${browser} shares the same ` +
'configuration with another browser.');
}
});
return browserTreeRoot;
@ -213,7 +213,7 @@ function constructBrowserTree(browsersArray) {
* @param {object} currentNode
*/
function updateFollowingFields(
currentFieldIndex, currentFieldValue, currentNode) {
currentFieldIndex, currentFieldValue, currentNode) {
const nextFieldIndex = currentFieldIndex + 1;
if (nextFieldIndex === TUNABLE_BROWSER_FIELDS.length) {
return;
@ -231,7 +231,7 @@ function updateFollowingFields(
// Update the options for the next field.
const nextFieldController = browserSettingControllers[nextFieldIndex].options(
nextFieldAvailableValues);
nextFieldAvailableValues);
// When updating options for a dat.gui controller, a new controller instacne
// will be created, so we need to bind the event again and record the new
@ -266,7 +266,7 @@ function updateFollowingFields(
function showBrowserField(fieldIndex, currentNode) {
const fieldName = TUNABLE_BROWSER_FIELDS[fieldIndex];
const fieldController =
browserFolder.add(state.browser, fieldName, Object.keys(currentNode));
browserFolder.add(state.browser, fieldName, Object.keys(currentNode));
fieldController.onFinishChange(() => {
const newValue = state.browser[fieldName];
@ -312,7 +312,7 @@ function drawTunableBrowserSummaryTable(summaryTabId, browsers) {
// Whenever a browser configuration is removed, this table will be re-drawn,
// so the index (the argument for state.removeBrowser) will be re-assigned.
const removeBrowserButtonElement =
`<button onclick="state.removeBrowser(${index})">Remove</button>`;
`<button onclick="state.removeBrowser(${index})">Remove</button>`;
row.push(removeBrowserButtonElement);
values.push(row);
@ -321,9 +321,9 @@ function drawTunableBrowserSummaryTable(summaryTabId, browsers) {
const surface = {
name: 'Browsers to benchmark',
tab: summaryTabId,
styles: {width: '100%'}
styles: { width: '100%' }
};
tfvis.render.table(surface, {headers, values});
tfvis.render.table(surface, { headers, values });
}
/**
@ -351,9 +351,9 @@ function drawUntunableBrowserSummaryTable(summaryTabId, browserTabIdConfigMap) {
const surface = {
name: 'Browsers to benchmark',
tab: summaryTabId,
styles: {width: '100%'}
styles: { width: '100%' }
};
tfvis.render.table(surface, {headers, values});
tfvis.render.table(surface, { headers, values });
}
function initVisor() {
@ -364,7 +364,7 @@ function initVisor() {
// Bind an event to visor's 'Maximize/Minimize' button.
const visorFullScreenButton =
tfvis.visor().el.getElementsByTagName('button')[0];
tfvis.visor().el.getElementsByTagName('button')[0];
const guiCloseButton = document.getElementsByClassName('close-button')[0];
const originalGuiWidth = gui.domElement.style.width;
@ -499,7 +499,7 @@ function setTabStatus(tabId, status) {
*/
function addLoaderElement(tabId) {
const surface = tfvis.visor().surface(
{name: 'Benchmark Summary', tab: tabId, styles: {width: '100%'}});
{ name: 'Benchmark Summary', tab: tabId, styles: { width: '100%' } });
const loaderElement = document.createElement('div');
loaderElement.className = 'loader';
loaderElement.id = `${tabId}-loader`;
@ -523,7 +523,7 @@ function drawBenchmarkResultSummaryTable(benchmarkResult) {
const headers = ['Field', 'Value'];
const values = [];
const {timeInfo, memoryInfo, tabId} = benchmarkResult;
const { timeInfo, memoryInfo, tabId } = benchmarkResult;
const timeArray = benchmarkResult.timeInfo.times;
const numRuns = timeArray.length;
@ -533,8 +533,8 @@ function drawBenchmarkResultSummaryTable(benchmarkResult) {
values.push(['2nd inference time', printTime(timeArray[1])]);
}
values.push([
`Average inference time (${numRuns} runs)`,
printTime(timeInfo.averageTime)
`Average inference time (${numRuns} runs) except the first`,
printTime(timeInfo.averageTimeExclFirst)
]);
values.push(['Best time', printTime(timeInfo.minTime)]);
values.push(['Worst time', printTime(timeInfo.maxTime)]);
@ -552,9 +552,9 @@ function drawBenchmarkResultSummaryTable(benchmarkResult) {
const surface = {
name: 'Benchmark Summary',
tab: tabId,
styles: {width: '100%'}
styles: { width: '100%' }
};
tfvis.render.table(surface, {headers, values});
tfvis.render.table(surface, { headers, values });
}
async function drawInferenceTimeLineChart(benchmarkResult) {
@ -571,34 +571,34 @@ async function drawInferenceTimeLineChart(benchmarkResult) {
if (index === 0) {
return;
}
values.push({x: index + 1, y: time});
values.push({ x: index + 1, y: time });
});
const surface = {
name: `2nd - ${inferenceTimeArray.length}st Inference Time`,
tab: tabId,
styles: {width: '100%'}
styles: { width: '100%' }
};
const data = {values};
const data = { values };
const drawOptions =
{zoomToFit: true, xLabel: '', yLabel: 'time (ms)', xType: 'ordinal'};
{ zoomToFit: true, xLabel: '', yLabel: 'time (ms)', xType: 'ordinal' };
await tfvis.render.linechart(surface, data, drawOptions);
// Whenever resize the parent div element, re-draw the chart canvas.
try {
const originalCanvasHeight = tfvis.visor()
.surface(surface)
.drawArea.getElementsByTagName('canvas')[0]
.height;
.surface(surface)
.drawArea.getElementsByTagName('canvas')[0]
.height;
const labelElement = tfvis.visor().surface(surface).label;
new ResizeObserver(() => {
// Keep the height of chart/canvas unchanged.
tfvis.visor()
.surface(surface)
.drawArea.getElementsByTagName('canvas')[0]
.height = originalCanvasHeight;
.surface(surface)
.drawArea.getElementsByTagName('canvas')[0]
.height = originalCanvasHeight;
tfvis.render.linechart(surface, data, drawOptions);
}).observe(labelElement);
} catch (e) {
@ -618,9 +618,9 @@ function drawBrowserSettingTable(tabId, browserConf) {
const surface = {
name: 'Browser Setting',
tab: tabId,
styles: {width: '100%'}
styles: { width: '100%' }
};
tfvis.render.table(surface, {headers, values});
tfvis.render.table(surface, { headers, values });
}
function drawBenchmarkParameterTable(tabId) {
@ -636,9 +636,9 @@ function drawBenchmarkParameterTable(tabId) {
const surface = {
name: 'Benchmark Parameter',
tab: tabId,
styles: {width: '100%'}
styles: { width: '100%' }
};
tfvis.render.table(surface, {headers, values});
tfvis.render.table(surface, { headers, values });
}
function showModelSelection() {
@ -646,21 +646,21 @@ function showModelSelection() {
let modelUrlController = null;
modelFolder
.add(
state.benchmark, 'model', [...Object.keys(benchmarks), 'codeSnippet'])
.name('model name')
.onChange(async model => {
if (model === 'custom') {
if (modelUrlController === null) {
modelUrlController = modelFolder.add(state.benchmark, 'modelUrl');
modelUrlController.domElement.querySelector('input').placeholder =
'https://your-domain.com/model-path/model.json';
}
} else if (modelUrlController != null) {
modelFolder.remove(modelUrlController);
modelUrlController = null;
.add(
state.benchmark, 'model', [...Object.keys(benchmarks), 'codeSnippet'])
.name('model name')
.onChange(async model => {
if (model === 'custom') {
if (modelUrlController === null) {
modelUrlController = modelFolder.add(state.benchmark, 'modelUrl');
modelUrlController.domElement.querySelector('input').placeholder =
'https://your-domain.com/model-path/model.json';
}
});
} else if (modelUrlController != null) {
modelFolder.remove(modelUrlController);
modelUrlController = null;
}
});
modelFolder.open();
return modelFolder;
}
@ -688,7 +688,7 @@ function printMemory(bytes) {
}
function onPageLoad() {
gui = new dat.gui.GUI({width: 400});
gui = new dat.gui.GUI({ width: 400 });
gui.domElement.id = 'gui';
socket = io();
@ -714,7 +714,7 @@ function onPageLoad() {
// Enable users to benchmark.
addingBrowserButton =
browserFolder.add(state, 'addBrowser').name('Add browser');
browserFolder.add(state, 'addBrowser').name('Add browser');
benchmarkButton = gui.add(state, 'run').name('Run benchmark');
// Disable the 'Run benchmark' button until a browser is added.