From a3fba36dceb453fbe25210e7c00f86d9144a0782 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Mon, 23 Apr 2018 23:45:45 -0400 Subject: [PATCH 01/12] Add basic types and helper methods for model exporting --- src/index.ts | 8 +- src/io/io_utils.ts | 124 ++++++++++++++++++++++++++++ src/io/io_utils_test.ts | 110 ++++++++++++++++++++++++ src/io/types.ts | 47 +++++++++++ src/{ => io}/weights_loader.ts | 21 ++--- src/{ => io}/weights_loader_test.ts | 8 +- 6 files changed, 295 insertions(+), 23 deletions(-) create mode 100644 src/io/io_utils.ts create mode 100644 src/io/io_utils_test.ts create mode 100644 src/io/types.ts rename src/{ => io}/weights_loader.ts (93%) rename src/{ => io}/weights_loader_test.ts (98%) diff --git a/src/index.ts b/src/index.ts index 05260cd89b..0f0c423750 100644 --- a/src/index.ts +++ b/src/index.ts @@ -24,6 +24,11 @@ import * as test_util from './test_util'; import * as util from './util'; import {version} from './version'; +// Serialization. +// tslint:disable-next-line:max-line-length +export {decodeTensors, encodeTensors} from './io/io_utils'; +export {WeightsManifestConfig} from './io/types'; +export {loadWeights} from './io/weights_loader'; // Optimizers. export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer'; export {AdagradOptimizer} from './optimizers/adagrad_optimizer'; @@ -36,9 +41,6 @@ export {SGDOptimizer} from './optimizers/sgd_optimizer'; // tslint:disable-next-line:max-line-length export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer, variable, Variable} from './tensor'; export {DataType, Rank, ShapeMap} from './types'; -// Serialization. -export {WeightsManifestConfig} from './weights_loader'; -export {loadWeights} from './weights_loader'; export * from './ops/ops'; export {LSTMCellFunc} from './ops/lstm'; diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts new file mode 100644 index 0000000000..5d975e7eb8 --- /dev/null +++ b/src/io/io_utils.ts @@ -0,0 +1,124 @@ +/** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {tensor} from '../index'; +import {Tensor} from '../tensor'; +import {NamedTensorMap} from '../types'; +import {WeightsManifestEntry} from './types'; + +/** + * Encode a map from names to Tensors as an ArrayBuffer. + * + * @param tensors A map ("dict") from names to tensors. + * @returns A `Promise` of + * - A flat `ArrayBuffer` with all the binary values of the `Tensor`s + * concatenated. + * - An `Array` of `WeightManifestEntry`s, carrying information including + * tensor names, `dtype`s and shapes. + * @throws Error: on unsupported tensor `dtype`. + */ +export async function encodeTensors(tensors: NamedTensorMap): + Promise<[ArrayBuffer, WeightsManifestEntry[]]> { + const specs: WeightsManifestEntry[] = []; + const dataPromises: Array> = []; + for (const name in tensors) { + const tensor = tensors[name]; + + if (tensor.dtype !== 'float32' && tensor.dtype !== 'int32') { + throw new Error(`Unsupported dtype: ${tensor.dtype}`); + } + specs.push({name, shape: tensor.shape, dtype: tensor.dtype}); + dataPromises.push(tensor.data()); + } + const tensorValues = await Promise.all(dataPromises); + return [concatenateTypedArrays(tensorValues), specs]; +} + +export function decodeTensors( + buffer: ArrayBuffer, specs: WeightsManifestEntry[]): NamedTensorMap { + const out: NamedTensorMap = {}; + let offset = 0; + for (const spec of specs) { + const name = spec.name; + const dtype = spec.dtype; + const shape = spec.shape; + + let numel = 1; + for (const dim of shape) { + numel *= dim; + } + let bytes: number; + let value: Tensor; + if (dtype === 'float32') { + bytes = numel * 4; + value = tensor(new Float32Array(buffer, offset, numel), shape, 'float32'); + } else if (dtype === 'int32') { + bytes = numel * 4; + value = tensor(new Int32Array(buffer, offset, numel), shape, 'int32'); + } else if (dtype === 'bool') { + bytes = numel; + value = tensor(new Uint8Array(buffer, offset, numel), shape, 'bool'); + } else { + throw new Error(`Unsupported dtype: ${dtype}`); + } + out[name] = value; + offset += bytes; + } + return out; +} + +/** + * Concatenate TypedArrays into an ArrayBuffer. + */ +export function concatenateTypedArrays( + xs: Array): ArrayBuffer { + if (xs === null) { + return null; + } + if (xs === undefined) { + return undefined; + } + if (xs.length === 0) { + return new ArrayBuffer(0); + } + + let totalByteLength = 0; + for (const x of xs) { + // tslint:disable-next-line:no-any + if (x as any instanceof Float32Array || x instanceof Int32Array) { + totalByteLength += x.length * 4; + // tslint:disable-next-line:no-any + } else if (x as any instanceof Uint8Array) { + totalByteLength += x.length; + } else { + throw new Error(`Unsupported type array subtype: ${x.constructor.name}`); + } + } + + const y = new Uint8Array(totalByteLength); + let offset = 0; + for (const x of xs) { + y.set(new Uint8Array(x.buffer), offset); + if (x instanceof Float32Array || x instanceof Int32Array) { + offset += x.length * 4; + } else { + offset += x.length; + } + } + + return y.buffer; +} diff --git a/src/io/io_utils_test.ts b/src/io/io_utils_test.ts new file mode 100644 index 0000000000..7b526da746 --- /dev/null +++ b/src/io/io_utils_test.ts @@ -0,0 +1,110 @@ +/** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {scalar, tensor1d, tensor2d, test_util} from '../index'; +import {NamedTensorMap} from '../types'; + +import {concatenateTypedArrays, encodeTensors} from './io_utils'; + +// import {WeightsManifestEntry} from './types'; + +describe('concatenateTypedArrays', () => { + it('Single float arrays', () => { + const x = new Float32Array([1.1, 2.2, 3.3]); + const buffer = concatenateTypedArrays([x]); + + expect(buffer.byteLength).toEqual(12); + const z = Array.from(new Float32Array(buffer, 0, 3)); + test_util.expectArraysClose(z, [1.1, 2.2, 3.3]); + }); + + it('Float arrays', () => { + const x = new Float32Array([1.1, 2.2, 3.3]); + const y = new Float32Array([-1.1, -2.2, -3.3]); + const buffer = concatenateTypedArrays([x, y]); + + expect(buffer.byteLength).toEqual(24); + const z = Array.from(new Float32Array(buffer, 0, 6)); + test_util.expectArraysClose(z, [1.1, 2.2, 3.3, -1.1, -2.2, -3.3]); + }); + + it('Single int32 arrays', () => { + const x = new Int32Array([11, 22, 33]); + const buffer = concatenateTypedArrays([x]); + + expect(buffer.byteLength).toEqual(12); + const z = Array.from(new Int32Array(buffer, 0, 3)); + test_util.expectArraysClose(z, [11, 22, 33]); + }); + + it('Int32 arrays', () => { + const x = new Int32Array([11, 22, 33]); + const y = new Int32Array([-11, -22, -33]); + const buffer = concatenateTypedArrays([x, y]); + + expect(buffer.byteLength).toEqual(24); + const z = Array.from(new Int32Array(buffer, 0, 6)); + test_util.expectArraysClose(z, [11, 22, 33, -11, -22, -33]); + }); + + it('Single uint8 arrays', () => { + const x = new Uint8Array([11, 22, 33]); + const buffer = concatenateTypedArrays([x]); + + expect(buffer.byteLength).toEqual(3); + const z = Array.from(new Uint8Array(buffer, 0, 3)); + test_util.expectArraysClose(z, [11, 22, 33]); + }); + + it('Uint8 arrays', () => { + const x = new Uint8Array([11, 22, 33]); + const y = new Uint8Array([111, 122, 133]); + const buffer = concatenateTypedArrays([x, y]); + + expect(buffer.byteLength).toEqual(6); + const z = Array.from(new Uint8Array(buffer, 0, 6)); + test_util.expectArraysClose(z, [11, 22, 33, 111, 122, 133]); + }); +}); + +describe('encodeTensors', () => { + it('Float32 tensors', async done => { + // const entries: WeightsManifestEntry[] = [{ + // name: 'x1', + // dtype: 'float32', + // shape: [2, 2], + // }, { + // name: 'x2', + // dtype: 'float32', + // shape: [], + // }, { + // name: 'x3', + // dtype: 'float32', + // shape: [3], + // }]; + const tensors: NamedTensorMap = { + x1: tensor2d([[10, 20], [30, 40]]), + x2: scalar(42), + x3: tensor1d([-1, -3, -3, -7]), + }; + encodeTensors(tensors).then(dataAndManifest => { + const data = dataAndManifest[0]; + expect(data.byteLength).toEqual(4 * (4 + 1 + 4)); + done(); + }); + }); +}); diff --git a/src/io/types.ts b/src/io/types.ts new file mode 100644 index 0000000000..2f04f906d7 --- /dev/null +++ b/src/io/types.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +export type WeightsManifestConfig = WeightsManifestGroupConfig[]; + +export interface WeightsManifestEntry { + name: string; + shape: number[]; + dtype: 'float32'|'int32'|'bool'; +} + +export interface WeightsManifestGroupConfig { + paths: string[]; + weights: WeightsManifestEntry[]; +} + +export class SaveResult { + success: boolean; + + resposnes?: Response[]; + + errors?: Array<{}|string>; +} + +export interface ModelArtifact { + modelTopology: {}|ArrayBuffer; + weightsManifest: WeightsManifestConfig; + weightsData: ArrayBuffer; +} + +export type SaveHandler = (modelArtifact: ModelArtifact) => Promise; + +export type LoadHandler = () => Promise; diff --git a/src/weights_loader.ts b/src/io/weights_loader.ts similarity index 93% rename from src/weights_loader.ts rename to src/io/weights_loader.ts index 97a3251a8f..db1e07d1ae 100644 --- a/src/weights_loader.ts +++ b/src/io/weights_loader.ts @@ -15,20 +15,10 @@ * ============================================================================= */ -import {tensor} from './ops/ops'; -import {NamedTensorMap} from './types'; -import * as util from './util'; - -export type WeightsManifestConfig = WeightsManifestGroupConfig[]; -export interface WeightsManifestGroupConfig { - paths: string[]; - weights: WeightsManifestEntry[]; -} -export interface WeightsManifestEntry { - name: string; - shape: number[]; - dtype: 'float32'|'int32'; -} +import {tensor} from '../ops/ops'; +import {NamedTensorMap} from '../types'; +import * as util from '../util'; +import {WeightsManifestConfig, WeightsManifestEntry} from './types'; const DTYPE_VALUE_SIZE_MAP: {[dtype: string]: number} = { 'float32': 4, @@ -58,8 +48,7 @@ export async function loadWeights( const groupIndicesToFetchMap = manifest.map(() => false); const groupWeightsToFetch: { [group: number]: Array<{ - manifestEntry: WeightsManifestEntry; - groupOffset: number; + manifestEntry: WeightsManifestEntry; groupOffset: number; sizeBytes: number; }> } = {}; diff --git a/src/weights_loader_test.ts b/src/io/weights_loader_test.ts similarity index 98% rename from src/weights_loader_test.ts rename to src/io/weights_loader_test.ts index 94655db8ff..d34b2bce11 100644 --- a/src/weights_loader_test.ts +++ b/src/io/weights_loader_test.ts @@ -14,10 +14,10 @@ * limitations under the License. * ============================================================================= */ -import * as tf from './index'; -import {CPU_ENVS, expectArraysClose} from './test_util'; -import {describeWithFlags} from './jasmine_util'; -import {WeightsManifestConfig} from './weights_loader'; +import * as tf from '../index'; +import {describeWithFlags} from '../jasmine_util'; +import {CPU_ENVS, expectArraysClose} from '../test_util'; +import {WeightsManifestConfig} from './types'; describeWithFlags('loadWeights', CPU_ENVS, () => { const setupFakeWeightFiles = From d1a03ab9047632e00167888e1da6acc7e2031e64 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Tue, 24 Apr 2018 12:11:44 -0400 Subject: [PATCH 02/12] WIP1 --- src/index.ts | 4 +- src/io/io_utils.ts | 7 +- src/io/io_utils_test.ts | 283 ++++++++++++++++++++++++++++++++++------ src/io/types.ts | 121 +++++++++++++++-- 4 files changed, 359 insertions(+), 56 deletions(-) diff --git a/src/index.ts b/src/index.ts index 0f0c423750..10b377f7e5 100644 --- a/src/index.ts +++ b/src/index.ts @@ -25,9 +25,9 @@ import * as util from './util'; import {version} from './version'; // Serialization. -// tslint:disable-next-line:max-line-length export {decodeTensors, encodeTensors} from './io/io_utils'; -export {WeightsManifestConfig} from './io/types'; +// tslint:disable-next-line:max-line-length +export {IOHandler, LoadHandler, ModelArtifacts, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './io/types'; export {loadWeights} from './io/weights_loader'; // Optimizers. export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer'; diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index 5d975e7eb8..28bd8aa59a 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -38,7 +38,8 @@ export async function encodeTensors(tensors: NamedTensorMap): for (const name in tensors) { const tensor = tensors[name]; - if (tensor.dtype !== 'float32' && tensor.dtype !== 'int32') { + if (tensor.dtype !== 'float32' && tensor.dtype !== 'int32' && + tensor.dtype !== 'bool') { throw new Error(`Unsupported dtype: ${tensor.dtype}`); } specs.push({name, shape: tensor.shape, dtype: tensor.dtype}); @@ -76,7 +77,7 @@ export function decodeTensors( throw new Error(`Unsupported dtype: ${dtype}`); } out[name] = value; - offset += bytes; + offset += bytes; } return out; } @@ -105,7 +106,7 @@ export function concatenateTypedArrays( } else if (x as any instanceof Uint8Array) { totalByteLength += x.length; } else { - throw new Error(`Unsupported type array subtype: ${x.constructor.name}`); + throw new Error(`Unsupported TypedArray subtype: ${x.constructor.name}`); } } diff --git a/src/io/io_utils_test.ts b/src/io/io_utils_test.ts index 7b526da746..01dc7a7458 100644 --- a/src/io/io_utils_test.ts +++ b/src/io/io_utils_test.ts @@ -15,10 +15,11 @@ * ============================================================================= */ -import {scalar, tensor1d, tensor2d, test_util} from '../index'; +import {scalar, tensor1d, tensor2d} from '../index'; +import {expectArraysEqual} from '../test_util'; import {NamedTensorMap} from '../types'; -import {concatenateTypedArrays, encodeTensors} from './io_utils'; +import {concatenateTypedArrays, decodeTensors, encodeTensors} from './io_utils'; // import {WeightsManifestEntry} from './types'; @@ -26,85 +27,287 @@ describe('concatenateTypedArrays', () => { it('Single float arrays', () => { const x = new Float32Array([1.1, 2.2, 3.3]); const buffer = concatenateTypedArrays([x]); - expect(buffer.byteLength).toEqual(12); - const z = Array.from(new Float32Array(buffer, 0, 3)); - test_util.expectArraysClose(z, [1.1, 2.2, 3.3]); + expect(new Float32Array(buffer, 0, 3)).toEqual(x); }); it('Float arrays', () => { const x = new Float32Array([1.1, 2.2, 3.3]); const y = new Float32Array([-1.1, -2.2, -3.3]); const buffer = concatenateTypedArrays([x, y]); - expect(buffer.byteLength).toEqual(24); - const z = Array.from(new Float32Array(buffer, 0, 6)); - test_util.expectArraysClose(z, [1.1, 2.2, 3.3, -1.1, -2.2, -3.3]); + expect(new Float32Array(buffer, 0, 3)).toEqual(x); + expect(new Float32Array(buffer, 12, 3)).toEqual(y); }); - it('Single int32 arrays', () => { const x = new Int32Array([11, 22, 33]); const buffer = concatenateTypedArrays([x]); - expect(buffer.byteLength).toEqual(12); - const z = Array.from(new Int32Array(buffer, 0, 3)); - test_util.expectArraysClose(z, [11, 22, 33]); + expect(new Int32Array(buffer, 0, 3)).toEqual(x); }); it('Int32 arrays', () => { const x = new Int32Array([11, 22, 33]); const y = new Int32Array([-11, -22, -33]); const buffer = concatenateTypedArrays([x, y]); - expect(buffer.byteLength).toEqual(24); - const z = Array.from(new Int32Array(buffer, 0, 6)); - test_util.expectArraysClose(z, [11, 22, 33, -11, -22, -33]); + expect(new Int32Array(buffer, 0, 3)).toEqual(x); + expect(new Int32Array(buffer, 12, 3)).toEqual(y); }); it('Single uint8 arrays', () => { const x = new Uint8Array([11, 22, 33]); const buffer = concatenateTypedArrays([x]); - expect(buffer.byteLength).toEqual(3); - const z = Array.from(new Uint8Array(buffer, 0, 3)); - test_util.expectArraysClose(z, [11, 22, 33]); + expect(new Uint8Array(buffer, 0, 3)).toEqual(x); }); it('Uint8 arrays', () => { const x = new Uint8Array([11, 22, 33]); const y = new Uint8Array([111, 122, 133]); const buffer = concatenateTypedArrays([x, y]); - expect(buffer.byteLength).toEqual(6); - const z = Array.from(new Uint8Array(buffer, 0, 6)); - test_util.expectArraysClose(z, [11, 22, 33, 111, 122, 133]); + expect(new Uint8Array(buffer, 0, 3)).toEqual(x); + expect(new Uint8Array(buffer, 3, 3)).toEqual(y); + }); + + it('Mixed Uint8, Int32 and Float32 arrays', () => { + const x = new Uint8Array([0, 1, 1, 0]); + const y = new Int32Array([10, 20, 30, 40]); + const z = new Float32Array([-1.1, -2.2, -3.3, -4.4]); + const buffer = concatenateTypedArrays([x, y, z]); + expect(buffer.byteLength).toEqual(1 * 4 + 4 * 4 + 4 * 4); + expect(new Uint8Array(buffer, 0, 4)).toEqual(x); + expect(new Int32Array(buffer, 4, 4)).toEqual(y); + expect(new Float32Array(buffer, 20, 4)).toEqual(z); + }); + + it('null and undefined inputs', () => { + expect(concatenateTypedArrays(null)).toEqual(null); + expect(concatenateTypedArrays(undefined)).toEqual(undefined); + }); + + it('empty input array', () => { + expect(concatenateTypedArrays([]).byteLength).toEqual(0); + }); + + it('Unsupported dtype', () => { + const x = new Int16Array([0, 1, 1, 0]); + // tslint:disable-next-line:no-any + expect(() => concatenateTypedArrays([x as any])) + .toThrowError(/Unsupported TypedArray subtype: Int16Array/); }); }); describe('encodeTensors', () => { it('Float32 tensors', async done => { - // const entries: WeightsManifestEntry[] = [{ - // name: 'x1', - // dtype: 'float32', - // shape: [2, 2], - // }, { - // name: 'x2', - // dtype: 'float32', - // shape: [], - // }, { - // name: 'x3', - // dtype: 'float32', - // shape: [3], - // }]; const tensors: NamedTensorMap = { x1: tensor2d([[10, 20], [30, 40]]), x2: scalar(42), - x3: tensor1d([-1, -3, -3, -7]), + x3: tensor1d([-1.3, -3.7, 1.3, 3.7]), + }; + encodeTensors(tensors) + .then(dataAndSpecs => { + const data = dataAndSpecs[0]; + const specs = dataAndSpecs[1]; + expect(data.byteLength).toEqual(4 * (4 + 1 + 4)); + expect(new Float32Array(data, 0, 4)).toEqual(new Float32Array([ + 10, 20, 30, 40 + ])); + expect(new Float32Array(data, 16, 1)).toEqual(new Float32Array([42])); + expect(new Float32Array(data, 20, 4)).toEqual(new Float32Array([ + -1.3, -3.7, 1.3, 3.7 + ])); + expect(specs).toEqual([ + { + name: 'x1', + dtype: 'float32', + shape: [2, 2], + }, + { + name: 'x2', + dtype: 'float32', + shape: [], + }, + { + name: 'x3', + dtype: 'float32', + shape: [4], + } + ]); + done(); + }) + .catch(err => { + console.error(err.stack); + }); + }); + + it('Int32 tensors', async done => { + const tensors: NamedTensorMap = { + x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'), + x2: scalar(42, 'int32'), + x3: tensor1d([-1, -3, -3, -7], 'int32'), }; - encodeTensors(tensors).then(dataAndManifest => { - const data = dataAndManifest[0]; - expect(data.byteLength).toEqual(4 * (4 + 1 + 4)); - done(); - }); + encodeTensors(tensors) + .then(dataAndSpecs => { + const data = dataAndSpecs[0]; + const specs = dataAndSpecs[1]; + expect(data.byteLength).toEqual(4 * (4 + 1 + 4)); + expect(new Int32Array(data, 0, 4)).toEqual(new Int32Array([ + 10, 20, 30, 40 + ])); + expect(new Int32Array(data, 16, 1)).toEqual(new Int32Array([42])); + expect(new Int32Array(data, 20, 4)).toEqual(new Int32Array([ + -1, -3, -3, -7 + ])); + expect(specs).toEqual([ + { + name: 'x1', + dtype: 'int32', + shape: [2, 2], + }, + { + name: 'x2', + dtype: 'int32', + shape: [], + }, + { + name: 'x3', + dtype: 'int32', + shape: [4], + } + ]); + done(); + }) + .catch(err => { + console.error(err.stack); + }); + }); + + it('Bool tensors', async done => { + const tensors: NamedTensorMap = { + x1: tensor2d([[true, false], [false, true]], [2, 2], 'bool'), + x2: scalar(false, 'bool'), + x3: tensor1d([false, true, true, false], 'bool'), + }; + encodeTensors(tensors) + .then(dataAndSpecs => { + const data = dataAndSpecs[0]; + const specs = dataAndSpecs[1]; + expect(data.byteLength).toEqual(4 + 1 + 4); + expect(new Uint8Array(data, 0, 4)).toEqual(new Uint8Array([ + 1, 0, 0, 1 + ])); + expect(new Uint8Array(data, 4, 1)).toEqual(new Uint8Array([0])); + expect(new Uint8Array(data, 5, 4)).toEqual(new Uint8Array([ + 0, 1, 1, 0 + ])); + expect(specs).toEqual([ + { + name: 'x1', + dtype: 'bool', + shape: [2, 2], + }, + { + name: 'x2', + dtype: 'bool', + shape: [], + }, + { + name: 'x3', + dtype: 'bool', + shape: [4], + } + ]); + done(); + }) + .catch(err => { + console.error(err.stack); + }); + }); + + it('Mixed dtype tensors', async done => { + const tensors: NamedTensorMap = { + x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'), + x2: scalar(13.37, 'float32'), + x3: tensor1d([true, false, false, true], 'bool'), + }; + encodeTensors(tensors) + .then(dataAndSpecs => { + const data = dataAndSpecs[0]; + const specs = dataAndSpecs[1]; + expect(data.byteLength).toEqual(4 * 4 + 4 * 1 + 1 * 4); + expect(new Int32Array(data, 0, 4)).toEqual(new Int32Array([ + 10, 20, 30, 40 + ])); + expect(new Float32Array(data, 16, 1)) + .toEqual(new Float32Array([13.37])); + expect(new Uint8Array(data, 20, 4)).toEqual(new Uint8Array([ + 1, 0, 0, 1 + ])); + expect(specs).toEqual([ + { + name: 'x1', + dtype: 'int32', + shape: [2, 2], + }, + { + name: 'x2', + dtype: 'float32', + shape: [], + }, + { + name: 'x3', + dtype: 'bool', + shape: [4], + } + ]); + done(); + }) + .catch(err => { + console.error(err.stack); + }); + }); +}); + +describe('decodeTensors', () => { + it('Mixed dtype tensors', async done => { + const tensors: NamedTensorMap = { + x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'), + x2: scalar(13.37, 'float32'), + x3: tensor1d([true, false, false, true], 'bool'), + y1: tensor2d([-10, -20, -30], [3, 1], 'float32'), + }; + encodeTensors(tensors) + .then(dataAndSpecs => { + const data = dataAndSpecs[0]; + const specs = dataAndSpecs[1]; + expect(data.byteLength).toEqual(4 * 4 + 4 * 1 + 1 * 4 + 4 * 3); + const decoded = decodeTensors(data, specs); + expect(Object.keys(decoded).length).toEqual(4); + expectArraysEqual(decoded['x1'], tensors['x1']); + expectArraysEqual(decoded['x2'], tensors['x2']); + expectArraysEqual(decoded['x3'], tensors['x3']); + expectArraysEqual(decoded['y1'], tensors['y1']); + done(); + }) + .catch(err => { + console.error(err.stack); + }); + }); + + it('Unsupported dtype raises Error', () => { + const buffer = new ArrayBuffer(4); + // tslint:disable-next-line:no-any + const specs: any = [ + { + name: 'x', + dtype: 'int16', + shape: [], + }, + {name: 'y', dtype: 'int16', shape: []} + ]; + expect(() => decodeTensors(buffer, specs)) + .toThrowError(/Unsupported dtype: int16/); }); }); diff --git a/src/io/types.ts b/src/io/types.ts index 2f04f906d7..0cad50d7b2 100644 --- a/src/io/types.ts +++ b/src/io/types.ts @@ -15,33 +15,132 @@ * ============================================================================= */ +/* Type definitions for exporting and importing of models. */ + +/** + * A weight manifest. + * + * The weight manifest consists of an ordered list of weight-manifest groups. + * Each weight-manifest group ("group" for short hereafter) consists of a + * number of weight values stored in a number of paths. + * See the documentation of `WeightManifestGroupConfig` below for more details. + */ export type WeightsManifestConfig = WeightsManifestGroupConfig[]; +/** + * A weight-manifest group. + * + * Consists of an ordered list of weight values encoded in binary format, + * sotred in an ordered list of paths. + */ +export interface WeightsManifestGroupConfig { + /** + * An ordered list of paths. + * + * Paths are intentionally abstract in order to be general. For example, they + * can be relative URL paths or relative paths on the file system. + */ + paths: string[]; + + /** + * Specifications of the weights stored in the paths. + */ + weights: WeightsManifestEntry[]; +} + +/** + * An entry in the weight manifest. + * + * The entry contains specification of a weight. + */ export interface WeightsManifestEntry { + /** + * Name of the weight, e.g., 'Dense_1/bias' + */ name: string; + + /** + * Shape of the weight. + */ shape: number[]; - dtype: 'float32'|'int32'|'bool'; -} -export interface WeightsManifestGroupConfig { - paths: string[]; - weights: WeightsManifestEntry[]; + /** + * Data type of the weight. + */ + dtype: 'float32'|'int32'|'bool'; } +/** + * Result of a saving operation. + */ export class SaveResult { + /** + * Whether the saving was successful. + */ success: boolean; + /** + * HTTP responses from the server that handled the model-saving request (if + * any). This is applicable only to server-based saving routes. + */ resposnes?: Response[]; + /** + * Error messages and related data (if any). + */ errors?: Array<{}|string>; } -export interface ModelArtifact { - modelTopology: {}|ArrayBuffer; - weightsManifest: WeightsManifestConfig; - weightsData: ArrayBuffer; +/** + * The serialized artifacts of a model, including topology and weights. + * + * The `modelTopology`, `weightSpecs` and `weightData` fields of this interface + * are optional, in order to support topology- or weights-only saving and + * loading. + */ +export interface ModelArtifacts { + /** + * Model topology. + * + * For Keras-style `tf.Model`s, this is a JSON object. + * For TensorFlow-style models (e.g., `FrozenModel`), this is a binary buffer + * carrying the `GraphDef` protocol buffer. + */ + modelTopology?: {}|ArrayBuffer; + + /** + * Weight specifications. + * + * This corresponds to the weightsData below. + */ + weightSpecs?: WeightsManifestEntry[]; + + /** + * Binary buffer for all weight values concatenated in the order specified + * by `weightSpecs`. + */ + weightData?: ArrayBuffer; } -export type SaveHandler = (modelArtifact: ModelArtifact) => Promise; +/** + * Type definition for handlers of loading opertaions. + */ +export type LoadHandler = () => Promise; + +/** + * Type definition for handlers of saving opertaions. + */ +export type SaveHandler = (modelArtifact: ModelArtifacts) => + Promise; -export type LoadHandler = () => Promise; +/** + * Interface for a model import/export handler. + * + * The `save` and `load` handlers are both optional, in order to allow handlers + * that support only saving or loading. + */ +// tslint:disable-next-line:interface-name +export interface IOHandler { + save?: SaveHandler; + load?: LoadHandler; +} From 711365fa0cfcb223ca6f75b137ccbf2b21335d08 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Tue, 24 Apr 2018 12:17:56 -0400 Subject: [PATCH 03/12] WIP2 --- src/io/io_utils.ts | 11 +++++++++++ src/io/types.ts | 19 ++++++++++++++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index 28bd8aa59a..3376d6947b 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -49,6 +49,17 @@ export async function encodeTensors(tensors: NamedTensorMap): return [concatenateTypedArrays(tensorValues), specs]; } +/** + * Decode flat ArrayBuffer as named Tensors. + * + * @param buffer A flat ArrayBuffer carrying the binary values of the tensors + * concatenated in the order specified in `specs`. + * @param specs Specifications of the names, dtypes and shapes of the tensors + * whose value are encoded by `buffer`. + * @return A map from tensor name to tensor value, with the names corresponding + * to names in `specs`. + * @throws Error, if any of the tensors has unsupported dtype. + */ export function decodeTensors( buffer: ArrayBuffer, specs: WeightsManifestEntry[]): NamedTensorMap { const out: NamedTensorMap = {}; diff --git a/src/io/types.ts b/src/io/types.ts index 0cad50d7b2..3cc0c47976 100644 --- a/src/io/types.ts +++ b/src/io/types.ts @@ -31,7 +31,7 @@ export type WeightsManifestConfig = WeightsManifestGroupConfig[]; * A weight-manifest group. * * Consists of an ordered list of weight values encoded in binary format, - * sotred in an ordered list of paths. + * stored in an ordered list of paths. */ export interface WeightsManifestGroupConfig { /** @@ -70,12 +70,25 @@ export interface WeightsManifestEntry { dtype: 'float32'|'int32'|'bool'; } +/** + * Options for saving a model. + */ +export interface SaveConfig { + /** + * Whether to save only the trainable weights of the model, ignoring the + * untrainable ones. + */ + trainableOnly?: boolean; + + // TODO(cais): Add more fields if necessary. +} + /** * Result of a saving operation. */ -export class SaveResult { +export interface SaveResult { /** - * Whether the saving was successful. + * Whether the saving succeeded. */ success: boolean; From 4172414b77b664776579db4b778fb375c1e0578c Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 25 Apr 2018 00:05:40 -0400 Subject: [PATCH 04/12] Respond to review comments --- src/index.ts | 3 +- src/io/io_utils.ts | 53 ++++++++++++++++------------------- src/io/io_utils_test.ts | 52 +++++++++++++++++----------------- src/io/types.ts | 6 ++-- src/io/weights_loader.ts | 3 +- src/io/weights_loader_test.ts | 2 +- 6 files changed, 58 insertions(+), 61 deletions(-) diff --git a/src/index.ts b/src/index.ts index a32293f7a1..9b8ee61f50 100644 --- a/src/index.ts +++ b/src/index.ts @@ -31,10 +31,11 @@ import * as util from './util'; import {version} from './version'; // Serialization. -export {decodeTensors, encodeTensors} from './io/io_utils'; +export {decodeWeights, encodeWeights} from './io/io_utils'; // tslint:disable-next-line:max-line-length export {IOHandler, LoadHandler, ModelArtifacts, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './io/types'; export {loadWeights} from './io/weights_loader'; + // Optimizers. export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer'; export {AdagradOptimizer} from './optimizers/adagrad_optimizer'; diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index 3376d6947b..0daca264de 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google Inc. All Rights Reserved. + * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -17,11 +17,14 @@ import {tensor} from '../index'; import {Tensor} from '../tensor'; -import {NamedTensorMap} from '../types'; +import {NamedTensorMap, TypedArray} from '../types'; +import {sizeFromShape} from '../util'; + import {WeightsManifestEntry} from './types'; /** - * Encode a map from names to Tensors as an ArrayBuffer. + * Encode a map from names to weight values as an ArrayBuffer, along with an + * `Array` of `WeightsManifestEntry` as specification of the encoded weights. * * @param tensors A map ("dict") from names to tensors. * @returns A `Promise` of @@ -31,26 +34,26 @@ import {WeightsManifestEntry} from './types'; * tensor names, `dtype`s and shapes. * @throws Error: on unsupported tensor `dtype`. */ -export async function encodeTensors(tensors: NamedTensorMap): - Promise<[ArrayBuffer, WeightsManifestEntry[]]> { +export async function encodeWeights(tensors: NamedTensorMap): + Promise<{data: ArrayBuffer, specs: WeightsManifestEntry[]}> { + // TODO(adarob, cais): Support quantization. const specs: WeightsManifestEntry[] = []; - const dataPromises: Array> = []; + const dataPromises: Array> = []; for (const name in tensors) { - const tensor = tensors[name]; + const t = tensors[name]; - if (tensor.dtype !== 'float32' && tensor.dtype !== 'int32' && - tensor.dtype !== 'bool') { - throw new Error(`Unsupported dtype: ${tensor.dtype}`); + if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool') { + throw new Error(`Unsupported dtype in weight '${name}': ${t.dtype}`); } - specs.push({name, shape: tensor.shape, dtype: tensor.dtype}); - dataPromises.push(tensor.data()); + specs.push({name, shape: t.shape, dtype: t.dtype}); + dataPromises.push(t.data()); } const tensorValues = await Promise.all(dataPromises); - return [concatenateTypedArrays(tensorValues), specs]; + return {data: concatenateTypedArrays(tensorValues), specs}; } /** - * Decode flat ArrayBuffer as named Tensors. + * Decode flat ArrayBuffer as weights. * * @param buffer A flat ArrayBuffer carrying the binary values of the tensors * concatenated in the order specified in `specs`. @@ -60,8 +63,9 @@ export async function encodeTensors(tensors: NamedTensorMap): * to names in `specs`. * @throws Error, if any of the tensors has unsupported dtype. */ -export function decodeTensors( +export function decodeWeights( buffer: ArrayBuffer, specs: WeightsManifestEntry[]): NamedTensorMap { + // TODO(adarob, cais): Support quantization. const out: NamedTensorMap = {}; let offset = 0; for (const spec of specs) { @@ -69,10 +73,7 @@ export function decodeTensors( const dtype = spec.dtype; const shape = spec.shape; - let numel = 1; - for (const dim of shape) { - numel *= dim; - } + const numel = sizeFromShape(shape); let bytes: number; let value: Tensor; if (dtype === 'float32') { @@ -85,7 +86,7 @@ export function decodeTensors( bytes = numel; value = tensor(new Uint8Array(buffer, offset, numel), shape, 'bool'); } else { - throw new Error(`Unsupported dtype: ${dtype}`); + throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); } out[name] = value; offset += bytes; @@ -96,16 +97,10 @@ export function decodeTensors( /** * Concatenate TypedArrays into an ArrayBuffer. */ -export function concatenateTypedArrays( - xs: Array): ArrayBuffer { +export function concatenateTypedArrays(xs: TypedArray[]): ArrayBuffer { + // TODO(adarob, cais): Support quantization. if (xs === null) { - return null; - } - if (xs === undefined) { - return undefined; - } - if (xs.length === 0) { - return new ArrayBuffer(0); + throw new Error(`Invalid input value: ${JSON.stringify(xs)}`); } let totalByteLength = 0; diff --git a/src/io/io_utils_test.ts b/src/io/io_utils_test.ts index 01dc7a7458..ff07b8d466 100644 --- a/src/io/io_utils_test.ts +++ b/src/io/io_utils_test.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google Inc. All Rights Reserved. + * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -15,11 +15,13 @@ * ============================================================================= */ -import {scalar, tensor1d, tensor2d} from '../index'; +import {decodeWeights, encodeWeights} from '../index'; + +import {scalar, tensor1d, tensor2d} from '../ops/ops'; import {expectArraysEqual} from '../test_util'; import {NamedTensorMap} from '../types'; -import {concatenateTypedArrays, decodeTensors, encodeTensors} from './io_utils'; +import {concatenateTypedArrays} from './io_utils'; // import {WeightsManifestEntry} from './types'; @@ -83,8 +85,8 @@ describe('concatenateTypedArrays', () => { }); it('null and undefined inputs', () => { - expect(concatenateTypedArrays(null)).toEqual(null); - expect(concatenateTypedArrays(undefined)).toEqual(undefined); + expect(() => concatenateTypedArrays(null)).toThrow(); + expect(() => concatenateTypedArrays(undefined)).toThrow(); }); it('empty input array', () => { @@ -99,17 +101,17 @@ describe('concatenateTypedArrays', () => { }); }); -describe('encodeTensors', () => { +describe('encodeWeights', () => { it('Float32 tensors', async done => { const tensors: NamedTensorMap = { x1: tensor2d([[10, 20], [30, 40]]), x2: scalar(42), x3: tensor1d([-1.3, -3.7, 1.3, 3.7]), }; - encodeTensors(tensors) + encodeWeights(tensors) .then(dataAndSpecs => { - const data = dataAndSpecs[0]; - const specs = dataAndSpecs[1]; + const data = dataAndSpecs.data; + const specs = dataAndSpecs.specs; expect(data.byteLength).toEqual(4 * (4 + 1 + 4)); expect(new Float32Array(data, 0, 4)).toEqual(new Float32Array([ 10, 20, 30, 40 @@ -148,10 +150,10 @@ describe('encodeTensors', () => { x2: scalar(42, 'int32'), x3: tensor1d([-1, -3, -3, -7], 'int32'), }; - encodeTensors(tensors) + encodeWeights(tensors) .then(dataAndSpecs => { - const data = dataAndSpecs[0]; - const specs = dataAndSpecs[1]; + const data = dataAndSpecs.data; + const specs = dataAndSpecs.specs; expect(data.byteLength).toEqual(4 * (4 + 1 + 4)); expect(new Int32Array(data, 0, 4)).toEqual(new Int32Array([ 10, 20, 30, 40 @@ -190,10 +192,10 @@ describe('encodeTensors', () => { x2: scalar(false, 'bool'), x3: tensor1d([false, true, true, false], 'bool'), }; - encodeTensors(tensors) + encodeWeights(tensors) .then(dataAndSpecs => { - const data = dataAndSpecs[0]; - const specs = dataAndSpecs[1]; + const data = dataAndSpecs.data; + const specs = dataAndSpecs.specs; expect(data.byteLength).toEqual(4 + 1 + 4); expect(new Uint8Array(data, 0, 4)).toEqual(new Uint8Array([ 1, 0, 0, 1 @@ -232,10 +234,10 @@ describe('encodeTensors', () => { x2: scalar(13.37, 'float32'), x3: tensor1d([true, false, false, true], 'bool'), }; - encodeTensors(tensors) + encodeWeights(tensors) .then(dataAndSpecs => { - const data = dataAndSpecs[0]; - const specs = dataAndSpecs[1]; + const data = dataAndSpecs.data; + const specs = dataAndSpecs.specs; expect(data.byteLength).toEqual(4 * 4 + 4 * 1 + 1 * 4); expect(new Int32Array(data, 0, 4)).toEqual(new Int32Array([ 10, 20, 30, 40 @@ -270,7 +272,7 @@ describe('encodeTensors', () => { }); }); -describe('decodeTensors', () => { +describe('decodeWeights', () => { it('Mixed dtype tensors', async done => { const tensors: NamedTensorMap = { x1: tensor2d([[10, 20], [30, 40]], [2, 2], 'int32'), @@ -278,12 +280,12 @@ describe('decodeTensors', () => { x3: tensor1d([true, false, false, true], 'bool'), y1: tensor2d([-10, -20, -30], [3, 1], 'float32'), }; - encodeTensors(tensors) + encodeWeights(tensors) .then(dataAndSpecs => { - const data = dataAndSpecs[0]; - const specs = dataAndSpecs[1]; + const data = dataAndSpecs.data; + const specs = dataAndSpecs.specs; expect(data.byteLength).toEqual(4 * 4 + 4 * 1 + 1 * 4 + 4 * 3); - const decoded = decodeTensors(data, specs); + const decoded = decodeWeights(data, specs); expect(Object.keys(decoded).length).toEqual(4); expectArraysEqual(decoded['x1'], tensors['x1']); expectArraysEqual(decoded['x2'], tensors['x2']); @@ -307,7 +309,7 @@ describe('decodeTensors', () => { }, {name: 'y', dtype: 'int16', shape: []} ]; - expect(() => decodeTensors(buffer, specs)) - .toThrowError(/Unsupported dtype: int16/); + expect(() => decodeWeights(buffer, specs)) + .toThrowError(/Unsupported dtype in weight \'x\': int16/); }); }); diff --git a/src/io/types.ts b/src/io/types.ts index 3cc0c47976..2c0398abbb 100644 --- a/src/io/types.ts +++ b/src/io/types.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google Inc. All Rights Reserved. + * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -79,8 +79,6 @@ export interface SaveConfig { * untrainable ones. */ trainableOnly?: boolean; - - // TODO(cais): Add more fields if necessary. } /** @@ -96,7 +94,7 @@ export interface SaveResult { * HTTP responses from the server that handled the model-saving request (if * any). This is applicable only to server-based saving routes. */ - resposnes?: Response[]; + responses?: Response[]; /** * Error messages and related data (if any). diff --git a/src/io/weights_loader.ts b/src/io/weights_loader.ts index db1e07d1ae..6c3d74c586 100644 --- a/src/io/weights_loader.ts +++ b/src/io/weights_loader.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google Inc. All Rights Reserved. + * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -42,6 +42,7 @@ export async function loadWeights( // single weight from a group, the whole group will be fetched. At a future // date, we should support fetching only the individual shards within a // group that are needed to reconstruct the requested weight. + // TODO(cais): Use `decodeWeights` for implementation. // Collect all the groups, weights, and their relative offsets to be // fetched. diff --git a/src/io/weights_loader_test.ts b/src/io/weights_loader_test.ts index d34b2bce11..805ac0150d 100644 --- a/src/io/weights_loader_test.ts +++ b/src/io/weights_loader_test.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2018 Google Inc. All Rights Reserved. + * Copyright 2018 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at From 748daa6f2ff951c82ac223c03d8fc67058c29054 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 25 Apr 2018 09:57:32 -0400 Subject: [PATCH 05/12] Export SaveConfig for layer use --- src/index.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/index.ts b/src/index.ts index 9b8ee61f50..f68ed0b518 100644 --- a/src/index.ts +++ b/src/index.ts @@ -33,7 +33,7 @@ import {version} from './version'; // Serialization. export {decodeWeights, encodeWeights} from './io/io_utils'; // tslint:disable-next-line:max-line-length -export {IOHandler, LoadHandler, ModelArtifacts, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './io/types'; +export {IOHandler, LoadHandler, ModelArtifacts, SaveConfig, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './io/types'; export {loadWeights} from './io/weights_loader'; // Optimizers. From a70d3a0ea1e23b813319692d8adc53298f14ab7f Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 25 Apr 2018 12:48:21 -0400 Subject: [PATCH 06/12] Respond to further review comments. --- src/io/io_utils.ts | 37 ++++++++++++++++++++++++------------- src/io/io_utils_test.ts | 2 -- src/io/types.ts | 8 ++++++++ src/io/weights_loader.ts | 9 +++------ 4 files changed, 35 insertions(+), 21 deletions(-) diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index 0daca264de..d755c18f37 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {tensor} from '../index'; +import {ArrayOps} from '../ops/array_ops'; import {Tensor} from '../tensor'; import {NamedTensorMap, TypedArray} from '../types'; import {sizeFromShape} from '../util'; @@ -26,6 +26,10 @@ import {WeightsManifestEntry} from './types'; * Encode a map from names to weight values as an ArrayBuffer, along with an * `Array` of `WeightsManifestEntry` as specification of the encoded weights. * + * This function does not perform sharding. + * + * This function is the reverse of `decodeWeights`. + * * @param tensors A map ("dict") from names to tensors. * @returns A `Promise` of * - A flat `ArrayBuffer` with all the binary values of the `Tensor`s @@ -55,6 +59,10 @@ export async function encodeWeights(tensors: NamedTensorMap): /** * Decode flat ArrayBuffer as weights. * + * This function does not handle sharding. + * + * This function is the reverse of `encodeWeights`. + * * @param buffer A flat ArrayBuffer carrying the binary values of the tensors * concatenated in the order specified in `specs`. * @param specs Specifications of the names, dtypes and shapes of the tensors @@ -73,18 +81,21 @@ export function decodeWeights( const dtype = spec.dtype; const shape = spec.shape; - const numel = sizeFromShape(shape); + const size = sizeFromShape(shape); let bytes: number; let value: Tensor; if (dtype === 'float32') { - bytes = numel * 4; - value = tensor(new Float32Array(buffer, offset, numel), shape, 'float32'); + bytes = size * 4; + value = ArrayOps.tensor( + new Float32Array(buffer, offset, size), shape, 'float32'); } else if (dtype === 'int32') { - bytes = numel * 4; - value = tensor(new Int32Array(buffer, offset, numel), shape, 'int32'); + bytes = size * 4; + value = + ArrayOps.tensor(new Int32Array(buffer, offset, size), shape, 'int32'); } else if (dtype === 'bool') { - bytes = numel; - value = tensor(new Uint8Array(buffer, offset, numel), shape, 'bool'); + bytes = size; + value = + ArrayOps.tensor(new Uint8Array(buffer, offset, size), shape, 'bool'); } else { throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); } @@ -104,9 +115,9 @@ export function concatenateTypedArrays(xs: TypedArray[]): ArrayBuffer { } let totalByteLength = 0; - for (const x of xs) { + xs.forEach(x => { // tslint:disable-next-line:no-any - if (x as any instanceof Float32Array || x instanceof Int32Array) { + if (x as any instanceof Float32Array || x as any instanceof Int32Array) { totalByteLength += x.length * 4; // tslint:disable-next-line:no-any } else if (x as any instanceof Uint8Array) { @@ -114,18 +125,18 @@ export function concatenateTypedArrays(xs: TypedArray[]): ArrayBuffer { } else { throw new Error(`Unsupported TypedArray subtype: ${x.constructor.name}`); } - } + }); const y = new Uint8Array(totalByteLength); let offset = 0; - for (const x of xs) { + xs.forEach(x => { y.set(new Uint8Array(x.buffer), offset); if (x instanceof Float32Array || x instanceof Int32Array) { offset += x.length * 4; } else { offset += x.length; } - } + }); return y.buffer; } diff --git a/src/io/io_utils_test.ts b/src/io/io_utils_test.ts index ff07b8d466..507c76ca73 100644 --- a/src/io/io_utils_test.ts +++ b/src/io/io_utils_test.ts @@ -23,8 +23,6 @@ import {NamedTensorMap} from '../types'; import {concatenateTypedArrays} from './io_utils'; -// import {WeightsManifestEntry} from './types'; - describe('concatenateTypedArrays', () => { it('Single float arrays', () => { const x = new Float32Array([1.1, 2.2, 3.3]); diff --git a/src/io/types.ts b/src/io/types.ts index 2c0398abbb..6c35e9cc8b 100644 --- a/src/io/types.ts +++ b/src/io/types.ts @@ -17,6 +17,14 @@ /* Type definitions for exporting and importing of models. */ +/** + * A map from Tensor dtype to number of bytes per element of the Tensor. + */ +export const DTYPE_VALUE_SIZE_MAP: {[dtype: string]: number} = { + 'float32': 4, + 'int32': 4 +}; + /** * A weight manifest. * diff --git a/src/io/weights_loader.ts b/src/io/weights_loader.ts index 6c3d74c586..da650b040a 100644 --- a/src/io/weights_loader.ts +++ b/src/io/weights_loader.ts @@ -15,15 +15,12 @@ * ============================================================================= */ +// tslint:disable:max-line-length import {tensor} from '../ops/ops'; import {NamedTensorMap} from '../types'; import * as util from '../util'; -import {WeightsManifestConfig, WeightsManifestEntry} from './types'; - -const DTYPE_VALUE_SIZE_MAP: {[dtype: string]: number} = { - 'float32': 4, - 'int32': 4 -}; +import {DTYPE_VALUE_SIZE_MAP, WeightsManifestConfig, WeightsManifestEntry} from './types'; +// tslint:enable:max-line-length /** * Reads a weights manifest JSON configuration, fetches the weights and From b2cd37beb25f3e29342d2e851e645d842a4a756f Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 25 Apr 2018 15:05:30 -0400 Subject: [PATCH 07/12] Add not-implemented Error for quantization in decodeWeights --- src/io/io_utils.ts | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index 625fb47769..e64f1ae7e0 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -81,6 +81,12 @@ export function decodeWeights( const dtype = spec.dtype; const shape = spec.shape; + if (spec.quantization != null) { + throw new Error( + `decodeWeights does not support quantization yet, but encountered ` + + `weight '${name} wit quantization.'`); + } + const size = sizeFromShape(shape); let bytes: number; let value: Tensor; From f9f28a111742d0b94a3066cc0047ff773fcb32f1 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 25 Apr 2018 15:47:13 -0400 Subject: [PATCH 08/12] Add missing bool field to DTYPE_VALUE_SIZE_MAP --- src/io/io_utils.ts | 5 ++--- src/io/types.ts | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index e64f1ae7e0..ece238a4f0 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -88,9 +88,7 @@ export function decodeWeights( } const size = sizeFromShape(shape); - let bytes: number; let value: Tensor; - bytes = size * DTYPE_VALUE_SIZE_MAP[dtype]; if (dtype === 'float32') { value = ArrayOps.tensor( new Float32Array(buffer, offset, size), shape, 'float32'); @@ -104,7 +102,8 @@ export function decodeWeights( throw new Error(`Unsupported dtype in weight '${name}': ${dtype}`); } out[name] = value; - offset += bytes; + + offset += size * DTYPE_VALUE_SIZE_MAP[dtype]; } return out; } diff --git a/src/io/types.ts b/src/io/types.ts index fa4cdbe41f..85d540a585 100644 --- a/src/io/types.ts +++ b/src/io/types.ts @@ -24,7 +24,8 @@ export const DTYPE_VALUE_SIZE_MAP: {[dtype: string]: number} = { 'float32': 4, 'int32': 4, 'uint16': 2, - 'uint8': 1 + 'uint8': 1, + 'bool': 1, }; /** From eeae1dbd9acee5325d97313590604fbcce517cc7 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 25 Apr 2018 16:36:20 -0400 Subject: [PATCH 09/12] Move serialization related types and functions to tf.io.* Including loadWeights. --- src/index.ts | 9 +++------ src/io/io.ts | 36 +++++++++++++++++++++++++++++++++++ src/io/io_utils_test.ts | 16 ++++++++-------- src/io/weights_loader_test.ts | 33 ++++++++++++++++---------------- 4 files changed, 64 insertions(+), 30 deletions(-) create mode 100644 src/io/io.ts diff --git a/src/index.ts b/src/index.ts index f68ed0b518..c80b4bce0b 100644 --- a/src/index.ts +++ b/src/index.ts @@ -24,18 +24,14 @@ import './kernels/backend_cpu'; import {BrowserUtil} from './browser_util'; import * as environment from './environment'; import {Environment} from './environment'; +// Serialization. +import * as io from './io/io'; import * as gpgpu_util from './kernels/webgl/gpgpu_util'; import * as webgl_util from './kernels/webgl/webgl_util'; import * as test_util from './test_util'; import * as util from './util'; import {version} from './version'; -// Serialization. -export {decodeWeights, encodeWeights} from './io/io_utils'; -// tslint:disable-next-line:max-line-length -export {IOHandler, LoadHandler, ModelArtifacts, SaveConfig, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './io/types'; -export {loadWeights} from './io/weights_loader'; - // Optimizers. export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer'; export {AdagradOptimizer} from './optimizers/adagrad_optimizer'; @@ -48,6 +44,7 @@ export {SGDOptimizer} from './optimizers/sgd_optimizer'; // tslint:disable-next-line:max-line-length export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer, variable, Variable} from './tensor'; export {DataType, Rank, ShapeMap} from './types'; +export {io}; export * from './ops/ops'; export {LSTMCellFunc} from './ops/lstm'; diff --git a/src/io/io.ts b/src/io/io.ts new file mode 100644 index 0000000000..4739af4ab0 --- /dev/null +++ b/src/io/io.ts @@ -0,0 +1,36 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {decodeWeights, encodeWeights} from './io_utils'; +// tslint:disable-next-line:max-line-length +import {IOHandler, LoadHandler, ModelArtifacts, SaveConfig, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types'; +import {loadWeights} from './weights_loader'; + +// tslint:disable-next-line:max-line-length +export { + decodeWeights, + encodeWeights, + IOHandler, + LoadHandler, + loadWeights, + ModelArtifacts, + SaveConfig, + SaveHandler, + SaveResult, + WeightsManifestConfig, + WeightsManifestEntry +}; diff --git a/src/io/io_utils_test.ts b/src/io/io_utils_test.ts index 507c76ca73..81084370b8 100644 --- a/src/io/io_utils_test.ts +++ b/src/io/io_utils_test.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {decodeWeights, encodeWeights} from '../index'; +import * as tf from '../index'; import {scalar, tensor1d, tensor2d} from '../ops/ops'; import {expectArraysEqual} from '../test_util'; @@ -106,7 +106,7 @@ describe('encodeWeights', () => { x2: scalar(42), x3: tensor1d([-1.3, -3.7, 1.3, 3.7]), }; - encodeWeights(tensors) + tf.io.encodeWeights(tensors) .then(dataAndSpecs => { const data = dataAndSpecs.data; const specs = dataAndSpecs.specs; @@ -148,7 +148,7 @@ describe('encodeWeights', () => { x2: scalar(42, 'int32'), x3: tensor1d([-1, -3, -3, -7], 'int32'), }; - encodeWeights(tensors) + tf.io.encodeWeights(tensors) .then(dataAndSpecs => { const data = dataAndSpecs.data; const specs = dataAndSpecs.specs; @@ -190,7 +190,7 @@ describe('encodeWeights', () => { x2: scalar(false, 'bool'), x3: tensor1d([false, true, true, false], 'bool'), }; - encodeWeights(tensors) + tf.io.encodeWeights(tensors) .then(dataAndSpecs => { const data = dataAndSpecs.data; const specs = dataAndSpecs.specs; @@ -232,7 +232,7 @@ describe('encodeWeights', () => { x2: scalar(13.37, 'float32'), x3: tensor1d([true, false, false, true], 'bool'), }; - encodeWeights(tensors) + tf.io.encodeWeights(tensors) .then(dataAndSpecs => { const data = dataAndSpecs.data; const specs = dataAndSpecs.specs; @@ -278,12 +278,12 @@ describe('decodeWeights', () => { x3: tensor1d([true, false, false, true], 'bool'), y1: tensor2d([-10, -20, -30], [3, 1], 'float32'), }; - encodeWeights(tensors) + tf.io.encodeWeights(tensors) .then(dataAndSpecs => { const data = dataAndSpecs.data; const specs = dataAndSpecs.specs; expect(data.byteLength).toEqual(4 * 4 + 4 * 1 + 1 * 4 + 4 * 3); - const decoded = decodeWeights(data, specs); + const decoded = tf.io.decodeWeights(data, specs); expect(Object.keys(decoded).length).toEqual(4); expectArraysEqual(decoded['x1'], tensors['x1']); expectArraysEqual(decoded['x2'], tensors['x2']); @@ -307,7 +307,7 @@ describe('decodeWeights', () => { }, {name: 'y', dtype: 'int16', shape: []} ]; - expect(() => decodeWeights(buffer, specs)) + expect(() => tf.io.decodeWeights(buffer, specs)) .toThrowError(/Unsupported dtype in weight \'x\': int16/); }); }); diff --git a/src/io/weights_loader_test.ts b/src/io/weights_loader_test.ts index 45685149ed..2b705291cd 100644 --- a/src/io/weights_loader_test.ts +++ b/src/io/weights_loader_test.ts @@ -38,7 +38,7 @@ describeWithFlags('loadWeights', CPU_ENVS, () => { }]; const weightsNamesToFetch = ['weight0']; - tf.loadWeights(manifest, './', weightsNamesToFetch) + tf.io.loadWeights(manifest, './', weightsNamesToFetch) .then(weights => { expect((window.fetch as jasmine.Spy).calls.count()).toBe(1); @@ -66,7 +66,7 @@ describeWithFlags('loadWeights', CPU_ENVS, () => { }]; // Load the first weight. - tf.loadWeights(manifest, './', ['weight0']) + tf.io.loadWeights(manifest, './', ['weight0']) .then(weights => { expect((window.fetch as jasmine.Spy).calls.count()).toBe(1); @@ -94,7 +94,7 @@ describeWithFlags('loadWeights', CPU_ENVS, () => { }]; // Load the second weight. - tf.loadWeights(manifest, './', ['weight1']) + tf.io.loadWeights(manifest, './', ['weight1']) .then(weights => { expect((window.fetch as jasmine.Spy).calls.count()).toBe(1); @@ -122,7 +122,7 @@ describeWithFlags('loadWeights', CPU_ENVS, () => { }]; // Load all weights. - tf.loadWeights(manifest, './', ['weight0', 'weight1']) + tf.io.loadWeights(manifest, './', ['weight0', 'weight1']) .then(weights => { expect((window.fetch as jasmine.Spy).calls.count()).toBe(1); @@ -161,7 +161,7 @@ describeWithFlags('loadWeights', CPU_ENVS, () => { }]; // Load all weights. - tf.loadWeights(manifest, './', ['weight0', 'weight1']) + tf.io.loadWeights(manifest, './', ['weight0', 'weight1']) .then(weights => { expect((window.fetch as jasmine.Spy).calls.count()).toBe(1); @@ -198,7 +198,7 @@ describeWithFlags('loadWeights', CPU_ENVS, () => { 'weights': [{'name': 'weight0', 'dtype': 'float32', 'shape': [5, 2]}] }]; - tf.loadWeights(manifest, './', ['weight0']) + tf.io.loadWeights(manifest, './', ['weight0']) .then(weights => { expect((window.fetch as jasmine.Spy).calls.count()).toBe(3); @@ -240,7 +240,7 @@ describeWithFlags('loadWeights', CPU_ENVS, () => { ] }]; - tf.loadWeights(manifest, './', ['weight0', 'weight1']) + tf.io.loadWeights(manifest, './', ['weight0', 'weight1']) .then(weights => { expect((window.fetch as jasmine.Spy).calls.count()).toBe(3); @@ -284,7 +284,7 @@ describeWithFlags('loadWeights', CPU_ENVS, () => { } ]; - tf.loadWeights(manifest, './', ['weight0', 'weight1']) + tf.io.loadWeights(manifest, './', ['weight0', 'weight1']) .then(weights => { // Only the first group should be fetched. expect((window.fetch as jasmine.Spy).calls.count()).toBe(1); @@ -329,7 +329,7 @@ describeWithFlags('loadWeights', CPU_ENVS, () => { } ]; - tf.loadWeights(manifest, './', ['weight0', 'weight2']) + tf.io.loadWeights(manifest, './', ['weight0', 'weight2']) .then(weights => { // Both groups need to be fetched. expect((window.fetch as jasmine.Spy).calls.count()).toBe(2); @@ -375,7 +375,7 @@ describeWithFlags('loadWeights', CPU_ENVS, () => { ]; // Don't pass a third argument to loadWeights to load all weights. - tf.loadWeights(manifest, './') + tf.io.loadWeights(manifest, './') .then(weights => { // Both groups need to be fetched. expect((window.fetch as jasmine.Spy).calls.count()).toBe(2); @@ -417,7 +417,7 @@ describeWithFlags('loadWeights', CPU_ENVS, () => { const weightsNamesToFetch = ['doesntexist']; try { - await tf.loadWeights(manifest, './', weightsNamesToFetch); + await tf.io.loadWeights(manifest, './', weightsNamesToFetch); done.fail(); } catch (e) { done(); @@ -439,7 +439,7 @@ describeWithFlags('loadWeights', CPU_ENVS, () => { const weightsNamesToFetch = ['weight0']; try { - await tf.loadWeights(manifest, './', weightsNamesToFetch); + await tf.io.loadWeights(manifest, './', weightsNamesToFetch); done.fail(); } catch (e) { done(); @@ -455,8 +455,9 @@ describeWithFlags('loadWeights', CPU_ENVS, () => { }]; const weightsNamesToFetch = ['weight0']; - tf.loadWeights( - manifest, './', weightsNamesToFetch, {credentials: 'include'}) + tf.io + .loadWeights( + manifest, './', weightsNamesToFetch, {credentials: 'include'}) .then(weights => { expect((window.fetch as jasmine.Spy).calls.count()).toBe(1); expect(window.fetch).toHaveBeenCalledWith('./weightfile0', { @@ -495,7 +496,7 @@ describeWithFlags('loadWeights', CPU_ENVS, () => { }]; const weightsNamesToFetch = ['weight0', 'weight1']; - tf.loadWeights(manifest, './', weightsNamesToFetch) + tf.io.loadWeights(manifest, './', weightsNamesToFetch) .then(weights => { expect((window.fetch as jasmine.Spy).calls.count()).toBe(1); @@ -557,7 +558,7 @@ describeWithFlags('loadWeights', CPU_ENVS, () => { } ]; - tf.loadWeights(manifest, './', ['weight0', 'weight2']) + tf.io.loadWeights(manifest, './', ['weight0', 'weight2']) .then(weights => { // Both groups need to be fetched. expect((window.fetch as jasmine.Spy).calls.count()).toBe(2); From ac3eb47885e8222fbfa23ccaf46011f39166277b Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 25 Apr 2018 17:13:47 -0400 Subject: [PATCH 10/12] Respond to Daniel's comment --- src/index.ts | 3 +- src/io/io.ts | 1 - src/io/local_storage.ts | 76 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 3 deletions(-) create mode 100644 src/io/local_storage.ts diff --git a/src/index.ts b/src/index.ts index c80b4bce0b..380177c290 100644 --- a/src/index.ts +++ b/src/index.ts @@ -44,7 +44,6 @@ export {SGDOptimizer} from './optimizers/sgd_optimizer'; // tslint:disable-next-line:max-line-length export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer, variable, Variable} from './tensor'; export {DataType, Rank, ShapeMap} from './types'; -export {io}; export * from './ops/ops'; export {LSTMCellFunc} from './ops/lstm'; @@ -64,7 +63,7 @@ export {doc} from './doc'; export const nextFrame = BrowserUtil.nextFrame; // Second level exports. -export {environment, test_util, util}; +export {environment, io, test_util, util}; // WebGL specific utils. export const webgl = { diff --git a/src/io/io.ts b/src/io/io.ts index 4739af4ab0..d8ac83cfd7 100644 --- a/src/io/io.ts +++ b/src/io/io.ts @@ -20,7 +20,6 @@ import {decodeWeights, encodeWeights} from './io_utils'; import {IOHandler, LoadHandler, ModelArtifacts, SaveConfig, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './types'; import {loadWeights} from './weights_loader'; -// tslint:disable-next-line:max-line-length export { decodeWeights, encodeWeights, diff --git a/src/io/local_storage.ts b/src/io/local_storage.ts new file mode 100644 index 0000000000..1d5371a857 --- /dev/null +++ b/src/io/local_storage.ts @@ -0,0 +1,76 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {IOHandler, ModelArtifacts, SaveResult} from './types'; + +const PATH_SEPARATOR = '/'; +const PATH_PREFIX = 'tensorflowjs_models'; +const TOPOLOGY_SUFFIX = 'topology'; +const WEIGHT_SPECS_SUFFIX = 'json'; +const WEIGHT_DATA_SUFFIX = 'weights'; + +export class BrowserLocalStorage implements IOHandler { + protected readonly modelPath: string; + + constructor(modelPath: string) { + if (modelPath == null) { + throw new Error('modelPath cannot be null or undefined.'); + } + if (!modelPath) { + throw new Error('modelPath must not be empty.'); + } + this.modelPath = modelPath; + } + + async save(modelArtifact: ModelArtifacts): Promise { + if (!(window && window.localStorage)) { + return { + success: false, + errors: ['The current environment does not support local storage.'], + }; + } + + if (modelArtifact.modelTopology instanceof ArrayBuffer) { + throw new Error( + 'BrowserLocalStorage.save() does not support saving protocol ' + + 'buffers as model topology yet.'); + } else { + try { + window.localStorage.localStorage.setItem( + [PATH_PREFIX, this.modelPath, TOPOLOGY_SUFFIX].join(PATH_SEPARATOR), + JSON.stringify(modelArtifact.modelTopology)); + window.localStorage.localStorage.setItem( + [PATH_PREFIX, this.modelPath, WEIGHT_SPECS_SUFFIX].join( + PATH_SEPARATOR), + JSON.stringify(modelArtifact.weightSpecs)); + window.localStorage.localStorage.setItem( + [PATH_PREFIX, this.modelPath, WEIGHT_DATA_SUFFIX].join( + PATH_SEPARATOR), + JSON.stringify(modelArtifact.weightSpecs)); + } catch (err) { + return { + success: false, + errors: [err], + }; + } + } + + return { + success: true, + }; + } +} From 0ea20f0cd03b79e1cc8b0df82f7d71e1e54d0820 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 25 Apr 2018 17:15:20 -0400 Subject: [PATCH 11/12] Delete unintended file --- src/io/local_storage.ts | 76 ----------------------------------------- 1 file changed, 76 deletions(-) delete mode 100644 src/io/local_storage.ts diff --git a/src/io/local_storage.ts b/src/io/local_storage.ts deleted file mode 100644 index 1d5371a857..0000000000 --- a/src/io/local_storage.ts +++ /dev/null @@ -1,76 +0,0 @@ -/** - * @license - * Copyright 2018 Google LLC. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * ============================================================================= - */ - -import {IOHandler, ModelArtifacts, SaveResult} from './types'; - -const PATH_SEPARATOR = '/'; -const PATH_PREFIX = 'tensorflowjs_models'; -const TOPOLOGY_SUFFIX = 'topology'; -const WEIGHT_SPECS_SUFFIX = 'json'; -const WEIGHT_DATA_SUFFIX = 'weights'; - -export class BrowserLocalStorage implements IOHandler { - protected readonly modelPath: string; - - constructor(modelPath: string) { - if (modelPath == null) { - throw new Error('modelPath cannot be null or undefined.'); - } - if (!modelPath) { - throw new Error('modelPath must not be empty.'); - } - this.modelPath = modelPath; - } - - async save(modelArtifact: ModelArtifacts): Promise { - if (!(window && window.localStorage)) { - return { - success: false, - errors: ['The current environment does not support local storage.'], - }; - } - - if (modelArtifact.modelTopology instanceof ArrayBuffer) { - throw new Error( - 'BrowserLocalStorage.save() does not support saving protocol ' + - 'buffers as model topology yet.'); - } else { - try { - window.localStorage.localStorage.setItem( - [PATH_PREFIX, this.modelPath, TOPOLOGY_SUFFIX].join(PATH_SEPARATOR), - JSON.stringify(modelArtifact.modelTopology)); - window.localStorage.localStorage.setItem( - [PATH_PREFIX, this.modelPath, WEIGHT_SPECS_SUFFIX].join( - PATH_SEPARATOR), - JSON.stringify(modelArtifact.weightSpecs)); - window.localStorage.localStorage.setItem( - [PATH_PREFIX, this.modelPath, WEIGHT_DATA_SUFFIX].join( - PATH_SEPARATOR), - JSON.stringify(modelArtifact.weightSpecs)); - } catch (err) { - return { - success: false, - errors: [err], - }; - } - } - - return { - success: true, - }; - } -} From 4b8d53d39bddd1c1dccffdb0e0984c2ba83373be Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 25 Apr 2018 17:16:54 -0400 Subject: [PATCH 12/12] fix typo --- src/io/io_utils.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/io/io_utils.ts b/src/io/io_utils.ts index ece238a4f0..2e9efa1b77 100644 --- a/src/io/io_utils.ts +++ b/src/io/io_utils.ts @@ -84,7 +84,7 @@ export function decodeWeights( if (spec.quantization != null) { throw new Error( `decodeWeights does not support quantization yet, but encountered ` + - `weight '${name} wit quantization.'`); + `weight '${name} with quantization.'`); } const size = sizeFromShape(shape);