diff --git a/e2e/benchmarks/browserstack-benchmark/README.md b/e2e/benchmarks/browserstack-benchmark/README.md
index aa6e72c3d30..e7d1de7283a 100644
--- a/e2e/benchmarks/browserstack-benchmark/README.md
+++ b/e2e/benchmarks/browserstack-benchmark/README.md
@@ -33,6 +33,10 @@ The Multi-device benchmark tool can benchmark the performance (time, memory) of
Then you can see `> Running socket on port: 8001` on your Command-line interface.
3. Open http://localhost:8001/ and start to benchmark.
+ 3.1 If you want to benchmark code snippet. Please update [`benchmarkCodeSnippet` ](/~https://github.com/tensorflow/tfjs/pull/6704/files#diff-a7c2ef12f0f2bc1a6cabb45bc9850aa68d10644cd2786e6505456e5537dccadbR92)with your code snippet before running `node app.js` and select `codeSnippet` in `model name`:
+
diff --git a/e2e/benchmarks/browserstack-benchmark/benchmark_models.js b/e2e/benchmarks/browserstack-benchmark/benchmark_models.js
index 9130fbe4c00..dc1c4b9902a 100644
--- a/e2e/benchmarks/browserstack-benchmark/benchmark_models.js
+++ b/e2e/benchmarks/browserstack-benchmark/benchmark_models.js
@@ -58,7 +58,66 @@ async function getBenchmarkSummary(timeInfo, memoryInfo, modelName = 'model') {
const KARMA_SERVER = './base';
-describe('benchmark models', () => {
+async function benchmarkModel(benchmarkParameters) {
+ // Load the model.
+ const benchmark = benchmarks[benchmarkParameters.model];
+ const numRuns = benchmarkParameters.numRuns;
+ let model;
+ if (benchmarkParameters.model === 'custom') {
+ if (benchmarkParameters.modelUrl == null) {
+ throw new Error('Please provide model url for the custom model.');
+ }
+ model = await loadModelByUrl(benchmarkParameters.modelUrl);
+ } else {
+ model = await benchmark.load();
+ }
+
+ // Benchmark.
+ let timeInfo;
+ let memoryInfo;
+ if (benchmark.predictFunc != null) {
+ const predict = benchmark.predictFunc();
+ timeInfo = await timeInference(() => predict(model), numRuns);
+ memoryInfo = await profileInference(() => predict(model));
+ } else {
+ const input = generateInput(model);
+ timeInfo = await timeModelInference(model, input, numRuns);
+ memoryInfo = await profileModelInference(model, input);
+ }
+
+ return `
${
+ JSON.stringify({timeInfo, memoryInfo})}`;
+}
+
+async function benchmarkCodeSnippet(benchmarkParameters) {
+ /* Please set up environments to run your code snippet here. */
+ /* Start */
+ const img = tf.randomUniform([1, 240, 240, 3], 0, 1000);
+ const filter = tf.randomUniform([3, 3, 3, 3], 0, 1000);
+ /* End */
+
+ /* Please put your code snippet to benchmark into the predict function. */
+ /* Start */
+ const predict = () => {
+ return tf.conv2d(img, filter, 1, 'same');
+ };
+ /* End */
+
+ // Warm up.
+ await timeInference(predict, 1);
+
+ // Benchmark code snippet.
+ timeInfo = await timeInference(predict, benchmarkParameters.numRuns);
+ memoryInfo = await profileInference(predict);
+
+ return `
${JSON.stringify({
+ timeInfo,
+ memoryInfo,
+ codeSnippet: predict.toString()
+ })}`;
+}
+
+describe('BrowserStack benchmark', () => {
let benchmarkParameters;
beforeAll(async () => {
jasmine.DEFAULT_TIMEOUT_INTERVAL = 1000000;
@@ -66,39 +125,20 @@ describe('benchmark models', () => {
benchmarkParameters = await response.json();
});
- it(`benchmark model`, async () => {
+ it(`benchmark`, async () => {
try {
+ // Setup benchmark environments.
await tf.setBackend(benchmarkParameters.backend);
- // Load the model.
- const benchmark = benchmarks[benchmarkParameters.model];
- const numRuns = benchmarkParameters.numRuns;
- let model;
- if (benchmarkParameters.model === 'custom') {
- if (benchmarkParameters.modelUrl == null) {
- throw new Error('Please provide model url for the custom model.');
- }
- model = await loadModelByUrl(benchmarkParameters.modelUrl);
- } else {
- model = await benchmark.load();
- }
-
- // Benchmark.
- let timeInfo;
- let memoryInfo;
- if (benchmark.predictFunc != null) {
- const predict = benchmark.predictFunc();
- timeInfo = await timeInference(() => predict(model), numRuns);
- memoryInfo = await profileInference(() => predict(model));
+ // Run benchmark and stringify results.
+ let resultStr;
+ if (benchmarkParameters.model === 'codeSnippet') {
+ resultStr = await benchmarkCodeSnippet(benchmarkParameters);
} else {
- const input = generateInput(model);
- timeInfo = await timeModelInference(model, input, numRuns);
- memoryInfo = await profileModelInference(model, input);
+ resultStr = await benchmarkModel(benchmarkParameters);
}
// Report results.
- const resultStr = `
${
- JSON.stringify({timeInfo, memoryInfo})}`;
console.log(resultStr);
} catch (error) {
console.log(`
${error}`);
diff --git a/e2e/benchmarks/browserstack-benchmark/index.js b/e2e/benchmarks/browserstack-benchmark/index.js
index 085d147880b..6642e9d9d60 100644
--- a/e2e/benchmarks/browserstack-benchmark/index.js
+++ b/e2e/benchmarks/browserstack-benchmark/index.js
@@ -536,6 +536,10 @@ function drawBenchmarkResultSummaryTable(benchmarkResult) {
values.push(['Number of kernels', memoryInfo.kernels.length]);
+ if ('codeSnippet' in benchmarkResult) {
+ values.push(['Code snippet', benchmarkResult.codeSnippet]);
+ }
+
const surface = {
name: 'Benchmark Summary',
tab: tabId,
@@ -632,7 +636,9 @@ function showModelSelection() {
const modelFolder = gui.addFolder('Model');
let modelUrlController = null;
- modelFolder.add(state.benchmark, 'model', Object.keys(benchmarks))
+ modelFolder
+ .add(
+ state.benchmark, 'model', [...Object.keys(benchmarks), 'codeSnippet'])
.name('model name')
.onChange(async model => {
if (model === 'custom') {