Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Move FFT ops under the spectral namespace. #1309

Merged
merged 6 commits into from
Oct 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/ops/ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ export * from './moving_average';
export * from './strided_slice';
export * from './topk';
export * from './scatter_nd';
export * from './spectral_ops';

export {op} from './operation';

// Second level exports.
import * as losses from './loss_ops';
import * as linalg from './linalg_ops';
import * as image from './image_ops';
export {image, linalg, losses};
import * as spectral from './spectral_ops';

export {image, linalg, losses, spectral};
5 changes: 4 additions & 1 deletion src/ops/spectral_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@ import {assert} from '../util';
* const imag = tf.tensor1d([1, 2, 3]);
* const x = tf.complex(real, imag);
*
* x.fft().print();
* x.fft().print(); // tf.spectral.fft(x).print();
* ```
* @param input The complex input to compute an fft over.
*/
/**
* @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
*/
function fft_(input: Tensor1D): Tensor1D {
assert(input.dtype === 'complex64', 'dtype must be complex64');
assert(input.rank === 1, 'input rank must be 1');
Expand Down
41 changes: 22 additions & 19 deletions src/ops/spectral_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

import * as tf from '../index';
import {describeWithFlags} from '../jasmine_util';
import {expectArraysClose, ALL_ENVS} from '../test_util';
import {ALL_ENVS, expectArraysClose} from '../test_util';

describeWithFlags('FFT', ALL_ENVS, () => {
it('should return the same value with TensorFlow (2 elements)', () => {
const t1Real = tf.tensor1d([1, 2]);
const t1Imag = tf.tensor1d([1, 1]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(tf.fft(t1), [3, 2, -1, 0]);
expectArraysClose(tf.spectral.fft(t1), [3, 2, -1, 0]);
});

it('should calculate FFT from Tensor directly', () => {
Expand All @@ -38,38 +38,41 @@ describeWithFlags('FFT', ALL_ENVS, () => {
const t1Real = tf.tensor1d([1, 2, 3]);
const t1Imag = tf.tensor1d([0, 0, 0]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(tf.fft(t1), [6, 0, -1.5, 0.866025, -1.5, -0.866025]);
expectArraysClose(
tf.spectral.fft(t1), [6, 0, -1.5, 0.866025, -1.5, -0.866025]);
});

it('should return the same value as TensorFlow with imaginary (3 elements)',
() => {
const t1Real = tf.tensor1d([1, 2, 3]);
const t1Imag = tf.tensor1d([1, 2, 3]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(
tf.fft(t1), [6, 6, -2.3660252, -0.63397473, -0.6339747, -2.3660254]);
});
() => {
const t1Real = tf.tensor1d([1, 2, 3]);
const t1Imag = tf.tensor1d([1, 2, 3]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(
tf.spectral.fft(t1),
[6, 6, -2.3660252, -0.63397473, -0.6339747, -2.3660254]);
});

it('should return the same value as TensorFlow (negative 3 elements)', () => {
const t1Real = tf.tensor1d([-1, -2, -3]);
const t1Imag = tf.tensor1d([-1, -2, -3]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(tf.fft(t1),
[-5.9999995, -6, 2.3660252, 0.63397473, 0.6339747, 2.3660254]);
expectArraysClose(
tf.spectral.fft(t1),
[-5.9999995, -6, 2.3660252, 0.63397473, 0.6339747, 2.3660254]);
});

it('should return the same value with TensorFlow (4 elements)', () => {
const t1Real = tf.tensor1d([1, 2, 3, 4]);
const t1Imag = tf.tensor1d([0, 0, 0, 0]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(tf.fft(t1), [10, 0, -2, 2, -2, 0, -2, -2]);
expectArraysClose(tf.spectral.fft(t1), [10, 0, -2, 2, -2, 0, -2, -2]);
});

it('should return the same value as TensorFlow with imaginary (4 elements)',
() => {
const t1Real = tf.tensor1d([1, 2, 3, 4]);
const t1Imag = tf.tensor1d([1, 2, 3, 4]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(tf.fft(t1), [10, 10, -4, 0, -2, -2, 0, -4]);
});
() => {
const t1Real = tf.tensor1d([1, 2, 3, 4]);
const t1Imag = tf.tensor1d([1, 2, 3, 4]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(tf.spectral.fft(t1), [10, 10, -4, 0, -2, -2, 0, -4]);
});
});
4 changes: 2 additions & 2 deletions src/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ export interface OpHandler {
x: T, begin: number[], end: number[], strides: number[],
beginMask: number, endMask: number): T;
depthToSpace(x: Tensor4D, blockSize: number, dataFormat: string): Tensor4D;
fft(x: Tensor1D): Tensor1D;
spectral: {fft(x: Tensor1D): Tensor1D;};
}

// For tracking tensor creation and disposal.
Expand Down Expand Up @@ -1253,7 +1253,7 @@ export class Tensor<R extends Rank = Rank> {

fft(this: Tensor1D): Tensor1D {
this.throwIfDisposed();
return opHandler.fft(this);
return opHandler.spectral.fft(this);
}
}
Object.defineProperty(Tensor, Symbol.hasInstance, {
Expand Down