mirror of https://github.com/tensorflow/tfjs.git
Add wasm support to inference binary (#3529)
FEATURE Add wasm support to inference binary
This commit is contained in:
parent
fedd3f524b
commit
bde1db0939
|
|
@ -5,6 +5,7 @@
|
|||
"private": false,
|
||||
"main": "dist/index.js",
|
||||
"types": "dist/index.d.ts",
|
||||
"bin": "dist/index.js",
|
||||
"engines": {
|
||||
"yarn": ">= 1.3.2"
|
||||
},
|
||||
|
|
@ -14,9 +15,10 @@
|
|||
},
|
||||
"license": "Apache-2.0",
|
||||
"devDependencies": {
|
||||
"@tensorflow/tfjs-core": "2.0.0",
|
||||
"@tensorflow/tfjs-converter": "2.0.0",
|
||||
"@tensorflow/tfjs-backend-cpu": "2.0.0",
|
||||
"@tensorflow/tfjs-core": "2.0.1",
|
||||
"@tensorflow/tfjs-converter": "2.0.1",
|
||||
"@tensorflow/tfjs-backend-cpu": "2.0.1",
|
||||
"@tensorflow/tfjs-backend-wasm": "2.0.1",
|
||||
"@types/jasmine": "~3.0.0",
|
||||
"@types/rimraf": "~3.0.0",
|
||||
"clang-format": "~1.2.4",
|
||||
|
|
@ -31,7 +33,7 @@
|
|||
"build": "tsc",
|
||||
"test": "ts-node --skip-ignore -P tsconfig.test.json src/test_node.ts",
|
||||
"test-ci": "yarn test",
|
||||
"build-binary": "yarn build && pkg dist/index.js --targets=node10-macos-x64,node10-linux-x64,node10-win-x64 --out-path=binaries",
|
||||
"build-binary": "yarn build && pkg . --targets=node10-macos-x64,node10-linux-x64,node10-win-x64 --out-path=binaries",
|
||||
"test-python": "./scripts/run_python.sh"
|
||||
},
|
||||
"dependencies": {
|
||||
|
|
@ -45,5 +47,8 @@
|
|||
"@tensorflow/tfjs-core": "2.0.0",
|
||||
"@tensorflow/tfjs-converter": "2.0.0",
|
||||
"@tensorflow/tfjs-backend-cpu": "2.0.0"
|
||||
},
|
||||
"pkg": {
|
||||
"assets": "node_modules/@tensorflow/tfjs-backend-wasm/dist/tfjs-backend-wasm.wasm"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from __future__ import print_function
|
|||
import subprocess
|
||||
|
||||
|
||||
def predict(binary_path, model_path, inputs_dir, outputs_dir):
|
||||
def predict(binary_path, model_path, inputs_dir, outputs_dir, backend=None):
|
||||
"""Use tfjs binary to make inference and store output in file.
|
||||
|
||||
Args:
|
||||
|
|
@ -33,9 +33,8 @@ def predict(binary_path, model_path, inputs_dir, outputs_dir):
|
|||
files.
|
||||
outputs_dir: Directory to write the outputs files, including data, shape
|
||||
and dtype files.
|
||||
|
||||
Returns:
|
||||
stdout from the subprocess.
|
||||
backend: Optional. Choose which TensorFlow.js backend to use. Supported
|
||||
backends include cpu and wasm. Default: cpu
|
||||
"""
|
||||
model_path_option = '--model_path=' + model_path
|
||||
inputs_dir_option = '--inputs_dir=' + inputs_dir
|
||||
|
|
@ -46,13 +45,17 @@ def predict(binary_path, model_path, inputs_dir, outputs_dir):
|
|||
outputs_dir_option
|
||||
]
|
||||
|
||||
if backend:
|
||||
backend_option = '--backend=' + backend
|
||||
tfjs_inference_command.append(backend_option)
|
||||
|
||||
popen = subprocess.Popen(
|
||||
tfjs_inference_command,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE)
|
||||
stdout, stderr = popen.communicate()
|
||||
|
||||
if popen.returncode != 0:
|
||||
raise ValueError('Inference failed with status %d\nstderr:\n%s' %
|
||||
(popen.returncode, stderr))
|
||||
return stdout
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for tfjs-inference binary."""
|
||||
|
||||
# To test the binary, you need to manually run `yarn build-binary` first.
|
||||
# This test only supports running in Linux.
|
||||
|
|
@ -32,35 +31,37 @@ import inference
|
|||
class InferenceTest(tf.test.TestCase):
|
||||
|
||||
def testInference(self):
|
||||
binary_path = os.path.join('../binaries', 'index-linux')
|
||||
model_path = os.path.join('../test_data', 'model.json')
|
||||
test_data_dir = os.path.join('../test_data')
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
backends = ['cpu', 'wasm']
|
||||
for backend in backends:
|
||||
binary_path = os.path.join('../binaries', 'tfjs-inference-linux')
|
||||
model_path = os.path.join('../test_data', 'model.json')
|
||||
test_data_dir = os.path.join('../test_data')
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
|
||||
inference.predict(binary_path, model_path, test_data_dir, tmp_dir)
|
||||
inference.predict(binary_path, model_path, test_data_dir, tmp_dir, backend)
|
||||
|
||||
with open(os.path.join(tmp_dir, 'data.json'), 'rt') as f:
|
||||
ys_values = json.load(f)
|
||||
with open(os.path.join(tmp_dir, 'data.json'), 'rt') as f:
|
||||
ys_values = json.load(f)
|
||||
|
||||
# The output is a list of tensor data in the form of dict.
|
||||
# Example output:
|
||||
# [{"0":0.7567615509033203,"1":-0.18349379301071167,"2":0.7567615509033203,"3":-0.18349379301071167}]
|
||||
ys_values = [list(y.values()) for y in ys_values]
|
||||
# The output is a list of tensor data in the form of dict.
|
||||
# Example output:
|
||||
# [{"0":0.7567615509033203,"1":-0.18349379301071167,"2":0.7567615509033203,"3":-0.18349379301071167}]
|
||||
ys_values = [list(y.values()) for y in ys_values]
|
||||
|
||||
with open(os.path.join(tmp_dir, 'shape.json'), 'rt') as f:
|
||||
ys_shapes = json.load(f)
|
||||
with open(os.path.join(tmp_dir, 'shape.json'), 'rt') as f:
|
||||
ys_shapes = json.load(f)
|
||||
|
||||
with open(os.path.join(tmp_dir, 'dtype.json'), 'rt') as f:
|
||||
ys_dtypes = json.load(f)
|
||||
with open(os.path.join(tmp_dir, 'dtype.json'), 'rt') as f:
|
||||
ys_dtypes = json.load(f)
|
||||
|
||||
self.assertAllClose(ys_values[0], [
|
||||
0.7567615509033203, -0.18349379301071167, 0.7567615509033203,
|
||||
-0.18349379301071167
|
||||
])
|
||||
self.assertAllEqual(ys_shapes[0], [2, 2])
|
||||
self.assertEqual(ys_dtypes[0], 'float32')
|
||||
# Cleanup tmp dir.
|
||||
shutil.rmtree(tmp_dir)
|
||||
self.assertAllClose(ys_values[0], [
|
||||
0.7567615509033203, -0.18349379301071167, 0.7567615509033203,
|
||||
-0.18349379301071167
|
||||
])
|
||||
self.assertAllEqual(ys_shapes[0], [2, 2])
|
||||
self.assertEqual(ys_dtypes[0], 'float32')
|
||||
# Cleanup tmp dir.
|
||||
shutil.rmtree(tmp_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -1,2 +1,2 @@
|
|||
numpy>=1.16.4
|
||||
tensorflow-cpu>=2.1.0<3
|
||||
numpy>=1.16.4, <1.19.0
|
||||
tensorflow-cpu>=2.1.0, <3
|
||||
|
|
|
|||
|
|
@ -18,19 +18,19 @@
|
|||
/**
|
||||
* This file is used to load a saved model and perform inference.
|
||||
* Run this script in console:
|
||||
* ts-node inference.ts --model_path=MODEL_PATH -inputs_dir=INPUTS_DIR
|
||||
* -outputs_dir=OUTPUTS_DIR
|
||||
* ts-node inference.ts --model_path=MODEL_PATH --inputs_dir=INPUTS_DIR
|
||||
* --outputs_dir=OUTPUTS_DIR
|
||||
*
|
||||
* For help, run:
|
||||
* ts-node inference.ts -h
|
||||
*/
|
||||
|
||||
import '@tensorflow/tfjs-backend-cpu';
|
||||
|
||||
import '@tensorflow/tfjs-backend-wasm'
|
||||
import '@tensorflow/tfjs-backend-cpu'
|
||||
import * as tfconv from '@tensorflow/tfjs-converter';
|
||||
import * as tfc from '@tensorflow/tfjs-core';
|
||||
import * as fs from 'fs';
|
||||
import {join} from 'path';
|
||||
import * as path from 'path';
|
||||
import * as yargs from 'yargs';
|
||||
|
||||
import {FileHandler} from './file_handler';
|
||||
|
|
@ -44,6 +44,7 @@ interface Options {
|
|||
inputs_data_file: string;
|
||||
inputs_shape_file: string;
|
||||
inputs_dtype_file: string;
|
||||
backend: string;
|
||||
}
|
||||
// tslint:enable:enforce-name-casing
|
||||
|
||||
|
|
@ -82,23 +83,38 @@ async function main() {
|
|||
description: 'Filename of the input dtype file.',
|
||||
type: 'string',
|
||||
default: 'dtype.json'
|
||||
},
|
||||
backend: {
|
||||
description: 'Choose which tfjs backend to use. Supported backends: ' +
|
||||
'cpu|wasm',
|
||||
type: 'string',
|
||||
default: 'cpu'
|
||||
}
|
||||
});
|
||||
|
||||
const options = argParser.argv as {} as Options;
|
||||
|
||||
if (options.backend === 'wasm') {
|
||||
await tfc.setBackend('wasm');
|
||||
} else if (options.backend === 'cpu') {
|
||||
await tfc.setBackend('cpu');
|
||||
} else {
|
||||
throw new Error(
|
||||
'Only cpu and wasm backend is supported, but got ' + options.backend);
|
||||
}
|
||||
|
||||
const model =
|
||||
await tfconv.loadGraphModel(new FileHandler(options.model_path));
|
||||
|
||||
// Read in input files.
|
||||
const inputsDataString = fs.readFileSync(
|
||||
join(options.inputs_dir, options.inputs_data_file), 'utf8');
|
||||
path.join(options.inputs_dir, options.inputs_data_file), 'utf8');
|
||||
const inputsData = JSON.parse(inputsDataString);
|
||||
const inputsShapeString = fs.readFileSync(
|
||||
join(options.inputs_dir, options.inputs_shape_file), 'utf8');
|
||||
path.join(options.inputs_dir, options.inputs_shape_file), 'utf8');
|
||||
const inputsShape = JSON.parse(inputsShapeString);
|
||||
const inputsDtypeString = fs.readFileSync(
|
||||
join(options.inputs_dir, options.inputs_dtype_file), 'utf8');
|
||||
path.join(options.inputs_dir, options.inputs_dtype_file), 'utf8');
|
||||
const inputsDtype = JSON.parse(inputsDtypeString);
|
||||
|
||||
const xs = createInputTensors(inputsData, inputsShape, inputsDtype);
|
||||
|
|
@ -122,11 +138,11 @@ async function main() {
|
|||
}
|
||||
|
||||
fs.writeFileSync(
|
||||
join(options.outputs_dir, 'data.json'), JSON.stringify(ysData));
|
||||
path.join(options.outputs_dir, 'data.json'), JSON.stringify(ysData));
|
||||
fs.writeFileSync(
|
||||
join(options.outputs_dir, 'shape.json'), JSON.stringify(ysShape));
|
||||
path.join(options.outputs_dir, 'shape.json'), JSON.stringify(ysShape));
|
||||
fs.writeFileSync(
|
||||
join(options.outputs_dir, 'dtype.json'), JSON.stringify(ysDtype));
|
||||
path.join(options.outputs_dir, 'dtype.json'), JSON.stringify(ysDtype));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -35,23 +35,30 @@
|
|||
"@nodelib/fs.scandir" "2.1.3"
|
||||
fastq "^1.6.0"
|
||||
|
||||
"@tensorflow/tfjs-backend-cpu@2.0.0":
|
||||
version "2.0.0"
|
||||
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-2.0.0.tgz#26d3ed8a6d814e751adc22e6036e40ed8940b5ac"
|
||||
integrity sha512-eYj8CBjL8v2gHaYdS7JN1swi9kQYOHenMYBkf4khhW83ViZwpPmISyYzun8fy2gBlv28Y7juMmXHSncSDWfI1Q==
|
||||
"@tensorflow/tfjs-backend-cpu@2.0.1":
|
||||
version "2.0.1"
|
||||
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-backend-cpu/-/tfjs-backend-cpu-2.0.1.tgz#959a5bbc7f956ff37c4fbced2db75cd299ce76c4"
|
||||
integrity sha512-ZTDdq+O6AgeOrkek42gmPWz2T0r8Y6dBGjEFWkCMLI/5v3KnkodUkHRQOUoIN5hiaPXnBp6425DpwT9CfxxJOg==
|
||||
dependencies:
|
||||
"@types/seedrandom" "2.4.27"
|
||||
seedrandom "2.4.3"
|
||||
|
||||
"@tensorflow/tfjs-converter@2.0.0":
|
||||
version "2.0.0"
|
||||
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-2.0.0.tgz#6242c618202f04d88fa308d68a2dfcdd81ec020c"
|
||||
integrity sha512-IFjjx2qe7M2UwwYJvCm2+OgpC+kooCWEjC8mOOoFV/o+g9/Q0RMohRDvffiqUuYx5Usi/vbjhlUBccsy/MhE3g==
|
||||
"@tensorflow/tfjs-backend-wasm@2.0.1":
|
||||
version "2.0.1"
|
||||
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-backend-wasm/-/tfjs-backend-wasm-2.0.1.tgz#6551b3bd1de7079a481750e9dafeb59555261086"
|
||||
integrity sha512-OYMPn3wPwuV4vYlgfqMaE9CICTvUEsNP/IkwJVaArzoN7txHZGikt+SQ3+lojh4MOUTN+9wxT1xkr05k7bivGg==
|
||||
dependencies:
|
||||
"@types/emscripten" "~0.0.34"
|
||||
|
||||
"@tensorflow/tfjs-core@2.0.0":
|
||||
version "2.0.0"
|
||||
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-2.0.0.tgz#c18e963d0332255dc37d6ff58fa0076402ae00d0"
|
||||
integrity sha512-GB02Lyjp7NLKbjCOW6S3Vx2CkkUwtJFt8fY7Zaoyy/ANB4Iw8eiHJV0308CClrFfjA+UKU3TrO+bOQfeCJaEUw==
|
||||
"@tensorflow/tfjs-converter@2.0.1":
|
||||
version "2.0.1"
|
||||
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-converter/-/tfjs-converter-2.0.1.tgz#0696455e6b6ed14e6f5f9cd937f8f2015a16569f"
|
||||
integrity sha512-AI4oUZ3Tv8l7fXeuLNJ3/vIp8shMo/VmtBlhIJye8i5FwMqSlZf984q3Jk6ES4lOxUdkmDehILf7uVNQX2Yb/w==
|
||||
|
||||
"@tensorflow/tfjs-core@2.0.1":
|
||||
version "2.0.1"
|
||||
resolved "https://registry.yarnpkg.com/@tensorflow/tfjs-core/-/tfjs-core-2.0.1.tgz#c64928423028e9e1821f7205367b1ff1f57ae3af"
|
||||
integrity sha512-LCmEXeGFgR3ai+ywGDYBqt4aCOSzEBlVKEflF1gAT22YcQuYh+/X4f58jY3yXfC+cn/FfIJFc2uj8b+D0MNWLQ==
|
||||
dependencies:
|
||||
"@types/offscreencanvas" "~2019.3.0"
|
||||
"@types/seedrandom" "2.4.27"
|
||||
|
|
@ -65,6 +72,11 @@
|
|||
resolved "https://registry.yarnpkg.com/@types/color-name/-/color-name-1.1.1.tgz#1c1261bbeaa10a8055bbc5d8ab84b7b2afc846a0"
|
||||
integrity sha512-rr+OQyAjxze7GgWrSaJwydHStIhHq2lvY3BOC2Mj7KnzI7XK0Uw1TOOdI9lDoajEbSWLiYgoo4f1R51erQfhPQ==
|
||||
|
||||
"@types/emscripten@~0.0.34":
|
||||
version "0.0.34"
|
||||
resolved "https://registry.yarnpkg.com/@types/emscripten/-/emscripten-0.0.34.tgz#12b4a344274fb102ff2f6c877b37587bc3e46008"
|
||||
integrity sha512-QSb9ojDincskc+uKMI0KXp8e1NALFINCrMlp8VGKGcTSxeEyRTTKyjWw75NYrCZHUsVEEEpr1tYHpbtaC++/sQ==
|
||||
|
||||
"@types/glob@*":
|
||||
version "7.1.2"
|
||||
resolved "https://registry.yarnpkg.com/@types/glob/-/glob-7.1.2.tgz#06ca26521353a545d94a0adc74f38a59d232c987"
|
||||
|
|
|
|||
Loading…
Reference in New Issue