Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Improve error message in tensorXd #1111

Merged
merged 5 commits into from
Jun 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/ops/array_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import {ConcatOps} from './concat';
import {operation} from './operation';
import {MPRandGauss} from './rand';
import {SegmentOps} from './segment_ops';
import {assertNonNull} from '../util';

export class ArrayOps {
/**
Expand Down Expand Up @@ -119,6 +120,7 @@ export class ArrayOps {
@doc({heading: 'Tensors', subheading: 'Creation'})
static tensor1d(values: TensorLike1D, dtype: DataType = 'float32'): Tensor1D {
const inferredShape = util.inferShape(values);
assertNonNull(values);
if (inferredShape.length !== 1) {
throw new Error('tensor1d() requires values to be a flat/TypedArray');
}
Expand Down Expand Up @@ -150,6 +152,7 @@ export class ArrayOps {
static tensor2d(
values: TensorLike2D, shape?: [number, number],
dtype: DataType = 'float32'): Tensor2D {
assertNonNull(values);
if (shape != null && shape.length !== 2) {
throw new Error('tensor2d() requires shape to have two numbers');
}
Expand Down Expand Up @@ -192,6 +195,7 @@ export class ArrayOps {
static tensor3d(
values: TensorLike3D, shape?: [number, number, number],
dtype: DataType = 'float32'): Tensor3D {
assertNonNull(values);
if (shape != null && shape.length !== 3) {
throw new Error('tensor3d() requires shape to have three numbers');
}
Expand Down Expand Up @@ -234,6 +238,7 @@ export class ArrayOps {
static tensor4d(
values: TensorLike4D, shape?: [number, number, number, number],
dtype: DataType = 'float32'): Tensor4D {
assertNonNull(values);
if (shape != null && shape.length !== 4) {
throw new Error('tensor4d() requires shape to have four numbers');
}
Expand Down Expand Up @@ -276,6 +281,7 @@ export class ArrayOps {
static tensor5d(
values: TensorLike5D, shape?: [number, number, number, number, number],
dtype: DataType = 'float32'): Tensor5D {
assertNonNull(values);
if (shape != null && shape.length !== 5) {
throw new Error('tensor5d() requires shape to have five numbers');
}
Expand Down Expand Up @@ -319,6 +325,7 @@ export class ArrayOps {
values: TensorLike6D,
shape?: [number, number, number, number, number, number],
dtype: DataType = 'float32'): Tensor6D {
assertNonNull(values);
if (shape != null && shape.length !== 6) {
throw new Error('tensor6d() requires shape to have six numbers');
}
Expand Down
36 changes: 36 additions & 0 deletions src/tensor_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,12 @@ describeWithFlags('tensor', ALL_ENVS, () => {
expectArraysClose(a, [1, 2, 3]);
});

it('tf.tensor1d() throw error with null input value', () => {
expect(() => tf.tensor1d(null))
.toThrowError('The input to the tensor constructor ' +
'must be a non-null value.');
});

it('tf.tensor1d() from number[][], shape mismatch', () => {
// tslint:disable-next-line:no-any
expect(() => tf.tensor1d([[1], [2], [3]] as any)).toThrowError();
Expand All @@ -296,6 +302,12 @@ describeWithFlags('tensor', ALL_ENVS, () => {
expect(() => tf.tensor2d([1, 2, 3, 4])).toThrowError();
});

it('tf.tensor2d() throw error with null input value', () => {
expect(() => tf.tensor2d(null))
.toThrowError('The input to the tensor constructor ' +
'must be a non-null value.');
});

it('tensor3d() from number[][][]', () => {
const a = tf.tensor3d([[[1], [2], [3]], [[4], [5], [6]]], [2, 3, 1]);
expectArraysClose(a, [1, 2, 3, 4, 5, 6]);
Expand All @@ -317,6 +329,12 @@ describeWithFlags('tensor', ALL_ENVS, () => {
expect(() => tf.tensor3d([1, 2, 3, 4], shape)).toThrowError();
});

it('tf.tensor3d() throw error with null input value', () => {
expect(() => tf.tensor3d(null))
.toThrowError('The input to the tensor constructor ' +
'must be a non-null value.');
});

it('tensor4d() from number[][][][]', () => {
const a = tf.tensor4d([[[[1]], [[2]]], [[[4]], [[5]]]], [2, 2, 1, 1]);
expectArraysClose(a, [1, 2, 4, 5]);
Expand All @@ -340,6 +358,24 @@ describeWithFlags('tensor', ALL_ENVS, () => {
expect(() => tf.tensor4d([1, 2, 3, 4], shape)).toThrowError();
});

it('tf.tensor4d() throw error with null input value', () => {
expect(() => tf.tensor4d(null))
.toThrowError('The input to the tensor constructor ' +
'must be a non-null value.');
});

it('tf.tensor5d() throw error with null input value', () => {
expect(() => tf.tensor5d(null))
.toThrowError('The input to the tensor constructor ' +
'must be a non-null value.');
});

it('tf.tensor6d() throw error with null input value', () => {
expect(() => tf.tensor6d(null))
.toThrowError('The input to the tensor constructor ' +
'must be a non-null value.');
});

it('default dtype', () => {
const a = tf.scalar(3);
expect(a.dtype).toBe('float32');
Expand Down
8 changes: 7 additions & 1 deletion src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
import {Tensor} from './tensor';
// tslint:disable-next-line:max-line-length
import {DataType, DataTypeMap, FlatVector, NamedTensorMap, RecursiveArray, RegularArray, TensorContainer, TensorContainerArray, TypedArray} from './types';
import {DataType, DataTypeMap, FlatVector, NamedTensorMap, RecursiveArray, RegularArray, TensorContainer, TensorContainerArray, TensorLike, TypedArray} from './types';

function assertArgumentIsTensor(
x: Tensor, argName: string, functionName: string) {
Expand Down Expand Up @@ -100,6 +100,12 @@ export function assertTypesMatch(a: Tensor, b: Tensor): void {
` second(${b.dtype}) input must match`);
}

export function assertNonNull(a: TensorLike): void {
assert(
a != null,
`The input to the tensor constructor must be a non-null value.`);
}

// NOTE: We explicitly type out what T extends instead of any so that
// util.flatten on a nested array of number doesn't try to infer T as a
// number[][], causing us to explicitly type util.flatten<number>().
Expand Down