-
Notifications
You must be signed in to change notification settings - Fork 432
Add compile flag to choose data type for tensor size due to performance degradation #371
Conversation
@TaoLv @eric-haibin-lin @tqchen Please help to review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I had an issue also with tensor size types and overflow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Try to understand that the change only affects the total size of tensor not the size of each dimension.
@TaoLv It affects the size of each dimension as well. |
@apeforest Do you mean size of each dimension will also be defined as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@TaoLv Updated the gemm function signature with index_t. Please help to check if this addressed your concern. |
Thank you @apeforest . Do we have any large dot/GEMM unit tests in MXNet for this change? |
@TaoLv There is one in test_operator.py: /~https://github.com/apache/incubator-mxnet/blob/e2f5b47346e148c2376da7e6628750747f2d6a94/tests/python/unittest/test_operator.py#L5668 |
@szha @eric-haibin-lin @tqchen Could you please help to review/merge this PR? Thanks! |
@apeforest Looks like M/N/K in that case are really small (eg. 2/2/3). I'm afraid it's not enough to test changes in this PR. |
@TaoLv The main purpose of this PR is not to support large tensors using gemm engine. It is to fall back to int32 by default with a compilation flag. We need more thorough inspection of performance impact using int64 before we turning the flag on again. |
mshadow/dot_engine-inl.h
Outdated
index_t m, index_t n, index_t k, float alpha, | ||
const float *A, index_t lda, | ||
const float *B, index_t ldb, float beta, | ||
float *C, index_t ldc) { | ||
cublasStatus_t err = cublasSgemm(Stream<gpu>::GetBlasHandle(stream), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does lda, ldb, ldc support int64 in cublas?
cublasStatus_t cublasSgemm(cublasHandle_t handle,
cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k,
const float *alpha,
const float *A, int lda,
const float *B, int ldb,
const float *beta,
float *C, int ldc)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's why I want to have a dot/GEMM unit test for large input tensors. Same problem may also exist for openblas and MKL.
Hi @apeforest , I feel okay if this flag only impacts the total tensor size as BLAS libraries should work well with that. But if it will also impact the size of each dimension, then we need more changes and validations to accommodate that. So this flag is not ready to be exposed to users. |
@TaoLv The data type of each dimension in the tensor will be defined as index_t (if the flag is on then it is int64). However, it does not mean we will support the size in any dimension greater than INT32_MAX. We can only support the total element size greater than INT32_MAX in the tensor. |
Thank you for the explanation @apeforest . Then seems we need document it somewhere and prevent users from passing a large tensor (dim[x] > INT32_MAX) to operators. If so I think there is no need to change the API in dot_engine. |
Makes sense. I have reverted the API signature in dot_engine. |
Changing data type for
index_t
from 'uint32_tto
int64_t` caused performance degradation in operators defined in mshadow library.I can think of three solutions to this problem:
(1) Add a compilation flag to choose data types for tensor size (This PR)
(2) Add an environment variable to choose data type for tensor size at runtime
(3) Choose data type for tensor size at runtime based on the size of the tensor
Due to the urgency of customer impact and the non-trivial change for approach (2) and (3), this PR is taking the quick fix of approach (1).
For more information and performance analysis, please refer to PR:
apache/mxnet#14570