Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Add basic types and helper methods for model exporting #990

Merged
merged 16 commits into from
Apr 26, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ import * as test_util from './test_util';
import * as util from './util';
import {version} from './version';

// Serialization.
export {decodeTensors, encodeTensors} from './io/io_utils';
// tslint:disable-next-line:max-line-length
export {IOHandler, LoadHandler, ModelArtifacts, SaveHandler, SaveResult, WeightsManifestConfig, WeightsManifestEntry} from './io/types';
export {loadWeights} from './io/weights_loader';
// Optimizers.
export {AdadeltaOptimizer} from './optimizers/adadelta_optimizer';
export {AdagradOptimizer} from './optimizers/adagrad_optimizer';
Expand All @@ -42,9 +47,6 @@ export {SGDOptimizer} from './optimizers/sgd_optimizer';
// tslint:disable-next-line:max-line-length
export {Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, TensorBuffer, variable, Variable} from './tensor';
export {DataType, Rank, ShapeMap} from './types';
// Serialization.
export {WeightsManifestConfig} from './weights_loader';
export {loadWeights} from './weights_loader';

export * from './ops/ops';
export {LSTMCellFunc} from './ops/lstm';
Expand Down
136 changes: 136 additions & 0 deletions src/io/io_utils.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/**
* @license
* Copyright 2018 Google Inc. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {tensor} from '../index';
import {Tensor} from '../tensor';
import {NamedTensorMap} from '../types';
import {WeightsManifestEntry} from './types';

/**
* Encode a map from names to Tensors as an ArrayBuffer.
*
* @param tensors A map ("dict") from names to tensors.
* @returns A `Promise` of
* - A flat `ArrayBuffer` with all the binary values of the `Tensor`s
* concatenated.
* - An `Array` of `WeightManifestEntry`s, carrying information including
* tensor names, `dtype`s and shapes.
* @throws Error: on unsupported tensor `dtype`.
*/
export async function encodeTensors(tensors: NamedTensorMap):
Promise<[ArrayBuffer, WeightsManifestEntry[]]> {
const specs: WeightsManifestEntry[] = [];
const dataPromises: Array<Promise<Float32Array|Int32Array|Uint8Array>> = [];
for (const name in tensors) {
const tensor = tensors[name];

if (tensor.dtype !== 'float32' && tensor.dtype !== 'int32' &&
tensor.dtype !== 'bool') {
throw new Error(`Unsupported dtype: ${tensor.dtype}`);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we include the name of the tensor here as well.

}
specs.push({name, shape: tensor.shape, dtype: tensor.dtype});
dataPromises.push(tensor.data());
}
const tensorValues = await Promise.all(dataPromises);
return [concatenateTypedArrays(tensorValues), specs];
}

/**
* Decode flat ArrayBuffer as named Tensors.
*
* @param buffer A flat ArrayBuffer carrying the binary values of the tensors
* concatenated in the order specified in `specs`.
* @param specs Specifications of the names, dtypes and shapes of the tensors
* whose value are encoded by `buffer`.
* @return A map from tensor name to tensor value, with the names corresponding
* to names in `specs`.
* @throws Error, if any of the tensors has unsupported dtype.
*/
export function decodeTensors(
buffer: ArrayBuffer, specs: WeightsManifestEntry[]): NamedTensorMap {
const out: NamedTensorMap = {};
let offset = 0;
for (const spec of specs) {
const name = spec.name;
const dtype = spec.dtype;
const shape = spec.shape;

let numel = 1;
for (const dim of shape) {
numel *= dim;
}
let bytes: number;
let value: Tensor;
if (dtype === 'float32') {
bytes = numel * 4;
value = tensor(new Float32Array(buffer, offset, numel), shape, 'float32');
} else if (dtype === 'int32') {
bytes = numel * 4;
value = tensor(new Int32Array(buffer, offset, numel), shape, 'int32');
} else if (dtype === 'bool') {
bytes = numel;
value = tensor(new Uint8Array(buffer, offset, numel), shape, 'bool');
} else {
throw new Error(`Unsupported dtype: ${dtype}`);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above, please include the name fields as well.

}
out[name] = value;
offset += bytes;
}
return out;
}

/**
* Concatenate TypedArrays into an ArrayBuffer.
*/
export function concatenateTypedArrays(
xs: Array<Float32Array|Int32Array|Uint8Array>): ArrayBuffer {
if (xs === null) {
return null;
}
if (xs === undefined) {
return undefined;
}
if (xs.length === 0) {
return new ArrayBuffer(0);
}

let totalByteLength = 0;
for (const x of xs) {
// tslint:disable-next-line:no-any
if (x as any instanceof Float32Array || x instanceof Int32Array) {
totalByteLength += x.length * 4;
// tslint:disable-next-line:no-any
} else if (x as any instanceof Uint8Array) {
totalByteLength += x.length;
} else {
throw new Error(`Unsupported TypedArray subtype: ${x.constructor.name}`);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there minification concerns here with constructor.name?

}
}

const y = new Uint8Array(totalByteLength);
let offset = 0;
for (const x of xs) {
y.set(new Uint8Array(x.buffer), offset);
if (x instanceof Float32Array || x instanceof Int32Array) {
offset += x.length * 4;
} else {
offset += x.length;
}
}

return y.buffer;
}
Loading