Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
change check and shape_is_known
Browse files Browse the repository at this point in the history
  • Loading branch information
roywei committed May 3, 2019
1 parent 36c3306 commit 5502fa0
Show file tree
Hide file tree
Showing 23 changed files with 97 additions and 69 deletions.
4 changes: 2 additions & 2 deletions src/operator/bilinear_sampler-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,10 @@ class BilinearSamplerProp : public OperatorProperty {
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, grid]";
const mxnet::TShape &dshape = (*in_shape)[bs::kData];
const mxnet::TShape &lshape = (*in_shape)[bs::kGrid];
if (!shape_is_known(dshape)) return false;
if (!ndim_is_known(dshape)) return false;
CHECK_EQ(dshape.ndim(), 4U) \
<< "input data should be 4D in batch-num_filter-y-x";
if (!shape_is_known(lshape)) return false;
if (!ndim_is_known(lshape)) return false;
CHECK_EQ(lshape.ndim(), 4U) \
<< "Sampler grid should be 4D in batch-2-y-x";
CHECK_EQ(dshape[0], lshape[0]);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/image/image_random-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ inline bool ToTensorShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);

mxnet::TShape &shp = (*in_attrs)[0];
if (!shape_is_known(shp)) return false;
if (!ndim_is_known(shp)) return false;

