diff --git a/mshadow/tensor_blob.h b/mshadow/tensor_blob.h index 7b9e45be8c08..1d8ca52c0be6 100644 --- a/mshadow/tensor_blob.h +++ b/mshadow/tensor_blob.h @@ -509,7 +509,7 @@ class TBlob { } /*! * \brief fetch a tensor in given shape - * if size do not match the stored dimension, an error will be issued + * If size do not match the stored size, an error will be issued * \return the tensor requested * \param shape the shape required * \param stream the possible stream target tensor should reside on @@ -518,7 +518,7 @@ class TBlob { * \tparam DType the type of elements in the tensor */ template - inline Tensor get_with_shape(const TShape &shape, + inline Tensor get_with_shape(const Shape &shape, Stream *stream = NULL) const { CHECK(Device::kDevMask == dev_mask_ && DataType::kFlag == type_flag_) << "TBlob.get_with_shape: device type do not match specified type"; @@ -526,7 +526,7 @@ class TBlob { CHECK_EQ(this->shape_.Size(), shape.Size()) << "TBlob.get_with_shape: new and old shape do not match total elements"; return Tensor(static_cast(dptr_), - shape.get(), + shape, shape[dim - 1], stream); }