Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

add support for binary model loading through browser http handler #1207

Merged
merged 6 commits into from
Aug 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 74 additions & 29 deletions src/io/browser_http.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,17 @@ export class BrowserHTTPRequest implements IOHandler {
'browserHTTPRequest is not supported outside the web browser without a fetch polyfill.');
}

if (Array.isArray(path)) {
throw new Error(
`Handling of multiple ${path.length} HTTP URLs (e.g., ` +
`for loading FrozenModel) is not implemented yet.`);
}

assert(
path != null && path.length > 0,
'URL path for browserHTTPRequest must not be null, undefined or ' +
'empty.');

if (Array.isArray(path)) {
assert(
path.length === 2,
'URL paths for browserHTTPRequest must have a length of 2, ' +
`(actual length is ${path.length}).`);
}
this.path = path;

if (requestInit != null && requestInit.body != null) {
Expand Down Expand Up @@ -118,6 +119,42 @@ export class BrowserHTTPRequest implements IOHandler {
* @returns The loaded model artifacts (if loading succeeds).
*/
async load(): Promise<ModelArtifacts> {
return Array.isArray(this.path) ? this.loadBinaryModel() :
this.loadJSONModel();
}

/**
* Loads the model topology file and build the in memory graph of the model.
*/
private async loadBinaryTopology(): Promise<ArrayBuffer> {
try {
const response = await fetch(this.path[0], this.requestInit);
return await response.arrayBuffer();
} catch (error) {
throw new Error(`${this.path[0]} not found. ${error}`);
}
}

protected async loadBinaryModel(): Promise<ModelArtifacts> {
const graphPromise = this.loadBinaryTopology();
const manifestPromise = await fetch(this.path[1], this.requestInit);

const [modelTopology, weightsManifestResponse] =
await Promise.all([graphPromise, manifestPromise]);

const weightsManifest =
await weightsManifestResponse.json() as WeightsManifestConfig;

let weightSpecs: WeightsManifestEntry[];
let weightData: ArrayBuffer;
if (weightsManifest != null) {
[weightSpecs, weightData] = await this.loadWeights(weightsManifest);
}

return {modelTopology, weightSpecs, weightData};
}

protected async loadJSONModel(): Promise<ModelArtifacts> {
const modelConfigRequest =
await fetch(this.path as string, this.requestInit);
const modelConfig = await modelConfigRequest.json();
Expand All @@ -136,28 +173,36 @@ export class BrowserHTTPRequest implements IOHandler {
if (weightsManifest != null) {
const weightsManifest =
modelConfig['weightsManifest'] as WeightsManifestConfig;
weightSpecs = [];
for (const entry of weightsManifest) {
weightSpecs.push(...entry.weights);
}
[weightSpecs, weightData] = await this.loadWeights(weightsManifest);
}

let pathPrefix =
(this.path as string).substring(0, this.path.lastIndexOf('/'));
if (!pathPrefix.endsWith('/')) {
pathPrefix = pathPrefix + '/';
}
return {modelTopology, weightSpecs, weightData};
}

const fetchURLs: string[] = [];
weightsManifest.forEach(weightsGroup => {
weightsGroup.paths.forEach(path => {
fetchURLs.push(pathPrefix + path);
});
});
weightData = concatenateArrayBuffers(
await loadWeightsAsArrayBuffer(fetchURLs, this.requestInit));
private async loadWeights(weightsManifest: WeightsManifestConfig):
Promise<[WeightsManifestEntry[], ArrayBuffer]> {
const weightPath = Array.isArray(this.path) ? this.path[1] : this.path;

const weightSpecs = [];
for (const entry of weightsManifest) {
weightSpecs.push(...entry.weights);
}

return {modelTopology, weightSpecs, weightData};
let pathPrefix = weightPath.substring(0, weightPath.lastIndexOf('/'));
if (!pathPrefix.endsWith('/')) {
pathPrefix = pathPrefix + '/';
}
const fetchURLs: string[] = [];
weightsManifest.forEach(weightsGroup => {
weightsGroup.paths.forEach(path => {
fetchURLs.push(pathPrefix + path);
});
});
return [
weightSpecs,
concatenateArrayBuffers(
await loadWeightsAsArrayBuffer(fetchURLs, this.requestInit))
];
}
}

Expand Down Expand Up @@ -309,11 +354,11 @@ IORouterRegistry.registerLoadRouter(httpRequestRouter);
* HTTP request to server using `fetch`. It can contain fields such as
* `method`, `credentials`, `headers`, `mode`, etc. See
* https://developer.mozilla.org/en-US/docs/Web/API/Request/Request
* for more information. `requestInit` must not have a body, because the body
* will be set by TensorFlow.js. File blobs representing
* the model topology (filename: 'model.json') and the weights of the
* model (filename: 'model.weights.bin') will be appended to the body.
* If `requestInit` has a `body`, an Error will be thrown.
* for more information. `requestInit` must not have a body, because the
* body will be set by TensorFlow.js. File blobs representing the model
* topology (filename: 'model.json') and the weights of the model (filename:
* 'model.weights.bin') will be appended to the body. If `requestInit` has a
* `body`, an Error will be thrown.
* @returns An instance of `IOHandler`.
*/
export function browserHTTPRequest(
Expand Down
Loading