Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix infershape bug #16907

Merged
merged 7 commits into from
Apr 17, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 30 additions & 16 deletions paddle/fluid/operators/linear_chain_crf_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,21 +152,28 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
auto transition_dims = ctx->GetInputDim("Transition");
PADDLE_ENFORCE_EQ(transition_dims.size(), 2,
"The Input(Transition) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(
transition_dims[0] - 2, transition_dims[1],
"An invalid dimension for the Input(Transition), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
PADDLE_ENFORCE_EQ(
emission_dims[1], transition_dims[1],
bool check = true;
if ((!ctx->IsRuntime()) &&
(transition_dims[0] <= 0 || transition_dims[1] <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(
transition_dims[0] - 2, transition_dims[1],
"An invalid dimension for the Input(Transition), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
}
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_dims[1], transition_dims[1],
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
"should be equal to the tag number.");

auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
"The Input(Label) should be a 2-D tensor with the 2nd "
"dimensions fixed to 1.");
PADDLE_ENFORCE_EQ(
emission_dims[0], label_dims[0],
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_dims[0], label_dims[0],
"The height of Input(Emission) and the height of Input(Label) "
"should be the same.");

Expand Down Expand Up @@ -211,21 +218,28 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
auto transition_exps_dims = ctx->GetInputDim("TransitionExps");
PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2,
"The Input(TransitionExps) should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(
transition_exps_dims[0] - 2, transition_exps_dims[1],
"An invalid dimension for the Input(TransitionExps), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
PADDLE_ENFORCE_EQ(
emission_exps_dims[1], transition_exps_dims[1],
bool check = true;
if ((!ctx->IsRuntime()) &&
(transition_exps_dims[0] <= 0 || transition_exps_dims[1] <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(
transition_exps_dims[0] - 2, transition_exps_dims[1],
"An invalid dimension for the Input(TransitionExps), which should "
"be a 2-D tensor with shape [(D + 2) x D].");
}
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_exps_dims[1], transition_exps_dims[1],
"The 2nd dimension of the Input(EmissionExps) and the "
"Input(TransitionExps) should be equal to the tag number.");

auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
"The Input(Label) should be a 2-D tensor with the 2nd "
"dimensions fixed to 1.");
PADDLE_ENFORCE_EQ(
emission_exps_dims[0], label_dims[0],
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, emission_exps_dims[0], label_dims[0],
"The height of Input(EmissionExps) and the height of Input(Label) "
"should be the same.");

Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/operators/metrics/accuracy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ class AccuracyOp : public framework::OperatorWithKernel {
// it's the output of topk.

PADDLE_ENFORCE_EQ(label_dim.size(), 2, "label's rank must be 2.");
PADDLE_ENFORCE_EQ(label_dim[1], 1, "label's second dimension must be 1");
PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0],
"the inference tensor's num_rows must be"
" the same as label.");
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, label_dim[1], 1,
"label's second dimension must be 1");
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, inference_dim[0], label_dim[0],
"the inference tensor's num_rows must be"
" the same as label.");

ctx->SetOutputDim("Accuracy", {1});
ctx->SetOutputDim("Correct", {1});
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/metrics/auc_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ class AucOp : public framework::OperatorWithKernel {
auto predict_height = ctx->GetInputDim("Predict")[0];
auto label_height = ctx->GetInputDim("Label")[0];

PADDLE_ENFORCE_EQ(predict_height, label_height,
"Out and Label should have same height.");
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, predict_height, label_height,
"Out and Label should have same height.");

int num_pred_buckets = ctx->Attrs().Get<int>("num_thresholds") + 1;
int slide_steps = ctx->Attrs().Get<int>("slide_steps");
Expand Down
8 changes: 6 additions & 2 deletions paddle/fluid/operators/sample_logits_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/sample_logits_op.h"
#include "paddle/fluid/operators/math/sample_prob.h"

#include <memory>

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -132,7 +133,10 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
"The labels should be a 2-D tensor.");

const int num_samples = ctx->Attrs().Get<int>("num_samples");
const int num_sampled_classes = labels_dims[1] + num_samples;
int num_sampled_classes = labels_dims[1] + num_samples;
if ((!ctx->IsRuntime()) && labels_dims[1] <= 0) {
num_sampled_classes = -1;
}
ctx->SetOutputDim("Samples", {logits_dims[0], num_sampled_classes});
ctx->SetOutputDim("Probabilities", {logits_dims[0], num_sampled_classes});
ctx->SetOutputDim("SampledLogits", {logits_dims[0], num_sampled_classes});
Expand Down
21 changes: 15 additions & 6 deletions paddle/fluid/operators/smooth_l1_loss_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ limitations under the License. */

#include "paddle/fluid/operators/smooth_l1_loss_op.h"

#include <memory>

