Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#7376 Allow seeding the Orthogonal Initializer and use QR instead of gramSchmidt #7377

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 34 additions & 17 deletions tfjs-layers/src/initializers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
* =============================================================================
*/

import {DataType, eye, linalg, mul, ones, randomUniform, scalar, serialization, Tensor, Tensor2D, tidy, transpose, truncatedNormal, zeros} from '@tensorflow/tfjs-core';
import {DataType, eye, linalg, mul, ones, randomUniform, scalar, serialization, Tensor, tidy, truncatedNormal, util, zeros} from '@tensorflow/tfjs-core';

import * as K from './backend/tfjs_backend';
import {checkDataFormat} from './common';
Expand Down Expand Up @@ -529,41 +529,58 @@ export class Orthogonal extends Initializer {
/** @nocollapse */
static className = 'Orthogonal';
readonly DEFAULT_GAIN = 1;
readonly ELEMENTS_WARN_SLOW = 2000;
protected readonly gain: number;
protected readonly seed: number;

constructor(args?: OrthogonalArgs) {
super();
this.gain = args.gain == null ? this.DEFAULT_GAIN : args.gain;
this.seed = args.seed;

if (this.seed != null) {
throw new NotImplementedError(
'Random seed is not implemented for Orthogonal Initializer yet.');
}
}

apply(shape: Shape, dtype?: DataType): Tensor {
return tidy(() => {
if (shape.length < 2) {
throw new NotImplementedError('Shape must be at least 2D.');
}
if (shape[0] * shape[1] > 2000) {
if (dtype !== 'int32' && dtype !== 'float32' && dtype !== undefined) {
throw new TypeError(`Unsupported data type ${dtype}.`);
}
dtype = dtype as 'int32' | 'float32' | undefined;

// flatten the input shape with the last dimension remaining its
// original shape so it works for conv2d
const numRows = util.sizeFromShape(shape.slice(0, -1));
const numCols = shape[shape.length - 1];
const numElements = numRows * numCols;
if (numElements > this.ELEMENTS_WARN_SLOW) {
console.warn(
`Orthogonal initializer is being called on a matrix with more ` +
`than 2000 (${shape[0] * shape[1]}) elements: ` +
`than ${this.ELEMENTS_WARN_SLOW} (${numElements}) elements: ` +
`Slowness may result.`);
}

// TODO(cais): Add seed support.
const normalizedShape =
shape[0] > shape[1] ? [shape[1], shape[0]] : shape;
const a = K.randomNormal(normalizedShape, 0, 1, 'float32') as Tensor2D;
let q = linalg.gramSchmidt(a) as Tensor2D;
if (shape[0] > shape[1]) {
q = transpose(q);
const flatShape =
[Math.max(numCols, numRows), Math.min(numCols, numRows)];

// Generate a random matrix
const randNormalMat = K.randomNormal(flatShape, 0, 1, dtype, this.seed);

// Compute QR factorization
const qr = linalg.qr(randNormalMat, false);
let qMat = qr[0];
const rMat = qr[1];

// Make Q uniform
const diag = rMat.flatten().stridedSlice(
[0], [Math.min(numCols, numRows) * Math.min(numCols, numRows)],
[Math.min(numCols, numRows) + 1]);
qMat = mul(qMat, diag.sign());
if (numRows < numCols) {
qMat = qMat.transpose();
}
return mul(this.gain, q);

return mul(scalar(this.gain), qMat.reshape(shape));
});
}

Expand Down
17 changes: 17 additions & 0 deletions tfjs-layers/src/initializers_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -831,4 +831,21 @@ describeMathCPUAndWebGL2('Orthogonal Initializer', () => {
expect((model.predict(randomNormal([1, 128, 128, 1])) as Tensor).shape)
.toEqual([1, 128, 128, 1]);
});

it('with configured seed', () => {
const initializerConfig: serialization.ConfigDict = {
className: 'Orthogonal',
config: {seed: 666013}
};

const expectedInitializer = getInitializer(initializerConfig);
const actualInitializer = getInitializer(initializerConfig);

const expected = expectedInitializer.apply([7, 2], 'float32');
const actual = actualInitializer.apply([7, 2], 'float32');

expect(actual.shape).toEqual(expected.shape);
expect(actual.dtype).toEqual(expected.dtype);
expectTensorsClose(actual, expected);
});
});
12 changes: 6 additions & 6 deletions tfjs-layers/src/model_save_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -386,13 +386,13 @@ describeMathCPUAndWebGL2('Save-load round trips', () => {
}));
const modelJSON = model.toJSON(null, false);

const gramSchmidtSpy = spyOn(linalg, 'gramSchmidt').and.callThrough();
const qrSpy = spyOn(linalg, 'qr').and.callThrough();
const modelPrime =
await tfl.models.modelFromJSON({modelTopology: modelJSON});
// Make sure modelPrime builds.
modelPrime.predict(zeros([2, 3, 4]));
// Assert the orthogonal initializer has been called.
expect(gramSchmidtSpy).toHaveBeenCalled();
expect(qrSpy).toHaveBeenCalled();
});

it('Partial non-strict load calls weight initializers', async () => {
Expand All @@ -415,15 +415,15 @@ describeMathCPUAndWebGL2('Save-load round trips', () => {
expect(savedArtifacts.weightSpecs.length).toEqual(3);
savedArtifacts.weightSpecs = savedArtifacts.weightSpecs.slice(0, 1);

const gramSchmidtSpy = spyOn(linalg, 'gramSchmidt').and.callThrough();
const qrSpy = spyOn(linalg, 'qr').and.callThrough();
const strict = false;
const modelPrime =
await tfl.loadLayersModel(io.fromMemory(savedArtifacts), {strict});
const weightsPrime = modelPrime.getWeights();
expect(weightsPrime.length).toEqual(weights.length);
expectTensorsClose(weightsPrime[0], weights[0]);
// Assert the orthogonal initializer has been called.
expect(gramSchmidtSpy).toHaveBeenCalled();
expect(qrSpy).toHaveBeenCalled();
});

it('loadLayersModel: non-strict load calls weight initializers', async () => {
Expand All @@ -446,15 +446,15 @@ describeMathCPUAndWebGL2('Save-load round trips', () => {
expect(savedArtifacts.weightSpecs.length).toEqual(3);
savedArtifacts.weightSpecs = savedArtifacts.weightSpecs.slice(0, 1);

const gramSchmidtSpy = spyOn(linalg, 'gramSchmidt').and.callThrough();
const qrSpy = spyOn(linalg, 'qr').and.callThrough();
const strict = false;
const modelPrime =
await tfl.loadLayersModel(io.fromMemory(savedArtifacts), {strict});
const weightsPrime = modelPrime.getWeights();
expect(weightsPrime.length).toEqual(weights.length);
expectTensorsClose(weightsPrime[0], weights[0]);
// Assert the orthogonal initializer has been called.
expect(gramSchmidtSpy).toHaveBeenCalled();
expect(qrSpy).toHaveBeenCalled();
});

it('Load model artifact with ndarray-format scalar objects', async () => {
Expand Down