Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

Allow different dtypes in binary math ops #1432

Merged
merged 7 commits into from
Dec 6, 2018
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/kernels/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ export class MathBackendCPU implements KernelBackend {
[b.strides[1], 1, b.strides[0]];

const size = leftDim * rightDim;
const result = new Float32Array(batchDim * size);

const result = buffer([batchDim, leftDim, rightDim], a.dtype);
const resVals = result.values as TypedArray;
const blockSize = this.blockSize;

for (let b = 0; b < batchDim; b++) {
Expand All @@ -428,15 +428,14 @@ export class MathBackendCPU implements KernelBackend {
sum += aValues[b * aBatch + i * aOuterStep + k * aInnerStep] *
bValues[k * bInnerStep + j * bOuterStep + b * bBatch];
}
result[b * size + (i * rightDim + j)] += sum;
resVals[b * size + (i * rightDim + j)] += sum;
}
}
}
}
}
}

return ops.tensor3d(result, [batchDim, leftDim, rightDim]);
return result.toTensor() as Tensor3D;
}

multiply(a: Tensor, b: Tensor): Tensor {
Expand Down
84 changes: 58 additions & 26 deletions src/ops/arithmetic_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,14 @@ describeWithFlags('div', ALL_ENVS, () => {
expectArraysClose(result, expected);
});

it('throws when passed tensors of different types', () => {
const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]);
const b = tf.tensor2d([1, 2, 3, 4, 2, 5], [2, 3], 'int32');
it('upcasts when dtypes dont match', () => {
let res = tf.div(tf.scalar(6, 'int32'), tf.scalar(3, 'float32'));
expect(res.dtype).toBe('float32');
expectArraysClose(res, [2]);

expect(() => tf.div(a, b)).toThrowError();
expect(() => tf.div(b, a)).toThrowError();
res = tf.div(tf.scalar(6, 'int32'), tf.scalar(true, 'bool'));
expect(res.dtype).toBe('int32');
expectArraysClose(res, [6]);
});

it('throws when passed tensors of different shapes', () => {
Expand Down Expand Up @@ -580,11 +582,18 @@ describeWithFlags('mul', ALL_ENVS, () => {
expect(() => tf.mul(tf.scalar(1), {} as tf.Tensor))
.toThrowError(/Argument 'b' passed to 'mul' must be a Tensor/);
});
it('throws when dtypes dont match', () => {
expect(() => tf.mul(tf.scalar(1, 'int32'), tf.scalar(1)))
.toThrowError(
// tslint:disable-next-line:max-line-length
/The dtypes of the first\(int32\) and second\(float32\) input must match/);
it('upcasts when dtypes dont match', () => {
let res = tf.mul(tf.scalar(2, 'int32'), tf.scalar(3, 'float32'));
expect(res.dtype).toBe('float32');
expectArraysClose(res, [6]);

res = tf.mul(tf.scalar(2, 'int32'), tf.scalar(true, 'bool'));
expect(res.dtype).toBe('int32');
expectArraysClose(res, [2]);

res = tf.mul(tf.scalar(2, 'int32'), tf.scalar(false, 'bool'));
expect(res.dtype).toBe('int32');
expectArraysClose(res, [0]);
});

it('accepts a tensor-like object', () => {
Expand Down Expand Up @@ -1149,11 +1158,26 @@ describeWithFlags('add', ALL_ENVS, () => {
.toThrowError(/Argument 'b' passed to 'add' must be a Tensor/);
});

it('throws when dtypes dont match', () => {
expect(() => tf.add(tf.scalar(1, 'int32'), tf.scalar(1)))
.toThrowError(
// tslint:disable-next-line:max-line-length
/The dtypes of the first\(int32\) and second\(float32\) input must match/);
it('upcasts when dtypes dont match', () => {
let res = tf.add(tf.scalar(1, 'int32'), tf.scalar(1, 'float32'));
expect(res.dtype).toBe('float32');
expectArraysClose(res, [2]);

res = tf.add(tf.scalar(1, 'int32'), tf.scalar(true, 'bool'));
expect(res.dtype).toBe('int32');
expectArraysClose(res, [2]);

res = tf.add(tf.scalar(1, 'int32'), tf.scalar(false, 'bool'));
expect(res.dtype).toBe('int32');
expectArraysClose(res, [1]);

res = tf.add(tf.complex(4, 7), tf.scalar(1, 'float32'));
expect(res.dtype).toBe('complex64');
expectArraysClose(res, [5, 7]);

res = tf.add(tf.complex(4, 7), tf.scalar(1, 'int32'));
expect(res.dtype).toBe('complex64');
expectArraysClose(res, [5, 7]);
});

it('accepts a tensor-like object', () => {
Expand Down Expand Up @@ -1495,18 +1519,26 @@ describeWithFlags('sub', ALL_ENVS, () => {
expect(() => tf.sub(tf.scalar(1), {} as tf.Tensor))
.toThrowError(/Argument 'b' passed to 'sub' must be a Tensor/);
});
it('throws when dtypes dont match', () => {
expect(() => tf.sub(tf.scalar(1, 'int32'), tf.scalar(1)))
.toThrowError(
// tslint:disable-next-line:max-line-length
/The dtypes of the first\(int32\) and second\(float32\) input must match/);
});
it('upcasts when dtypes dont match', () => {
let res = tf.sub(tf.scalar(1, 'int32'), tf.scalar(1, 'float32'));
expect(res.dtype).toBe('float32');
expectArraysClose(res, [0]);

it('throws when dtypes dont match', () => {
expect(() => tf.sub(tf.scalar(1, 'float32'), tf.complex(1, 2)))
.toThrowError(
// tslint:disable-next-line:max-line-length
/The dtypes of the first\(float32\) and second\(complex64\) input must match/);
res = tf.sub(tf.scalar(1, 'int32'), tf.scalar(true, 'bool'));
expect(res.dtype).toBe('int32');
expectArraysClose(res, [0]);

res = tf.sub(tf.scalar(1, 'int32'), tf.scalar(false, 'bool'));
expect(res.dtype).toBe('int32');
expectArraysClose(res, [1]);

res = tf.sub(tf.complex(4, 7), tf.scalar(1, 'float32'));
expect(res.dtype).toBe('complex64');
expectArraysClose(res, [3, 7]);

res = tf.sub(tf.complex(4, 7), tf.scalar(1, 'int32'));
expect(res.dtype).toBe('complex64');
expectArraysClose(res, [3, 7]);
});

it('accepts a tensor-like object', () => {
Expand Down
60 changes: 29 additions & 31 deletions src/ops/binary_ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import {ENV} from '../environment';
import {KernelBackend} from '../kernels/backend';
import {Tensor} from '../tensor';
import {NamedTensorMap} from '../tensor_types';
import {assertTypesMatch} from '../tensor_util';
import {makeTypesMatch} from '../tensor_util';
import {convertToTensor} from '../tensor_util_env';
import {TensorLike, upcastType} from '../types';
import * as util from '../util';
Expand Down Expand Up @@ -53,9 +53,9 @@ import {neg} from './unary_ops';
*/
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
function add_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'add');
const $b = convertToTensor(b, 'b', 'add');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'add');
let $b = convertToTensor(b, 'b', 'add');
[$a, $b] = makeTypesMatch($a, $b);

const outShape =
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
Expand Down Expand Up @@ -172,9 +172,9 @@ function addStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
*/
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
function sub_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'sub');
const $b = convertToTensor(b, 'b', 'sub');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'sub');
let $b = convertToTensor(b, 'b', 'sub');
[$a, $b] = makeTypesMatch($a, $b);

const outShape =
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
Expand Down Expand Up @@ -318,9 +318,9 @@ function powStrict_<T extends Tensor>(base: T, exp: Tensor): T {
*/
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
function mul_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'mul');
const $b = convertToTensor(b, 'b', 'mul');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'mul');
let $b = convertToTensor(b, 'b', 'mul');
[$a, $b] = makeTypesMatch($a, $b);

const outShape =
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
Expand Down Expand Up @@ -391,9 +391,9 @@ function mulStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
*/
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
function div_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'div');
const $b = convertToTensor(b, 'b', 'div');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'div');
let $b = convertToTensor(b, 'b', 'div');
[$a, $b] = makeTypesMatch($a, $b);

let forwardFunc: (backend: KernelBackend) => Tensor;
if ($a.dtype === 'int32' && $b.dtype === 'int32') {
Expand Down Expand Up @@ -454,9 +454,9 @@ function div_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
function floorDiv_<T extends Tensor>(
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'floorDiv');
const $b = convertToTensor(b, 'b', 'floorDiv');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'floorDiv');
let $b = convertToTensor(b, 'b', 'floorDiv');
[$a, $b] = makeTypesMatch($a, $b);

const forwardFunc = (backend: KernelBackend) => backend.floorDiv($a, $b);
const outShape =
Expand Down Expand Up @@ -526,9 +526,9 @@ function divStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
*/
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
function mod_<T extends Tensor>(a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'mod');
const $b = convertToTensor(b, 'b', 'mod');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'mod');
let $b = convertToTensor(b, 'b', 'mod');
[$a, $b] = makeTypesMatch($a, $b);

const outShape =
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
Expand Down Expand Up @@ -598,14 +598,13 @@ function minimum_<T extends Tensor>(
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
let $a = convertToTensor(a, 'a', 'minimum');
let $b = convertToTensor(b, 'b', 'minimum');
assertTypesMatch($a, $b);
[$a, $b] = makeTypesMatch($a, $b);

if ($a.dtype === 'bool') {
$a = $a.toInt();
}
if ($b.dtype === 'bool') {
$b = $b.toInt();
}

broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
const der = (dy: Tensor) => {
const derA = () => dy.mul($a.lessEqual($b).toFloat());
Expand Down Expand Up @@ -660,14 +659,13 @@ function maximum_<T extends Tensor>(
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
let $a = convertToTensor(a, 'a', 'maximum');
let $b = convertToTensor(b, 'b', 'maximum');
assertTypesMatch($a, $b);
[$a, $b] = makeTypesMatch($a, $b);

if ($a.dtype === 'bool') {
$a = $a.toInt();
}
if ($b.dtype === 'bool') {
$b = $b.toInt();
}

broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
const der = (dy: Tensor) => {
const derA = () => dy.mul($a.greaterEqual($b).toFloat());
Expand Down Expand Up @@ -721,9 +719,9 @@ function maximumStrict_<T extends Tensor>(a: T|TensorLike, b: T|TensorLike): T {
/** @doc {heading: 'Operations', subheading: 'Arithmetic'} */
function squaredDifference_<T extends Tensor>(
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'squaredDifference');
const $b = convertToTensor(b, 'b', 'squaredDifference');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'squaredDifference');
let $b = convertToTensor(b, 'b', 'squaredDifference');
[$a, $b] = makeTypesMatch($a, $b);

broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
const der = (dy: Tensor) => {
Expand Down Expand Up @@ -772,9 +770,9 @@ function squaredDifferenceStrict_<T extends Tensor>(
/** @doc {heading: 'Operations', subheading: 'Basic math'} */
function atan2_<T extends Tensor>(
a: Tensor|TensorLike, b: Tensor|TensorLike): T {
const $a = convertToTensor(a, 'a', 'atan2');
const $b = convertToTensor(b, 'b', 'atan2');
assertTypesMatch($a, $b);
let $a = convertToTensor(a, 'a', 'atan2');
let $b = convertToTensor(b, 'b', 'atan2');
[$a, $b] = makeTypesMatch($a, $b);

const outShape =
broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape);
Expand Down
Loading