From 53bfcfe9db1550948263c6c5d52f005ffa63b536 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 6 Aug 2018 13:26:37 -0700 Subject: [PATCH] add support for binary model loading through browser http handler (#1207) * add support for binary model loading through browser http handler * address review comments * fixed the tests * refactor the loadWeights method --- src/io/browser_http.ts | 103 +++-- src/io/browser_http_test.ts | 756 +++++++++++++++++++++++------------- 2 files changed, 570 insertions(+), 289 deletions(-) diff --git a/src/io/browser_http.ts b/src/io/browser_http.ts index d8b9db8034..a71941bea6 100644 --- a/src/io/browser_http.ts +++ b/src/io/browser_http.ts @@ -42,16 +42,17 @@ export class BrowserHTTPRequest implements IOHandler { 'browserHTTPRequest is not supported outside the web browser without a fetch polyfill.'); } - if (Array.isArray(path)) { - throw new Error( - `Handling of multiple ${path.length} HTTP URLs (e.g., ` + - `for loading FrozenModel) is not implemented yet.`); - } - assert( path != null && path.length > 0, 'URL path for browserHTTPRequest must not be null, undefined or ' + 'empty.'); + + if (Array.isArray(path)) { + assert( + path.length === 2, + 'URL paths for browserHTTPRequest must have a length of 2, ' + + `(actual length is ${path.length}).`); + } this.path = path; if (requestInit != null && requestInit.body != null) { @@ -118,6 +119,42 @@ export class BrowserHTTPRequest implements IOHandler { * @returns The loaded model artifacts (if loading succeeds). */ async load(): Promise { + return Array.isArray(this.path) ? this.loadBinaryModel() : + this.loadJSONModel(); + } + + /** + * Loads the model topology file and build the in memory graph of the model. + */ + private async loadBinaryTopology(): Promise { + try { + const response = await fetch(this.path[0], this.requestInit); + return await response.arrayBuffer(); + } catch (error) { + throw new Error(`${this.path[0]} not found. ${error}`); + } + } + + protected async loadBinaryModel(): Promise { + const graphPromise = this.loadBinaryTopology(); + const manifestPromise = await fetch(this.path[1], this.requestInit); + + const [modelTopology, weightsManifestResponse] = + await Promise.all([graphPromise, manifestPromise]); + + const weightsManifest = + await weightsManifestResponse.json() as WeightsManifestConfig; + + let weightSpecs: WeightsManifestEntry[]; + let weightData: ArrayBuffer; + if (weightsManifest != null) { + [weightSpecs, weightData] = await this.loadWeights(weightsManifest); + } + + return {modelTopology, weightSpecs, weightData}; + } + + protected async loadJSONModel(): Promise { const modelConfigRequest = await fetch(this.path as string, this.requestInit); const modelConfig = await modelConfigRequest.json(); @@ -136,28 +173,36 @@ export class BrowserHTTPRequest implements IOHandler { if (weightsManifest != null) { const weightsManifest = modelConfig['weightsManifest'] as WeightsManifestConfig; - weightSpecs = []; - for (const entry of weightsManifest) { - weightSpecs.push(...entry.weights); - } + [weightSpecs, weightData] = await this.loadWeights(weightsManifest); + } - let pathPrefix = - (this.path as string).substring(0, this.path.lastIndexOf('/')); - if (!pathPrefix.endsWith('/')) { - pathPrefix = pathPrefix + '/'; - } + return {modelTopology, weightSpecs, weightData}; + } - const fetchURLs: string[] = []; - weightsManifest.forEach(weightsGroup => { - weightsGroup.paths.forEach(path => { - fetchURLs.push(pathPrefix + path); - }); - }); - weightData = concatenateArrayBuffers( - await loadWeightsAsArrayBuffer(fetchURLs, this.requestInit)); + private async loadWeights(weightsManifest: WeightsManifestConfig): + Promise<[WeightsManifestEntry[], ArrayBuffer]> { + const weightPath = Array.isArray(this.path) ? this.path[1] : this.path; + + const weightSpecs = []; + for (const entry of weightsManifest) { + weightSpecs.push(...entry.weights); } - return {modelTopology, weightSpecs, weightData}; + let pathPrefix = weightPath.substring(0, weightPath.lastIndexOf('/')); + if (!pathPrefix.endsWith('/')) { + pathPrefix = pathPrefix + '/'; + } + const fetchURLs: string[] = []; + weightsManifest.forEach(weightsGroup => { + weightsGroup.paths.forEach(path => { + fetchURLs.push(pathPrefix + path); + }); + }); + return [ + weightSpecs, + concatenateArrayBuffers( + await loadWeightsAsArrayBuffer(fetchURLs, this.requestInit)) + ]; } } @@ -309,11 +354,11 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter); * HTTP request to server using `fetch`. It can contain fields such as * `method`, `credentials`, `headers`, `mode`, etc. See * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request - * for more information. `requestInit` must not have a body, because the body - * will be set by TensorFlow.js. File blobs representing - * the model topology (filename: 'model.json') and the weights of the - * model (filename: 'model.weights.bin') will be appended to the body. - * If `requestInit` has a `body`, an Error will be thrown. + * for more information. `requestInit` must not have a body, because the + * body will be set by TensorFlow.js. File blobs representing the model + * topology (filename: 'model.json') and the weights of the model (filename: + * 'model.weights.bin') will be appended to the body. If `requestInit` has a + * `body`, an Error will be thrown. * @returns An instance of `IOHandler`. */ export function browserHTTPRequest( diff --git a/src/io/browser_http_test.ts b/src/io/browser_http_test.ts index 11d2ffd8de..1cab9bc668 100644 --- a/src/io/browser_http_test.ts +++ b/src/io/browser_http_test.ts @@ -99,7 +99,7 @@ describeWithFlags('browserHTTPRequest-load fetch-polyfill', NODE_ENVS, () => { }); }; - it('1 group, 2 weights, 1 path', done => { + it('1 group, 2 weights, 1 path', (done: DoneFn) => { const weightManifest1: tf.io.WeightsManifestConfig = [{ paths: ['weightfile0'], weights: [ @@ -186,7 +186,7 @@ describeWithFlags('browserHTTPRequest-save', CHROME_ENVS, () => { }); }); - it('Save topology and weights, default POST method', done => { + it('Save topology and weights, default POST method', (done: DoneFn) => { const testStartDate = new Date(); const handler = tf.io.getSaveHandlers('http://model-upload-test')[0]; handler.save(artifacts1) @@ -239,7 +239,7 @@ describeWithFlags('browserHTTPRequest-save', CHROME_ENVS, () => { }); }); - it('Save topology only, default POST method', done => { + it('Save topology only, default POST method', (done: DoneFn) => { const testStartDate = new Date(); const handler = tf.io.getSaveHandlers('http://model-upload-test')[0]; const topologyOnlyArtifacts = {modelTopology: modelTopology1}; @@ -278,7 +278,7 @@ describeWithFlags('browserHTTPRequest-save', CHROME_ENVS, () => { }); }); - it('Save topology and weights, PUT method, extra headers', done => { + it('Save topology and weights, PUT method, extra headers', (done: DoneFn) => { const testStartDate = new Date(); const handler = tf.io.browserHTTPRequest('model-upload-test', { method: 'PUT', @@ -344,7 +344,7 @@ describeWithFlags('browserHTTPRequest-save', CHROME_ENVS, () => { }); }); - it('404 response causes Error', done => { + it('404 response causes Error', (done: DoneFn) => { const handler = tf.io.getSaveHandlers('http://invalid/path')[0]; handler.save(artifacts1) .then(saveResult => { @@ -385,296 +385,532 @@ describeWithFlags('browserHTTPRequest-save', CHROME_ENVS, () => { }); describeWithFlags('browserHTTPRequest-load', BROWSER_ENVS, () => { - let requestInits: RequestInit[]; + describe('JSON model', () => { + let requestInits: RequestInit[]; + + const setupFakeWeightFiles = (fileBufferMap: { + [filename: string]: string|Float32Array|Int32Array|ArrayBuffer|Uint8Array| + Uint16Array + }) => { + spyOn(window, 'fetch').and.callFake((path: string, init: RequestInit) => { + requestInits.push(init); + return new Response(fileBufferMap[path]); + }); + }; - const setupFakeWeightFiles = (fileBufferMap: { - [filename: string]: string|Float32Array|Int32Array|ArrayBuffer|Uint8Array| - Uint16Array - }) => { - spyOn(window, 'fetch').and.callFake((path: string, init: RequestInit) => { - requestInits.push(init); - return new Response(fileBufferMap[path]); + beforeEach(() => { + requestInits = []; }); - }; - beforeEach(() => { - requestInits = []; - }); + it('1 group, 2 weights, 1 path', (done: DoneFn) => { + const weightManifest1: tf.io.WeightsManifestConfig = [{ + paths: ['weightfile0'], + weights: [ + { + name: 'dense/kernel', + shape: [3, 1], + dtype: 'float32', + }, + { + name: 'dense/bias', + shape: [2], + dtype: 'float32', + } + ] + }]; + const floatData = new Float32Array([1, 3, 3, 7, 4]); + setupFakeWeightFiles({ + './model.json': JSON.stringify( + {modelTopology: modelTopology1, weightsManifest: weightManifest1}), + './weightfile0': floatData, + }); + + const handler = tf.io.browserHTTPRequest('./model.json'); + handler.load() + .then(modelArtifacts => { + expect(modelArtifacts.modelTopology).toEqual(modelTopology1); + expect(modelArtifacts.weightSpecs) + .toEqual(weightManifest1[0].weights); + expect(new Float32Array(modelArtifacts.weightData)) + .toEqual(floatData); + expect(requestInits).toEqual([{}, {}]); + done(); + }) + .catch(err => done.fail(err.stack)); + }); - it('1 group, 2 weights, 1 path', done => { - const weightManifest1: tf.io.WeightsManifestConfig = [{ - paths: ['weightfile0'], - weights: [ + it('1 group, 2 weights, 1 path, with requestInit', (done: DoneFn) => { + const weightManifest1: tf.io.WeightsManifestConfig = [{ + paths: ['weightfile0'], + weights: [ + { + name: 'dense/kernel', + shape: [3, 1], + dtype: 'float32', + }, + { + name: 'dense/bias', + shape: [2], + dtype: 'float32', + } + ] + }]; + const floatData = new Float32Array([1, 3, 3, 7, 4]); + setupFakeWeightFiles({ + './model.json': JSON.stringify( + {modelTopology: modelTopology1, weightsManifest: weightManifest1}), + './weightfile0': floatData, + }); + + const handler = tf.io.browserHTTPRequest( + './model.json', {headers: {'header_key_1': 'header_value_1'}}); + handler.load() + .then(modelArtifacts => { + expect(modelArtifacts.modelTopology).toEqual(modelTopology1); + expect(modelArtifacts.weightSpecs) + .toEqual(weightManifest1[0].weights); + expect(new Float32Array(modelArtifacts.weightData)) + .toEqual(floatData); + expect(requestInits).toEqual([ + {headers: {'header_key_1': 'header_value_1'}}, + {headers: {'header_key_1': 'header_value_1'}} + ]); + done(); + }) + .catch(err => done.fail(err.stack)); + }); + + it('1 group, 2 weight, 2 paths', (done: DoneFn) => { + const weightManifest1: tf.io.WeightsManifestConfig = [{ + paths: ['weightfile0', 'weightfile1'], + weights: [ + { + name: 'dense/kernel', + shape: [3, 1], + dtype: 'float32', + }, + { + name: 'dense/bias', + shape: [2], + dtype: 'float32', + } + ] + }]; + const floatData1 = new Float32Array([1, 3, 3]); + const floatData2 = new Float32Array([7, 4]); + setupFakeWeightFiles({ + './model.json': JSON.stringify( + {modelTopology: modelTopology1, weightsManifest: weightManifest1}), + './weightfile0': floatData1, + './weightfile1': floatData2, + }); + + const handler = tf.io.browserHTTPRequest('./model.json'); + handler.load() + .then(modelArtifacts => { + expect(modelArtifacts.modelTopology).toEqual(modelTopology1); + expect(modelArtifacts.weightSpecs) + .toEqual(weightManifest1[0].weights); + expect(new Float32Array(modelArtifacts.weightData)) + .toEqual(new Float32Array([1, 3, 3, 7, 4])); + done(); + }) + .catch(err => done.fail(err.stack)); + }); + + it('2 groups, 2 weight, 2 paths', (done: DoneFn) => { + const weightsManifest: tf.io.WeightsManifestConfig = [ { - name: 'dense/kernel', - shape: [3, 1], - dtype: 'float32', + paths: ['weightfile0'], + weights: [{ + name: 'dense/kernel', + shape: [3, 1], + dtype: 'float32', + }] }, { - name: 'dense/bias', - shape: [2], - dtype: 'float32', + paths: ['weightfile1'], + weights: [{ + name: 'dense/bias', + shape: [2], + dtype: 'float32', + }], } - ] - }]; - const floatData = new Float32Array([1, 3, 3, 7, 4]); - setupFakeWeightFiles({ - './model.json': JSON.stringify( - {modelTopology: modelTopology1, weightsManifest: weightManifest1}), - './weightfile0': floatData, + ]; + const floatData1 = new Float32Array([1, 3, 3]); + const floatData2 = new Float32Array([7, 4]); + setupFakeWeightFiles({ + './model.json': + JSON.stringify({modelTopology: modelTopology1, weightsManifest}), + './weightfile0': floatData1, + './weightfile1': floatData2, + }); + + const handler = tf.io.browserHTTPRequest('./model.json'); + handler.load() + .then(modelArtifacts => { + expect(modelArtifacts.modelTopology).toEqual(modelTopology1); + expect(modelArtifacts.weightSpecs) + .toEqual(weightsManifest[0].weights.concat( + weightsManifest[1].weights)); + expect(new Float32Array(modelArtifacts.weightData)) + .toEqual(new Float32Array([1, 3, 3, 7, 4])); + done(); + }) + .catch(err => done.fail(err.stack)); }); - const handler = tf.io.browserHTTPRequest('./model.json'); - handler.load() - .then(modelArtifacts => { - expect(modelArtifacts.modelTopology).toEqual(modelTopology1); - expect(modelArtifacts.weightSpecs) - .toEqual(weightManifest1[0].weights); - expect(new Float32Array(modelArtifacts.weightData)) - .toEqual(floatData); - expect(requestInits).toEqual([{}, {}]); - done(); - }) - .catch(err => done.fail(err.stack)); - }); - - it('1 group, 2 weights, 1 path, with requestInit', done => { - const weightManifest1: tf.io.WeightsManifestConfig = [{ - paths: ['weightfile0'], - weights: [ + it('2 groups, 2 weight, 2 paths, Int32 and Uint8 Data', (done: DoneFn) => { + const weightsManifest: tf.io.WeightsManifestConfig = [ { - name: 'dense/kernel', - shape: [3, 1], - dtype: 'float32', + paths: ['weightfile0'], + weights: [{ + name: 'fooWeight', + shape: [3, 1], + dtype: 'int32', + }] }, { - name: 'dense/bias', - shape: [2], - dtype: 'float32', + paths: ['weightfile1'], + weights: [{ + name: 'barWeight', + shape: [2], + dtype: 'bool', + }], } - ] - }]; - const floatData = new Float32Array([1, 3, 3, 7, 4]); - setupFakeWeightFiles({ - './model.json': JSON.stringify( - {modelTopology: modelTopology1, weightsManifest: weightManifest1}), - './weightfile0': floatData, + ]; + const floatData1 = new Int32Array([1, 3, 3]); + const floatData2 = new Uint8Array([7, 4]); + setupFakeWeightFiles({ + 'path1/model.json': + JSON.stringify({modelTopology: modelTopology1, weightsManifest}), + 'path1/weightfile0': floatData1, + 'path1/weightfile1': floatData2, + }); + + const handler = tf.io.browserHTTPRequest('path1/model.json'); + handler.load() + .then(modelArtifacts => { + expect(modelArtifacts.modelTopology).toEqual(modelTopology1); + expect(modelArtifacts.weightSpecs) + .toEqual(weightsManifest[0].weights.concat( + weightsManifest[1].weights)); + expect(new Int32Array(modelArtifacts.weightData.slice(0, 12))) + .toEqual(new Int32Array([1, 3, 3])); + expect(new Uint8Array(modelArtifacts.weightData.slice(12, 14))) + .toEqual(new Uint8Array([7, 4])); + done(); + }) + .catch(err => done.fail(err.stack)); }); - const handler = tf.io.browserHTTPRequest( - './model.json', {headers: {'header_key_1': 'header_value_1'}}); - handler.load() - .then(modelArtifacts => { - expect(modelArtifacts.modelTopology).toEqual(modelTopology1); - expect(modelArtifacts.weightSpecs) - .toEqual(weightManifest1[0].weights); - expect(new Float32Array(modelArtifacts.weightData)) - .toEqual(floatData); - expect(requestInits).toEqual([ - {headers: {'header_key_1': 'header_value_1'}}, - {headers: {'header_key_1': 'header_value_1'}} - ]); - done(); - }) - .catch(err => done.fail(err.stack)); - }); + it('topology only', (done: DoneFn) => { + setupFakeWeightFiles({ + './model.json': JSON.stringify({modelTopology: modelTopology1}), + }); + + const handler = tf.io.browserHTTPRequest('./model.json'); + handler.load() + .then(modelArtifacts => { + expect(modelArtifacts.modelTopology).toEqual(modelTopology1); + expect(modelArtifacts.weightSpecs).toBeUndefined(); + expect(modelArtifacts.weightData).toBeUndefined(); + done(); + }) + .catch(err => done.fail(err.stack)); + }); - it('1 group, 2 weight, 2 paths', done => { - const weightManifest1: tf.io.WeightsManifestConfig = [{ - paths: ['weightfile0', 'weightfile1'], - weights: [ + it('weights only', (done: DoneFn) => { + const weightsManifest: tf.io.WeightsManifestConfig = [ { - name: 'dense/kernel', - shape: [3, 1], - dtype: 'float32', + paths: ['weightfile0'], + weights: [{ + name: 'fooWeight', + shape: [3, 1], + dtype: 'int32', + }] }, { - name: 'dense/bias', - shape: [2], - dtype: 'float32', + paths: ['weightfile1'], + weights: [{ + name: 'barWeight', + shape: [2], + dtype: 'float32', + }], } - ] - }]; - const floatData1 = new Float32Array([1, 3, 3]); - const floatData2 = new Float32Array([7, 4]); - setupFakeWeightFiles({ - './model.json': JSON.stringify( - {modelTopology: modelTopology1, weightsManifest: weightManifest1}), - './weightfile0': floatData1, - './weightfile1': floatData2, + ]; + const floatData1 = new Int32Array([1, 3, 3]); + const floatData2 = new Float32Array([-7, -4]); + setupFakeWeightFiles({ + 'path1/model.json': JSON.stringify({weightsManifest}), + 'path1/weightfile0': floatData1, + 'path1/weightfile1': floatData2, + }); + + const handler = tf.io.browserHTTPRequest('path1/model.json'); + handler.load() + .then(modelArtifacts => { + expect(modelArtifacts.modelTopology).toBeUndefined(); + expect(modelArtifacts.weightSpecs) + .toEqual(weightsManifest[0].weights.concat( + weightsManifest[1].weights)); + expect(new Int32Array(modelArtifacts.weightData.slice(0, 12))) + .toEqual(new Int32Array([1, 3, 3])); + expect(new Float32Array(modelArtifacts.weightData.slice(12, 20))) + .toEqual(new Float32Array([-7, -4])); + done(); + }) + .catch(err => done.fail(err.stack)); }); - const handler = tf.io.browserHTTPRequest('./model.json'); - handler.load() - .then(modelArtifacts => { - expect(modelArtifacts.modelTopology).toEqual(modelTopology1); - expect(modelArtifacts.weightSpecs) - .toEqual(weightManifest1[0].weights); - expect(new Float32Array(modelArtifacts.weightData)) - .toEqual(new Float32Array([1, 3, 3, 7, 4])); - done(); - }) - .catch(err => done.fail(err.stack)); + it('Missing modelTopology and weightsManifest leads to error', + (done: DoneFn) => { + setupFakeWeightFiles({'path1/model.json': JSON.stringify({})}); + const handler = tf.io.browserHTTPRequest('path1/model.json'); + handler.load() + .then(modelTopology1 => { + done.fail( + 'Loading from missing modelTopology and weightsManifest ' + + 'succeeded expectedly.'); + }) + .catch(err => { + expect(err.message) + .toMatch(/contains neither model topology or manifest/); + done(); + }); + }); }); - it('2 groups, 2 weight, 2 paths', done => { - const weightsManifest: tf.io.WeightsManifestConfig = [ - { - paths: ['weightfile0'], - weights: [{ - name: 'dense/kernel', - shape: [3, 1], - dtype: 'float32', - }] - }, - { - paths: ['weightfile1'], - weights: [{ - name: 'dense/bias', - shape: [2], - dtype: 'float32', - }], - } - ]; - const floatData1 = new Float32Array([1, 3, 3]); - const floatData2 = new Float32Array([7, 4]); - setupFakeWeightFiles({ - './model.json': - JSON.stringify({modelTopology: modelTopology1, weightsManifest}), - './weightfile0': floatData1, - './weightfile1': floatData2, - }); + describe('Binary model', () => { + let requestInits: RequestInit[]; + let modelData: ArrayBuffer; - const handler = tf.io.browserHTTPRequest('./model.json'); - handler.load() - .then(modelArtifacts => { - expect(modelArtifacts.modelTopology).toEqual(modelTopology1); - expect(modelArtifacts.weightSpecs) - .toEqual(weightsManifest[0].weights.concat( - weightsManifest[1].weights)); - expect(new Float32Array(modelArtifacts.weightData)) - .toEqual(new Float32Array([1, 3, 3, 7, 4])); - done(); - }) - .catch(err => done.fail(err.stack)); - }); + const setupFakeWeightFiles = (fileBufferMap: { + [filename: string]: string|Float32Array|Int32Array|ArrayBuffer|Uint8Array| + Uint16Array + }) => { + spyOn(window, 'fetch').and.callFake((path: string, init: RequestInit) => { + requestInits.push(init); + return new Response(fileBufferMap[path]); + }); + }; - it('2 groups, 2 weight, 2 paths, Int32 and Uint8 Data', done => { - const weightsManifest: tf.io.WeightsManifestConfig = [ - { - paths: ['weightfile0'], - weights: [{ - name: 'fooWeight', - shape: [3, 1], - dtype: 'int32', - }] - }, - { - paths: ['weightfile1'], - weights: [{ - name: 'barWeight', - shape: [2], - dtype: 'bool', - }], - } - ]; - const floatData1 = new Int32Array([1, 3, 3]); - const floatData2 = new Uint8Array([7, 4]); - setupFakeWeightFiles({ - 'path1/model.json': - JSON.stringify({modelTopology: modelTopology1, weightsManifest}), - 'path1/weightfile0': floatData1, - 'path1/weightfile1': floatData2, + beforeEach(() => { + requestInits = []; + modelData = new ArrayBuffer(5); }); - const handler = tf.io.browserHTTPRequest('path1/model.json'); - handler.load() - .then(modelArtifacts => { - expect(modelArtifacts.modelTopology).toEqual(modelTopology1); - expect(modelArtifacts.weightSpecs) - .toEqual(weightsManifest[0].weights.concat( - weightsManifest[1].weights)); - expect(new Int32Array(modelArtifacts.weightData.slice(0, 12))) - .toEqual(new Int32Array([1, 3, 3])); - expect(new Uint8Array(modelArtifacts.weightData.slice(12, 14))) - .toEqual(new Uint8Array([7, 4])); - done(); - }) - .catch(err => done.fail(err.stack)); - }); + it('1 group, 2 weights, 1 path', (done: DoneFn) => { + const weightManifest1: tf.io.WeightsManifestConfig = [{ + paths: ['weightfile0'], + weights: [ + { + name: 'dense/kernel', + shape: [3, 1], + dtype: 'float32', + }, + { + name: 'dense/bias', + shape: [2], + dtype: 'float32', + } + ] + }]; + const floatData = new Float32Array([1, 3, 3, 7, 4]); + setupFakeWeightFiles({ + './model.pb': modelData, + './weights_manifest.json': JSON.stringify(weightManifest1), + './weightfile0': floatData, + }); + + const handler = + tf.io.browserHTTPRequest(['./model.pb', './weights_manifest.json']); + handler.load() + .then(modelArtifacts => { + expect(modelArtifacts.modelTopology).toEqual(modelData); + expect(modelArtifacts.weightSpecs) + .toEqual(weightManifest1[0].weights); + expect(new Float32Array(modelArtifacts.weightData)) + .toEqual(floatData); + expect(requestInits).toEqual([{}, {}, {}]); + done(); + }) + .catch(err => done.fail(err.stack)); + }); - it('topology only', done => { - setupFakeWeightFiles({ - './model.json': JSON.stringify({modelTopology: modelTopology1}), + it('1 group, 2 weights, 1 path, with requestInit', (done: DoneFn) => { + const weightManifest1: tf.io.WeightsManifestConfig = [{ + paths: ['weightfile0'], + weights: [ + { + name: 'dense/kernel', + shape: [3, 1], + dtype: 'float32', + }, + { + name: 'dense/bias', + shape: [2], + dtype: 'float32', + } + ] + }]; + const floatData = new Float32Array([1, 3, 3, 7, 4]); + + setupFakeWeightFiles({ + './model.pb': modelData, + './weights_manifest.json': JSON.stringify(weightManifest1), + './weightfile0': floatData, + }); + + const handler = tf.io.browserHTTPRequest( + ['./model.pb', './weights_manifest.json'], + {headers: {'header_key_1': 'header_value_1'}}); + handler.load() + .then(modelArtifacts => { + expect(modelArtifacts.modelTopology).toEqual(modelData); + expect(modelArtifacts.weightSpecs) + .toEqual(weightManifest1[0].weights); + expect(new Float32Array(modelArtifacts.weightData)) + .toEqual(floatData); + expect(requestInits).toEqual([ + {headers: {'header_key_1': 'header_value_1'}}, + {headers: {'header_key_1': 'header_value_1'}}, + {headers: {'header_key_1': 'header_value_1'}}, + ]); + done(); + }) + .catch(err => done.fail(err.stack)); }); - const handler = tf.io.browserHTTPRequest('./model.json'); - handler.load() - .then(modelArtifacts => { - expect(modelArtifacts.modelTopology).toEqual(modelTopology1); - expect(modelArtifacts.weightSpecs).toBeUndefined(); - expect(modelArtifacts.weightData).toBeUndefined(); - done(); - }) - .catch(err => done.fail(err.stack)); - }); + it('1 group, 2 weight, 2 paths', (done: DoneFn) => { + const weightManifest1: tf.io.WeightsManifestConfig = [{ + paths: ['weightfile0', 'weightfile1'], + weights: [ + { + name: 'dense/kernel', + shape: [3, 1], + dtype: 'float32', + }, + { + name: 'dense/bias', + shape: [2], + dtype: 'float32', + } + ] + }]; + const floatData1 = new Float32Array([1, 3, 3]); + const floatData2 = new Float32Array([7, 4]); + setupFakeWeightFiles({ + './model.pb': modelData, + './weights_manifest.json': JSON.stringify(weightManifest1), + './weightfile0': floatData1, + './weightfile1': floatData2, + }); + + const handler = + tf.io.browserHTTPRequest(['./model.pb', './weights_manifest.json']); + handler.load() + .then(modelArtifacts => { + expect(modelArtifacts.modelTopology).toEqual(modelData); + expect(modelArtifacts.weightSpecs) + .toEqual(weightManifest1[0].weights); + expect(new Float32Array(modelArtifacts.weightData)) + .toEqual(new Float32Array([1, 3, 3, 7, 4])); + done(); + }) + .catch(err => done.fail(err.stack)); + }); - it('weights only', done => { - const weightsManifest: tf.io.WeightsManifestConfig = [ - { - paths: ['weightfile0'], - weights: [{ - name: 'fooWeight', - shape: [3, 1], - dtype: 'int32', - }] - }, - { - paths: ['weightfile1'], - weights: [{ - name: 'barWeight', - shape: [2], - dtype: 'float32', - }], - } - ]; - const floatData1 = new Int32Array([1, 3, 3]); - const floatData2 = new Float32Array([-7, -4]); - setupFakeWeightFiles({ - 'path1/model.json': JSON.stringify({weightsManifest}), - 'path1/weightfile0': floatData1, - 'path1/weightfile1': floatData2, + it('2 groups, 2 weight, 2 paths', (done: DoneFn) => { + const weightsManifest: tf.io.WeightsManifestConfig = [ + { + paths: ['weightfile0'], + weights: [{ + name: 'dense/kernel', + shape: [3, 1], + dtype: 'float32', + }] + }, + { + paths: ['weightfile1'], + weights: [{ + name: 'dense/bias', + shape: [2], + dtype: 'float32', + }], + } + ]; + const floatData1 = new Float32Array([1, 3, 3]); + const floatData2 = new Float32Array([7, 4]); + setupFakeWeightFiles({ + './model.pb': modelData, + './weights_manifest.json': JSON.stringify(weightsManifest), + './weightfile0': floatData1, + './weightfile1': floatData2, + }); + + const handler = + tf.io.browserHTTPRequest(['./model.pb', './weights_manifest.json']); + handler.load() + .then(modelArtifacts => { + expect(modelArtifacts.modelTopology).toEqual(modelData); + expect(modelArtifacts.weightSpecs) + .toEqual(weightsManifest[0].weights.concat( + weightsManifest[1].weights)); + expect(new Float32Array(modelArtifacts.weightData)) + .toEqual(new Float32Array([1, 3, 3, 7, 4])); + done(); + }) + .catch(err => done.fail(err.stack)); }); - const handler = tf.io.browserHTTPRequest('path1/model.json'); - handler.load() - .then(modelArtifacts => { - expect(modelArtifacts.modelTopology).toBeUndefined(); - expect(modelArtifacts.weightSpecs) - .toEqual(weightsManifest[0].weights.concat( - weightsManifest[1].weights)); - expect(new Int32Array(modelArtifacts.weightData.slice(0, 12))) - .toEqual(new Int32Array([1, 3, 3])); - expect(new Float32Array(modelArtifacts.weightData.slice(12, 20))) - .toEqual(new Float32Array([-7, -4])); - done(); - }) - .catch(err => done.fail(err.stack)); - }); + it('2 groups, 2 weight, 2 paths, Int32 and Uint8 Data', (done: DoneFn) => { + const weightsManifest: tf.io.WeightsManifestConfig = [ + { + paths: ['weightfile0'], + weights: [{ + name: 'fooWeight', + shape: [3, 1], + dtype: 'int32', + }] + }, + { + paths: ['weightfile1'], + weights: [{ + name: 'barWeight', + shape: [2], + dtype: 'bool', + }], + } + ]; + const floatData1 = new Int32Array([1, 3, 3]); + const floatData2 = new Uint8Array([7, 4]); + setupFakeWeightFiles({ + 'path1/model.pb': modelData, + 'path2/weights_manifest.json': JSON.stringify(weightsManifest), + 'path2/weightfile0': floatData1, + 'path2/weightfile1': floatData2, + }); + + const handler = tf.io.browserHTTPRequest( + ['path1/model.pb', 'path2/weights_manifest.json']); + handler.load() + .then(modelArtifacts => { + expect(modelArtifacts.modelTopology).toEqual(modelData); + expect(modelArtifacts.weightSpecs) + .toEqual(weightsManifest[0].weights.concat( + weightsManifest[1].weights)); + expect(new Int32Array(modelArtifacts.weightData.slice(0, 12))) + .toEqual(new Int32Array([1, 3, 3])); + expect(new Uint8Array(modelArtifacts.weightData.slice(12, 14))) + .toEqual(new Uint8Array([7, 4])); + done(); + }) + .catch(err => done.fail(err.stack)); + }); - it('Missing modelTopology and weightsManifest leads to error', done => { - setupFakeWeightFiles({'path1/model.json': JSON.stringify({})}); - const handler = tf.io.browserHTTPRequest('path1/model.json'); - handler.load() - .then(modelTopology1 => { - done.fail( - 'Loading from missing modelTopology and weightsManifest ' + - 'succeeded expectedly.'); - }) - .catch(err => { - expect(err.message) - .toMatch(/contains neither model topology or manifest/); - done(); - }); + it('the url path length is not 2 should leads to error', () => { + expect(() => tf.io.browserHTTPRequest(['path1/model.pb'])).toThrow(); + }); }); });