diff --git a/src/index.ts b/src/index.ts index 6ea655a656..b416f8907c 100644 --- a/src/index.ts +++ b/src/index.ts @@ -26,6 +26,7 @@ import * as environment from './environment'; import {Environment} from './environment'; // Serialization. import * as io from './io/io'; +import * as serialization from './serialization'; import * as test_util from './test_util'; import * as util from './util'; import {version} from './version'; @@ -63,7 +64,7 @@ export {doc} from './doc'; export const nextFrame = BrowserUtil.nextFrame; // Second level exports. -export {environment, io, test_util, util, webgl}; +export {environment, io, serialization, test_util, util, webgl}; // Backend specific. export {KernelBackend, BackendTimingInfo} from './kernels/backend'; diff --git a/src/optimizers/adadelta_optimizer.ts b/src/optimizers/adadelta_optimizer.ts index f5825f17e8..82e4bef911 100644 --- a/src/optimizers/adadelta_optimizer.ts +++ b/src/optimizers/adadelta_optimizer.ts @@ -18,25 +18,31 @@ import {ENV} from '../environment'; import {keep, tidy} from '../globals'; import {scalar, zerosLike} from '../ops/ops'; +// tslint:disable-next-line:max-line-length +import {ConfigDict, Serializable, SerializableConstructor, SerializationMap} from '../serialization'; import {Scalar} from '../tensor'; import {NamedVariableMap} from '../types'; + import {Optimizer} from './optimizer'; /** @doclink Optimizer */ export class AdadeltaOptimizer extends Optimizer { + static className = 'AdadeltaOptimizer'; private c: Scalar; - private epsilon: Scalar; - private rho: Scalar; + private epsilonScalar: Scalar; + private rhoScalar: Scalar; private oneMinusRho: Scalar; private accumulatedGrads: NamedVariableMap = {}; private accumulatedUpdates: NamedVariableMap = {}; - constructor(learningRate: number, rho: number, epsilon = 1e-8) { + constructor( + protected learningRate: number, protected rho: number, + protected epsilon = 1e-8) { super(); this.c = keep(scalar(-learningRate)); - this.epsilon = keep(scalar(epsilon)); - this.rho = keep(scalar(rho)); + this.epsilonScalar = keep(scalar(epsilon)); + this.rhoScalar = keep(scalar(rho)); this.oneMinusRho = keep(scalar(1 - rho)); } @@ -64,16 +70,16 @@ export class AdadeltaOptimizer extends Optimizer { tidy(() => { const newAccumulatedGrad = - this.rho.mul(accumulatedGrad) + this.rhoScalar.mul(accumulatedGrad) .add(this.oneMinusRho.mul(gradient.square())); - const updates = accumulatedUpdate.add(this.epsilon) + const updates = accumulatedUpdate.add(this.epsilonScalar) .sqrt() - .div(accumulatedGrad.add(this.epsilon).sqrt()) + .div(accumulatedGrad.add(this.epsilonScalar).sqrt()) .mul(gradient); const newAccumulatedUpdate = - this.rho.mul(accumulatedUpdate) + this.rhoScalar.mul(accumulatedUpdate) .add(this.oneMinusRho.mul(updates.square())); this.accumulatedGrads[variableName].assign(newAccumulatedGrad); @@ -87,8 +93,8 @@ export class AdadeltaOptimizer extends Optimizer { dispose() { this.c.dispose(); - this.epsilon.dispose(); - this.rho.dispose(); + this.epsilonScalar.dispose(); + this.rhoScalar.dispose(); this.oneMinusRho.dispose(); if (this.accumulatedUpdates != null) { Object.keys(this.accumulatedUpdates) @@ -97,4 +103,16 @@ export class AdadeltaOptimizer extends Optimizer { .forEach(name => this.accumulatedGrads[name].dispose()); } } + getConfig(): ConfigDict { + return { + learningRate: this.learningRate, + rho: this.rho, + epsilon: this.epsilon + }; + } + static fromConfig( + cls: SerializableConstructor, config: ConfigDict): T { + return new cls(config.learningRate, config.rho, config.epsilon); + } } +SerializationMap.register(AdadeltaOptimizer); diff --git a/src/optimizers/adadelta_optimizer_test.ts b/src/optimizers/adadelta_optimizer_test.ts index eedbabbf47..a8a7d9406a 100644 --- a/src/optimizers/adadelta_optimizer_test.ts +++ b/src/optimizers/adadelta_optimizer_test.ts @@ -16,8 +16,8 @@ */ import * as tf from '../index'; -import {ALL_ENVS, expectArraysClose} from '../test_util'; import {describeWithFlags} from '../jasmine_util'; +import {ALL_ENVS, expectArraysClose} from '../test_util'; describeWithFlags('AdadeltaOptimizer', ALL_ENVS, () => { it('basic', () => { @@ -75,4 +75,10 @@ describeWithFlags('AdadeltaOptimizer', ALL_ENVS, () => { // The only tensor remaining is the argument to variable(). expect(tf.memory().numTensors).toBe(1); }); + it('serialization round-trip', () => { + const originalOpt = tf.train.adadelta(0.1, 0.2, 2e-8); + const reserialized = tf.AdadeltaOptimizer.fromConfig( + tf.AdadeltaOptimizer, originalOpt.getConfig()); + expect(reserialized.getConfig()).toEqual(originalOpt.getConfig()); + }); }); diff --git a/src/optimizers/adagrad_optimizer.ts b/src/optimizers/adagrad_optimizer.ts index a3add1e46e..43ce255506 100644 --- a/src/optimizers/adagrad_optimizer.ts +++ b/src/optimizers/adagrad_optimizer.ts @@ -18,6 +18,8 @@ import {ENV} from '../environment'; import {keep, tidy} from '../globals'; import {fill, scalar} from '../ops/ops'; +// tslint:disable-next-line:max-line-length +import {ConfigDict, Serializable, SerializableConstructor, SerializationMap} from '../serialization'; import {Scalar} from '../tensor'; import {NamedVariableMap} from '../types'; @@ -25,6 +27,7 @@ import {Optimizer} from './optimizer'; /** @doclink Optimizer */ export class AdagradOptimizer extends Optimizer { + static className = 'AdagradOptimizer'; private c: Scalar; private epsilon: Scalar; @@ -73,4 +76,15 @@ export class AdagradOptimizer extends Optimizer { .forEach(name => this.accumulatedGrads[name].dispose()); } } + getConfig(): ConfigDict { + return { + learningRate: this.learningRate, + initialAccumulatorValue: this.initialAccumulatorValue, + }; + } + static fromConfig( + cls: SerializableConstructor, config: ConfigDict): T { + return new cls(config.learningRate, config.initialAccumulatorValue); + } } +SerializationMap.register(AdagradOptimizer); diff --git a/src/optimizers/adagrad_optimizer_test.ts b/src/optimizers/adagrad_optimizer_test.ts index baf2e0dacc..ae3f1c671e 100644 --- a/src/optimizers/adagrad_optimizer_test.ts +++ b/src/optimizers/adagrad_optimizer_test.ts @@ -16,8 +16,8 @@ */ import * as tf from '../index'; -import {ALL_ENVS, expectArraysClose} from '../test_util'; import {describeWithFlags} from '../jasmine_util'; +import {ALL_ENVS, expectArraysClose} from '../test_util'; describeWithFlags('AdagradOptimizer', ALL_ENVS, () => { it('basic', () => { @@ -69,4 +69,10 @@ describeWithFlags('AdagradOptimizer', ALL_ENVS, () => { // The only tensor remaining is the argument to variable(). expect(tf.memory().numTensors).toBe(1); }); + it('serialization round-trip', () => { + const originalOpt = tf.train.adagrad(0.1, 0.2); + const reserialized = tf.AdagradOptimizer.fromConfig( + tf.AdagradOptimizer, originalOpt.getConfig()); + expect(reserialized.getConfig()).toEqual(originalOpt.getConfig()); + }); }); diff --git a/src/optimizers/adam_optimizer.ts b/src/optimizers/adam_optimizer.ts index c509b2b235..25417e9904 100644 --- a/src/optimizers/adam_optimizer.ts +++ b/src/optimizers/adam_optimizer.ts @@ -18,15 +18,19 @@ import {ENV} from '../environment'; import {keep, tidy} from '../globals'; import {scalar, zerosLike} from '../ops/ops'; +// tslint:disable-next-line:max-line-length +import {ConfigDict, Serializable, SerializableConstructor, SerializationMap} from '../serialization'; import {Scalar, Variable} from '../tensor'; import {NamedVariableMap} from '../types'; + import {Optimizer} from './optimizer'; export class AdamOptimizer extends Optimizer { + static className = 'AdamOptimizer'; private c: Scalar; - private eps: Scalar; - private beta1: Scalar; - private beta2: Scalar; + private epsScalar: Scalar; + private beta1Scalar: Scalar; + private beta2Scalar: Scalar; private accBeta1: Variable; private accBeta2: Variable; private oneMinusBeta1: Scalar; @@ -37,14 +41,14 @@ export class AdamOptimizer extends Optimizer { private accumulatedSecondMoment: NamedVariableMap = {}; constructor( - protected learningRate: number, beta1: number, beta2: number, - epsilon = 1e-8) { + protected learningRate: number, protected beta1: number, + protected beta2: number, protected epsilon = 1e-8) { super(); this.c = keep(scalar(-learningRate)); - this.eps = keep(scalar(epsilon)); + this.epsScalar = keep(scalar(epsilon)); // b1, b2 keep initial value of beta* hyperparameters. - this.beta1 = keep(scalar(beta1)); - this.beta2 = keep(scalar(beta2)); + this.beta1Scalar = keep(scalar(beta1)); + this.beta2Scalar = keep(scalar(beta2)); tidy(() => { // accB* will be updated by batch. this.accBeta1 = scalar(beta1).variable(); @@ -77,10 +81,10 @@ export class AdamOptimizer extends Optimizer { const firstMoment = this.accumulatedFirstMoment[variableName]; const secondMoment = this.accumulatedSecondMoment[variableName]; - const newFirstMoment = - this.beta1.mul(firstMoment).add(this.oneMinusBeta1.mul(gradient)); + const newFirstMoment = this.beta1Scalar.mul(firstMoment) + .add(this.oneMinusBeta1.mul(gradient)); const newSecondMoment = - this.beta2.mul(secondMoment) + this.beta2Scalar.mul(secondMoment) .add(this.oneMinusBeta2.mul(gradient.square())); const biasCorrectedFirstMoment = newFirstMoment.div(oneMinusAccBeta1); @@ -89,23 +93,24 @@ export class AdamOptimizer extends Optimizer { this.accumulatedFirstMoment[variableName].assign(newFirstMoment); this.accumulatedSecondMoment[variableName].assign(newSecondMoment); - const newValue = this.c - .mul(biasCorrectedFirstMoment.div(this.eps.add( - biasCorrectedSecondMoment.sqrt()))) - .add(value); + const newValue = + this.c + .mul(biasCorrectedFirstMoment.div( + this.epsScalar.add(biasCorrectedSecondMoment.sqrt()))) + .add(value); value.assign(newValue); } - this.accBeta1.assign(this.accBeta1.mul(this.beta1)); - this.accBeta2.assign(this.accBeta2.mul(this.beta2)); + this.accBeta1.assign(this.accBeta1.mul(this.beta1Scalar)); + this.accBeta2.assign(this.accBeta2.mul(this.beta2Scalar)); }); } dispose() { this.c.dispose(); - this.eps.dispose(); - this.beta1.dispose(); - this.beta2.dispose(); + this.epsScalar.dispose(); + this.beta1Scalar.dispose(); + this.beta2Scalar.dispose(); this.accBeta1.dispose(); this.accBeta2.dispose(); this.oneMinusBeta1.dispose(); @@ -122,4 +127,18 @@ export class AdamOptimizer extends Optimizer { .forEach(name => this.accumulatedSecondMoment[name].dispose()); } } + getConfig(): ConfigDict { + return { + learningRate: this.learningRate, + beta1: this.beta1, + beta2: this.beta2, + epsilon: this.epsilon, + }; + } + static fromConfig( + cls: SerializableConstructor, config: ConfigDict): T { + return new cls( + config.learningRate, config.beta1, config.beta2, config.epsilon); + } } +SerializationMap.register(AdamOptimizer); diff --git a/src/optimizers/adam_optimizer_test.ts b/src/optimizers/adam_optimizer_test.ts index f4351f6b5b..739fb7a4cd 100644 --- a/src/optimizers/adam_optimizer_test.ts +++ b/src/optimizers/adam_optimizer_test.ts @@ -16,8 +16,8 @@ */ import * as tf from '../index'; -import {ALL_ENVS, expectArraysClose} from '../test_util'; import {describeWithFlags} from '../jasmine_util'; +import {ALL_ENVS, expectArraysClose} from '../test_util'; describeWithFlags('AdamOptimizer', ALL_ENVS, () => { it('basic', () => { @@ -79,4 +79,10 @@ describeWithFlags('AdamOptimizer', ALL_ENVS, () => { // The only tensor remaining should be the argument to variable(). expect(tf.memory().numTensors).toBe(1); }); + it('serialization round-trip', () => { + const originalOpt = tf.train.adam(0.1, 0.2, 0.3, 2e-8); + const reserialized = + tf.AdamOptimizer.fromConfig(tf.AdamOptimizer, originalOpt.getConfig()); + expect(reserialized.getConfig()).toEqual(originalOpt.getConfig()); + }); }); diff --git a/src/optimizers/adamax_optimizer.ts b/src/optimizers/adamax_optimizer.ts index 24c9d2bc1b..7718b29c61 100644 --- a/src/optimizers/adamax_optimizer.ts +++ b/src/optimizers/adamax_optimizer.ts @@ -18,17 +18,21 @@ import {ENV} from '../environment'; import {keep, tidy} from '../globals'; import {scalar, zerosLike} from '../ops/ops'; +// tslint:disable-next-line:max-line-length +import {ConfigDict, Serializable, SerializableConstructor, SerializationMap} from '../serialization'; import {Scalar, Variable} from '../tensor'; import {NamedVariableMap} from '../types'; + import {Optimizer} from './optimizer'; export class AdamaxOptimizer extends Optimizer { + static className = 'AdamaxOptimizer'; private c: Scalar; - private eps: Scalar; + private epsScalar: Scalar; private accBeta1: Variable; - private beta1: Scalar; - private beta2: Scalar; - private decay: Scalar; + private beta1Scalar: Scalar; + private beta2Scalar: Scalar; + private decayScalar: Scalar; private oneMinusBeta1: Scalar; private one: Scalar; private iteration: Variable; @@ -37,16 +41,17 @@ export class AdamaxOptimizer extends Optimizer { private accumulatedWeightedInfNorm: NamedVariableMap = {}; constructor( - protected learningRate: number, beta1: number, beta2: number, - epsilon = 1e-8, decay = 0.0) { + protected learningRate: number, protected beta1: number, + protected beta2: number, protected epsilon = 1e-8, + protected decay = 0.0) { super(); this.c = keep(scalar(-learningRate)); - this.eps = keep(scalar(epsilon)); + this.epsScalar = keep(scalar(epsilon)); // b1, b2 keep initial value of beta* hyperparameters. - this.beta1 = keep(scalar(beta1)); - this.beta2 = keep(scalar(beta2)); + this.beta1Scalar = keep(scalar(beta1)); + this.beta2Scalar = keep(scalar(beta2)); - this.decay = keep(scalar(decay)); + this.decayScalar = keep(scalar(decay)); tidy(() => { this.iteration = scalar(0).variable(); @@ -60,7 +65,7 @@ export class AdamaxOptimizer extends Optimizer { applyGradients(variableGradients: NamedVariableMap) { tidy(() => { const oneMinusAccBeta1 = this.one.sub(this.accBeta1); - const lr = this.c.div(this.one.add(this.decay.mul(this.iteration))); + const lr = this.c.div(this.one.add(this.decayScalar.mul(this.iteration))); for (const variableName in variableGradients) { const value = ENV.engine.registeredVariables[variableName]; @@ -79,10 +84,10 @@ export class AdamaxOptimizer extends Optimizer { const firstMoment = this.accumulatedFirstMoment[variableName]; const weightedInfNorm = this.accumulatedWeightedInfNorm[variableName]; - const newFirstMoment = - this.beta1.mul(firstMoment).add(this.oneMinusBeta1.mul(gradient)); + const newFirstMoment = this.beta1Scalar.mul(firstMoment) + .add(this.oneMinusBeta1.mul(gradient)); - const ut0 = this.beta2.mul(weightedInfNorm); + const ut0 = this.beta2Scalar.mul(weightedInfNorm); const ut1 = gradient.abs(); const newWeightedInfNorm = ut0.maximum(ut1); @@ -93,26 +98,26 @@ export class AdamaxOptimizer extends Optimizer { const newValue = lr.div(oneMinusAccBeta1) - .mul(newFirstMoment.div(this.eps.add(newWeightedInfNorm))) + .mul(newFirstMoment.div(this.epsScalar.add(newWeightedInfNorm))) .add(value); value.assign(newValue); } this.iteration.assign(this.iteration.add(this.one)); - this.accBeta1.assign(this.accBeta1.mul(this.beta1)); + this.accBeta1.assign(this.accBeta1.mul(this.beta1Scalar)); }); } dispose() { this.c.dispose(); - this.eps.dispose(); + this.epsScalar.dispose(); this.accBeta1.dispose(); - this.beta1.dispose(); - this.beta2.dispose(); + this.beta1Scalar.dispose(); + this.beta2Scalar.dispose(); this.oneMinusBeta1.dispose(); - this.decay.dispose(); + this.decayScalar.dispose(); this.iteration.dispose(); this.one.dispose(); @@ -127,4 +132,20 @@ export class AdamaxOptimizer extends Optimizer { .forEach(name => this.accumulatedWeightedInfNorm[name].dispose()); } } + getConfig(): ConfigDict { + return { + learningRate: this.learningRate, + beta1: this.beta1, + beta2: this.beta2, + epsilon: this.epsilon, + decay: this.decay + }; + } + static fromConfig( + cls: SerializableConstructor, config: ConfigDict): T { + return new cls( + config.learningRate, config.beta1, config.beta2, config.epsilon, + config.decay); + } } +SerializationMap.register(AdamaxOptimizer); diff --git a/src/optimizers/adamax_optimizer_test.ts b/src/optimizers/adamax_optimizer_test.ts index 824867364f..590513358b 100644 --- a/src/optimizers/adamax_optimizer_test.ts +++ b/src/optimizers/adamax_optimizer_test.ts @@ -16,8 +16,8 @@ */ import * as tf from '../index'; -import {ALL_ENVS, expectArraysClose} from '../test_util'; import {describeWithFlags} from '../jasmine_util'; +import {ALL_ENVS, expectArraysClose} from '../test_util'; describeWithFlags('AdamaxOptimizer', ALL_ENVS, () => { it('basic', () => { @@ -110,4 +110,10 @@ describeWithFlags('AdamaxOptimizer', ALL_ENVS, () => { // The only tensor remaining should be the argument to variable(). expect(tf.memory().numTensors).toBe(1); }); + it('serialization round-trip', () => { + const originalOpt = tf.train.adamax(0.1, 0.2, 0.3, 2e-8, 0.1); + const reserialized = tf.AdamaxOptimizer.fromConfig( + tf.AdamaxOptimizer, originalOpt.getConfig()); + expect(reserialized.getConfig()).toEqual(originalOpt.getConfig()); + }); }); diff --git a/src/optimizers/momentum_optimizer.ts b/src/optimizers/momentum_optimizer.ts index b8ea921f15..1c7a4a68dc 100644 --- a/src/optimizers/momentum_optimizer.ts +++ b/src/optimizers/momentum_optimizer.ts @@ -18,12 +18,16 @@ import {ENV} from '../environment'; import {tidy} from '../globals'; import {scalar, zerosLike} from '../ops/ops'; +// tslint:disable-next-line:max-line-length +import {ConfigDict, Serializable, SerializableConstructor, SerializationMap} from '../serialization'; import {Scalar, Tensor} from '../tensor'; import {NamedVariableMap} from '../types'; + import {SGDOptimizer} from './sgd_optimizer'; /** @doclink Optimizer */ export class MomentumOptimizer extends SGDOptimizer { + static className = 'MomentumOptimizer'; private m: Scalar; private accumulations: NamedVariableMap; @@ -82,4 +86,17 @@ export class MomentumOptimizer extends SGDOptimizer { setMomentum(momentum: number) { this.momentum = momentum; } + + getConfig(): ConfigDict { + return { + learningRate: this.learningRate, + momentum: this.momentum, + useNesterov: this.useNesterov + }; + } + static fromConfig( + cls: SerializableConstructor, config: ConfigDict): T { + return new cls(config.learningRate, config.momentum, config.useNesterov); + } } +SerializationMap.register(MomentumOptimizer); diff --git a/src/optimizers/momentum_optimizer_test.ts b/src/optimizers/momentum_optimizer_test.ts index d41364b738..7cd7241934 100644 --- a/src/optimizers/momentum_optimizer_test.ts +++ b/src/optimizers/momentum_optimizer_test.ts @@ -16,8 +16,8 @@ */ import * as tf from '../index'; -import {ALL_ENVS, expectArraysClose} from '../test_util'; import {describeWithFlags} from '../jasmine_util'; +import {ALL_ENVS, expectArraysClose} from '../test_util'; describeWithFlags('MomentumOptimizer', ALL_ENVS, () => { it('basic', () => { @@ -117,4 +117,10 @@ describeWithFlags('MomentumOptimizer', ALL_ENVS, () => { // The only tensor remaining is the argument to variable(). expect(tf.memory().numTensors).toBe(1); }); + it('serialization round-trip', () => { + const originalOpt = tf.train.momentum(0.1, 0.2, true); + const reserialized = tf.MomentumOptimizer.fromConfig( + tf.MomentumOptimizer, originalOpt.getConfig()); + expect(reserialized.getConfig()).toEqual(originalOpt.getConfig()); + }); }); diff --git a/src/optimizers/optimizer.ts b/src/optimizers/optimizer.ts index e0e50e027d..0e15858575 100644 --- a/src/optimizers/optimizer.ts +++ b/src/optimizers/optimizer.ts @@ -17,11 +17,12 @@ import {doc} from '../doc'; import {variableGrads} from '../globals'; +import {Serializable} from '../serialization'; import {Scalar, Variable} from '../tensor'; import {NamedTensorMap} from '../types'; @doc({heading: 'Training', subheading: 'Classes', namespace: 'train'}) -export abstract class Optimizer { +export abstract class Optimizer extends Serializable { /** * Executes `f()` and minimizes the scalar output of `f()` by computing * gradients of y with respect to the list of trainable variables provided by diff --git a/src/optimizers/rmsprop_optimizer.ts b/src/optimizers/rmsprop_optimizer.ts index 765342dfbd..8518d8091d 100644 --- a/src/optimizers/rmsprop_optimizer.ts +++ b/src/optimizers/rmsprop_optimizer.ts @@ -18,16 +18,20 @@ import {ENV} from '../environment'; import {keep, tidy} from '../globals'; import {scalar, zerosLike} from '../ops/ops'; +// tslint:disable-next-line:max-line-length +import {ConfigDict, Serializable, SerializableConstructor, SerializationMap} from '../serialization'; import {Scalar} from '../tensor'; import {NamedVariableMap} from '../types'; + import {Optimizer} from './optimizer'; /** @doclink Optimizer */ export class RMSPropOptimizer extends Optimizer { + static className = 'RMSPropOptimizer'; private c: Scalar; - private epsilon: Scalar; - private decay: Scalar; - private momentum: Scalar; + private epsilonScalar: Scalar; + private decayScalar: Scalar; + private momentumScalar: Scalar; private oneMinusDecay: Scalar; private centered: boolean; @@ -36,14 +40,14 @@ export class RMSPropOptimizer extends Optimizer { private accumulatedMoments: NamedVariableMap = {}; constructor( - protected learningRate: number, decay = 0.9, momentum = 0.0, - epsilon = 1e-8, centered = false) { + protected learningRate: number, protected decay = 0.9, + protected momentum = 0.0, protected epsilon = 1e-8, centered = false) { super(); this.c = keep(scalar(learningRate)); - this.epsilon = keep(scalar(epsilon)); - this.decay = keep(scalar(decay)); - this.momentum = keep(scalar(momentum)); + this.epsilonScalar = keep(scalar(epsilon)); + this.decayScalar = keep(scalar(decay)); + this.momentumScalar = keep(scalar(momentum)); this.oneMinusDecay = keep(scalar(1 - decay)); this.centered = centered; } @@ -80,21 +84,22 @@ export class RMSPropOptimizer extends Optimizer { tidy(() => { const newAccumulatedMeanSquare = - this.decay.mul(accumulatedMeanSquare) + this.decayScalar.mul(accumulatedMeanSquare) .add(this.oneMinusDecay.mul(gradient.square())); if (this.centered) { // Centered gradient const newAccumulatedMeanGrad = - this.decay.mul(accumulatedMeanGrad) + this.decayScalar.mul(accumulatedMeanGrad) .add(this.oneMinusDecay.mul(gradient)); const newAccumulatedMoments = - this.momentum.mul(accumulatedMoments) + this.momentumScalar.mul(accumulatedMoments) .add(this.c.mul(gradient).div( - newAccumulatedMeanSquare.sub( - newAccumulatedMeanGrad.square().add( - this.epsilon)).sqrt())); + newAccumulatedMeanSquare + .sub(newAccumulatedMeanGrad.square().add( + this.epsilonScalar)) + .sqrt())); this.accumulatedMeanSquares[variableName].assign( newAccumulatedMeanSquare); @@ -107,13 +112,13 @@ export class RMSPropOptimizer extends Optimizer { } else { // Plain gradient const newAccumulatedMeanSquare = - this.decay.mul(accumulatedMeanSquare) + this.decayScalar.mul(accumulatedMeanSquare) .add(this.oneMinusDecay.mul(gradient.square())); const newAccumulatedMoments = - this.momentum.mul(accumulatedMoments) + this.momentumScalar.mul(accumulatedMoments) .add(this.c.mul(gradient).div( - newAccumulatedMeanSquare.add(this.epsilon).sqrt())); + newAccumulatedMeanSquare.add(this.epsilonScalar).sqrt())); this.accumulatedMeanSquares[variableName].assign( newAccumulatedMeanSquare); @@ -128,9 +133,9 @@ export class RMSPropOptimizer extends Optimizer { dispose() { this.c.dispose(); - this.epsilon.dispose(); - this.decay.dispose(); - this.momentum.dispose(); + this.epsilonScalar.dispose(); + this.decayScalar.dispose(); + this.momentumScalar.dispose(); this.oneMinusDecay.dispose(); if (this.accumulatedMeanSquares != null) { Object.keys(this.accumulatedMeanSquares) @@ -145,4 +150,21 @@ export class RMSPropOptimizer extends Optimizer { .forEach(name => this.accumulatedMoments[name].dispose()); } } + + getConfig(): ConfigDict { + return { + learningRate: this.learningRate, + decay: this.decay, + momentum: this.momentum, + epsilon: this.epsilon, + centered: this.centered + }; + } + static fromConfig( + cls: SerializableConstructor, config: ConfigDict): T { + return new cls( + config.learningRate, config.decay, config.momentum, config.epsilon, + config.centered); + } } +SerializationMap.register(RMSPropOptimizer); diff --git a/src/optimizers/rmsprop_optimizer_test.ts b/src/optimizers/rmsprop_optimizer_test.ts index a7cd661ecc..8d4da432a6 100644 --- a/src/optimizers/rmsprop_optimizer_test.ts +++ b/src/optimizers/rmsprop_optimizer_test.ts @@ -16,8 +16,8 @@ */ import * as tf from '../index'; -import {ALL_ENVS, expectArraysClose} from '../test_util'; import {describeWithFlags} from '../jasmine_util'; +import {ALL_ENVS, expectArraysClose} from '../test_util'; describeWithFlags('RMSPropOptimizer', ALL_ENVS, () => { it('basic', () => { @@ -147,4 +147,10 @@ describeWithFlags('RMSPropOptimizer', ALL_ENVS, () => { // The only tensor remaining is the argument to variable(). expect(tf.memory().numTensors).toBe(1); }); + it('serialization round-trip', () => { + const originalOpt = tf.train.rmsprop(0.1, 0.5, 0.1, 1e-7, true); + const reserialized = tf.RMSPropOptimizer.fromConfig( + tf.RMSPropOptimizer, originalOpt.getConfig()); + expect(reserialized.getConfig()).toEqual(originalOpt.getConfig()); + }); }); diff --git a/src/optimizers/sgd_optimizer.ts b/src/optimizers/sgd_optimizer.ts index 14c0b44936..888d610efe 100644 --- a/src/optimizers/sgd_optimizer.ts +++ b/src/optimizers/sgd_optimizer.ts @@ -18,6 +18,8 @@ import {ENV} from '../environment'; import {keep, tidy} from '../globals'; import {scalar} from '../ops/ops'; +// tslint:disable-next-line:max-line-length +import {ConfigDict, Serializable, SerializableConstructor, SerializationMap} from '../serialization'; import {Scalar} from '../tensor'; import {NamedTensorMap} from '../types'; @@ -25,6 +27,7 @@ import {Optimizer} from './optimizer'; /** @doclink Optimizer */ export class SGDOptimizer extends Optimizer { + static className = 'SGDOptimizer'; protected c: Scalar; constructor(protected learningRate: number) { @@ -59,4 +62,13 @@ export class SGDOptimizer extends Optimizer { dispose() { this.c.dispose(); } + + getConfig(): ConfigDict { + return {learningRate: this.learningRate}; + } + static fromConfig( + cls: SerializableConstructor, config: ConfigDict): T { + return new cls(config.learningRate); + } } +SerializationMap.register(SGDOptimizer); diff --git a/src/optimizers/sgd_optimizer_test.ts b/src/optimizers/sgd_optimizer_test.ts index b9c688806e..70f0fe442c 100644 --- a/src/optimizers/sgd_optimizer_test.ts +++ b/src/optimizers/sgd_optimizer_test.ts @@ -16,8 +16,8 @@ */ import * as tf from '../index'; -import {ALL_ENVS, expectArraysClose} from '../test_util'; import {describeWithFlags} from '../jasmine_util'; +import {ALL_ENVS, expectArraysClose} from '../test_util'; describeWithFlags('SGDOptimizer', ALL_ENVS, () => { it('basic', () => { @@ -54,4 +54,11 @@ describeWithFlags('SGDOptimizer', ALL_ENVS, () => { // The only tensor remaining is the argument to variable(). expect(tf.memory().numTensors).toBe(1); }); + it('serialization round-trip', () => { + const learningRate = .1; + const originalOpt = tf.train.sgd(learningRate); + const reserialized = + tf.SGDOptimizer.fromConfig(tf.SGDOptimizer, originalOpt.getConfig()); + expect(reserialized.getConfig()).toEqual(originalOpt.getConfig()); + }); }); diff --git a/src/serialization.ts b/src/serialization.ts new file mode 100644 index 0000000000..33884b39f6 --- /dev/null +++ b/src/serialization.ts @@ -0,0 +1,131 @@ +/** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Types to support JSON-esque data structures internally. + * + * Internally ConfigDict's use camelCase keys and values where the + * values are class names to be instantiated. On the python side, these + * will be snake_case. Internally we allow Enums into the values for better + * type safety, but these need to be converted to raw primitives (usually + * strings) for round-tripping with python. + * + * toConfig returns the TS-friendly representation. model.toJSON() returns + * the pythonic version as that's the portable format. If you need to + * python-ify a non-model level toConfig output, you'll need to use a + * convertTsToPythonic from serialization_utils in -Layers. + * + */ +export type ConfigDictValue = + boolean|number|string|null|ConfigDictArray|ConfigDict; +export interface ConfigDict { [key: string]: ConfigDictValue; } +export interface ConfigDictArray extends Array {} + +/** + * Type to represent the class-type of Serializable objects. + * + * Ie the class prototype with access to the constructor and any + * static members/methods. Instance methods are not listed here. + * + * Source for this idea: https://stackoverflow.com/a/43607255 + */ +export type SerializableConstructor = { + // tslint:disable-next-line:no-any + new (...args: any[]): T; className: string; fromConfig: FromConfigMethod; +}; +export type FromConfigMethod = + (cls: SerializableConstructor, config: ConfigDict) => T; + +/** + * Serializable defines the serialization contract. + * + * TFJS requires serializable classes to return their className when asked + * to avoid issues with minification. + */ +export abstract class Serializable { + /** + * Return the class name for this class to use in serialization contexts. + * + * Generally speaking this will be the same thing that constructor.name + * would have returned. However, the class name needs to be robust + * against minification for serialization/deserialazation to work properly. + * + * There's also places such as initializers.VarianceScaling, where + * implementation details between different languages led to different + * class hierarchies and a non-leaf node is used for serialization purposes. + */ + getClassName(): string { + return (this.constructor as SerializableConstructor) + .className; + } + + /** + * Return all the non-weight state needed to serialize this object. + */ + abstract getConfig(): ConfigDict; + + /** + * Creates an instance of T from a ConfigDict. + * + * This works for most descendants of serializable. A few need to + * provide special handling. + * @param cls A Constructor for the class to instantiate. + * @param config The Configuration for the object. + */ + static fromConfig( + cls: SerializableConstructor, config: ConfigDict): T { + return new cls(config); + } +} + +/** + * Maps string keys to class constructors. + * + * Used during (de)serialization from the cross-language JSON format, which + * requires the class name in the serialization format matches the class + * names as used in Python, should it exist. + */ +export class SerializationMap { + private static instance: SerializationMap; + classNameMap: { + [className: string]: + [ + SerializableConstructor, FromConfigMethod + ] + }; + + private constructor() { + this.classNameMap = {}; + } + + /** + * Returns the singleton instance of the map. + */ + static getMap(): SerializationMap { + if (SerializationMap.instance == null) { + SerializationMap.instance = new SerializationMap(); + } + return SerializationMap.instance; + } + + /** + * Registers the class as serializable. + */ + static register(cls: SerializableConstructor) { + this.getMap().classNameMap[cls.className] = [cls, cls.fromConfig]; + } +}