Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
add support for binary model loading through browser http handler (#1207
Browse files Browse the repository at this point in the history
)

* add support for binary model loading through browser http handler

* address review comments

* fixed the tests

* refactor the loadWeights method
  • Loading branch information
pyu10055 authored Aug 6, 2018
1 parent a51dffc commit 53bfcfe
Show file tree
Hide file tree
Showing 2 changed files with 570 additions and 289 deletions.
103 changes: 74 additions & 29 deletions src/io/browser_http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,17 @@ export class BrowserHTTPRequest implements IOHandler {
'browserHTTPRequest is not supported outside the web browser without a fetch polyfill.');
}

if (Array.isArray(path)) {
throw new Error(
`Handling of multiple ${path.length} HTTP URLs (e.g., ` +
`for loading FrozenModel) is not implemented yet.`);
}

assert(
path != null && path.length > 0,
'URL path for browserHTTPRequest must not be null, undefined or ' +
'empty.');

if (Array.isArray(path)) {
assert(
path.length === 2,
'URL paths for browserHTTPRequest must have a length of 2, ' +
`(actual length is ${path.length}).`);
}
this.path = path;

if (requestInit != null && requestInit.body != null) {
Expand Down Expand Up @@ -118,6 +119,42 @@ export class BrowserHTTPRequest implements IOHandler {
* @returns The loaded model artifacts (if loading succeeds).
*/
async load(): Promise<ModelArtifacts> {
return Array.isArray(this.path) ? this.loadBinaryModel() :
this.loadJSONModel();
}

/**
* Loads the model topology file and build the in memory graph of the model.
*/
private async loadBinaryTopology(): Promise<ArrayBuffer> {
try {
const response = await fetch(this.path[0], this.requestInit);
return await response.arrayBuffer();
} catch (error) {
throw new Error(`${this.path[0]} not found. ${error}`);
}
}

protected async loadBinaryModel(): Promise<ModelArtifacts> {
const graphPromise = this.loadBinaryTopology();
const manifestPromise = await fetch(this.path[1], this.requestInit);

const [modelTopology, weightsManifestResponse] =
await Promise.all([graphPromise, manifestPromise]);

const weightsManifest =
await weightsManifestResponse.json() as WeightsManifestConfig;

let weightSpecs: WeightsManifestEntry[];
let weightData: ArrayBuffer;
if (weightsManifest != null) {
[weightSpecs, weightData] = await this.loadWeights(weightsManifest);
}

return {modelTopology, weightSpecs, weightData};
}

protected async loadJSONModel(): Promise<ModelArtifacts> {
const modelConfigRequest =
await fetch(this.path as string, this.requestInit);
const modelConfig = await modelConfigRequest.json();
Expand All @@ -136,28 +173,36 @@ export class BrowserHTTPRequest implements IOHandler {
if (weightsManifest != null) {
const weightsManifest =
modelConfig['weightsManifest'] as WeightsManifestConfig;
weightSpecs = [];
for (const entry of weightsManifest) {
weightSpecs.push(...entry.weights);
}
[weightSpecs, weightData] = await this.loadWeights(weightsManifest);
}

let pathPrefix =
(this.path as string).substring(0, this.path.lastIndexOf('/'));
if (!pathPrefix.endsWith('/')) {
pathPrefix = pathPrefix + '/';
}
return {modelTopology, weightSpecs, weightData};
}

const fetchURLs: string[] = [];
weightsManifest.forEach(weightsGroup => {
weightsGroup.paths.forEach(path => {
fetchURLs.push(pathPrefix + path);
});
});
weightData = concatenateArrayBuffers(
await loadWeightsAsArrayBuffer(fetchURLs, this.requestInit));
private async loadWeights(weightsManifest: WeightsManifestConfig):
Promise<[WeightsManifestEntry[], ArrayBuffer]> {
const weightPath = Array.isArray(this.path) ? this.path[1] : this.path;

const weightSpecs = [];
for (const entry of weightsManifest) {
weightSpecs.push(...entry.weights);
}

return {modelTopology, weightSpecs, weightData};
let pathPrefix = weightPath.substring(0, weightPath.lastIndexOf('/'));
if (!pathPrefix.endsWith('/')) {
pathPrefix = pathPrefix + '/';
}
const fetchURLs: string[] = [];
weightsManifest.forEach(weightsGroup => {
weightsGroup.paths.forEach(path => {
fetchURLs.push(pathPrefix + path);
});
});
return [
weightSpecs,
concatenateArrayBuffers(
await loadWeightsAsArrayBuffer(fetchURLs, this.requestInit))
];
}
}

Expand Down Expand Up @@ -309,11 +354,11 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
* HTTP request to server using `fetch`. It can contain fields such as
* `method`, `credentials`, `headers`, `mode`, etc. See
* https://developer.mozilla.org/en-US/docs/Web/API/Request/Request
* for more information. `requestInit` must not have a body, because the body
* will be set by TensorFlow.js. File blobs representing
* the model topology (filename: 'model.json') and the weights of the
* model (filename: 'model.weights.bin') will be appended to the body.
* If `requestInit` has a `body`, an Error will be thrown.
* for more information. `requestInit` must not have a body, because the
* body will be set by TensorFlow.js. File blobs representing the model
* topology (filename: 'model.json') and the weights of the model (filename:
* 'model.weights.bin') will be appended to the body. If `requestInit` has a
* `body`, an Error will be thrown.
* @returns An instance of `IOHandler`.
*/
export function browserHTTPRequest(
Expand Down
Loading

0 comments on commit 53bfcfe

Please sign in to comment.