diff --git a/src/kernels/backend_cpu.ts b/src/kernels/backend_cpu.ts index 29ca29e217..9ad6241bfc 100644 --- a/src/kernels/backend_cpu.ts +++ b/src/kernels/backend_cpu.ts @@ -407,8 +407,8 @@ export class MathBackendCPU implements KernelBackend { [b.strides[1], 1, b.strides[0]]; const size = leftDim * rightDim; - const result = new Float32Array(batchDim * size); - + const result = buffer([batchDim, leftDim, rightDim], a.dtype); + const resVals = result.values as TypedArray; const blockSize = this.blockSize; for (let b = 0; b < batchDim; b++) { @@ -428,15 +428,14 @@ export class MathBackendCPU implements KernelBackend { sum += aValues[b * aBatch + i * aOuterStep + k * aInnerStep] * bValues[k * bInnerStep + j * bOuterStep + b * bBatch]; } - result[b * size + (i * rightDim + j)] += sum; + resVals[b * size + (i * rightDim + j)] += sum; } } } } } } - - return ops.tensor3d(result, [batchDim, leftDim, rightDim]); + return result.toTensor() as Tensor3D; } multiply(a: Tensor, b: Tensor): Tensor { diff --git a/src/kernels/backend_webgl.ts b/src/kernels/backend_webgl.ts index 4cafcbd240..9d707a0ee1 100644 --- a/src/kernels/backend_webgl.ts +++ b/src/kernels/backend_webgl.ts @@ -35,7 +35,6 @@ import {DataId, Scalar, setTensorTracker, Tensor, Tensor1D, Tensor2D, Tensor3D, import {DataType, DataTypeMap, DataValues, NumericDataType, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../types'; import * as util from '../util'; import {getTypedArrayFromDType, sizeFromShape} from '../util'; - import {DataMover, DataStorage, KernelBackend} from './backend'; import * as backend_util from './backend_util'; import {mergeRealAndImagArrays} from './complex_util'; @@ -682,6 +681,8 @@ export class MathBackendWebGL implements KernelBackend { return this.multiply(a3D, b3D).sum(axis, true /* keepDims */); } + const dtype = upcastType(a.dtype, b.dtype); + // TODO(/~https://github.com/tensorflow/tfjs/issues/693): Support 3D tensors if (batch === 1) { const aSqueezed = a.as2D(a.shape[1], a.shape[2]); @@ -690,13 +691,17 @@ export class MathBackendWebGL implements KernelBackend { const program = new MatMulPackedProgram( aSqueezed.shape, bSqueezed.shape, [outerShapeA, outerShapeB], transposeA, transposeB); + const output = + this.makePackedTensor(program.outputShape, dtype) as Tensor2D; const result = - this.compileAndRun(program, [aSqueezed, bSqueezed]); - + this.compileAndRun(program, [aSqueezed, bSqueezed], output); return result.reshape([1, result.shape[0], result.shape[1]]); } else { - return this.compileAndRun( - new MatMulProgram(a.shape, b.shape, transposeA, transposeB), [a, b]); + const program = + new MatMulProgram(a.shape, b.shape, transposeA, transposeB); + const output = + this.makeOutputArray(program.outputShape, dtype) as Tensor3D; + return this.compileAndRun(program, [a, b], output); } } @@ -1517,7 +1522,8 @@ export class MathBackendWebGL implements KernelBackend { convInfo.outChannels / convInfo.inChannels === 1) { program = new DepthwiseConvPacked2DProgram(convInfo); return this.compileAndRun( - program, [x, filter], this.makePackedTensor(convInfo.outShape)); + program, [x, filter], + this.makePackedTensor(convInfo.outShape, x.dtype)); } program = new DepthwiseConv2DProgram(convInfo); @@ -1769,8 +1775,9 @@ export class MathBackendWebGL implements KernelBackend { return Tensor.make(shape, {}, dtype) as T; } - private makePackedTensor(shape: number[]): T { - const packedTensor = Tensor.make(shape, {}); + private makePackedTensor(shape: number[], dtype: DataType): + T { + const packedTensor = Tensor.make(shape, {}, dtype); this.texData.get(packedTensor.dataId).isPacked = true; return packedTensor as T; } @@ -1778,7 +1785,7 @@ export class MathBackendWebGL implements KernelBackend { private unpackTensor(input: T): T { const program = new UnpackProgram(input.shape); return this.compileAndRun( - program, [input], Tensor.make(program.outputShape, {})); + program, [input], Tensor.make(program.outputShape, {}, input.dtype)); } private getBatchDim(shape: number[], dimsToSkip = 2): number { @@ -1815,7 +1822,8 @@ export class MathBackendWebGL implements KernelBackend { pageToCpu = true): K { if (output == null) { if (program.usesPackedTextures) { - output = this.makePackedTensor(program.outputShape) as {} as K; + output = this.makePackedTensor(program.outputShape, inputs[0].dtype) as + {} as K; } else { output = this.makeOutputArray(program.outputShape, inputs[0].dtype) as {} as K; @@ -1872,11 +1880,12 @@ export class MathBackendWebGL implements KernelBackend { preProcessProgram = new UnpackProgram(input.shape); processedInput = this.compileAndRun( preProcessProgram, [input], - Tensor.make(preProcessProgram.outputShape, {})); + Tensor.make(preProcessProgram.outputShape, {}, input.dtype)); } else { preProcessProgram = new PackProgram(input.shape); processedInput = this.compileAndRun( - preProcessProgram, [input], this.makePackedTensor(input.shape)); + preProcessProgram, [input], + this.makePackedTensor(input.shape, input.dtype)); } texData = this.texData.get(processedInput.dataId); diff --git a/src/ops/arithmetic_test.ts b/src/ops/arithmetic_test.ts index c58a04bc0c..24ab408bb8 100644 --- a/src/ops/arithmetic_test.ts +++ b/src/ops/arithmetic_test.ts @@ -102,12 +102,14 @@ describeWithFlags('div', ALL_ENVS, () => { expectArraysClose(result, expected); }); - it('throws when passed tensors of different types', () => { - const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); - const b = tf.tensor2d([1, 2, 3, 4, 2, 5], [2, 3], 'int32'); + it('upcasts when dtypes dont match', () => { + let res = tf.div(tf.scalar(6, 'int32'), tf.scalar(3, 'float32')); + expect(res.dtype).toBe('float32'); + expectArraysClose(res, [2]); - expect(() => tf.div(a, b)).toThrowError(); - expect(() => tf.div(b, a)).toThrowError(); + res = tf.div(tf.scalar(6, 'int32'), tf.scalar(true, 'bool')); + expect(res.dtype).toBe('int32'); + expectArraysClose(res, [6]); }); it('throws when passed tensors of different shapes', () => { @@ -580,11 +582,18 @@ describeWithFlags('mul', ALL_ENVS, () => { expect(() => tf.mul(tf.scalar(1), {} as tf.Tensor)) .toThrowError(/Argument 'b' passed to 'mul' must be a Tensor/); }); - it('throws when dtypes dont match', () => { - expect(() => tf.mul(tf.scalar(1, 'int32'), tf.scalar(1))) - .toThrowError( - // tslint:disable-next-line:max-line-length - /The dtypes of the first\(int32\) and second\(float32\) input must match/); + it('upcasts when dtypes dont match', () => { + let res = tf.mul(tf.scalar(2, 'int32'), tf.scalar(3, 'float32')); + expect(res.dtype).toBe('float32'); + expectArraysClose(res, [6]); + + res = tf.mul(tf.scalar(2, 'int32'), tf.scalar(true, 'bool')); + expect(res.dtype).toBe('int32'); + expectArraysClose(res, [2]); + + res = tf.mul(tf.scalar(2, 'int32'), tf.scalar(false, 'bool')); + expect(res.dtype).toBe('int32'); + expectArraysClose(res, [0]); }); it('accepts a tensor-like object', () => { @@ -1149,11 +1158,26 @@ describeWithFlags('add', ALL_ENVS, () => { .toThrowError(/Argument 'b' passed to 'add' must be a Tensor/); }); - it('throws when dtypes dont match', () => { - expect(() => tf.add(tf.scalar(1, 'int32'), tf.scalar(1))) - .toThrowError( - // tslint:disable-next-line:max-line-length - /The dtypes of the first\(int32\) and second\(float32\) input must match/); + it('upcasts when dtypes dont match', () => { + let res = tf.add(tf.scalar(1, 'int32'), tf.scalar(1, 'float32')); + expect(res.dtype).toBe('float32'); + expectArraysClose(res, [2]); + + res = tf.add(tf.scalar(1, 'int32'), tf.scalar(true, 'bool')); + expect(res.dtype).toBe('int32'); + expectArraysClose(res, [2]); + + res = tf.add(tf.scalar(1, 'int32'), tf.scalar(false, 'bool')); + expect(res.dtype).toBe('int32'); + expectArraysClose(res, [1]); + + res = tf.add(tf.complex(4, 7), tf.scalar(1, 'float32')); + expect(res.dtype).toBe('complex64'); + expectArraysClose(res, [5, 7]); + + res = tf.add(tf.complex(4, 7), tf.scalar(1, 'int32')); + expect(res.dtype).toBe('complex64'); + expectArraysClose(res, [5, 7]); }); it('accepts a tensor-like object', () => { @@ -1495,18 +1519,26 @@ describeWithFlags('sub', ALL_ENVS, () => { expect(() => tf.sub(tf.scalar(1), {} as tf.Tensor)) .toThrowError(/Argument 'b' passed to 'sub' must be a Tensor/); }); - it('throws when dtypes dont match', () => { - expect(() => tf.sub(tf.scalar(1, 'int32'), tf.scalar(1))) - .toThrowError( - // tslint:disable-next-line:max-line-length - /The dtypes of the first\(int32\) and second\(float32\) input must match/); - }); + it('upcasts when dtypes dont match', () => { + let res = tf.sub(tf.scalar(1, 'int32'), tf.scalar(1, 'float32')); + expect(res.dtype).toBe('float32'); + expectArraysClose(res, [0]); - it('throws when dtypes dont match', () => { - expect(() => tf.sub(tf.scalar(1, 'float32'), tf.complex(1, 2))) - .toThrowError( - // tslint:disable-next-line:max-line-length - /The dtypes of the first\(float32\) and second\(complex64\) input must match/); + res = tf.sub(tf.scalar(1, 'int32'), tf.scalar(true, 'bool')); + expect(res.dtype).toBe('int32'); + expectArraysClose(res, [0]); + + res = tf.sub(tf.scalar(1, 'int32'), tf.scalar(false, 'bool')); + expect(res.dtype).toBe('int32'); + expectArraysClose(res, [1]); + + res = tf.sub(tf.complex(4, 7), tf.scalar(1, 'float32')); + expect(res.dtype).toBe('complex64'); + expectArraysClose(res, [3, 7]); + + res = tf.sub(tf.complex(4, 7), tf.scalar(1, 'int32')); + expect(res.dtype).toBe('complex64'); + expectArraysClose(res, [3, 7]); }); it('accepts a tensor-like object', () => { diff --git a/src/ops/binary_ops.ts b/src/ops/binary_ops.ts index d196940ea6..f0f7f0a1bc 100644 --- a/src/ops/binary_ops.ts +++ b/src/ops/binary_ops.ts @@ -19,7 +19,7 @@ import {ENV} from '../environment'; import {KernelBackend} from '../kernels/backend'; import {Tensor} from '../tensor'; import {NamedTensorMap} from '../tensor_types'; -import {assertTypesMatch} from '../tensor_util'; +import {makeTypesMatch} from '../tensor_util'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike, upcastType} from '../types'; import * as util from '../util'; @@ -53,9 +53,9 @@ import {neg} from './unary_ops'; */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function add_(a: Tensor|TensorLike, b: Tensor|TensorLike): T { - const $a = convertToTensor(a, 'a', 'add'); - const $b = convertToTensor(b, 'b', 'add'); - assertTypesMatch($a, $b); + let $a = convertToTensor(a, 'a', 'add'); + let $b = convertToTensor(b, 'b', 'add'); + [$a, $b] = makeTypesMatch($a, $b); const outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); @@ -172,9 +172,9 @@ function addStrict_(a: T|TensorLike, b: T|TensorLike): T { */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function sub_(a: Tensor|TensorLike, b: Tensor|TensorLike): T { - const $a = convertToTensor(a, 'a', 'sub'); - const $b = convertToTensor(b, 'b', 'sub'); - assertTypesMatch($a, $b); + let $a = convertToTensor(a, 'a', 'sub'); + let $b = convertToTensor(b, 'b', 'sub'); + [$a, $b] = makeTypesMatch($a, $b); const outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); @@ -318,9 +318,9 @@ function powStrict_(base: T, exp: Tensor): T { */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function mul_(a: Tensor|TensorLike, b: Tensor|TensorLike): T { - const $a = convertToTensor(a, 'a', 'mul'); - const $b = convertToTensor(b, 'b', 'mul'); - assertTypesMatch($a, $b); + let $a = convertToTensor(a, 'a', 'mul'); + let $b = convertToTensor(b, 'b', 'mul'); + [$a, $b] = makeTypesMatch($a, $b); const outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); @@ -391,9 +391,9 @@ function mulStrict_(a: T|TensorLike, b: T|TensorLike): T { */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function div_(a: Tensor|TensorLike, b: Tensor|TensorLike): T { - const $a = convertToTensor(a, 'a', 'div'); - const $b = convertToTensor(b, 'b', 'div'); - assertTypesMatch($a, $b); + let $a = convertToTensor(a, 'a', 'div'); + let $b = convertToTensor(b, 'b', 'div'); + [$a, $b] = makeTypesMatch($a, $b); let forwardFunc: (backend: KernelBackend) => Tensor; if ($a.dtype === 'int32' && $b.dtype === 'int32') { @@ -454,9 +454,9 @@ function div_(a: Tensor|TensorLike, b: Tensor|TensorLike): T { /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function floorDiv_( a: Tensor|TensorLike, b: Tensor|TensorLike): T { - const $a = convertToTensor(a, 'a', 'floorDiv'); - const $b = convertToTensor(b, 'b', 'floorDiv'); - assertTypesMatch($a, $b); + let $a = convertToTensor(a, 'a', 'floorDiv'); + let $b = convertToTensor(b, 'b', 'floorDiv'); + [$a, $b] = makeTypesMatch($a, $b); const forwardFunc = (backend: KernelBackend) => backend.floorDiv($a, $b); const outShape = @@ -526,9 +526,9 @@ function divStrict_(a: T|TensorLike, b: T|TensorLike): T { */ /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function mod_(a: Tensor|TensorLike, b: Tensor|TensorLike): T { - const $a = convertToTensor(a, 'a', 'mod'); - const $b = convertToTensor(b, 'b', 'mod'); - assertTypesMatch($a, $b); + let $a = convertToTensor(a, 'a', 'mod'); + let $b = convertToTensor(b, 'b', 'mod'); + [$a, $b] = makeTypesMatch($a, $b); const outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); @@ -598,14 +598,13 @@ function minimum_( a: Tensor|TensorLike, b: Tensor|TensorLike): T { let $a = convertToTensor(a, 'a', 'minimum'); let $b = convertToTensor(b, 'b', 'minimum'); - assertTypesMatch($a, $b); + [$a, $b] = makeTypesMatch($a, $b); if ($a.dtype === 'bool') { $a = $a.toInt(); - } - if ($b.dtype === 'bool') { $b = $b.toInt(); } + broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); const der = (dy: Tensor) => { const derA = () => dy.mul($a.lessEqual($b).toFloat()); @@ -660,14 +659,13 @@ function maximum_( a: Tensor|TensorLike, b: Tensor|TensorLike): T { let $a = convertToTensor(a, 'a', 'maximum'); let $b = convertToTensor(b, 'b', 'maximum'); - assertTypesMatch($a, $b); + [$a, $b] = makeTypesMatch($a, $b); if ($a.dtype === 'bool') { $a = $a.toInt(); - } - if ($b.dtype === 'bool') { $b = $b.toInt(); } + broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); const der = (dy: Tensor) => { const derA = () => dy.mul($a.greaterEqual($b).toFloat()); @@ -721,9 +719,9 @@ function maximumStrict_(a: T|TensorLike, b: T|TensorLike): T { /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ function squaredDifference_( a: Tensor|TensorLike, b: Tensor|TensorLike): T { - const $a = convertToTensor(a, 'a', 'squaredDifference'); - const $b = convertToTensor(b, 'b', 'squaredDifference'); - assertTypesMatch($a, $b); + let $a = convertToTensor(a, 'a', 'squaredDifference'); + let $b = convertToTensor(b, 'b', 'squaredDifference'); + [$a, $b] = makeTypesMatch($a, $b); broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); const der = (dy: Tensor) => { @@ -772,9 +770,9 @@ function squaredDifferenceStrict_( /** @doc {heading: 'Operations', subheading: 'Basic math'} */ function atan2_( a: Tensor|TensorLike, b: Tensor|TensorLike): T { - const $a = convertToTensor(a, 'a', 'atan2'); - const $b = convertToTensor(b, 'b', 'atan2'); - assertTypesMatch($a, $b); + let $a = convertToTensor(a, 'a', 'atan2'); + let $b = convertToTensor(b, 'b', 'atan2'); + [$a, $b] = makeTypesMatch($a, $b); const outShape = broadcast_util.assertAndGetBroadcastShape($a.shape, $b.shape); diff --git a/src/ops/binary_ops_test.ts b/src/ops/binary_ops_test.ts index aba9b1a960..5f68d56e0d 100644 --- a/src/ops/binary_ops_test.ts +++ b/src/ops/binary_ops_test.ts @@ -122,11 +122,13 @@ describeWithFlags('maximum', ALL_ENVS, () => { expectArraysEqual(result, [1, 0, 1, 1]); }); - it('different dtypes throws error', () => { - const a = tf.tensor1d([true, false, false, true], 'float32'); - const b = tf.tensor1d([false, false, true, true], 'int32'); - // tslint:disable-next-line:no-any - expect(() => tf.maximum(a, b as any)).toThrowError(); + it('upcasts when dtypes dont match', () => { + const a = tf.tensor1d([1, 0, 0, 1], 'float32'); + const b = tf.tensor1d([0, 0, 1, 1], 'int32'); + const res = tf.maximum(a, b); + expect(res.shape).toEqual(a.shape); + expect(res.dtype).toBe('float32'); + expectArraysEqual(res, [1, 0, 1, 1]); }); it('propagates NaN', () => { @@ -304,11 +306,19 @@ describeWithFlags('squaredDifference', ALL_ENVS, () => { ]); }); - it('different dtypes throws error', () => { - const a = tf.tensor1d([0.5, 3, -0.1, -4], 'float32'); - const b = tf.tensor1d([2, 3, 1, 4], 'int32'); - // tslint:disable-next-line:no-any - expect(() => tf.squaredDifference(a, b as any)).toThrowError(); + it('upcasts when dtypes dont match', () => { + let res = + tf.squaredDifference(tf.scalar(5, 'int32'), tf.scalar(2, 'float32')); + expect(res.dtype).toBe('float32'); + expectArraysClose(res, [9]); + + res = tf.squaredDifference(tf.scalar(5, 'int32'), tf.scalar(true, 'bool')); + expect(res.dtype).toBe('int32'); + expectArraysClose(res, [16]); + + res = tf.squaredDifference(tf.scalar(5, 'int32'), tf.scalar(false, 'bool')); + expect(res.dtype).toBe('int32'); + expectArraysClose(res, [25]); }); it('propagates NaN', () => { @@ -514,11 +524,13 @@ describeWithFlags('minimum', ALL_ENVS, () => { expectArraysEqual(result, [0, 0, 0, 1]); }); - it('different dtypes throws error', () => { - const a = tf.tensor1d([true, false, false, true], 'float32'); - const b = tf.tensor1d([false, false, true, true], 'int32'); - // tslint:disable-next-line:no-any - expect(() => tf.minimum(a, b as any)).toThrowError(); + it('upcasts when dtypes dont match', () => { + const a = tf.tensor1d([1, 0, 0, 1], 'float32'); + const b = tf.tensor1d([0, 0, 1, 1], 'int32'); + const res = tf.minimum(a, b); + expect(res.shape).toEqual(a.shape); + expect(res.dtype).toBe('float32'); + expectArraysEqual(res, [0, 0, 0, 1]); }); it('propagates NaN', () => { @@ -682,11 +694,14 @@ describeWithFlags('mod', ALL_ENVS, () => { expectArraysEqual(result, [1, 2, 0, 3]); }); - it('different dtypes throws error', () => { - const a = tf.tensor1d([1.1, 2.2, 3.3, 4.4], 'float32'); - const b = tf.tensor1d([1, 2, 3, 4], 'int32'); - // tslint:disable-next-line:no-any - expect(() => tf.mod(a, b as any)).toThrowError(); + it('upcasts when dtypes dont match', () => { + let res = tf.mod(tf.scalar(5, 'int32'), tf.scalar(2, 'float32')); + expect(res.dtype).toBe('float32'); + expectArraysClose(res, [1]); + + res = tf.mod(tf.scalar(5, 'int32'), tf.scalar(true, 'bool')); + expect(res.dtype).toBe('int32'); + expectArraysClose(res, [0]); }); it('propagates NaN', () => { @@ -926,12 +941,22 @@ describeWithFlags('atan2', ALL_ENVS, () => { expect(() => tf.atan2(b, a)).toThrowError(); }); - it('throws when passed tensors of different types', () => { - const a = tf.tensor2d([1, 2, -3, -4, 5, 6], [2, 3]); - const b = tf.tensor2d([5.0, 3.0, 4.0, -7.0], [2, 2]); + it('upcasts when dtypes dont match', () => { + const aValues = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + const bValues = [1, 2, 3, 4, 2, 5]; - expect(() => tf.atan2(a, b)).toThrowError(); - expect(() => tf.atan2(b, a)).toThrowError(); + const a = tf.tensor2d(aValues, [2, 3], 'float32'); + const c = tf.tensor2d(bValues, [2, 3], 'int32'); + + const r = tf.atan2(a, c); + const expected = []; + + for (let i = 0; i < a.size; i++) { + expected[i] = Math.atan2(aValues[i], bValues[i]); + } + expect(r.shape).toEqual([2, 3]); + expect(r.dtype).toBe('float32'); + expectArraysClose(r, expected); }); it('atan2 of scalar and array propagates NaNs', () => { diff --git a/src/ops/compare.ts b/src/ops/compare.ts index 3bd1298f98..fa79359a0d 100644 --- a/src/ops/compare.ts +++ b/src/ops/compare.ts @@ -17,7 +17,7 @@ import {ENV} from '../environment'; import {Tensor} from '../tensor'; -import {assertTypesMatch} from '../tensor_util'; +import {makeTypesMatch} from '../tensor_util'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import {assertShapesMatch} from '../util'; @@ -43,9 +43,9 @@ import {zerosLike} from './tensor_ops'; /** @doc {heading: 'Operations', subheading: 'Logical'} */ function notEqual_( a: Tensor|TensorLike, b: Tensor|TensorLike): T { - const $a = convertToTensor(a, 'a', 'notEqual'); - const $b = convertToTensor(b, 'b', 'notEqual'); - assertTypesMatch($a, $b); + let $a = convertToTensor(a, 'a', 'notEqual'); + let $b = convertToTensor(b, 'b', 'notEqual'); + [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); return ENV.engine.runKernel(backend => backend.notEqual($a, $b), {$a, $b}) as T; @@ -85,9 +85,9 @@ function notEqualStrict_( /** @doc {heading: 'Operations', subheading: 'Logical'} */ function less_( a: Tensor|TensorLike, b: Tensor|TensorLike): T { - const $a = convertToTensor(a, 'a', 'less'); - const $b = convertToTensor(b, 'b', 'less'); - assertTypesMatch($a, $b); + let $a = convertToTensor(a, 'a', 'less'); + let $b = convertToTensor(b, 'b', 'less'); + [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); return ENV.engine.runKernel(backend => backend.less($a, $b), {$a, $b}) as T; @@ -127,9 +127,9 @@ function lessStrict_(a: T|TensorLike, b: T|TensorLike): T { /** @doc {heading: 'Operations', subheading: 'Logical'} */ function equal_( a: Tensor|TensorLike, b: Tensor|TensorLike): T { - const $a = convertToTensor(a, 'a', 'equal'); - const $b = convertToTensor(b, 'b', 'equal'); - assertTypesMatch($a, $b); + let $a = convertToTensor(a, 'a', 'equal'); + let $b = convertToTensor(b, 'b', 'equal'); + [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); return ENV.engine.runKernel(backend => backend.equal($a, $b), {$a, $b}) as T; @@ -161,9 +161,9 @@ function equalStrict_(a: T|TensorLike, b: T|TensorLike): T { /** @doc {heading: 'Operations', subheading: 'Logical'} */ function lessEqual_( a: Tensor|TensorLike, b: Tensor|TensorLike): T { - const $a = convertToTensor(a, 'a', 'lessEqual'); - const $b = convertToTensor(b, 'b', 'lessEqual'); - assertTypesMatch($a, $b); + let $a = convertToTensor(a, 'a', 'lessEqual'); + let $b = convertToTensor(b, 'b', 'lessEqual'); + [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); return ENV.engine.runKernel(backend => backend.lessEqual($a, $b), {$a, $b}) as @@ -197,9 +197,9 @@ function lessEqualStrict_( /** @doc {heading: 'Operations', subheading: 'Logical'} */ function greater_( a: Tensor|TensorLike, b: Tensor|TensorLike): T { - const $a = convertToTensor(a, 'a', 'greater'); - const $b = convertToTensor(b, 'b', 'greater'); - assertTypesMatch($a, $b); + let $a = convertToTensor(a, 'a', 'greater'); + let $b = convertToTensor(b, 'b', 'greater'); + [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); return ENV.engine.runKernel(backend => backend.greater($a, $b), {$a, $b}) as @@ -232,9 +232,9 @@ function greaterStrict_(a: T|TensorLike, b: T|TensorLike): T { /** @doc {heading: 'Operations', subheading: 'Logical'} */ function greaterEqual_( a: Tensor|TensorLike, b: Tensor|TensorLike): T { - const $a = convertToTensor(a, 'a', 'greaterEqual'); - const $b = convertToTensor(b, 'b', 'greaterEqual'); - assertTypesMatch($a, $b); + let $a = convertToTensor(a, 'a', 'greaterEqual'); + let $b = convertToTensor(b, 'b', 'greaterEqual'); + [$a, $b] = makeTypesMatch($a, $b); assertAndGetBroadcastShape($a.shape, $b.shape); const grad = (dy: T) => { diff --git a/src/ops/compare_ops_test.ts b/src/ops/compare_ops_test.ts index 1e16cd97b0..9636e84a2c 100644 --- a/src/ops/compare_ops_test.ts +++ b/src/ops/compare_ops_test.ts @@ -48,6 +48,24 @@ describeWithFlags('equal', ALL_ENVS, () => { b = tf.tensor1d([3.123, 3.321], 'float32'); expectArraysClose(tf.equal(a, b), [0, 0]); }); + + it('upcasts when dtypes dont match', () => { + const a = [1.1, 4.1, 5]; + const b = [2.2, 3.2, 5]; + + let res = + tf.equal(tf.tensor(a, [3], 'float32'), tf.tensor(b, [3], 'int32')); + expect(res.dtype).toBe('bool'); + expect(res.shape).toEqual([3]); + expectArraysClose(res, [0, 0, 1]); + + res = + tf.equal(tf.tensor(a, [3], 'int32'), tf.tensor(b, [3], 'bool')); + expect(res.dtype).toBe('bool'); + expect(res.shape).toEqual([3]); + expectArraysClose(res, [1, 0, 0]); + }); + it('TensorLike', () => { const a = [1.1, 4.1, 5.1]; const b = [2.2, 3.2, 5.1]; @@ -533,6 +551,24 @@ describeWithFlags('notEqual', ALL_ENVS, () => { b = tf.tensor1d([3.123, 3.321], 'float32'); expectArraysClose(tf.notEqual(a, b), [1, 1]); }); + + it('upcasts when dtypes dont match', () => { + const a = [1.1, 4.1, 5]; + const b = [2.2, 3.2, 5]; + + let res = + tf.notEqual(tf.tensor(a, [3], 'float32'), tf.tensor(b, [3], 'int32')); + expect(res.dtype).toBe('bool'); + expect(res.shape).toEqual([3]); + expectArraysClose(res, [1, 1, 0]); + + res = + tf.notEqual(tf.tensor(a, [3], 'int32'), tf.tensor(b, [3], 'bool')); + expect(res.dtype).toBe('bool'); + expect(res.shape).toEqual([3]); + expectArraysClose(res, [0, 1, 1]); + }); + it('TensorLike', () => { const a = [1.1, 4.1, 5.1]; const b = [2.2, 3.2, 5.1]; @@ -1080,6 +1116,24 @@ describeWithFlags('less', ALL_ENVS, () => { expect(res.dtype).toBe('bool'); expectArraysClose(res, [1, 0, 0]); }); + + it('upcasts when dtypes dont match', () => { + const a = [1.1, 4.1, 5.2]; + const b = [2.2, 3.2, 5.1]; + + let res = + tf.less(tf.tensor(a, [3], 'float32'), tf.tensor(b, [3], 'int32')); + expect(res.dtype).toBe('bool'); + expect(res.shape).toEqual([3]); + expectArraysClose(res, [1, 0, 0]); + + res = + tf.less(tf.tensor(a, [3], 'int32'), tf.tensor(b, [3], 'bool')); + expect(res.dtype).toBe('bool'); + expect(res.shape).toEqual([3]); + expectArraysClose(res, [0, 0, 0]); + }); + it('mismatched Tensor1D shapes - int32', () => { const a = tf.tensor1d([1, 2], 'int32'); const b = tf.tensor1d([1, 2, 3], 'int32'); @@ -1448,6 +1502,24 @@ describeWithFlags('lessEqual', ALL_ENVS, () => { expect(res.dtype).toBe('bool'); expectArraysClose(res, [1, 1]); }); + + it('upcasts when dtypes dont match', () => { + const a = [1.1, 4.1, 5]; + const b = [2.2, 3.2, 5]; + + let res = + tf.lessEqual(tf.tensor(a, [3], 'float32'), tf.tensor(b, [3], 'int32')); + expect(res.dtype).toBe('bool'); + expect(res.shape).toEqual([3]); + expectArraysClose(res, [1, 0, 1]); + + res = + tf.lessEqual(tf.tensor(a, [3], 'int32'), tf.tensor(b, [3], 'bool')); + expect(res.dtype).toBe('bool'); + expect(res.shape).toEqual([3]); + expectArraysClose(res, [1, 0, 0]); + }); + it('TensorLike', () => { const a = [1.1, 4.1, 5.1]; const b = [2.2, 3.2, 5.1]; @@ -1830,6 +1902,24 @@ describeWithFlags('greater', ALL_ENVS, () => { expect(res.dtype).toBe('bool'); expectArraysClose(res, [1, 1]); }); + + it('upcasts when dtypes dont match', () => { + const a = [1.1, 4.1, 5.2]; + const b = [2.2, 3.2, 5.1]; + + let res = + tf.greater(tf.tensor(a, [3], 'float32'), tf.tensor(b, [3], 'int32')); + expect(res.dtype).toBe('bool'); + expect(res.shape).toEqual([3]); + expectArraysClose(res, [0, 1, 1]); + + res = + tf.greater(tf.tensor(a, [3], 'int32'), tf.tensor(b, [3], 'bool')); + expect(res.dtype).toBe('bool'); + expect(res.shape).toEqual([3]); + expectArraysClose(res, [0, 1, 1]); + }); + it('TensorLike', () => { const a = [1.1, 4.1, 5.1]; const b = [2.2, 3.2, 5.1]; @@ -2222,6 +2312,24 @@ describeWithFlags('greaterEqual', ALL_ENVS, () => { expect(res.dtype).toBe('bool'); expectArraysClose(res, [0, 0]); }); + + it('upcasts when dtypes dont match', () => { + const a = [1.1, 4.1, 5]; + const b = [2.2, 3.2, 5]; + + let res = + tf.greaterEqual(tf.tensor(a, [3], 'float32'), tf.tensor(b, [3], 'int32')); + expect(res.dtype).toBe('bool'); + expect(res.shape).toEqual([3]); + expectArraysClose(res, [0, 1, 1]); + + res = + tf.greaterEqual(tf.tensor(a, [3], 'int32'), tf.tensor(b, [3], 'bool')); + expect(res.dtype).toBe('bool'); + expect(res.shape).toEqual([3]); + expectArraysClose(res, [1, 1, 1]); + }); + it('mismatched Tensor1D shapes - int32', () => { const a = tf.tensor1d([1, 2], 'int32'); const b = tf.tensor1d([1, 2, 3], 'int32'); diff --git a/src/ops/matmul.ts b/src/ops/matmul.ts index 8273052be2..702a13c90c 100644 --- a/src/ops/matmul.ts +++ b/src/ops/matmul.ts @@ -17,6 +17,7 @@ import {ENV} from '../environment'; import {Tensor, Tensor1D, Tensor2D, Tensor3D} from '../tensor'; +import {makeTypesMatch} from '../tensor_util'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; @@ -40,8 +41,9 @@ import {op} from './operation'; function matMul_( a: T|TensorLike, b: T|TensorLike, transposeA = false, transposeB = false): T { - const $a = convertToTensor(a, 'a', 'matMul'); - const $b = convertToTensor(b, 'b', 'matMul'); + let $a = convertToTensor(a, 'a', 'matMul'); + let $b = convertToTensor(b, 'b', 'matMul'); + [$a, $b] = makeTypesMatch($a, $b); const innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1]; @@ -86,23 +88,23 @@ function matMul_( const grad = (dy: Tensor3D) => { if (!transposeA && !transposeB) { return { - $a: () => dy.matMul(b3D.toFloat(), false, true), - $b: () => a3D.toFloat().matMul(dy, true, false) + $a: () => dy.matMul(b3D, false, true), + $b: () => a3D.matMul(dy, true, false) }; } else if (!transposeA && transposeB) { return { - $a: () => dy.matMul(b3D.toFloat(), false, false), - $b: () => dy.matMul(a3D.toFloat(), true, false) + $a: () => dy.matMul(b3D, false, false), + $b: () => dy.matMul(a3D, true, false) }; } else if (transposeA && !transposeB) { return { - $a: () => b3D.toFloat().matMul(dy, false, true), - $b: () => a3D.toFloat().matMul(dy, false, false) + $a: () => b3D.matMul(dy, false, true), + $b: () => a3D.matMul(dy, false, false) }; } else { return { - $a: () => b3D.toFloat().matMul(dy, true, true), - $b: () => dy.matMul(a3D.toFloat(), true, true) + $a: () => b3D.matMul(dy, true, true), + $b: () => dy.matMul(a3D, true, true) }; } }; diff --git a/src/ops/matmul_test.ts b/src/ops/matmul_test.ts index d70fea694e..0550516962 100644 --- a/src/ops/matmul_test.ts +++ b/src/ops/matmul_test.ts @@ -166,6 +166,24 @@ describeWithFlags('matmul', ALL_ENVS, () => { expectArraysClose(c, [0, 8, -3, 20]); }); + it('upcasts when dtypes dont match', () => { + const a = [1, 2, 3, 4, 5, 6]; + const b = [0, 1, -3, 2, 2, 1]; + + let c = tf.matMul( + tf.tensor(a, [2, 3], 'float32'), tf.tensor(b, [3, 2], 'int32')); + + expect(c.shape).toEqual([2, 2]); + expect(c.dtype).toBe('float32'); + expectArraysClose(c, [0, 8, -3, 20]); + + c = tf.matMul(tf.tensor(a, [2, 3], 'int32'), tf.tensor(b, [3, 2], 'bool')); + + expect(c.shape).toEqual([2, 2]); + expect(c.dtype).toBe('int32'); + expectArraysClose(c, [5, 6, 11, 15]); + }); + it('A x B^t', () => { const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); const b = tf.tensor2d([1, 0, 2, 4, 3, 0], [2, 3]); diff --git a/src/tensor_util.ts b/src/tensor_util.ts index f4a8f7b831..77c7394377 100644 --- a/src/tensor_util.ts +++ b/src/tensor_util.ts @@ -17,8 +17,17 @@ import {Tensor} from './tensor'; import {NamedTensorMap, TensorContainer, TensorContainerArray} from './tensor_types'; +import {upcastType} from './types'; import {assert} from './util'; +export function makeTypesMatch(a: T, b: T): [T, T] { + if (a.dtype === b.dtype) { + return [a, b]; + } + const dtype = upcastType(a.dtype, b.dtype); + return [a.cast(dtype), b.cast(dtype)]; +} + export function assertTypesMatch(a: Tensor, b: Tensor): void { assert( a.dtype === b.dtype,