diff --git a/src/ops/loss_ops.ts b/src/ops/loss_ops.ts index ab93017110..c54325c619 100644 --- a/src/ops/loss_ops.ts +++ b/src/ops/loss_ops.ts @@ -101,4 +101,33 @@ export class LossOps { const losses = labels.sub(predictions).abs(); return LossOps.computeWeightedLoss(losses, weights, reduction); } + + /** + * Computes the mean squared error between two tensors. + * + * @param labels The ground truth output tensor, same dimensions as + * 'predictions'. + * @param predictions The predicted outputs. + * @param weights Tensor whose rank is either 0, or the same rank as + * `labels`, and must be broadcastable to `labels` (i.e., all dimensions + * must be either `1`, or the same as the corresponding `losses` + * dimension). + * @param reduction Type of reduction to apply to loss. Should be of type + * `Reduction` + */ + @doc({heading: 'Training', subheading: 'Losses', namespace: 'losses'}) + @operation + static meanSquaredError( + labels: T, predictions: T, weights?: Tensor, + reduction = Reduction.SUM_BY_NONZERO_WEIGHTS): O { + util.assertArgumentsAreTensors({labels, predictions}, 'meanSquaredError'); + if (weights != null) { + util.assertArgumentsAreTensors({weights}, 'meanSquaredError'); + } + util.assertShapesMatch( + labels.shape, predictions.shape, 'Error in meanSquaredError: '); + + const losses = labels.squaredDifference(predictions); + return LossOps.computeWeightedLoss(losses, weights, reduction); + } } diff --git a/src/ops/loss_ops_test.ts b/src/ops/loss_ops_test.ts index 0235c2e7ae..e56766cdf9 100644 --- a/src/ops/loss_ops_test.ts +++ b/src/ops/loss_ops_test.ts @@ -16,9 +16,9 @@ */ import * as tf from '../index'; +import {describeWithFlags} from '../jasmine_util'; // tslint:disable-next-line:max-line-length import {ALL_ENVS, expectArraysClose, expectNumbersClose} from '../test_util'; -import {describeWithFlags} from '../jasmine_util'; describeWithFlags('computeWeightedLoss', ALL_ENVS, () => { it('1D - no weights', () => { @@ -391,3 +391,196 @@ describeWithFlags('absoluteDifference', ALL_ENVS, () => { .toThrowError(e); }); }); + +describeWithFlags('meanSquaredError', ALL_ENVS, () => { + it('1D', () => { + const predictions = tf.tensor1d([1, 2, 3]); + const label = tf.tensor1d([0.3, -0.6, -0.1]); + + const y = tf.losses.meanSquaredError(label, predictions); + + expect(y.shape).toEqual([]); + expectNumbersClose( + y.get(), + ((1 - 0.3) * (1 - 0.3) + (2 - (-0.6)) * (2 - (-0.6)) + + (3 - (-0.1)) * (3 - (-0.1))) / + 3); + }); + + it('1D - weighted - Reduction.SUM_BY_NONZERO_WEIGHTS', () => { + const predictions = tf.tensor1d([1, 2, 3]); + const label = tf.tensor1d([0.3, -0.6, -0.1]); + const weights = tf.tensor1d([0.1, 0.2, 0.3]); + + const y = tf.losses.meanSquaredError(label, predictions, weights); + + expect(y.shape).toEqual([]); + expectNumbersClose( + y.get(), + ((1 - 0.3) * (1 - 0.3) * 0.1 + (2 - (-0.6)) * (2 - (-0.6)) * 0.2 + + (3 - (-0.1)) * (3 - (-0.1)) * 0.3) / + 3); + }); + + it('1D - weighted - Reduction.NONE', () => { + const predictions = tf.tensor1d([1, 2, 3]); + const label = tf.tensor1d([0.3, -0.6, -0.1]); + const weights = tf.tensor1d([0.1, 0.2, 0.3]); + + const y = tf.losses.meanSquaredError( + label, predictions, weights, tf.Reduction.NONE); + + expect(y.shape).toEqual([3]); + expectArraysClose(y, [ + (1 - 0.3) * (1 - 0.3) * 0.1, (2 - (-0.6)) * (2 - (-0.6)) * 0.2, + (3 - (-0.1)) * (3 - (-0.1)) * 0.3 + ]); + }); + + it('1D - Reduction.MEAN', () => { + const predictions = tf.tensor1d([1, 2, 3]); + const label = tf.tensor1d([0.3, -0.6, -0.1]); + + const y = tf.losses.meanSquaredError( + label, predictions, undefined, tf.Reduction.MEAN); + + expect(y.shape).toEqual([]); + expectNumbersClose( + y.get(), + ((1 - 0.3) * (1 - 0.3) + (2 - (-0.6)) * (2 - (-0.6)) + + (3 - (-0.1)) * (3 - (-0.1))) / + 3); + }); + + it('1D - weighted - Reduction.MEAN', () => { + const predictions = tf.tensor1d([1, 2, 3]); + const label = tf.tensor1d([0.3, -0.6, -0.1]); + const weights = tf.tensor1d([0.1, 0.2, 0.3]); + + const y = tf.losses.meanSquaredError( + label, predictions, weights, tf.Reduction.MEAN); + + expect(y.shape).toEqual([]); + expectNumbersClose( + y.get(), + (((1 - 0.3) * (1 - 0.3) * 0.1) + ((2 - (-0.6)) * (2 - (-0.6)) * 0.2) + + ((3 - (-0.1)) * (3 - (-0.1)) * 0.3)) / + 0.6); + }); + + it('2D', () => { + const predictions = tf.tensor2d([4, 8, 12, 8, 1, 3], [2, 3]); + const label = tf.tensor2d([1, 9, 2, -5, -2, 6], [2, 3]); + + const y = tf.losses.meanSquaredError(label, predictions); + + expect(y.shape).toEqual([]); + expectNumbersClose( + y.get(), + ((4 - 1) * (4 - 1) + (8 - 9) * (8 - 9) + (12 - 2) * (12 - 2) + + (8 - (-5)) * (8 - (-5)) + (1 - (-2)) * (1 - (-2)) + + (3 - 6) * (3 - 6)) / + 6); + }); + + it('2D - weighted - Reduction.SUM_BY_NONZERO_WEIGHTS', () => { + const predictions = tf.tensor2d([4, 8, 12, 8, 1, 3], [2, 3]); + const label = tf.tensor2d([1, 9, 2, -5, -2, 6], [2, 3]); + const weights = tf.tensor2d([3, 0, 5, 0, 4, 2], [2, 3]); + + const y = tf.losses.meanSquaredError(label, predictions, weights); + + expect(y.shape).toEqual([]); + expectNumbersClose( + y.get(), + ((4 - 1) * (4 - 1) * 3 + (8 - 9) * (8 - 9) * 0 + + (12 - 2) * (12 - 2) * 5 + (8 - (-5)) * (8 - (-5)) * 0 + + (1 - (-2)) * (1 - (-2)) * 4 + (3 - 6) * (3 - 6) * 2) / + 4); + }); + + it('2D - weighted - Reduction.NONE', () => { + const predictions = tf.tensor2d([4, 8, 12, 8, 1, 3], [2, 3]); + const label = tf.tensor2d([1, 9, 2, -5, -2, 6], [2, 3]); + const weights = tf.tensor2d([3, 6, 5, 0, 4, 2], [2, 3]); + + const y = tf.losses.meanSquaredError( + label, predictions, weights, tf.Reduction.NONE); + + expect(y.shape).toEqual([2, 3]); + expectArraysClose(y, [ + (4 - 1) * (4 - 1) * 3, (8 - 9) * (8 - 9) * 6, (12 - 2) * (12 - 2) * 5, + (8 - (-5)) * (8 - (-5)) * 0, (1 - (-2)) * (1 - (-2)) * 4, + (3 - 6) * (3 - 6) * 2 + ]); + }); + + it('2D - Reduction.MEAN', () => { + const predictions = tf.tensor2d([4, 8, 12, 8, 1, 3], [2, 3]); + const label = tf.tensor2d([1, 9, 2, -5, -2, 6], [2, 3]); + + const y = tf.losses.meanSquaredError( + label, predictions, undefined, tf.Reduction.MEAN); + + expect(y.shape).toEqual([]); + expectNumbersClose( + y.get(), + ((4 - 1) * (4 - 1) + (8 - 9) * (8 - 9) + (12 - 2) * (12 - 2) + + (8 - (-5)) * (8 - (-5)) + (1 - (-2)) * (1 - (-2)) + + (3 - 6) * (3 - 6)) / + 6); + }); + + it('2D - weighted - Reduction.MEAN', () => { + const predictions = tf.tensor2d([4, 8, 12, 8, 1, 3], [2, 3]); + const label = tf.tensor2d([1, 9, 2, -5, -2, 6], [2, 3]); + const weights = tf.tensor2d([3, 6, 5, 0, 4, 2], [2, 3]); + + const y = tf.losses.meanSquaredError( + label, predictions, weights, tf.Reduction.MEAN); + + expect(y.shape).toEqual([]); + expectNumbersClose( + y.get(), + ((4 - 1) * (4 - 1) * 3 + (8 - 9) * (8 - 9) * 6 + + (12 - 2) * (12 - 2) * 5 + (8 - (-5)) * (8 - (-5)) * 0 + + (1 - (-2)) * (1 - (-2)) * 4 + (3 - 6) * (3 - 6) * 2) / + 20); + }); + + it('throws when passed label as a non-tensor', () => { + const predictions = tf.tensor2d([4, 8, 12, 8, 1, 3], [2, 3]); + const weights = tf.tensor2d([3, 6, 5, 0, 4, 2], [2, 3]); + + const e = /Argument 'labels' passed to 'meanSquaredError' must be a Tensor/; + expect( + () => tf.losses.meanSquaredError( + {} as tf.Tensor, predictions, weights, tf.Reduction.MEAN)) + .toThrowError(e); + }); + + it('throws when passed label as a non-tensor', () => { + const label = tf.tensor2d([1, 9, 2, -5, -2, 6], [2, 3]); + const weights = tf.tensor2d([3, 6, 5, 0, 4, 2], [2, 3]); + + const e = new RegExp( + 'Argument \'predictions\' passed to \'meanSquaredError\' ' + + 'must be a Tensor'); + expect( + () => tf.losses.meanSquaredError( + label, {} as tf.Tensor, weights, tf.Reduction.MEAN)) + .toThrowError(e); + }); + + it('throws when passed weights as a non-tensor', () => { + const predictions = tf.tensor2d([4, 8, 12, 8, 1, 3], [2, 3]); + const label = tf.tensor2d([1, 9, 2, -5, -2, 6], [2, 3]); + + const e = + /Argument 'weights' passed to 'meanSquaredError' must be a Tensor/; + expect( + () => tf.losses.meanSquaredError( + label, predictions, {} as tf.Tensor, tf.Reduction.MEAN)) + .toThrowError(e); + }); +}); diff --git a/src/ops/ops.ts b/src/ops/ops.ts index d7edd98a99..132d86b227 100644 --- a/src/ops/ops.ts +++ b/src/ops/ops.ts @@ -225,6 +225,7 @@ import {Rank} from '../types'; export const losses = { softmaxCrossEntropy: SoftmaxOps.softmaxCrossEntropy, absoluteDifference: LossOps.absoluteDifference, + meanSquaredError: LossOps.meanSquaredError, computeWeightedLoss: LossOps.computeWeightedLoss };