From 16fddf32a54d19913c372f4ab59c98ee507fd6ff Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Sun, 3 Sep 2017 17:51:25 +0800 Subject: [PATCH 1/2] Add broadcasting support (e.g. matrix-vector) for cos sim operator. --- paddle/operators/cos_sim_op.cc | 81 +++++++--- paddle/operators/cos_sim_op.h | 142 +++++++++++------- .../v2/framework/tests/test_cos_sim_op.py | 93 +++++++++++- 3 files changed, 238 insertions(+), 78 deletions(-) diff --git a/paddle/operators/cos_sim_op.cc b/paddle/operators/cos_sim_op.cc index c033af3b741ae2..428ee7d9d0318f 100644 --- a/paddle/operators/cos_sim_op.cc +++ b/paddle/operators/cos_sim_op.cc @@ -25,16 +25,29 @@ class CosSimOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + // notnull check PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null."); - PADDLE_ENFORCE_EQ(ctx.Input("X")->dims(), - ctx.Input("Y")->dims(), - "Dimensions of Input(X) and Input(Y) must be the same."); - - auto dims = ctx.Input("X")->dims(); - ctx.Output("Out")->Resize({dims[0], 1}); - ctx.Output("XNorm")->Resize({dims[0], 1}); - ctx.Output("YNorm")->Resize({dims[0], 1}); + + // shape check + auto x_dims = ctx.Input("X")->dims(); + auto y_dims = ctx.Input("Y")->dims(); + PADDLE_ENFORCE_EQ(framework::arity(x_dims), framework::arity(y_dims), + "Ranks of Input(X) and Input(Y) must be equal."); + PADDLE_ENFORCE_GE(framework::arity(x_dims), 2, + "Rank of Input(X) must not be less than 2."); + PADDLE_ENFORCE_EQ( + framework::slice_ddim(x_dims, 1, framework::arity(x_dims)), + framework::slice_ddim(y_dims, 1, framework::arity(y_dims)), + "All dimensions except 1st of Input(X) and Input(Y) must be equal."); + PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1, + "1st dimension of Input(Y) must be equal to Input(X) or " + "just 1 (which will be broadcasted to match Input(X))."); + + // resize tensor + ctx.Output("Out")->Resize({x_dims[0], 1}); + ctx.Output("XNorm")->Resize({x_dims[0], 1}); + ctx.Output("YNorm")->Resize({y_dims[0], 1}); } }; @@ -42,8 +55,8 @@ class CosSimOpMaker : public framework::OpProtoAndCheckerMaker { public: CosSimOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The first input of cos_sim op."); - AddInput("Y", "The second input of cos_sim op."); + AddInput("X", "The 1st input of cos_sim op."); + AddInput("Y", "The 2nd input of cos_sim op."); AddOutput("Out", "The output of cos_sim op."); AddOutput("XNorm", "Row norm of the first input.").AsIntermediate(); AddOutput("YNorm", "Row norm of the second input.").AsIntermediate(); @@ -51,7 +64,12 @@ class CosSimOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Cosine Similarity Operator. -The equation is: Out = X^T * Y / (sqrt(X^T * X) * sqrt(Y^T * Y)) +The equation is: Out = X^T * Y / (sqrt(X^T * X) * sqrt(Y^T * Y)). + +Input(X) and Input(Y) must have the same shape, except that the 1st dimension +of Input(Y) could be just 1 (different from Input(X)), which will be +broadcasted to match the shape of Input(X) before computing their cosine +similarity. )DOC"); } }; @@ -62,32 +80,47 @@ class CosSimOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + // notnull check PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("XNorm"), "Input(XNorm) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("YNorm"), "Input(YNorm) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Out"), + "Input(Out) must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) must not be null."); + // shape check auto x_dims = ctx.Input("X")->dims(); auto y_dims = ctx.Input("Y")->dims(); + PADDLE_ENFORCE_GE(framework::arity(x_dims), framework::arity(y_dims), + "Ranks of Input(X) and Input(Y) must be equal."); + PADDLE_ENFORCE_GE(framework::arity(x_dims), 2, + "Rank of Input(X) must not be less than 2."); + PADDLE_ENFORCE_EQ( + framework::slice_ddim(x_dims, 1, framework::arity(x_dims)), + framework::slice_ddim(y_dims, 1, framework::arity(y_dims)), + "All dimensions except 1st of Input(X) and Input(Y) must be equal."); + PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1, + "1st dimension of Input(Y) must be equal to Input(X) or " + "just 1 (which will be broadcasted to match Input(X))."); auto xnorm_dims = ctx.Input("XNorm")->dims(); + PADDLE_ENFORCE_EQ(xnorm_dims, framework::make_ddim({x_dims[0], 1}), + "Shape of Input(XNorm) must be [X.Dim(0), 1]."); auto ynorm_dims = ctx.Input("YNorm")->dims(); - auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); - PADDLE_ENFORCE_EQ(x_dims, y_dims, - "Dimensions of Input(X) and Input(Y) must be the same."); - PADDLE_ENFORCE_EQ(xnorm_dims[0], x_dims[0], - "1st dimension of XNorm must equal that of Input(X)."); - PADDLE_ENFORCE_EQ(xnorm_dims[1], 1, "2st dimension of XNorm must be one."); - PADDLE_ENFORCE_EQ(ynorm_dims[0], y_dims[0], - "1st dimension of YNorm must equal that of Input(Y)."); - PADDLE_ENFORCE_EQ(ynorm_dims[1], 1, "2st dimension of YNorm must be one."); - PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0], - "1st dimension of Out@GRAD must equal that of Input(X)"); - PADDLE_ENFORCE_EQ(out_dims[1], 1, "1st dimension of Out@GRAD must be one."); - + PADDLE_ENFORCE_EQ(ynorm_dims, framework::make_ddim({y_dims[0], 1}), + "Shape of Input(YNorm) must be [Y.Dim(0), 1]."); + auto out_dims = ctx.Input("Out")->dims(); + PADDLE_ENFORCE_EQ(out_dims, framework::make_ddim({x_dims[0], 1}), + "Shape of Input(Out) must be [X.Dim(0), 1]."); + auto out_grad_dims = + ctx.Input(framework::GradVarName("Out"))->dims(); + PADDLE_ENFORCE_EQ(out_grad_dims, framework::make_ddim({x_dims[0], 1}), + "Shape of Input(Out@Grad) must be [X.Dim(0), 1]."); + + // resize tensor auto *x_grad = ctx.Output(framework::GradVarName("X")); auto *y_grad = ctx.Output(framework::GradVarName("Y")); if (x_grad) x_grad->Resize(x_dims); diff --git a/paddle/operators/cos_sim_op.h b/paddle/operators/cos_sim_op.h index 9e3ff26815644e..62298ccbcec3ed 100644 --- a/paddle/operators/cos_sim_op.h +++ b/paddle/operators/cos_sim_op.h @@ -28,30 +28,38 @@ template class CosSimKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* input_x = context.Input("X"); - auto* input_y = context.Input("Y"); - auto* output_z = context.Output("Out"); - auto* output_x_norm = context.Output("XNorm"); - auto* output_y_norm = context.Output("YNorm"); + // get Tensor + auto* in_x = context.Input("X"); + auto* in_y = context.Input("Y"); + auto* out_z = context.Output("Out"); + auto* out_x_norm = context.Output("XNorm"); + auto* out_y_norm = context.Output("YNorm"); + out_z->mutable_data(context.GetPlace()); + out_x_norm->mutable_data(context.GetPlace()); + out_y_norm->mutable_data(context.GetPlace()); - output_z->mutable_data(context.GetPlace()); - output_x_norm->mutable_data(context.GetPlace()); - output_y_norm->mutable_data(context.GetPlace()); - - auto dims = input_x->dims(); - int size = static_cast(framework::product(dims)); - auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); - auto x = EigenMatrix::From(*input_x, new_dims); - auto y = EigenMatrix::From(*input_y, new_dims); - auto z = EigenMatrix::From(*output_z); - auto x_norm = EigenMatrix::From(*output_x_norm); - auto y_norm = EigenMatrix::From(*output_y_norm); + // convert Tensor to Eigen Tensor + int rows_x = in_x->dims()[0]; + int rows_y = in_y->dims()[0]; + int cols = framework::product(in_x->dims()) / rows_x; + auto x = EigenMatrix::From(*in_x, framework::make_ddim({rows_x, cols})); + auto y = EigenMatrix::From(*in_y, framework::make_ddim({rows_y, cols})); + auto z = EigenMatrix::From(*out_z); + auto x_norm = EigenMatrix::From(*out_x_norm); + auto y_norm = EigenMatrix::From(*out_y_norm); + // compute auto place = context.GetEigenDevice(); - auto xy = (x * y).sum(Eigen::array({1})); x_norm.device(place) = x.square().sum(Eigen::array({1})).sqrt(); y_norm.device(place) = y.square().sum(Eigen::array({1})).sqrt(); - z.device(place) = xy / x_norm / y_norm; + if (rows_x == rows_y) { + auto xy = (x * y).sum(Eigen::array({1})); + z.device(place) = xy / x_norm / y_norm; + } else { + Eigen::DSizes bcast(rows_x, 1); + auto xy = (x * y.broadcast(bcast)).sum(Eigen::array({1})); + z.device(place) = xy / x_norm / y_norm.broadcast(bcast); + } } }; @@ -59,43 +67,75 @@ template class CosSimGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* input_x = context.Input("X"); - auto* input_y = context.Input("Y"); - auto* input_z = context.Input("Out"); - auto* input_x_norm = context.Input("XNorm"); - auto* input_y_norm = context.Input("YNorm"); - auto* output_grad_x = context.Output(framework::GradVarName("X")); - auto* output_grad_y = context.Output(framework::GradVarName("Y")); - auto* input_grad_z = context.Input(framework::GradVarName("Out")); + // get Tensor + auto* in_x = context.Input("X"); + auto* in_y = context.Input("Y"); + auto* in_z = context.Input("Out"); + auto* in_x_norm = context.Input("XNorm"); + auto* in_y_norm = context.Input("YNorm"); + auto* out_grad_x = context.Output(framework::GradVarName("X")); + auto* out_grad_y = context.Output(framework::GradVarName("Y")); + auto* in_grad_z = context.Input(framework::GradVarName("Out")); - auto dims = input_x->dims(); - int size = static_cast(framework::product(dims)); - auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); - auto x = EigenMatrix::From(*input_x, new_dims); - auto y = EigenMatrix::From(*input_y, new_dims); - auto z = EigenMatrix::From(*input_z); - auto x_norm = EigenMatrix::From(*input_x_norm); - auto y_norm = EigenMatrix::From(*input_y_norm); - auto dz = EigenMatrix::From(*input_grad_z); + // convert Tensor to Eigen Tensor + int rows_x = in_x->dims()[0]; + int rows_y = in_y->dims()[0]; + int cols = framework::product(in_x->dims()) / rows_x; + auto x = EigenMatrix::From(*in_x, framework::make_ddim({rows_x, cols})); + auto y = EigenMatrix::From(*in_y, framework::make_ddim({rows_y, cols})); + auto z = EigenMatrix::From(*in_z); + auto x_norm = EigenMatrix::From(*in_x_norm); + auto y_norm = EigenMatrix::From(*in_y_norm); + auto dz = EigenMatrix::From(*in_grad_z); - Eigen::DSizes bcast(1, new_dims[1]); + // compute gradident + Eigen::DSizes bcast(1, cols); auto z_bcast = z.broadcast(bcast); auto dz_bcast = dz.broadcast(bcast); - auto place = context.GetEigenDevice(); auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast); - auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast); - auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast); - if (output_grad_x) { - output_grad_x->mutable_data(context.GetPlace()); - auto dx = EigenMatrix::From(*output_grad_x, new_dims); - dx.device(place) = - dz_bcast * (y / norm_prod_bcast - z_bcast * x / x_snorm_bcast); - } - if (output_grad_y) { - output_grad_y->mutable_data(context.GetPlace()); - auto dy = EigenMatrix::From(*output_grad_y, new_dims); - dy.device(place) = - dz_bcast * (x / norm_prod_bcast - z_bcast * y / y_snorm_bcast); + auto place = context.GetEigenDevice(); + if (rows_x == rows_y) { + auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast); + auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast); + // compute dx + if (out_grad_x) { + out_grad_x->mutable_data(context.GetPlace()); + auto dx = EigenMatrix::From(*out_grad_x, + framework::make_ddim({rows_x, cols})); + auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast; + dx.device(place) = dz_bcast * grad; + } + // compute dy + if (out_grad_y) { + out_grad_y->mutable_data(context.GetPlace()); + auto dy = EigenMatrix::From(*out_grad_y, + framework::make_ddim({rows_y, cols})); + auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast; + dy.device(place) = dz_bcast * grad; + } + } else { + Eigen::DSizes bcast_row(rows_x, 1); + auto y_bcast = y.broadcast(bcast_row); + auto y_snorm_bcast = + y_norm.square().eval().broadcast(bcast_row).eval().broadcast(bcast); + auto norm_prod_bcast = + (x_norm * y_norm.broadcast(bcast_row)).eval().broadcast(bcast); + // compute dx + if (out_grad_x) { + out_grad_x->mutable_data(context.GetPlace()); + auto dx = EigenMatrix::From( + *out_grad_x, framework::make_ddim({rows_x, cols})); + auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast; + dx.device(place) = dz_bcast * grad; + } + // compute dy + if (out_grad_y) { + out_grad_y->mutable_data(context.GetPlace()); + auto dy = EigenMatrix::From( + *out_grad_y, framework::make_ddim({rows_y, cols})); + auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast; + dy.device(place) = (dz_bcast * grad).sum(Eigen::array({0})); + } } } }; diff --git a/python/paddle/v2/framework/tests/test_cos_sim_op.py b/python/paddle/v2/framework/tests/test_cos_sim_op.py index 32013a7999a4be..3f2feaa9339be2 100644 --- a/python/paddle/v2/framework/tests/test_cos_sim_op.py +++ b/python/paddle/v2/framework/tests/test_cos_sim_op.py @@ -4,7 +4,7 @@ from op_test_util import OpTestMeta -class TestCosSimOp(unittest.TestCase): +class TestCosSimOpWithRank2(unittest.TestCase): __metaclass__ = OpTestMeta def setUp(self): @@ -24,12 +24,72 @@ def setUp(self): } +class TestCosSimOpWithRank2Bcast(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "cos_sim" + self.inputs = { + 'X': np.random.random((32, 64)).astype("float32"), + 'Y': np.random.random((1, 64)).astype("float32") + } + expect_x_norm = np.linalg.norm(self.inputs['X'], axis=1) + expect_y_norm = np.linalg.norm(self.inputs['Y'], axis=1) + expect_out = (self.inputs['X'] * self.inputs['Y']).sum(axis=1) / \ + expect_x_norm / expect_y_norm + self.outputs = { + 'XNorm': np.expand_dims(expect_x_norm, 1), + 'YNorm': np.expand_dims(expect_y_norm, 1), + 'Out': np.expand_dims(expect_out, 1) + } + + +class TestCosSimOpWithRank3(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "cos_sim" + self.inputs = { + 'X': np.random.random((32, 64, 10)).astype("float32"), + 'Y': np.random.random((32, 64, 10)).astype("float32") + } + expect_x_norm = np.linalg.norm(self.inputs['X'], axis=(1, 2)) + expect_y_norm = np.linalg.norm(self.inputs['Y'], axis=(1, 2)) + expect_out = (self.inputs['X'] * self.inputs['Y']).sum(axis=(1, 2)) / \ + expect_x_norm / expect_y_norm + self.outputs = { + 'XNorm': np.expand_dims(expect_x_norm, 1), + 'YNorm': np.expand_dims(expect_y_norm, 1), + 'Out': np.expand_dims(expect_out, 1) + } + + +class TestCosSimOpWithRank3Bcast(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "cos_sim" + self.inputs = { + 'X': np.random.random((32, 64, 10)).astype("float32"), + 'Y': np.random.random((1, 64, 10)).astype("float32") + } + expect_x_norm = np.linalg.norm(self.inputs['X'], axis=(1, 2)) + expect_y_norm = np.linalg.norm(self.inputs['Y'], axis=(1, 2)) + expect_out = (self.inputs['X'] * self.inputs['Y']).sum(axis=(1, 2)) / \ + expect_x_norm / expect_y_norm + self.outputs = { + 'XNorm': np.expand_dims(expect_x_norm, 1), + 'YNorm': np.expand_dims(expect_y_norm, 1), + 'Out': np.expand_dims(expect_out, 1) + } + + class TestCosSimGradOp(GradientChecker): def setUp(self): self.op = create_op("cos_sim") self.inputs = { - 'X': np.random.random((10, 5)).astype("float32"), - 'Y': np.random.random((10, 5)).astype("float32") + 'X': np.random.random((6, 5)).astype("float32"), + 'Y': np.random.random((6, 5)).astype("float32") } def test_cpu_gpu_compare(self): @@ -56,5 +116,32 @@ def test_ignore_y(self): no_grad_set={"Y"}) +class TestCosSimGradOpWithRank2Bcast(TestCosSimGradOp): + def setUp(self): + self.op = create_op("cos_sim") + self.inputs = { + 'X': np.random.random((6, 5)).astype("float32"), + 'Y': np.random.random((1, 5)).astype("float32") + } + + +class TestCosSimGradOpWithRank3(TestCosSimGradOp): + def setUp(self): + self.op = create_op("cos_sim") + self.inputs = { + 'X': np.random.random((6, 5, 2)).astype("float32"), + 'Y': np.random.random((6, 5, 2)).astype("float32") + } + + +class TestCosSimGradOpWithRank3Bcast(TestCosSimGradOp): + def setUp(self): + self.op = create_op("cos_sim") + self.inputs = { + 'X': np.random.random((6, 5, 2)).astype("float32"), + 'Y': np.random.random((1, 5, 2)).astype("float32") + } + + if __name__ == '__main__': unittest.main() From 03ea7320d3de03dec3880bd1504db8d61ad06a0c Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Wed, 13 Sep 2017 13:23:03 +0800 Subject: [PATCH 2/2] Update cos_sim operator by following reviewer's comments. --- paddle/operators/cos_sim_op.cc | 66 +++++++++++++++++++--------------- paddle/operators/cos_sim_op.h | 58 ++++++++++++++---------------- 2 files changed, 65 insertions(+), 59 deletions(-) diff --git a/paddle/operators/cos_sim_op.cc b/paddle/operators/cos_sim_op.cc index 428ee7d9d0318f..e3bee437921356 100644 --- a/paddle/operators/cos_sim_op.cc +++ b/paddle/operators/cos_sim_op.cc @@ -32,17 +32,18 @@ class CosSimOp : public framework::OperatorWithKernel { // shape check auto x_dims = ctx.Input("X")->dims(); auto y_dims = ctx.Input("Y")->dims(); - PADDLE_ENFORCE_EQ(framework::arity(x_dims), framework::arity(y_dims), + + PADDLE_ENFORCE_EQ(x_dims.size(), y_dims.size(), "Ranks of Input(X) and Input(Y) must be equal."); - PADDLE_ENFORCE_GE(framework::arity(x_dims), 2, + PADDLE_ENFORCE_GE(x_dims.size(), 2, "Rank of Input(X) must not be less than 2."); - PADDLE_ENFORCE_EQ( - framework::slice_ddim(x_dims, 1, framework::arity(x_dims)), - framework::slice_ddim(y_dims, 1, framework::arity(y_dims)), - "All dimensions except 1st of Input(X) and Input(Y) must be equal."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()), + framework::slice_ddim(y_dims, 1, y_dims.size()), + "All dimensions except the 1st of Input(X) and Input(Y) " + "must be equal."); PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1, - "1st dimension of Input(Y) must be equal to Input(X) or " - "just 1 (which will be broadcasted to match Input(X))."); + "The 1st dimension of Input(Y) must be equal to Input(X) or" + " just 1 (which will be broadcasted to match Input(X))."); // resize tensor ctx.Output("Out")->Resize({x_dims[0], 1}); @@ -58,8 +59,14 @@ class CosSimOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The 1st input of cos_sim op."); AddInput("Y", "The 2nd input of cos_sim op."); AddOutput("Out", "The output of cos_sim op."); - AddOutput("XNorm", "Row norm of the first input.").AsIntermediate(); - AddOutput("YNorm", "Row norm of the second input.").AsIntermediate(); + AddOutput("XNorm", + "Norm of the first input, reduced along the 1st " + "dimension.") + .AsIntermediate(); + AddOutput("YNorm", + "Norm of the second input, reduced along the 1st " + "dimension.") + .AsIntermediate(); AddComment(R"DOC( Cosine Similarity Operator. @@ -95,29 +102,32 @@ class CosSimOpGrad : public framework::OperatorWithKernel { // shape check auto x_dims = ctx.Input("X")->dims(); auto y_dims = ctx.Input("Y")->dims(); - PADDLE_ENFORCE_GE(framework::arity(x_dims), framework::arity(y_dims), - "Ranks of Input(X) and Input(Y) must be equal."); - PADDLE_ENFORCE_GE(framework::arity(x_dims), 2, - "Rank of Input(X) must not be less than 2."); - PADDLE_ENFORCE_EQ( - framework::slice_ddim(x_dims, 1, framework::arity(x_dims)), - framework::slice_ddim(y_dims, 1, framework::arity(y_dims)), - "All dimensions except 1st of Input(X) and Input(Y) must be equal."); - PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1, - "1st dimension of Input(Y) must be equal to Input(X) or " - "just 1 (which will be broadcasted to match Input(X))."); auto xnorm_dims = ctx.Input("XNorm")->dims(); - PADDLE_ENFORCE_EQ(xnorm_dims, framework::make_ddim({x_dims[0], 1}), - "Shape of Input(XNorm) must be [X.Dim(0), 1]."); auto ynorm_dims = ctx.Input("YNorm")->dims(); - PADDLE_ENFORCE_EQ(ynorm_dims, framework::make_ddim({y_dims[0], 1}), - "Shape of Input(YNorm) must be [Y.Dim(0), 1]."); auto out_dims = ctx.Input("Out")->dims(); - PADDLE_ENFORCE_EQ(out_dims, framework::make_ddim({x_dims[0], 1}), - "Shape of Input(Out) must be [X.Dim(0), 1]."); auto out_grad_dims = ctx.Input(framework::GradVarName("Out"))->dims(); - PADDLE_ENFORCE_EQ(out_grad_dims, framework::make_ddim({x_dims[0], 1}), + + PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), + "Ranks of Input(X) and Input(Y) must be equal."); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + "Rank of Input(X) must not be less than 2."); + PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 1, x_dims.size()), + framework::slice_ddim(y_dims, 1, y_dims.size()), + "All dimensions except the 1st of Input(X) and Input(Y) " + "must be equal."); + PADDLE_ENFORCE(x_dims[0] == y_dims[0] || y_dims[0] == 1, + "The 1st dimension of Input(Y) must be equal to Input(X) or" + " just 1 (which will be broadcasted to match Input(X))."); + auto target_xnorm_dims = framework::make_ddim({x_dims[0], 1}), + auto target_ynorm_dims = framework::make_ddim({y_dims[0], 1}), + PADDLE_ENFORCE_EQ(xnorm_dims, target_xnorm_dims, + "Shape of Input(XNorm) must be [X.Dim(0), 1]."); + PADDLE_ENFORCE_EQ(ynorm_dims, target_ynorm_dims, + "Shape of Input(YNorm) must be [Y.Dim(0), 1]."); + PADDLE_ENFORCE_EQ(out_dims, target_xnorm_dims, + "Shape of Input(Out) must be [X.Dim(0), 1]."); + PADDLE_ENFORCE_EQ(out_grad_dims, target_xnorm_dims, "Shape of Input(Out@Grad) must be [X.Dim(0), 1]."); // resize tensor diff --git a/paddle/operators/cos_sim_op.h b/paddle/operators/cos_sim_op.h index 62298ccbcec3ed..4d03d5902d5866 100644 --- a/paddle/operators/cos_sim_op.h +++ b/paddle/operators/cos_sim_op.h @@ -42,22 +42,23 @@ class CosSimKernel : public framework::OpKernel { int rows_x = in_x->dims()[0]; int rows_y = in_y->dims()[0]; int cols = framework::product(in_x->dims()) / rows_x; - auto x = EigenMatrix::From(*in_x, framework::make_ddim({rows_x, cols})); - auto y = EigenMatrix::From(*in_y, framework::make_ddim({rows_y, cols})); + auto x = EigenMatrix::Reshape(*in_x, 1); + auto y = EigenMatrix::Reshape(*in_y, 1); auto z = EigenMatrix::From(*out_z); auto x_norm = EigenMatrix::From(*out_x_norm); auto y_norm = EigenMatrix::From(*out_y_norm); // compute auto place = context.GetEigenDevice(); - x_norm.device(place) = x.square().sum(Eigen::array({1})).sqrt(); - y_norm.device(place) = y.square().sum(Eigen::array({1})).sqrt(); + auto row_along = Eigen::array({{1}}); + x_norm.device(place) = x.square().sum(row_along).sqrt(); + y_norm.device(place) = y.square().sum(row_along).sqrt(); if (rows_x == rows_y) { auto xy = (x * y).sum(Eigen::array({1})); z.device(place) = xy / x_norm / y_norm; } else { Eigen::DSizes bcast(rows_x, 1); - auto xy = (x * y.broadcast(bcast)).sum(Eigen::array({1})); + auto xy = (x * y.broadcast(bcast)).sum(row_along); z.device(place) = xy / x_norm / y_norm.broadcast(bcast); } } @@ -78,61 +79,56 @@ class CosSimGradKernel : public framework::OpKernel { auto* in_grad_z = context.Input(framework::GradVarName("Out")); // convert Tensor to Eigen Tensor - int rows_x = in_x->dims()[0]; - int rows_y = in_y->dims()[0]; - int cols = framework::product(in_x->dims()) / rows_x; - auto x = EigenMatrix::From(*in_x, framework::make_ddim({rows_x, cols})); - auto y = EigenMatrix::From(*in_y, framework::make_ddim({rows_y, cols})); + auto x = EigenMatrix::Reshape(*in_x, 1); + auto y = EigenMatrix::Reshape(*in_y, 1); auto z = EigenMatrix::From(*in_z); auto x_norm = EigenMatrix::From(*in_x_norm); auto y_norm = EigenMatrix::From(*in_y_norm); auto dz = EigenMatrix::From(*in_grad_z); // compute gradident - Eigen::DSizes bcast(1, cols); - auto z_bcast = z.broadcast(bcast); - auto dz_bcast = dz.broadcast(bcast); - auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast); + int rows_x = in_x->dims()[0]; + int rows_y = in_y->dims()[0]; + int cols = framework::product(in_x->dims()) / rows_x; + Eigen::DSizes bcast_cols(1, cols); + auto z_bcast = z.broadcast(bcast_cols); + auto dz_bcast = dz.broadcast(bcast_cols); + auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols); auto place = context.GetEigenDevice(); if (rows_x == rows_y) { - auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast); - auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast); + auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols); + auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols); // compute dx if (out_grad_x) { out_grad_x->mutable_data(context.GetPlace()); - auto dx = EigenMatrix::From(*out_grad_x, - framework::make_ddim({rows_x, cols})); + auto dx = EigenMatrix::Reshape(*out_grad_x, 1); auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast; dx.device(place) = dz_bcast * grad; } // compute dy if (out_grad_y) { out_grad_y->mutable_data(context.GetPlace()); - auto dy = EigenMatrix::From(*out_grad_y, - framework::make_ddim({rows_y, cols})); - auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast; + auto dy = EigenMatrix::Reshape(*out_grad_y, 1) auto grad = + x / norm_prod_bcast - z_bcast * y / y_snorm_bcast; dy.device(place) = dz_bcast * grad; } } else { - Eigen::DSizes bcast_row(rows_x, 1); - auto y_bcast = y.broadcast(bcast_row); - auto y_snorm_bcast = - y_norm.square().eval().broadcast(bcast_row).eval().broadcast(bcast); - auto norm_prod_bcast = - (x_norm * y_norm.broadcast(bcast_row)).eval().broadcast(bcast); + Eigen::DSizes bcast_rows(rows_x, 1); + Eigen::DSizes bcast_rows_cols(rows_x, 1); + auto y_bcast = y.broadcast(bcast_rows); + auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_rows_cols); + auto norm_prod_bcast = x_norm * y_norm.broadcast(bcast_rows_cols); // compute dx if (out_grad_x) { out_grad_x->mutable_data(context.GetPlace()); - auto dx = EigenMatrix::From( - *out_grad_x, framework::make_ddim({rows_x, cols})); + auto dx = EigenMatrix::Reshape(*out_grad_x, 1); auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast; dx.device(place) = dz_bcast * grad; } // compute dy if (out_grad_y) { out_grad_y->mutable_data(context.GetPlace()); - auto dy = EigenMatrix::From( - *out_grad_y, framework::make_ddim({rows_y, cols})); + auto dy = EigenMatrix::Reshape(*out_grad_y, 1); auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast; dy.device(place) = (dz_bcast * grad).sum(Eigen::array({0})); }