-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add function to get element count from tensor. #3958
Conversation
Why not use |
@Canpio |
It seems that the size of tensor is often used, so make a cache is necessary. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
As a cache of element number, numel
had appeared in a very early design of Tensor
. However, then it was removed because we were not sure whether it is necessary. During the recent developing, many colleagues have voiced that writing framework::product(tensor.dims())
is tedious and inefficient, for tensor's dims
never change in the majority of cases. So, I think adding numel
is reasonable. @wangkuiyi
Great, I think this PR is very helpful. When converting a tensor to a matrix, we often calculate the column number using |
paddle/framework/tensor_impl.h
Outdated
@@ -131,7 +131,7 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { | |||
PADDLE_ENFORCE_LT(begin_idx, end_idx, | |||
"Begin index must be less than end index."); | |||
PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1."); | |||
size_t base = product(dims_) / dims_[0]; | |||
size_t base = numel_ / dims_[0]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that we add Tensor::numel()
as it is such a commonly used feature.
I noticed that in the tensor implementation, we use Tensor::numel_
, but in other places, we call Tensor::numel()
. This implies that that is a cache -- the former, and the latter returns the former. It would be great if we don't expose such implementation details to code readers. It looks to me that we can just call Tensor::numel()
everywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
paddle/framework/tensor.h
Outdated
@@ -162,6 +165,9 @@ class Tensor { | |||
/*! points to dimensions of memory block. */ | |||
DDim dims_; | |||
|
|||
/*! the element count of tensor. */ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/*! A cache of the number of elements in a tensor. Would be 0 for an uninitialized tensor. */
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
9622bd1
to
dbe0598
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Fix #3914