Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Commit

Permalink
Move FFT ops under the spectral namespace. (#1309)
Browse files Browse the repository at this point in the history
BUG
  • Loading branch information
Nikhil Thorat authored Oct 11, 2018
1 parent b650bfd commit e48803c
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 24 deletions.
5 changes: 3 additions & 2 deletions src/ops/ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ export * from './moving_average';
export * from './strided_slice';
export * from './topk';
export * from './scatter_nd';
export * from './spectral_ops';

export {op} from './operation';

// Second level exports.
import * as losses from './loss_ops';
import * as linalg from './linalg_ops';
import * as image from './image_ops';
export {image, linalg, losses};
import * as spectral from './spectral_ops';

export {image, linalg, losses, spectral};
5 changes: 4 additions & 1 deletion src/ops/spectral_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@ import {assert} from '../util';
* const imag = tf.tensor1d([1, 2, 3]);
* const x = tf.complex(real, imag);
*
* x.fft().print();
* x.fft().print(); // tf.spectral.fft(x).print();
* ```
* @param input The complex input to compute an fft over.
*/
/**
* @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'}
*/
function fft_(input: Tensor1D): Tensor1D {
assert(input.dtype === 'complex64', 'dtype must be complex64');
assert(input.rank === 1, 'input rank must be 1');
Expand Down
41 changes: 22 additions & 19 deletions src/ops/spectral_ops_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@

import * as tf from '../index';
import {describeWithFlags} from '../jasmine_util';
import {expectArraysClose, ALL_ENVS} from '../test_util';
import {ALL_ENVS, expectArraysClose} from '../test_util';

describeWithFlags('FFT', ALL_ENVS, () => {
it('should return the same value with TensorFlow (2 elements)', () => {
const t1Real = tf.tensor1d([1, 2]);
const t1Imag = tf.tensor1d([1, 1]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(tf.fft(t1), [3, 2, -1, 0]);
expectArraysClose(tf.spectral.fft(t1), [3, 2, -1, 0]);
});

it('should calculate FFT from Tensor directly', () => {
Expand All @@ -38,38 +38,41 @@ describeWithFlags('FFT', ALL_ENVS, () => {
const t1Real = tf.tensor1d([1, 2, 3]);
const t1Imag = tf.tensor1d([0, 0, 0]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(tf.fft(t1), [6, 0, -1.5, 0.866025, -1.5, -0.866025]);
expectArraysClose(
tf.spectral.fft(t1), [6, 0, -1.5, 0.866025, -1.5, -0.866025]);
});

it('should return the same value as TensorFlow with imaginary (3 elements)',
() => {
const t1Real = tf.tensor1d([1, 2, 3]);
const t1Imag = tf.tensor1d([1, 2, 3]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(
tf.fft(t1), [6, 6, -2.3660252, -0.63397473, -0.6339747, -2.3660254]);
});
() => {
const t1Real = tf.tensor1d([1, 2, 3]);
const t1Imag = tf.tensor1d([1, 2, 3]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(
tf.spectral.fft(t1),
[6, 6, -2.3660252, -0.63397473, -0.6339747, -2.3660254]);
});

it('should return the same value as TensorFlow (negative 3 elements)', () => {
const t1Real = tf.tensor1d([-1, -2, -3]);
const t1Imag = tf.tensor1d([-1, -2, -3]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(tf.fft(t1),
[-5.9999995, -6, 2.3660252, 0.63397473, 0.6339747, 2.3660254]);
expectArraysClose(
tf.spectral.fft(t1),
[-5.9999995, -6, 2.3660252, 0.63397473, 0.6339747, 2.3660254]);
});

it('should return the same value with TensorFlow (4 elements)', () => {
const t1Real = tf.tensor1d([1, 2, 3, 4]);
const t1Imag = tf.tensor1d([0, 0, 0, 0]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(tf.fft(t1), [10, 0, -2, 2, -2, 0, -2, -2]);
expectArraysClose(tf.spectral.fft(t1), [10, 0, -2, 2, -2, 0, -2, -2]);
});

it('should return the same value as TensorFlow with imaginary (4 elements)',
() => {
const t1Real = tf.tensor1d([1, 2, 3, 4]);
const t1Imag = tf.tensor1d([1, 2, 3, 4]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(tf.fft(t1), [10, 10, -4, 0, -2, -2, 0, -4]);
});
() => {
const t1Real = tf.tensor1d([1, 2, 3, 4]);
const t1Imag = tf.tensor1d([1, 2, 3, 4]);
const t1 = tf.complex(t1Real, t1Imag);
expectArraysClose(tf.spectral.fft(t1), [10, 10, -4, 0, -2, -2, 0, -4]);
});
});
4 changes: 2 additions & 2 deletions src/tensor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ export interface OpHandler {
x: T, begin: number[], end: number[], strides: number[],
beginMask: number, endMask: number): T;
depthToSpace(x: Tensor4D, blockSize: number, dataFormat: string): Tensor4D;
fft(x: Tensor1D): Tensor1D;
spectral: {fft(x: Tensor1D): Tensor1D;};
}

// For tracking tensor creation and disposal.
Expand Down Expand Up @@ -1253,7 +1253,7 @@ export class Tensor<R extends Rank = Rank> {

fft(this: Tensor1D): Tensor1D {
this.throwIfDisposed();
return opHandler.fft(this);
return opHandler.spectral.fft(this);
}
}
Object.defineProperty(Tensor, Symbol.hasInstance, {
Expand Down

0 comments on commit e48803c

Please sign in to comment.