namespace paddle {
namespace operators {

Expand All @@ -27,7 +29,14 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {

auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x_dims, y_dims);
bool check = true;
if ((!ctx->IsRuntime()) &&
(framework::product(x_dims) <= 0 || framework::product(y_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(x_dims, y_dims);
}
PADDLE_ENFORCE_GE(x_dims.size(), 2,
"The tensor rank of Input(X) should not be less than 2.");
if (ctx->HasInput("InsideWeight")) {
Expand Down Expand Up @@ -110,11 +119,11 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {

PADDLE_ENFORCE_GE(out_dims.size(), 2,
"The tensor rank of Input(Out@Grad) should be 2.");
PADDLE_ENFORCE_EQ(out_dims[0], in_dims[0],
"The 1st dimension of Input(Out@Grad) must be "
"same as input.");
PADDLE_ENFORCE_EQ(out_dims[1], 1,
"The 2nd dimension of Input(Out@Grad) must be 1.");
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, out_dims[0], in_dims[0],
"The 1st dimension of Input(Out@Grad) must be "
"same as input.");
PADDLE_INFERSHAPE_ENFORCE_EQ(
ctx, out_dims[1], 1, "The 2nd dimension of Input(Out@Grad) must be 1.");

auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
Expand Down
39 changes: 26 additions & 13 deletions paddle/fluid/operators/squared_l2_distance_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,26 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {

int rank = framework::arity(x_dims);
PADDLE_ENFORCE_GE(rank, 2, "Tensor rank should be at least equal to 2.");
PADDLE_ENFORCE_EQ(product(x_dims) / x_dims[0], product(y_dims) / y_dims[0],
"Product of dimensions expcet the first dimension of "
"input and target must be equal.");
PADDLE_ENFORCE(y_dims[0] == 1 || y_dims[0] == x_dims[0],
"First dimension of target must be equal to input "
"or to 1.");

bool check = true;
if ((!ctx->IsRuntime()) &&
(framework::product(x_dims) <= 0 || framework::product(y_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(product(x_dims) / x_dims[0],
product(y_dims) / y_dims[0],
"Product of dimensions expcet the first dimension of "
"input and target must be equal.");
}
check = true;
if ((!ctx->IsRuntime()) && (y_dims[0] <= 0 || x_dims[0] <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE(y_dims[0] == 1 || y_dims[0] == x_dims[0],
"First dimension of target must be equal to input "
"or to 1.");
}
ctx->SetOutputDim("sub_result", {x_dims[0], product(x_dims) / x_dims[0]});
ctx->SetOutputDim("Out", {x_dims[0], 1});
ctx->ShareLoD("X", /*->*/ "Out");
Expand Down Expand Up @@ -124,12 +137,12 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(out_dims[0], x_dims[0],
"First dimension of output gradient and "
"input value must be equal.");
PADDLE_ENFORCE_EQ(out_dims[1], 1,
"Second dimension of output gradient "
"must be 1.");
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, out_dims[0], x_dims[0],
"First dimension of output gradient and "
"input value must be equal.");
PADDLE_INFERSHAPE_ENFORCE_EQ(ctx, out_dims[1], 1,
"Second dimension of output gradient "
"must be 1.");
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) ctx->SetOutputDim(x_grad_name, x_dims);
Expand Down
41 changes: 41 additions & 0 deletions paddle/fluid/platform/enforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -356,5 +356,46 @@ using CommonType2 = typename std::add_lvalue_reference<
#define PADDLE_ENFORCE_LE(__VAL0, __VAL1, ...) \
__PADDLE_BINARY_COMPARE(__VAL0, __VAL1, <=, >, __VA_ARGS__)

#define __PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL1, __VAL2, __CMP, \
__INV_CMP, ...) \
do { \
auto __val1 = (__VAL1); \
auto __val2 = (__VAL2); \
if (!__CTX->IsRuntime()) { \
if (__val1 == -1 || __val2 == -1) { \
break; \
} \
} \
using __TYPE1__ = decltype(__val1); \
using __TYPE2__ = decltype(__val2); \
using __COMMON_TYPE1__ = \
::paddle::platform::details::CommonType1<__TYPE1__, __TYPE2__>; \
using __COMMON_TYPE2__ = \
::paddle::platform::details::CommonType2<__TYPE1__, __TYPE2__>; \
bool __is_not_error = (static_cast<__COMMON_TYPE1__>(__val1))__CMP( \
static_cast<__COMMON_TYPE2__>(__val2)); \
if (UNLIKELY(!__is_not_error)) { \
PADDLE_THROW("Enforce failed. Expected %s " #__CMP \
" %s, but received %s:%s " #__INV_CMP " %s:%s.\n%s", \
#__VAL1, #__VAL2, #__VAL1, \
::paddle::string::to_string(__val1), #__VAL2, \
::paddle::string::to_string(__val2), \
::paddle::string::Sprintf(__VA_ARGS__)); \
} \
} while (0)

#define PADDLE_INFERSHAPE_ENFORCE_EQ(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, ==, !=, __VA_ARGS__)
#define PADDLE_INFERSHAPE_ENFORCE_NE(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, !=, ==, __VA_ARGS__)
#define PADDLE_INFERSHAPE_ENFORCE_GT(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, >, <=, __VA_ARGS__)
#define PADDLE_INFERSHAPE_ENFORCE_GE(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, >=, <, __VA_ARGS__)
#define PADDLE_INFERSHAPE_ENFORCE_LT(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, <, >=, __VA_ARGS__)
#define PADDLE_INFERSHAPE_ENFORCE_LE(__CTX, __VAL0, __VAL1, ...) \
__PADDLE_INFERSHAPE_BINARY_COMPARE(__CTX, __VAL0, __VAL1, <=, >, __VA_ARGS__)

} // namespace platform
} // namespace paddle