CHECK((shp.ndim() == 3) || (shp.ndim() == 4))
<< "Input image must have shape (height, width, channels), or "
Expand Down
2 changes: 1 addition & 1 deletion src/operator/instance_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class InstanceNormOp : public Operator {
CHECK_EQ(in_data.size(), 3U);
CHECK_EQ(out_data.size(), 3U);

CHECK_GE(in_data[instance_norm::kData].ndim(), 3U)
CHECK_GE(in_data[instance_norm::kData].ndim(), 3)
<< "InstanceNorm only supports input tensors of rank > 2.";

Stream<xpu> *s = ctx.get_stream<xpu>();
Expand Down
12 changes: 6 additions & 6 deletions src/operator/l2_normalization-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class L2NormalizationOp : public Operator {
norm = F<mxnet::op::mshadow_op::square_root>(norm);
out = data / broadcast<0>(norm, out.shape_);
} else if (param_.mode == l2_normalization::kChannel) {
CHECK_GE(orig_shape.ndim(), 3U);
CHECK_GE(orig_shape.ndim(), 3);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<xpu, 3, DType> data = in_data[l2_normalization::kData]
Expand All @@ -120,7 +120,7 @@ class L2NormalizationOp : public Operator {
norm = F<mxnet::op::mshadow_op::square_root>(norm);
out = data / broadcast_with_axis(norm, 0, orig_shape[1]);
} else if (param_.mode == l2_normalization::kSpatial) {
CHECK_GE(orig_shape.ndim(), 3U);
CHECK_GE(orig_shape.ndim(), 3);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<xpu, 3, DType> data = in_data[l2_normalization::kData]
Expand Down Expand Up @@ -174,7 +174,7 @@ class L2NormalizationOp : public Operator {
(grad_out - data * broadcast<0>(temp, data.shape_)) /
broadcast<0>(norm, data.shape_));
} else if (param_.mode == l2_normalization::kChannel) {
CHECK_GE(orig_shape.ndim(), 3U);
CHECK_GE(orig_shape.ndim(), 3);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<xpu, 3, DType> data = out_data[l2_normalization::kOut]
Expand All @@ -193,7 +193,7 @@ class L2NormalizationOp : public Operator {
(grad_out - data * broadcast_with_axis(temp, 0, orig_shape[1])) /
broadcast_with_axis(norm, 0, orig_shape[1]));
} else if (param_.mode == l2_normalization::kSpatial) {
CHECK_GE(orig_shape.ndim(), 3U);
CHECK_GE(orig_shape.ndim(), 3);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<xpu, 3, DType> data = out_data[l2_normalization::kOut]
Expand Down Expand Up @@ -273,12 +273,12 @@ class L2NormalizationProp : public OperatorProperty {
if (param_.mode == l2_normalization::kInstance) {
out_shape->push_back(Shape1(dshape[0]));
} else if (param_.mode == l2_normalization::kChannel) {
CHECK_GE(dshape.ndim(), 3U) << "At lease 3 dimensions required in channel mode";
CHECK_GE(dshape.ndim(), 3) << "At lease 3 dimensions required in channel mode";
mxnet::TShape norm_shape = dshape;
norm_shape[1] = 1;
out_shape->push_back(norm_shape);
} else if (param_.mode == l2_normalization::kSpatial) {
CHECK_GE(dshape.ndim(), 3U) << "At lease 3 dimensions required in spatial mode";
CHECK_GE(dshape.ndim(), 3) << "At lease 3 dimensions required in spatial mode";
out_shape->push_back(Shape2(dshape[0], dshape[1]));
} else {
return false;
Expand Down
4 changes: 2 additions & 2 deletions src/operator/l2_normalization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class L2NormalizationOpCPU : public L2NormalizationOp<cpu, DType> {
}
}
} else if (this->param_.mode == l2_normalization::kChannel) {
CHECK_GE(orig_shape.ndim(), 3U);
CHECK_GE(orig_shape.ndim(), 3);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<cpu, 3, DType> data = in_data[l2_normalization::kData]
Expand All @@ -94,7 +94,7 @@ class L2NormalizationOpCPU : public L2NormalizationOp<cpu, DType> {
}
}
} else if (this->param_.mode == l2_normalization::kSpatial) {
CHECK_GE(orig_shape.ndim(), 3U);
CHECK_GE(orig_shape.ndim(), 3);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<cpu, 3, DType> data = in_data[l2_normalization::kData]
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/lrn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ bool LRNShape(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
const mxnet::TShape &dshape = in_shape->at(0);
if (!shape_is_known(dshape)) return false;
if (!ndim_is_known(dshape)) return false;
out_shape->clear();
out_shape->push_back(dshape);
out_shape->push_back(dshape);
Expand Down
36 changes: 18 additions & 18 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP
mkldnn::memory::dims strides(param.conv_param.kernel.ndim());
mkldnn::memory::dims padding(param.conv_param.kernel.ndim());
if (param.conv_param.kernel.ndim() == 1) {
CHECK_GE(param.conv_param.stride.ndim(), 1U);
CHECK_GE(param.conv_param.pad.ndim(), 1U);
CHECK_GE(param.conv_param.dilate.ndim(), 1U);
CHECK_GE(param.conv_param.stride.ndim(), 1);
CHECK_GE(param.conv_param.pad.ndim(), 1);
CHECK_GE(param.conv_param.dilate.ndim(), 1);
strides[0] = param.conv_param.stride[0];
padding[0] = param.conv_param.pad[0];
} else if (param.conv_param.kernel.ndim() == 2) {
CHECK_GE(param.conv_param.stride.ndim(), 2U);
CHECK_GE(param.conv_param.pad.ndim(), 2U);
CHECK_GE(param.conv_param.dilate.ndim(), 2U);
CHECK_GE(param.conv_param.stride.ndim(), 2);
CHECK_GE(param.conv_param.pad.ndim(), 2);
CHECK_GE(param.conv_param.dilate.ndim(), 2);
strides[0] = param.conv_param.stride[0];
strides[1] = param.conv_param.stride[1];
padding[0] = param.conv_param.pad[0];
Expand Down Expand Up @@ -173,15 +173,15 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData(
mkldnn::memory::dims strides(param.kernel.ndim());
mkldnn::memory::dims padding(param.kernel.ndim());
if (param.kernel.ndim() == 1) {
CHECK_GE(param.stride.ndim(), 1U);
CHECK_GE(param.pad.ndim(), 1U);
CHECK_GE(param.dilate.ndim(), 1U);
CHECK_GE(param.stride.ndim(), 1);
CHECK_GE(param.pad.ndim(), 1);
CHECK_GE(param.dilate.ndim(), 1);
strides[0] = param.stride[0];
padding[0] = param.pad[0];
} else if (param.kernel.ndim() == 2) {
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
CHECK_GE(param.stride.ndim(), 2);
CHECK_GE(param.pad.ndim(), 2);
CHECK_GE(param.dilate.ndim(), 2);
strides[0] = param.stride[0];
strides[1] = param.stride[1];
padding[0] = param.pad[0];
Expand Down Expand Up @@ -241,15 +241,15 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
mkldnn::memory::dims strides(param.kernel.ndim());
mkldnn::memory::dims padding(param.kernel.ndim());
if (param.kernel.ndim() == 1) {
CHECK_GE(param.stride.ndim(), 1U);
CHECK_GE(param.pad.ndim(), 1U);
CHECK_GE(param.dilate.ndim(), 1U);
CHECK_GE(param.stride.ndim(), 1);
CHECK_GE(param.pad.ndim(), 1);
CHECK_GE(param.dilate.ndim(), 1);
strides[0] = param.stride[0];
padding[0] = param.pad[0];
} else if (param.kernel.ndim() == 2) {
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
CHECK_GE(param.stride.ndim(), 2);
CHECK_GE(param.pad.ndim(), 2);
CHECK_GE(param.dilate.ndim(), 2);
strides[0] = param.stride[0];
strides[1] = param.stride[1];
padding[0] = param.pad[0];
Expand Down
18 changes: 9 additions & 9 deletions src/operator/nn/mkldnn/mkldnn_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
CHECK_GE(param.stride.ndim(), 2);
CHECK_GE(param.pad.ndim(), 2);
CHECK_GE(param.dilate.ndim(), 2);
mkldnn::memory::dims strides{0, 0};
strides[0] = param.stride[0];
strides[1] = param.stride[1];
Expand Down Expand Up @@ -128,9 +128,9 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdDataImpl(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
CHECK_GE(param.stride.ndim(), 2);
CHECK_GE(param.pad.ndim(), 2);
CHECK_GE(param.dilate.ndim(), 2);
mkldnn::memory::dims strides{0, 0};
strides[0] = param.stride[0];
strides[1] = param.stride[1];
Expand All @@ -153,9 +153,9 @@ GetDeconvBwdWeightsImpl(
auto weight_md = GetWeightDesc(weights, param.num_group);
auto out_md = GetMemDesc(output);
auto engine = CpuEngine::Get()->get_engine();
CHECK_GE(param.stride.ndim(), 2U);
CHECK_GE(param.pad.ndim(), 2U);
CHECK_GE(param.dilate.ndim(), 2U);
CHECK_GE(param.stride.ndim(), 2);
CHECK_GE(param.pad.ndim(), 2);
CHECK_GE(param.dilate.ndim(), 2);
mkldnn::memory::dims strides{0, 0};
strides[0] = param.stride[0];
strides[1] = param.stride[1];
Expand Down
6 changes: 3 additions & 3 deletions src/operator/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,17 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
}
const mxnet::TShape &dshape = (*in_shape)[0];
if (param.pooling_convention == pool_enum::kSame) {
CHECK_EQ(dshape.ndim(), 3U)
CHECK_EQ(dshape.ndim(), 3)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< ". Currently 'same' supports Max Pooling 1-D";
CHECK(param.pad[0] == 0 && param.pad[1] == 0 && param.pad[2] == 0)
<< "Same pooling convention disables the use of pad parameter.";
}
CHECK_GE(dshape.ndim(), 3U)
CHECK_GE(dshape.ndim(), 3)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< " Or 4D in (batch, channel, y, x) "
<< " Or 5D in (batch, channel, d, y, x)";
CHECK_LE(dshape.ndim(), 5U)
CHECK_LE(dshape.ndim(), 5)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< " Or 4D in (batch, channel, y, x) "
<< " Or 5D in (batch, channel, d, y, x)";
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/upsampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ static bool UpSamplingShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
CHECK_EQ(dshape.ndim(), 4U) << \
"UpSamplingBilinear: Input data should be 4D in (batch, channel, y, x)";
if (!shape_is_known(dshape)) return false;
if (!ndim_is_known(dshape)) return false;
int kernel = 2 * param_.scale - param_.scale % 2;
SHAPE_ASSIGN_CHECK(*in_shape,
up_enum::kWeight,
Expand Down
2 changes: 1 addition & 1 deletion src/operator/operator_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ class SimpleUnaryOpProp : public SimpleOpPropBase {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 1) << "Input:[data]";
const mxnet::TShape &dshape = in_shape->at(0);
if (!shape_is_known(dshape)) return false;
if (!ndim_is_known(dshape)) return false;
out_shape->clear();
if (source->unary_shape_ == nullptr) {
out_shape->push_back(dshape);
Expand Down
4 changes: 2 additions & 2 deletions src/operator/pooling_v1-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,9 @@ class PoolingV1Prop : public OperatorProperty {
mxnet::ShapeVector *aux_shape) const override {
CHECK_EQ(in_shape->size(), 1U);
const mxnet::TShape &dshape = (*in_shape)[0];
CHECK_GE(dshape.ndim(), 4U) << "Pooling: Input data should be 4D in (batch, channel, y, x) "
CHECK_GE(dshape.ndim(), 4) << "Pooling: Input data should be 4D in (batch, channel, y, x) "
<< "Or 5D in (batch, channel, d, y, x)";
CHECK_LE(dshape.ndim(), 5U) << "Pooling: Input data should be 4D in (batch, channel, y, x) "
CHECK_LE(dshape.ndim(), 5) << "Pooling: Input data should be 4D in (batch, channel, y, x) "
<< "Or 5D in (batch, channel, d, y, x)";
mxnet::TShape oshape = dshape;
if (dshape.ndim() == -1) return false;
Expand Down
2 changes: 1 addition & 1 deletion src/operator/quantization/quantized_flatten-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ inline bool QuantizedFlattenShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 3U);

const mxnet::TShape &dshape = (*in_attrs)[0];
if (!shape_is_known(dshape)) return false;
if (!ndim_is_known(dshape)) return false;

dim_t target_dim = 1;
for (int i = 1; i < dshape.ndim(); ++i) {
Expand Down
2 changes: 1 addition & 1 deletion src/operator/random/sample_multinomial_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ inline bool SampleMultinomialOpShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), param.get_prob ? 2U : 1U);
const mxnet::TShape& ishape = (*in_attrs)[0];
if (!shape_is_known(ishape)) return false;
if (!ndim_is_known(ishape)) return false;

MSHADOW_TYPE_SWITCH(param.dtype, DType, {
CHECK_LE(ishape[ishape.ndim() - 1], mxnet::common::MaxIntegerValue<DType>())
Expand Down
2 changes: 1 addition & 1 deletion src/operator/regression_output-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ inline bool RegressionOpShape(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(in_attrs->size(), 2U) << "Input:[data, label]";
const mxnet::TShape &dshape = in_attrs->at(0);
if (!shape_is_known(dshape)) return false;
if (!ndim_is_known(dshape)) return false;
auto &lshape = (*in_attrs)[1];
if (lshape.ndim() == 0) {
// special treatment for 1D output, to allow 1D label by default.
Expand Down
2 changes: 1 addition & 1 deletion src/operator/softmax_output-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ class SoftmaxOutputProp : public OperatorProperty {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, label]";
const mxnet::TShape &dshape = in_shape->at(0);
if (!shape_is_known(dshape)) return false;
if (!ndim_is_known(dshape)) return false;

// label.shape == data.shape: use probability as label
if (dshape != (*in_shape)[softmaxout_enum::kLabel]) {
Expand Down
4 changes: 2 additions & 2 deletions src/operator/spatial_transformer-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,10 @@ class SpatialTransformerProp : public OperatorProperty {
CHECK_EQ(param_.sampler_type, st::kBilinear) << "only supports bilinear sampling currently";
const mxnet::TShape &dshape = (*in_shape)[st::kData];
const mxnet::TShape &lshape = (*in_shape)[st::kLoc];
if (!shape_is_known(dshape)) return false;
if (!ndim_is_known(dshape)) return false;
CHECK_EQ(dshape.ndim(), 4U) \
<< "input data should be 4D in batch-num_filter-y-x";
if (!shape_is_known(lshape)) return false;
if (!ndim_is_known(lshape)) return false;
CHECK_EQ(lshape.ndim(), 2U) \
<< "locolisation paramter should be 4D in batch-num_hidden";
if (param_.transform_type == st::kAffine) {
Expand Down
10 changes: 5 additions & 5 deletions src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ inline bool ReduceAxisShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& ishape = (*in_attrs)[0];
if (!shape_is_known(ishape)) return false;
if (!ndim_is_known(ishape)) return false;

const ReduceAxisParam& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
SHAPE_ASSIGN_CHECK(*out_attrs, 0,
Expand Down Expand Up @@ -304,7 +304,7 @@ inline bool ReduceAxesShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if (!shape_is_known((*in_attrs)[0])) return false;
if (!ndim_is_known((*in_attrs)[0])) return false;
const ReduceAxesParam& param = nnvm::get<ReduceAxesParam>(attrs.parsed);
SHAPE_ASSIGN_CHECK(*out_attrs, 0,
ReduceAxesShapeImpl((*in_attrs)[0], param.axis,
Expand All @@ -317,7 +317,7 @@ inline bool ReduceMinMaxAxesShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if (!shape_is_known((*in_attrs)[0])) return false;
if (!ndim_is_known((*in_attrs)[0])) return false;
CHECK_GT((*in_attrs)[0].Size(), 0U)
<< "Reduction input's size should > 0 "
<< (*in_attrs)[0];
Expand Down Expand Up @@ -351,7 +351,7 @@ inline bool NormShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if (!shape_is_known((*in_attrs)[0])) return false;
if (!ndim_is_known((*in_attrs)[0])) return false;
const NormParam& param = nnvm::get<NormParam>(attrs.parsed);
SHAPE_ASSIGN_CHECK(*out_attrs, 0,
ReduceAxesShapeImpl((*in_attrs)[0], param.axis,
Expand All @@ -364,7 +364,7 @@ inline bool BroadcastAxesShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
if (!shape_is_known((*in_attrs)[0])) return false;
if (!ndim_is_known((*in_attrs)[0])) return false;
const BroadcastAxesParam& param = nnvm::get<BroadcastAxesParam>(attrs.parsed);
CHECK_EQ(param.axis.ndim() , param.size.ndim());
mxnet::TShape &ishape = (*in_attrs)[0];
Expand Down
2 changes: 2 additions & 0 deletions src/operator/tensor/dot-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,7 @@ inline bool DotShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& lshape = (*in_attrs)[0];
mxnet::TShape& rshape = (*in_attrs)[1];
if (!ndim_is_known(lshape) || !ndim_is_known(rshape)) return false;
if (lshape.ndim() == 1 && rshape.ndim() == 1) {
CHECK(!param.transpose_a && !param.transpose_b) << "Cannot transpose vectors";
CHECK_EQ(lshape[0], rshape[0]) << "dot shape error: " << lshape << " X " << rshape;
Expand Down Expand Up @@ -1479,6 +1480,7 @@ inline bool BatchDotShape(const nnvm::NodeAttrs& attrs,
const DotParam& param = nnvm::get<DotParam>(attrs.parsed);
mxnet::TShape& lshape = (*in_attrs)[0];
mxnet::TShape& rshape = (*in_attrs)[1];
if (!ndim_is_known(lshape) || !ndim_is_known(rshape)) return false;
if (lshape.ndim() == 3 && rshape.ndim() == 3) {
CHECK(lshape[0] == rshape[0])
<< "batch_dot shape error(batch_size must be equal): " << lshape << " X " << rshape
Expand Down
4 changes: 2 additions & 2 deletions src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ void BinaryBroadcastComputeSparseEx(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_LE(inputs[1].shape().ndim(), 2U)
CHECK_LE(inputs[1].shape().ndim(), 2)
<< "input dense matrix should have less than or equal to 2 dimensions";
if (req[0] == kNullOp) return;
const NDArray& lhs = inputs[0];
Expand Down Expand Up @@ -488,7 +488,7 @@ void BinaryBroadcastComputeDenseEx(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_LE(inputs[1].shape().ndim(), 2U)
CHECK_LE(inputs[1].shape().ndim(), 2)
<< "input dense matrix should have less than or equal to 2 dimensions";
if (req[0] == kNullOp) return;
const NDArray& lhs = inputs[0];
Expand Down
Loading

0 comments on commit 5502fa0

Please sign in to comment.