From 4305465e80769ba1f738ae110284de297ed8668c Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Fri, 29 Mar 2019 10:00:16 -0700 Subject: [PATCH 01/11] add a compiler flag to select int64 type --- mshadow/base.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mshadow/base.h b/mshadow/base.h index 4cdab74d..81e43aea 100755 --- a/mshadow/base.h +++ b/mshadow/base.h @@ -285,7 +285,11 @@ const unsigned kRandBufferSize = 1000000; /*! \brief pi */ const float kPi = 3.1415926f; /*! \brief type that will be used for index */ -typedef int64_t index_t; +#if MSHADOW_INT64_TENSOR_SIZE=1 + typedef int64_t index_t; +#else + typedef uint32_t index_t; +#endif #ifdef _WIN32 /*! \brief openmp index for windows */ From 44a5348acc5ec8be9974c99f0c7975fba97369e6 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Fri, 29 Mar 2019 10:02:24 -0700 Subject: [PATCH 02/11] fix typo --- mshadow/base.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mshadow/base.h b/mshadow/base.h index 81e43aea..9b1ddcc3 100755 --- a/mshadow/base.h +++ b/mshadow/base.h @@ -285,7 +285,7 @@ const unsigned kRandBufferSize = 1000000; /*! \brief pi */ const float kPi = 3.1415926f; /*! \brief type that will be used for index */ -#if MSHADOW_INT64_TENSOR_SIZE=1 +#if MSHADOW_INT64_TENSOR_SIZE == 1 typedef int64_t index_t; #else typedef uint32_t index_t; From 0ea71658816fa039b312984d186190a6c4e764c8 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Fri, 29 Mar 2019 10:38:14 -0700 Subject: [PATCH 03/11] fix compilation error --- mshadow/base.h | 2 +- mshadow/tensor.h | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/mshadow/base.h b/mshadow/base.h index 9b1ddcc3..d08efd38 100755 --- a/mshadow/base.h +++ b/mshadow/base.h @@ -288,7 +288,7 @@ const float kPi = 3.1415926f; #if MSHADOW_INT64_TENSOR_SIZE == 1 typedef int64_t index_t; #else - typedef uint32_t index_t; + typedef int32_t index_t; #endif #ifdef _WIN32 diff --git a/mshadow/tensor.h b/mshadow/tensor.h index b04b154d..53a1f02a 100755 --- a/mshadow/tensor.h +++ b/mshadow/tensor.h @@ -1069,13 +1069,19 @@ inline void BatchGEMM(Tensor dst, #define MSHADOW_SCALAR_ double #include "./expr_scalar-inl.h" #undef MSHADOW_SCALAR_ -#define MSHADOW_SCALAR_ int +#define MSHADOW_SCALAR_ int16_t #include "./expr_scalar-inl.h" #undef MSHADOW_SCALAR_ -#define MSHADOW_SCALAR_ mshadow::index_t +#define MSHADOW_SCALAR_ uint16_t #include "./expr_scalar-inl.h" #undef MSHADOW_SCALAR_ -#define MSHADOW_SCALAR_ mshadow::half::half_t +#define MSHADOW_SCALAR_ int32_t +#include "./expr_scalar-inl.h" +#undef MSHADOW_SCALAR_ +#define MSHADOW_SCALAR_ uint32_t +#include "./expr_scalar-inl.h" +#undef MSHADOW_SCALAR_ +#define MSHADOW_SCALAR_ int64_t #include "./expr_scalar-inl.h" #undef MSHADOW_SCALAR_ #endif // MSHADOW_TENSOR_H_ From 359959e72db8c7c722f4f551ef9b7b47de137f4d Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Fri, 29 Mar 2019 10:00:16 -0700 Subject: [PATCH 04/11] add a compiler flag to select int64 type --- mshadow/base.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mshadow/base.h b/mshadow/base.h index 4cdab74d..81e43aea 100755 --- a/mshadow/base.h +++ b/mshadow/base.h @@ -285,7 +285,11 @@ const unsigned kRandBufferSize = 1000000; /*! \brief pi */ const float kPi = 3.1415926f; /*! \brief type that will be used for index */ -typedef int64_t index_t; +#if MSHADOW_INT64_TENSOR_SIZE=1 + typedef int64_t index_t; +#else + typedef uint32_t index_t; +#endif #ifdef _WIN32 /*! \brief openmp index for windows */ From 69459137e53797bc6affbb03fb46e44f7565e306 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Fri, 29 Mar 2019 10:02:24 -0700 Subject: [PATCH 05/11] fix typo --- mshadow/base.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mshadow/base.h b/mshadow/base.h index 81e43aea..9b1ddcc3 100755 --- a/mshadow/base.h +++ b/mshadow/base.h @@ -285,7 +285,7 @@ const unsigned kRandBufferSize = 1000000; /*! \brief pi */ const float kPi = 3.1415926f; /*! \brief type that will be used for index */ -#if MSHADOW_INT64_TENSOR_SIZE=1 +#if MSHADOW_INT64_TENSOR_SIZE == 1 typedef int64_t index_t; #else typedef uint32_t index_t; From f618c0bd71d5447490c05b61794f2846e7278207 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Fri, 29 Mar 2019 10:38:14 -0700 Subject: [PATCH 06/11] fix compilation error --- mshadow/base.h | 2 +- mshadow/tensor.h | 12 +++++++++--- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/mshadow/base.h b/mshadow/base.h index 9b1ddcc3..d08efd38 100755 --- a/mshadow/base.h +++ b/mshadow/base.h @@ -288,7 +288,7 @@ const float kPi = 3.1415926f; #if MSHADOW_INT64_TENSOR_SIZE == 1 typedef int64_t index_t; #else - typedef uint32_t index_t; + typedef int32_t index_t; #endif #ifdef _WIN32 diff --git a/mshadow/tensor.h b/mshadow/tensor.h index b04b154d..53a1f02a 100755 --- a/mshadow/tensor.h +++ b/mshadow/tensor.h @@ -1069,13 +1069,19 @@ inline void BatchGEMM(Tensor dst, #define MSHADOW_SCALAR_ double #include "./expr_scalar-inl.h" #undef MSHADOW_SCALAR_ -#define MSHADOW_SCALAR_ int +#define MSHADOW_SCALAR_ int16_t #include "./expr_scalar-inl.h" #undef MSHADOW_SCALAR_ -#define MSHADOW_SCALAR_ mshadow::index_t +#define MSHADOW_SCALAR_ uint16_t #include "./expr_scalar-inl.h" #undef MSHADOW_SCALAR_ -#define MSHADOW_SCALAR_ mshadow::half::half_t +#define MSHADOW_SCALAR_ int32_t +#include "./expr_scalar-inl.h" +#undef MSHADOW_SCALAR_ +#define MSHADOW_SCALAR_ uint32_t +#include "./expr_scalar-inl.h" +#undef MSHADOW_SCALAR_ +#define MSHADOW_SCALAR_ int64_t #include "./expr_scalar-inl.h" #undef MSHADOW_SCALAR_ #endif // MSHADOW_TENSOR_H_ From 655a9923e40bbd60e0a640b0b115fabf8bf75e49 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 2 Apr 2019 10:16:23 -0700 Subject: [PATCH 07/11] fix type in gemm functions --- mshadow/dot_engine-inl.h | 408 +++++++++++++++++++-------------------- mshadow/tensor.h | 9 +- 2 files changed, 207 insertions(+), 210 deletions(-) diff --git a/mshadow/dot_engine-inl.h b/mshadow/dot_engine-inl.h index 21816f20..e3399cae 100644 --- a/mshadow/dot_engine-inl.h +++ b/mshadow/dot_engine-inl.h @@ -65,49 +65,49 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, DType alpha, - const DType *A, int lda, const DType *B, int ldb, - DType beta, DType *C, int ldc) { + index_t m, index_t n, index_t k, DType alpha, + const DType *A, index_t lda, const DType *B, index_t ldb, + DType beta, DType *C, index_t ldc) { LOG(FATAL) << "Not implmented!"; } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, DType alpha, - const DType *A, int lda, const DType *B, int ldb, - DType beta, DType *C, int ldc, int batch_count, + index_t m, index_t n, index_t k, DType alpha, + const DType *A, index_t lda, const DType *B, index_t ldb, + DType beta, DType *C, index_t ldc, index_t batch_count, DType **workspace) { LOG(FATAL) << "Not implmented!"; } inline static void gemv(Stream *stream, - bool trans, int m, int n, - DType alpha, const DType *A, int lda, - const DType *X, int incX, - DType beta, DType *Y, int incY) { + bool trans, index_t m, index_t n, + DType alpha, const DType *A, index_t lda, + const DType *X, index_t incX, + DType beta, DType *Y, index_t incY) { LOG(FATAL) << "Not implmented!"; } inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - DType alpha, const DType *A, int lda, - const DType *X, int incX, - DType beta, DType *Y, int incY, int batch_count) { + bool trans, index_t m, index_t n, + DType alpha, const DType *A, index_t lda, + const DType *X, index_t incX, + DType beta, DType *Y, index_t incY, index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void ger(Stream *stream, - int m, int n, DType alpha, - const DType *X, int incX, - const DType *Y, int incY, DType *A, int lda) { + index_t m, index_t n, DType alpha, + const DType *X, index_t incX, + const DType *Y, index_t incY, DType *A, index_t lda) { LOG(FATAL) << "Not implmented!"; } inline static void batched_ger(Stream *stream, - int m, int n, DType alpha, - const DType *X, int incX, - const DType *Y, int incY, DType *A, int lda, int batch_count) { + index_t m, index_t n, DType alpha, + const DType *X, index_t incX, + const DType *Y, index_t incY, DType *A, index_t lda, index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void dot(Stream *stream, - int n, - const DType* X, int incX, - const DType* Y, int incY, + index_t n, + const DType* X, index_t incX, + const DType* Y, index_t incY, DType* ret) { LOG(FATAL) << "Not implmented!"; } @@ -123,9 +123,9 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc) { + index_t m, index_t n, index_t k, float alpha, + const float *A, index_t lda, const float *B, index_t ldb, + float beta, float *C, index_t ldc) { if (alpha == 1.0f && beta == 0.0f) { bool transpose_left = transb; bool transpose_right = transa; @@ -147,46 +147,46 @@ struct BLASEngine { } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc, int batch_count, + index_t m, index_t n, index_t k, float alpha, + const float *A, index_t lda, const float *B, index_t ldb, + float beta, float *C, index_t ldc, index_t batch_count, float **workspace) { - for (int i = 0; i < batch_count; ++i) { + for (index_t i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); } } inline static void gemv(Stream *stream, - bool trans, int m, int n, - float alpha, const float *A, int lda, - const float *X, int incX, - float beta, float *Y, int incY) { + bool trans, index_t m, index_t n, + float alpha, const float *A, index_t lda, + const float *X, index_t incX, + float beta, float *Y, index_t incY) { LOG(FATAL) << "Not implmented!"; } inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - float alpha, const float *A, int lda, - const float *X, int incX, - float beta, float *Y, int incY, int batch_count) { + bool trans, index_t m, index_t n, + float alpha, const float *A, index_t lda, + const float *X, index_t incX, + float beta, float *Y, index_t incY, index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void ger(Stream *stream, - int m, int n, float alpha, - const float *X, int incX, - const float *Y, int incY, float *A, int lda) { + index_t m, index_t n, float alpha, + const float *X, index_t incX, + const float *Y, index_t incY, float *A, index_t lda) { LOG(FATAL) << "Not implmented!"; } inline static void batched_ger(Stream *stream, - int m, int n, float alpha, - const float *X, int incX, - const float *Y, int incY, float *A, int lda, int batch_count) { + index_t m, index_t n, float alpha, + const float *X, index_t incX, + const float *Y, index_t incY, float *A, index_t lda, index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void dot(Stream *stream, - int n, - const float* X, int incX, - const float* Y, int incY, + index_t n, + const float* X, index_t incX, + const float* Y, index_t incY, float* ret) { LOG(FATAL) << "Not implmented!"; } @@ -201,9 +201,9 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, double alpha, - const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc) { + index_t m, index_t n, index_t k, double alpha, + const double *A, index_t lda, const double *B, index_t ldb, + double beta, double *C, index_t ldc) { if (alpha == 1.0f && beta == 0.0f) { bool transpose_left = transb; bool transpose_right = transa; @@ -225,46 +225,46 @@ struct BLASEngine { } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, double alpha, - const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc, int batch_count, + index_t m, index_t n, index_t k, double alpha, + const double *A, index_t lda, const double *B, index_t ldb, + double beta, double *C, index_t ldc, index_t batch_count, double **workspace) { - for (int i = 0; i < batch_count; ++i) { + for (index_t i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); } } inline static void gemv(Stream *stream, - bool trans, int m, int n, - double alpha, const double *A, int lda, - const double *X, int incX, - double beta, double *Y, int incY) { + bool trans, index_t m, index_t n, + double alpha, const double *A, index_t lda, + const double *X, index_t incX, + double beta, double *Y, index_t incY) { LOG(FATAL) << "Not implmented!"; } inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - double alpha, const double *A, int lda, - const double *X, int incX, - double beta, double *Y, int incY, int batch_count) { + bool trans, index_t m, index_t n, + double alpha, const double *A, index_t lda, + const double *X, index_t incX, + double beta, double *Y, index_t incY, index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void ger(Stream *stream, - int m, int n, double alpha, - const double *X, int incX, - const double *Y, int incY, double *A, int lda) { + index_t m, index_t n, double alpha, + const double *X, index_t incX, + const double *Y, index_t incY, double *A, index_t lda) { LOG(FATAL) << "Not implmented!"; } inline static void batched_ger(Stream *stream, - int m, int n, double alpha, - const double *X, int incX, - const double *Y, int incY, double *A, int lda, int batch_count) { + index_t m, index_t n, double alpha, + const double *X, index_t incX, + const double *Y, index_t incY, double *A, index_t lda, index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void dot(Stream *stream, - int n, - const double* X, int incX, - const double* Y, int incY, + index_t n, + const double* X, index_t incX, + const double* Y, index_t incY, double* ret) { LOG(FATAL) << "Not implmented!"; } @@ -280,17 +280,17 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc) { + index_t m, index_t n, index_t k, float alpha, + const float *A, index_t lda, const float *B, index_t ldb, + float beta, float *C, index_t ldc) { cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb), m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc, int batch_count, + index_t m, index_t n, index_t k, float alpha, + const float *A, index_t lda, const float *B, index_t ldb, + float beta, float *C, index_t ldc, index_t batch_count, float **workspace) { #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) // since same m/n/k is used for all single gemms, so we put all gemms into one group @@ -323,7 +323,7 @@ struct BLASEngine { auto k_n = k * n; auto m_n = m * n; - for (int i = 0; i < batch_count; i++) { + for (index_t i = 0; i < batch_count; i++) { pp_A[i] = A + i * m_k; pp_B[i] = B + i * k_n; pp_C[i] = C + i * m_n; @@ -333,7 +333,7 @@ struct BLASEngine { p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(), p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb); #else - for (int i = 0; i < batch_count; ++i) { + for (index_t i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); @@ -341,43 +341,43 @@ struct BLASEngine { #endif } inline static void gemv(Stream *stream, - bool trans, int m, int n, - float alpha, const float *A, int lda, - const float *X, int incX, - float beta, float *Y, int incY) { + bool trans, index_t m, index_t n, + float alpha, const float *A, index_t lda, + const float *X, index_t incX, + float beta, float *Y, index_t incY) { cblas_sgemv(CblasColMajor, GetT(trans), m, n, alpha, A, lda, X, incX, beta, Y, incY); } inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - float alpha, const float *A, int lda, - const float *X, int incX, - float beta, float *Y, int incY, int batch_count) { - for (int i = 0; i < batch_count; ++i) { + bool trans, index_t m, index_t n, + float alpha, const float *A, index_t lda, + const float *X, index_t incX, + float beta, float *Y, index_t incY, index_t batch_count) { + for (index_t i = 0; i < batch_count; ++i) { gemv(stream, trans, m, n, alpha, A + i * m * n, lda, X + i * (trans ? m : n) * incX, incX, beta, Y + i * (trans ? n : m) * incY, incY); } } inline static void ger(Stream *stream, - int m, int n, float alpha, - const float *X, int incX, - const float *Y, int incY, float *A, int lda) { + index_t m, index_t n, float alpha, + const float *X, index_t incX, + const float *Y, index_t incY, float *A, index_t lda) { cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); } inline static void batched_ger(Stream *stream, - int m, int n, float alpha, - const float *X, int incX, - const float *Y, int incY, float *A, int lda, int batch_count) { - for (int i = 0; i < batch_count; ++i) { + index_t m, index_t n, float alpha, + const float *X, index_t incX, + const float *Y, index_t incY, float *A, index_t lda, index_t batch_count) { + for (index_t i = 0; i < batch_count; ++i) { ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, A + i * lda * n, lda); } } inline static void dot(Stream *stream, - int n, - const float* X, int incX, - const float* Y, int incY, + index_t n, + const float* X, index_t incX, + const float* Y, index_t incY, float* ret) { *ret = cblas_sdot(n, X, incX, Y, incY); } @@ -392,17 +392,17 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, double alpha, - const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc) { + index_t m, index_t n, index_t k, double alpha, + const double *A, index_t lda, const double *B, index_t ldb, + double beta, double *C, index_t ldc) { cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb), m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, double alpha, - const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc, int batch_count, + index_t m, index_t n, index_t k, double alpha, + const double *A, index_t lda, const double *B, index_t ldb, + double beta, double *C, index_t ldc, index_t batch_count, double **workspace) { #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) // since same m/n/k is used for all single gemms, so we put all gemms into one group @@ -435,7 +435,7 @@ struct BLASEngine { auto k_n = k * n; auto m_n = m * n; - for (int i = 0; i < batch_count; i++) { + for (index_t i = 0; i < batch_count; i++) { pp_A[i] = A + i * m_k; pp_B[i] = B + i * k_n; pp_C[i] = C + i * m_n; @@ -445,7 +445,7 @@ struct BLASEngine { p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(), p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb); #else - for (int i = 0; i < batch_count; ++i) { + for (index_t i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); @@ -453,43 +453,43 @@ struct BLASEngine { #endif } inline static void gemv(Stream *stream, - bool trans, int m, int n, double alpha, - const double *A, int lda, - const double *X, int incX, - double beta, double *Y, int incY) { + bool trans, index_t m, index_t n, double alpha, + const double *A, index_t lda, + const double *X, index_t incX, + double beta, double *Y, index_t incY) { cblas_dgemv(CblasColMajor, GetT(trans), m, n, alpha, A, lda, X, incX, beta, Y, incY); } inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - double alpha, const double *A, int lda, - const double *X, int incX, - double beta, double *Y, int incY, int batch_count) { - for (int i = 0; i < batch_count; ++i) { + bool trans, index_t m, index_t n, + double alpha, const double *A, index_t lda, + const double *X, index_t incX, + double beta, double *Y, index_t incY, index_t batch_count) { + for (index_t i = 0; i < batch_count; ++i) { gemv(stream, trans, m, n, alpha, A + i * m * n, lda, X + i * (trans ? m : n) * incX, incX, beta, Y + i * (trans ? n : m) * incY, incY); } } inline static void ger(Stream *stream, - int m, int n, double alpha, - const double *X, int incX, - const double *Y, int incY, double *A, int lda) { + index_t m, index_t n, double alpha, + const double *X, index_t incX, + const double *Y, index_t incY, double *A, index_t lda) { cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); } inline static void batched_ger(Stream *stream, - int m, int n, double alpha, - const double *X, int incX, - const double *Y, int incY, double *A, int lda, int batch_count) { - for (int i = 0; i < batch_count; ++i) { + index_t m, index_t n, double alpha, + const double *X, index_t incX, + const double *Y, index_t incY, double *A, index_t lda, index_t batch_count) { + for (index_t i = 0; i < batch_count; ++i) { ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, A + i * lda * n, lda); } } inline static void dot(Stream *stream, - int n, - const double* X, int incX, - const double* Y, int incY, + index_t n, + const double* X, index_t incX, + const double* Y, index_t incY, double* ret) { *ret = cblas_ddot(n, X, incX, Y, incY); } @@ -510,10 +510,10 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, half::half_t alpha, - const half::half_t *A, int lda, - const half::half_t *B, int ldb, half::half_t beta, - half::half_t *C, int ldc) { + index_t m, index_t n, index_t k, half::half_t alpha, + const half::half_t *A, index_t lda, + const half::half_t *B, index_t ldb, half::half_t beta, + half::half_t *C, index_t ldc) { #if defined(CUDA_VERSION) && CUDA_VERSION >= 7050 // Always use pseudo-fp16: fp32 compute with fp16 I/O. float alpha_f = float(alpha); // NOLINT(*) @@ -537,9 +537,9 @@ struct BLASEngine { } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, half::half_t alpha, - const half::half_t *A, int lda, const half::half_t *B, int ldb, - half::half_t beta, half::half_t *C, int ldc, int batch_count, + index_t m, index_t n, index_t k, half::half_t alpha, + const half::half_t *A, index_t lda, const half::half_t *B, index_t ldb, + half::half_t beta, half::half_t *C, index_t ldc, index_t batch_count, half::half_t **workspace) { #if defined(__CUDACC__) && CUDA_VERSION >= 9000 int major = stream->prop.major; @@ -561,42 +561,42 @@ struct BLASEngine { return; } #endif - for (int i = 0; i < batch_count; ++i) { + for (index_t i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); } } inline static void gemv(Stream *stream, - bool trans, int m, int n, half::half_t alpha, - const half::half_t *A, int lda, - const half::half_t *X, int incX, half::half_t beta, - half::half_t *Y, int incY) { + bool trans, index_t m, index_t n, half::half_t alpha, + const half::half_t *A, index_t lda, + const half::half_t *X, index_t incX, half::half_t beta, + half::half_t *Y, index_t incY) { LOG(FATAL) << "Not implmented!"; } inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - half::half_t alpha, const half::half_t *A, int lda, - const half::half_t *X, int incX, - half::half_t beta, half::half_t *Y, int incY, int batch_count) { + bool trans, index_t m, index_t n, + half::half_t alpha, const half::half_t *A, index_t lda, + const half::half_t *X, index_t incX, + half::half_t beta, half::half_t *Y, index_t incY, index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void ger(Stream *stream, - int m, int n, half::half_t alpha, - const half::half_t *X, int incX, - const half::half_t *Y, int incY, half::half_t *A, int lda) { + index_t m, index_t n, half::half_t alpha, + const half::half_t *X, index_t incX, + const half::half_t *Y, index_t incY, half::half_t *A, index_t lda) { LOG(FATAL) << "Not implmented!"; } inline static void batched_ger(Stream *stream, - int m, int n, half::half_t alpha, - const half::half_t *X, int incX, const half::half_t *Y, int incY, - half::half_t *A, int lda, int batch_count) { + index_t m, index_t n, half::half_t alpha, + const half::half_t *X, index_t incX, const half::half_t *Y, index_t incY, + half::half_t *A, index_t lda, index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void dot(Stream *stream, - int n, - const half::half_t* X, int incX, - const half::half_t* Y, int incY, + index_t n, + const half::half_t* X, index_t incX, + const half::half_t* Y, index_t incY, half::half_t *ret) { LOG(FATAL) << "Not implmented!"; } @@ -614,10 +614,10 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, - const float *B, int ldb, float beta, - float *C, int ldc) { + index_t m, index_t n, index_t k, float alpha, + const float *A, index_t lda, + const float *B, index_t ldb, float beta, + float *C, index_t ldc) { cublasStatus_t err = cublasSgemm(Stream::GetBlasHandle(stream), GetT(transa), GetT(transb), m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc); @@ -625,9 +625,9 @@ struct BLASEngine { } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, float alpha, - const float *A, int lda, const float *B, int ldb, - float beta, float *C, int ldc, int batch_count, + index_t m, index_t n, index_t k, float alpha, + const float *A, index_t lda, const float *B, index_t ldb, + float beta, float *C, index_t ldc, index_t batch_count, float **workspace) { #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000 // Cast DType* to DType** using workspace as a buffer @@ -660,7 +660,7 @@ struct BLASEngine { batch_count); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmStridedBatched fail"; #else - for (int i = 0; i < batch_count; ++i) { + for (index_t i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); @@ -668,46 +668,46 @@ struct BLASEngine { #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 } inline static void gemv(Stream *stream, - bool trans, int m, int n, float alpha, - const float *A, int lda, - const float *X, int incX, float beta, - float *Y, int incY) { + bool trans, index_t m, index_t n, float alpha, + const float *A, index_t lda, + const float *X, index_t incX, float beta, + float *Y, index_t incY) { cublasStatus_t err = cublasSgemv(Stream::GetBlasHandle(stream), GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemv fail"; } inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - float alpha, const float *A, int lda, - const float *X, int incX, - float beta, float *Y, int incY, int batch_count) { - for (int i = 0; i < batch_count; ++i) { + bool trans, index_t m, index_t n, + float alpha, const float *A, index_t lda, + const float *X, index_t incX, + float beta, float *Y, index_t incY, index_t batch_count) { + for (index_t i = 0; i < batch_count; ++i) { gemv(stream, trans, m, n, alpha, A + i * m * n, lda, X + i * (trans ? m : n) * incX, incX, beta, Y + i * (trans ? n : m) * incY, incY); } } inline static void ger(Stream *stream, - int m, int n, float alpha, - const float *X, int incX, - const float *Y, int incY, float *A, int lda) { + index_t m, index_t n, float alpha, + const float *X, index_t incX, + const float *Y, index_t incY, float *A, index_t lda) { cublasStatus_t err = cublasSger(Stream::GetBlasHandle(stream), m, n, &alpha, X, incX, Y, incY, A, lda); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sger fail"; } inline static void batched_ger(Stream *stream, - int m, int n, float alpha, - const float *X, int incX, - const float *Y, int incY, float *A, int lda, int batch_count) { - for (int i = 0; i < batch_count; ++i) { + index_t m, index_t n, float alpha, + const float *X, index_t incX, + const float *Y, index_t incY, float *A, index_t lda, index_t batch_count) { + for (index_t i = 0; i < batch_count; ++i) { ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, A + i * lda * n, lda); } } inline static void dot(Stream *stream, - int n, - const float* X, int incX, - const float* Y, int incY, + index_t n, + const float* X, index_t incX, + const float* Y, index_t incY, float *ret) { cublasSetPointerMode(Stream::GetBlasHandle(stream), CUBLAS_POINTER_MODE_DEVICE); @@ -731,10 +731,10 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, double alpha, - const double *A, int lda, - const double *B, int ldb, - double beta, double *C, int ldc) { + index_t m, index_t n, index_t k, double alpha, + const double *A, index_t lda, + const double *B, index_t ldb, + double beta, double *C, index_t ldc) { cublasStatus_t err = cublasDgemm(Stream::GetBlasHandle(stream), GetT(transa), GetT(transb), m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc); @@ -742,9 +742,9 @@ struct BLASEngine { } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - int m, int n, int k, double alpha, - const double *A, int lda, const double *B, int ldb, - double beta, double *C, int ldc, int batch_count, + index_t m, index_t n, index_t k, double alpha, + const double *A, index_t lda, const double *B, index_t ldb, + double beta, double *C, index_t ldc, index_t batch_count, double **workspace) { #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000 // Cast DType* to DType** using workspace as a buffer @@ -777,7 +777,7 @@ struct BLASEngine { batch_count); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmStridedBatched fail"; #else - for (int i = 0; i < batch_count; ++i) { + for (index_t i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); @@ -785,46 +785,46 @@ struct BLASEngine { #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 } inline static void gemv(Stream *stream, - bool trans, int m, int n, double alpha, - const double *A, int lda, - const double *X, int incX, - double beta, double *Y, int incY) { + bool trans, index_t m, index_t n, double alpha, + const double *A, index_t lda, + const double *X, index_t incX, + double beta, double *Y, index_t incY) { cublasStatus_t err = cublasDgemv(Stream::GetBlasHandle(stream), GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemv fail"; } inline static void batched_gemv(Stream *stream, - bool trans, int m, int n, - double alpha, const double *A, int lda, - const double *X, int incX, - double beta, double *Y, int incY, int batch_count) { - for (int i = 0; i < batch_count; ++i) { + bool trans, index_t m, index_t n, + double alpha, const double *A, index_t lda, + const double *X, index_t incX, + double beta, double *Y, index_t incY, index_t batch_count) { + for (index_t i = 0; i < batch_count; ++i) { gemv(stream, trans, m, n, alpha, A + i * m * n, lda, X + i * (trans ? m : n) * incX, incX, beta, Y + i * (trans ? n : m) * incY, incY); } } inline static void ger(Stream *stream, - int m, int n, double alpha, - const double *X, int incX, - const double *Y, int incY, double *A, int lda) { + index_t m, index_t n, double alpha, + const double *X, index_t incX, + const double *Y, index_t incY, double *A, index_t lda) { cublasStatus_t err = cublasDger(Stream::GetBlasHandle(stream), m, n, &alpha, X, incX, Y, incY, A, lda); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dger fail"; } inline static void batched_ger(Stream *stream, - int m, int n, double alpha, - const double *X, int incX, - const double *Y, int incY, double *A, int lda, int batch_count) { - for (int i = 0; i < batch_count; ++i) { + index_t m, index_t n, double alpha, + const double *X, index_t incX, + const double *Y, index_t incY, double *A, index_t lda, index_t batch_count) { + for (index_t i = 0; i < batch_count; ++i) { ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, A + i * lda * n, lda); } } inline static void dot(Stream *stream, - int n, - const double* X, int incX, - const double* Y, int incY, + index_t n, + const double* X, index_t incX, + const double* Y, index_t incY, double *ret) { cublasSetPointerMode(Stream::GetBlasHandle(stream), CUBLAS_POINTER_MODE_DEVICE); diff --git a/mshadow/tensor.h b/mshadow/tensor.h index 53a1f02a..4f690712 100755 --- a/mshadow/tensor.h +++ b/mshadow/tensor.h @@ -1069,12 +1069,6 @@ inline void BatchGEMM(Tensor dst, #define MSHADOW_SCALAR_ double #include "./expr_scalar-inl.h" #undef MSHADOW_SCALAR_ -#define MSHADOW_SCALAR_ int16_t -#include "./expr_scalar-inl.h" -#undef MSHADOW_SCALAR_ -#define MSHADOW_SCALAR_ uint16_t -#include "./expr_scalar-inl.h" -#undef MSHADOW_SCALAR_ #define MSHADOW_SCALAR_ int32_t #include "./expr_scalar-inl.h" #undef MSHADOW_SCALAR_ @@ -1084,4 +1078,7 @@ inline void BatchGEMM(Tensor dst, #define MSHADOW_SCALAR_ int64_t #include "./expr_scalar-inl.h" #undef MSHADOW_SCALAR_ +#define MSHADOW_SCALAR_ mshadow::half::half_t +#include "./expr_scalar-inl.h" +#undef MSHADOW_SCALAR_ #endif // MSHADOW_TENSOR_H_ From a0c5721a7ed1ebf89632f841586acdeff23d944d Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 2 Apr 2019 10:53:16 -0700 Subject: [PATCH 08/11] fix lint --- mshadow/dot_engine-inl.h | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/mshadow/dot_engine-inl.h b/mshadow/dot_engine-inl.h index e3399cae..149dfc4f 100644 --- a/mshadow/dot_engine-inl.h +++ b/mshadow/dot_engine-inl.h @@ -258,7 +258,8 @@ struct BLASEngine { inline static void batched_ger(Stream *stream, index_t m, index_t n, double alpha, const double *X, index_t incX, - const double *Y, index_t incY, double *A, index_t lda, index_t batch_count) { + const double *Y, index_t incY, double *A, index_t lda, + index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void dot(Stream *stream, @@ -368,7 +369,8 @@ struct BLASEngine { inline static void batched_ger(Stream *stream, index_t m, index_t n, float alpha, const float *X, index_t incX, - const float *Y, index_t incY, float *A, index_t lda, index_t batch_count) { + const float *Y, index_t incY, float *A, index_t lda, + index_t batch_count) { for (index_t i = 0; i < batch_count; ++i) { ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, A + i * lda * n, lda); @@ -480,7 +482,8 @@ struct BLASEngine { inline static void batched_ger(Stream *stream, index_t m, index_t n, double alpha, const double *X, index_t incX, - const double *Y, index_t incY, double *A, index_t lda, index_t batch_count) { + const double *Y, index_t incY, double *A, index_t lda, + index_t batch_count) { for (index_t i = 0; i < batch_count; ++i) { ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, A + i * lda * n, lda); @@ -538,9 +541,9 @@ struct BLASEngine { inline static void batched_gemm(Stream *stream, bool transa, bool transb, index_t m, index_t n, index_t k, half::half_t alpha, - const half::half_t *A, index_t lda, const half::half_t *B, index_t ldb, - half::half_t beta, half::half_t *C, index_t ldc, index_t batch_count, - half::half_t **workspace) { + const half::half_t *A, index_t lda, const half::half_t *B, + index_t ldb, half::half_t beta, half::half_t *C, index_t ldc, + index_t batch_count, half::half_t **workspace) { #if defined(__CUDACC__) && CUDA_VERSION >= 9000 int major = stream->prop.major; int minor = stream->prop.minor; @@ -578,7 +581,8 @@ struct BLASEngine { bool trans, index_t m, index_t n, half::half_t alpha, const half::half_t *A, index_t lda, const half::half_t *X, index_t incX, - half::half_t beta, half::half_t *Y, index_t incY, index_t batch_count) { + half::half_t beta, half::half_t *Y, index_t incY, + index_t batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void ger(Stream *stream, @@ -815,7 +819,8 @@ struct BLASEngine { inline static void batched_ger(Stream *stream, index_t m, index_t n, double alpha, const double *X, index_t incX, - const double *Y, index_t incY, double *A, index_t lda, index_t batch_count) { + const double *Y, index_t incY, double *A, index_t lda, + index_t batch_count) { for (index_t i = 0; i < batch_count; ++i) { ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, A + i * lda * n, lda); From 8b28974eedec52777cb19162712f99ac0f0b8f55 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 9 Apr 2019 10:00:54 -0700 Subject: [PATCH 09/11] revert type change in gemm function call --- mshadow/dot_engine-inl.h | 533 ++++++++++++++++++--------------------- 1 file changed, 249 insertions(+), 284 deletions(-) diff --git a/mshadow/dot_engine-inl.h b/mshadow/dot_engine-inl.h index 149dfc4f..5363974f 100644 --- a/mshadow/dot_engine-inl.h +++ b/mshadow/dot_engine-inl.h @@ -65,49 +65,49 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, DType alpha, - const DType *A, index_t lda, const DType *B, index_t ldb, - DType beta, DType *C, index_t ldc) { + int m, int n, int k, DType alpha, + const DType *A, int lda, const DType *B, int ldb, + DType beta, DType *C, int ldc) { LOG(FATAL) << "Not implmented!"; } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, DType alpha, - const DType *A, index_t lda, const DType *B, index_t ldb, - DType beta, DType *C, index_t ldc, index_t batch_count, + int m, int n, int k, DType alpha, + const DType *A, int lda, const DType *B, int ldb, + DType beta, DType *C, int ldc, int batch_count, DType **workspace) { LOG(FATAL) << "Not implmented!"; } inline static void gemv(Stream *stream, - bool trans, index_t m, index_t n, - DType alpha, const DType *A, index_t lda, - const DType *X, index_t incX, - DType beta, DType *Y, index_t incY) { + bool trans, int m, int n, + DType alpha, const DType *A, int lda, + const DType *X, int incX, + DType beta, DType *Y, int incY) { LOG(FATAL) << "Not implmented!"; } inline static void batched_gemv(Stream *stream, - bool trans, index_t m, index_t n, - DType alpha, const DType *A, index_t lda, - const DType *X, index_t incX, - DType beta, DType *Y, index_t incY, index_t batch_count) { + bool trans, int m, int n, + DType alpha, const DType *A, int lda, + const DType *X, int incX, + DType beta, DType *Y, int incY, int batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void ger(Stream *stream, - index_t m, index_t n, DType alpha, - const DType *X, index_t incX, - const DType *Y, index_t incY, DType *A, index_t lda) { + int m, int n, DType alpha, + const DType *X, int incX, + const DType *Y, int incY, DType *A, int lda) { LOG(FATAL) << "Not implmented!"; } inline static void batched_ger(Stream *stream, - index_t m, index_t n, DType alpha, - const DType *X, index_t incX, - const DType *Y, index_t incY, DType *A, index_t lda, index_t batch_count) { + int m, int n, DType alpha, + const DType *X, int incX, + const DType *Y, int incY, DType *A, int lda, int batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void dot(Stream *stream, - index_t n, - const DType* X, index_t incX, - const DType* Y, index_t incY, + int n, + const DType* X, int incX, + const DType* Y, int incY, DType* ret) { LOG(FATAL) << "Not implmented!"; } @@ -123,9 +123,9 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, float alpha, - const float *A, index_t lda, const float *B, index_t ldb, - float beta, float *C, index_t ldc) { + int m, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc) { if (alpha == 1.0f && beta == 0.0f) { bool transpose_left = transb; bool transpose_right = transa; @@ -147,46 +147,46 @@ struct BLASEngine { } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, float alpha, - const float *A, index_t lda, const float *B, index_t ldb, - float beta, float *C, index_t ldc, index_t batch_count, + int m, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc, int batch_count, float **workspace) { - for (index_t i = 0; i < batch_count; ++i) { + for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); } } inline static void gemv(Stream *stream, - bool trans, index_t m, index_t n, - float alpha, const float *A, index_t lda, - const float *X, index_t incX, - float beta, float *Y, index_t incY) { + bool trans, int m, int n, + float alpha, const float *A, int lda, + const float *X, int incX, + float beta, float *Y, int incY) { LOG(FATAL) << "Not implmented!"; } inline static void batched_gemv(Stream *stream, - bool trans, index_t m, index_t n, - float alpha, const float *A, index_t lda, - const float *X, index_t incX, - float beta, float *Y, index_t incY, index_t batch_count) { + bool trans, int m, int n, + float alpha, const float *A, int lda, + const float *X, int incX, + float beta, float *Y, int incY, int batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void ger(Stream *stream, - index_t m, index_t n, float alpha, - const float *X, index_t incX, - const float *Y, index_t incY, float *A, index_t lda) { + int m, int n, float alpha, + const float *X, int incX, + const float *Y, int incY, float *A, int lda) { LOG(FATAL) << "Not implmented!"; } inline static void batched_ger(Stream *stream, - index_t m, index_t n, float alpha, - const float *X, index_t incX, - const float *Y, index_t incY, float *A, index_t lda, index_t batch_count) { + int m, int n, float alpha, + const float *X, int incX, + const float *Y, int incY, float *A, int lda, int batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void dot(Stream *stream, - index_t n, - const float* X, index_t incX, - const float* Y, index_t incY, + int n, + const float* X, int incX, + const float* Y, int incY, float* ret) { LOG(FATAL) << "Not implmented!"; } @@ -201,9 +201,9 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, double alpha, - const double *A, index_t lda, const double *B, index_t ldb, - double beta, double *C, index_t ldc) { + int m, int n, int k, double alpha, + const double *A, int lda, const double *B, int ldb, + double beta, double *C, int ldc) { if (alpha == 1.0f && beta == 0.0f) { bool transpose_left = transb; bool transpose_right = transa; @@ -225,47 +225,46 @@ struct BLASEngine { } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, double alpha, - const double *A, index_t lda, const double *B, index_t ldb, - double beta, double *C, index_t ldc, index_t batch_count, + int m, int n, int k, double alpha, + const double *A, int lda, const double *B, int ldb, + double beta, double *C, int ldc, int batch_count, double **workspace) { - for (index_t i = 0; i < batch_count; ++i) { + for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); } } inline static void gemv(Stream *stream, - bool trans, index_t m, index_t n, - double alpha, const double *A, index_t lda, - const double *X, index_t incX, - double beta, double *Y, index_t incY) { + bool trans, int m, int n, + double alpha, const double *A, int lda, + const double *X, int incX, + double beta, double *Y, int incY) { LOG(FATAL) << "Not implmented!"; } inline static void batched_gemv(Stream *stream, - bool trans, index_t m, index_t n, - double alpha, const double *A, index_t lda, - const double *X, index_t incX, - double beta, double *Y, index_t incY, index_t batch_count) { + bool trans, int m, int n, + double alpha, const double *A, int lda, + const double *X, int incX, + double beta, double *Y, int incY, int batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void ger(Stream *stream, - index_t m, index_t n, double alpha, - const double *X, index_t incX, - const double *Y, index_t incY, double *A, index_t lda) { + int m, int n, double alpha, + const double *X, int incX, + const double *Y, int incY, double *A, int lda) { LOG(FATAL) << "Not implmented!"; } inline static void batched_ger(Stream *stream, - index_t m, index_t n, double alpha, - const double *X, index_t incX, - const double *Y, index_t incY, double *A, index_t lda, - index_t batch_count) { + int m, int n, double alpha, + const double *X, int incX, + const double *Y, int incY, double *A, int lda, int batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void dot(Stream *stream, - index_t n, - const double* X, index_t incX, - const double* Y, index_t incY, + int n, + const double* X, int incX, + const double* Y, int incY, double* ret) { LOG(FATAL) << "Not implmented!"; } @@ -281,60 +280,55 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, float alpha, - const float *A, index_t lda, const float *B, index_t ldb, - float beta, float *C, index_t ldc) { + int m, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc) { cblas_sgemm(CblasColMajor, GetT(transa), GetT(transb), m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, float alpha, - const float *A, index_t lda, const float *B, index_t ldb, - float beta, float *C, index_t ldc, index_t batch_count, + int m, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc, int batch_count, float **workspace) { #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) - // since same m/n/k is used for all single gemms, so we put all gemms into one group - const int GROUP_SIZE = 1; - MKL_INT p_m[GROUP_SIZE] = {m}; - MKL_INT p_n[GROUP_SIZE] = {n}; - MKL_INT p_k[GROUP_SIZE] = {k}; - MKL_INT p_lda[GROUP_SIZE] = {lda}; - MKL_INT p_ldb[GROUP_SIZE] = {ldb}; - MKL_INT p_ldc[GROUP_SIZE] = {ldc}; - - float p_alpha[GROUP_SIZE] = {alpha}; - float p_beta[GROUP_SIZE] = {beta}; + std::vector p_m(batch_count, m); + std::vector p_n(batch_count, n); + std::vector p_k(batch_count, k); + std::vector p_lda(batch_count, lda); + std::vector p_ldb(batch_count, ldb); + std::vector p_ldc(batch_count, ldc); + std::vector p_alpha(batch_count, alpha); + std::vector p_beta(batch_count, beta); + std::vector pp_A; + std::vector pp_B; + std::vector pp_C; CBLAS_TRANSPOSE cblas_a_trans = GetT(transa); CBLAS_TRANSPOSE cblas_b_trans = GetT(transb); - MKL_INT p_group_sizeb[GROUP_SIZE] = {batch_count}; - CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans}; - CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans}; - - std::vector pp_A; - std::vector pp_B; - std::vector pp_C; - pp_A.reserve(batch_count); - pp_B.reserve(batch_count); - pp_C.reserve(batch_count); + std::vector p_group_sizeb(batch_count, batch_count); + std::vector p_transa(batch_count, cblas_a_trans); + std::vector p_transb(batch_count, cblas_b_trans); auto m_k = m * k; auto k_n = k * n; auto m_n = m * n; - for (index_t i = 0; i < batch_count; i++) { - pp_A[i] = A + i * m_k; - pp_B[i] = B + i * k_n; - pp_C[i] = C + i * m_n; + for (int i = 0; i < batch_count; i++) { + pp_A.push_back(A + i * m_k); + pp_B.push_back(B + i * k_n); + pp_C.push_back(C + i * m_n); } - cblas_sgemm_batch(CblasColMajor, p_transa, p_transb, - p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(), - p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb); + cblas_sgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(), + p_m.data(), p_n.data(), p_k.data(), + p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(), + p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(), + 1, p_group_sizeb.data()); #else - for (index_t i = 0; i < batch_count; ++i) { + for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); @@ -342,44 +336,43 @@ struct BLASEngine { #endif } inline static void gemv(Stream *stream, - bool trans, index_t m, index_t n, - float alpha, const float *A, index_t lda, - const float *X, index_t incX, - float beta, float *Y, index_t incY) { + bool trans, int m, int n, + float alpha, const float *A, int lda, + const float *X, int incX, + float beta, float *Y, int incY) { cblas_sgemv(CblasColMajor, GetT(trans), m, n, alpha, A, lda, X, incX, beta, Y, incY); } inline static void batched_gemv(Stream *stream, - bool trans, index_t m, index_t n, - float alpha, const float *A, index_t lda, - const float *X, index_t incX, - float beta, float *Y, index_t incY, index_t batch_count) { - for (index_t i = 0; i < batch_count; ++i) { + bool trans, int m, int n, + float alpha, const float *A, int lda, + const float *X, int incX, + float beta, float *Y, int incY, int batch_count) { + for (int i = 0; i < batch_count; ++i) { gemv(stream, trans, m, n, alpha, A + i * m * n, lda, X + i * (trans ? m : n) * incX, incX, beta, Y + i * (trans ? n : m) * incY, incY); } } inline static void ger(Stream *stream, - index_t m, index_t n, float alpha, - const float *X, index_t incX, - const float *Y, index_t incY, float *A, index_t lda) { + int m, int n, float alpha, + const float *X, int incX, + const float *Y, int incY, float *A, int lda) { cblas_sger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); } inline static void batched_ger(Stream *stream, - index_t m, index_t n, float alpha, - const float *X, index_t incX, - const float *Y, index_t incY, float *A, index_t lda, - index_t batch_count) { - for (index_t i = 0; i < batch_count; ++i) { + int m, int n, float alpha, + const float *X, int incX, + const float *Y, int incY, float *A, int lda, int batch_count) { + for (int i = 0; i < batch_count; ++i) { ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, A + i * lda * n, lda); } } inline static void dot(Stream *stream, - index_t n, - const float* X, index_t incX, - const float* Y, index_t incY, + int n, + const float* X, int incX, + const float* Y, int incY, float* ret) { *ret = cblas_sdot(n, X, incX, Y, incY); } @@ -394,60 +387,55 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, double alpha, - const double *A, index_t lda, const double *B, index_t ldb, - double beta, double *C, index_t ldc) { + int m, int n, int k, double alpha, + const double *A, int lda, const double *B, int ldb, + double beta, double *C, int ldc) { cblas_dgemm(CblasColMajor, GetT(transa), GetT(transb), m, n, k, alpha, A, lda, B, ldb, beta, C, ldc); } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, double alpha, - const double *A, index_t lda, const double *B, index_t ldb, - double beta, double *C, index_t ldc, index_t batch_count, + int m, int n, int k, double alpha, + const double *A, int lda, const double *B, int ldb, + double beta, double *C, int ldc, int batch_count, double **workspace) { #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) - // since same m/n/k is used for all single gemms, so we put all gemms into one group - const int GROUP_SIZE = 1; - MKL_INT p_m[GROUP_SIZE] = {m}; - MKL_INT p_n[GROUP_SIZE] = {n}; - MKL_INT p_k[GROUP_SIZE] = {k}; - MKL_INT p_lda[GROUP_SIZE] = {lda}; - MKL_INT p_ldb[GROUP_SIZE] = {ldb}; - MKL_INT p_ldc[GROUP_SIZE] = {ldc}; - - double p_alpha[GROUP_SIZE] = {alpha}; - double p_beta[GROUP_SIZE] = {beta}; + std::vector p_m(batch_count, m); + std::vector p_n(batch_count, n); + std::vector p_k(batch_count, k); + std::vector p_lda(batch_count, lda); + std::vector p_ldb(batch_count, ldb); + std::vector p_ldc(batch_count, ldc); + std::vector p_alpha(batch_count, alpha); + std::vector p_beta(batch_count, beta); + std::vector pp_A; + std::vector pp_B; + std::vector pp_C; CBLAS_TRANSPOSE cblas_a_trans = GetT(transa); CBLAS_TRANSPOSE cblas_b_trans = GetT(transb); - MKL_INT p_group_sizeb[GROUP_SIZE] = {batch_count}; - CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans}; - CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans}; - - std::vector pp_A; - std::vector pp_B; - std::vector pp_C; - pp_A.reserve(batch_count); - pp_B.reserve(batch_count); - pp_C.reserve(batch_count); + std::vector p_group_sizeb(batch_count, batch_count); + std::vector p_transa(batch_count, cblas_a_trans); + std::vector p_transb(batch_count, cblas_b_trans); auto m_k = m * k; auto k_n = k * n; auto m_n = m * n; - for (index_t i = 0; i < batch_count; i++) { - pp_A[i] = A + i * m_k; - pp_B[i] = B + i * k_n; - pp_C[i] = C + i * m_n; + for (int i = 0; i < batch_count; i++) { + pp_A.push_back(A + i * m_k); + pp_B.push_back(B + i * k_n); + pp_C.push_back(C + i * m_n); } - cblas_dgemm_batch(CblasColMajor, p_transa, p_transb, - p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(), - p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb); + cblas_dgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(), + p_m.data(), p_n.data(), p_k.data(), + p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(), + p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(), + 1, p_group_sizeb.data()); #else - for (index_t i = 0; i < batch_count; ++i) { + for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); @@ -455,44 +443,43 @@ struct BLASEngine { #endif } inline static void gemv(Stream *stream, - bool trans, index_t m, index_t n, double alpha, - const double *A, index_t lda, - const double *X, index_t incX, - double beta, double *Y, index_t incY) { + bool trans, int m, int n, double alpha, + const double *A, int lda, + const double *X, int incX, + double beta, double *Y, int incY) { cblas_dgemv(CblasColMajor, GetT(trans), m, n, alpha, A, lda, X, incX, beta, Y, incY); } inline static void batched_gemv(Stream *stream, - bool trans, index_t m, index_t n, - double alpha, const double *A, index_t lda, - const double *X, index_t incX, - double beta, double *Y, index_t incY, index_t batch_count) { - for (index_t i = 0; i < batch_count; ++i) { + bool trans, int m, int n, + double alpha, const double *A, int lda, + const double *X, int incX, + double beta, double *Y, int incY, int batch_count) { + for (int i = 0; i < batch_count; ++i) { gemv(stream, trans, m, n, alpha, A + i * m * n, lda, X + i * (trans ? m : n) * incX, incX, beta, Y + i * (trans ? n : m) * incY, incY); } } inline static void ger(Stream *stream, - index_t m, index_t n, double alpha, - const double *X, index_t incX, - const double *Y, index_t incY, double *A, index_t lda) { + int m, int n, double alpha, + const double *X, int incX, + const double *Y, int incY, double *A, int lda) { cblas_dger(CblasColMajor, m, n, alpha, X, incX, Y, incY, A, lda); } inline static void batched_ger(Stream *stream, - index_t m, index_t n, double alpha, - const double *X, index_t incX, - const double *Y, index_t incY, double *A, index_t lda, - index_t batch_count) { - for (index_t i = 0; i < batch_count; ++i) { + int m, int n, double alpha, + const double *X, int incX, + const double *Y, int incY, double *A, int lda, int batch_count) { + for (int i = 0; i < batch_count; ++i) { ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, A + i * lda * n, lda); } } inline static void dot(Stream *stream, - index_t n, - const double* X, index_t incX, - const double* Y, index_t incY, + int n, + const double* X, int incX, + const double* Y, int incY, double* ret) { *ret = cblas_ddot(n, X, incX, Y, incY); } @@ -513,10 +500,10 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, half::half_t alpha, - const half::half_t *A, index_t lda, - const half::half_t *B, index_t ldb, half::half_t beta, - half::half_t *C, index_t ldc) { + int m, int n, int k, half::half_t alpha, + const half::half_t *A, int lda, + const half::half_t *B, int ldb, half::half_t beta, + half::half_t *C, int ldc) { #if defined(CUDA_VERSION) && CUDA_VERSION >= 7050 // Always use pseudo-fp16: fp32 compute with fp16 I/O. float alpha_f = float(alpha); // NOLINT(*) @@ -540,67 +527,46 @@ struct BLASEngine { } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, half::half_t alpha, - const half::half_t *A, index_t lda, const half::half_t *B, - index_t ldb, half::half_t beta, half::half_t *C, index_t ldc, - index_t batch_count, half::half_t **workspace) { -#if defined(__CUDACC__) && CUDA_VERSION >= 9000 - int major = stream->prop.major; - int minor = stream->prop.minor; - // fp16 is not supported before ARCH 53 - if ((major > 5) || (major == 5 && minor >= 3)) { - const __half* A_h = reinterpret_cast(A); - const __half* B_h = reinterpret_cast(B); - __half* alpha_h = reinterpret_cast<__half*>(&alpha); - __half* beta_h = reinterpret_cast<__half*>(&beta); - __half* C_h = reinterpret_cast<__half*>(C); - cublasStatus_t err = cublasHgemmStridedBatched(Stream::GetBlasHandle(stream), - GetT(transa), GetT(transb), m, n, k, alpha_h, - A_h, lda, m * k, - B_h, ldb, k * n, - beta_h, C_h, ldc, m * n, - batch_count); - CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: HgemmStridedBatched fail"; - return; - } -#endif - for (index_t i = 0; i < batch_count; ++i) { + int m, int n, int k, half::half_t alpha, + const half::half_t *A, int lda, const half::half_t *B, int ldb, + half::half_t beta, half::half_t *C, int ldc, int batch_count, + half::half_t **workspace) { + for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); } } inline static void gemv(Stream *stream, - bool trans, index_t m, index_t n, half::half_t alpha, - const half::half_t *A, index_t lda, - const half::half_t *X, index_t incX, half::half_t beta, - half::half_t *Y, index_t incY) { + bool trans, int m, int n, half::half_t alpha, + const half::half_t *A, int lda, + const half::half_t *X, int incX, half::half_t beta, + half::half_t *Y, int incY) { LOG(FATAL) << "Not implmented!"; } inline static void batched_gemv(Stream *stream, - bool trans, index_t m, index_t n, - half::half_t alpha, const half::half_t *A, index_t lda, - const half::half_t *X, index_t incX, - half::half_t beta, half::half_t *Y, index_t incY, - index_t batch_count) { + bool trans, int m, int n, + half::half_t alpha, const half::half_t *A, int lda, + const half::half_t *X, int incX, + half::half_t beta, half::half_t *Y, int incY, int batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void ger(Stream *stream, - index_t m, index_t n, half::half_t alpha, - const half::half_t *X, index_t incX, - const half::half_t *Y, index_t incY, half::half_t *A, index_t lda) { + int m, int n, half::half_t alpha, + const half::half_t *X, int incX, + const half::half_t *Y, int incY, half::half_t *A, int lda) { LOG(FATAL) << "Not implmented!"; } inline static void batched_ger(Stream *stream, - index_t m, index_t n, half::half_t alpha, - const half::half_t *X, index_t incX, const half::half_t *Y, index_t incY, - half::half_t *A, index_t lda, index_t batch_count) { + int m, int n, half::half_t alpha, + const half::half_t *X, int incX, const half::half_t *Y, int incY, + half::half_t *A, int lda, int batch_count) { LOG(FATAL) << "Not implmented!"; } inline static void dot(Stream *stream, - index_t n, - const half::half_t* X, index_t incX, - const half::half_t* Y, index_t incY, + int n, + const half::half_t* X, int incX, + const half::half_t* Y, int incY, half::half_t *ret) { LOG(FATAL) << "Not implmented!"; } @@ -618,10 +584,10 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, float alpha, - const float *A, index_t lda, - const float *B, index_t ldb, float beta, - float *C, index_t ldc) { + int m, int n, int k, float alpha, + const float *A, int lda, + const float *B, int ldb, float beta, + float *C, int ldc) { cublasStatus_t err = cublasSgemm(Stream::GetBlasHandle(stream), GetT(transa), GetT(transb), m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc); @@ -629,9 +595,9 @@ struct BLASEngine { } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, float alpha, - const float *A, index_t lda, const float *B, index_t ldb, - float beta, float *C, index_t ldc, index_t batch_count, + int m, int n, int k, float alpha, + const float *A, int lda, const float *B, int ldb, + float beta, float *C, int ldc, int batch_count, float **workspace) { #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000 // Cast DType* to DType** using workspace as a buffer @@ -664,7 +630,7 @@ struct BLASEngine { batch_count); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmStridedBatched fail"; #else - for (index_t i = 0; i < batch_count; ++i) { + for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); @@ -672,46 +638,46 @@ struct BLASEngine { #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 } inline static void gemv(Stream *stream, - bool trans, index_t m, index_t n, float alpha, - const float *A, index_t lda, - const float *X, index_t incX, float beta, - float *Y, index_t incY) { + bool trans, int m, int n, float alpha, + const float *A, int lda, + const float *X, int incX, float beta, + float *Y, int incY) { cublasStatus_t err = cublasSgemv(Stream::GetBlasHandle(stream), GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sgemv fail"; } inline static void batched_gemv(Stream *stream, - bool trans, index_t m, index_t n, - float alpha, const float *A, index_t lda, - const float *X, index_t incX, - float beta, float *Y, index_t incY, index_t batch_count) { - for (index_t i = 0; i < batch_count; ++i) { + bool trans, int m, int n, + float alpha, const float *A, int lda, + const float *X, int incX, + float beta, float *Y, int incY, int batch_count) { + for (int i = 0; i < batch_count; ++i) { gemv(stream, trans, m, n, alpha, A + i * m * n, lda, X + i * (trans ? m : n) * incX, incX, beta, Y + i * (trans ? n : m) * incY, incY); } } inline static void ger(Stream *stream, - index_t m, index_t n, float alpha, - const float *X, index_t incX, - const float *Y, index_t incY, float *A, index_t lda) { + int m, int n, float alpha, + const float *X, int incX, + const float *Y, int incY, float *A, int lda) { cublasStatus_t err = cublasSger(Stream::GetBlasHandle(stream), m, n, &alpha, X, incX, Y, incY, A, lda); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Sger fail"; } inline static void batched_ger(Stream *stream, - index_t m, index_t n, float alpha, - const float *X, index_t incX, - const float *Y, index_t incY, float *A, index_t lda, index_t batch_count) { - for (index_t i = 0; i < batch_count; ++i) { + int m, int n, float alpha, + const float *X, int incX, + const float *Y, int incY, float *A, int lda, int batch_count) { + for (int i = 0; i < batch_count; ++i) { ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, A + i * lda * n, lda); } } inline static void dot(Stream *stream, - index_t n, - const float* X, index_t incX, - const float* Y, index_t incY, + int n, + const float* X, int incX, + const float* Y, int incY, float *ret) { cublasSetPointerMode(Stream::GetBlasHandle(stream), CUBLAS_POINTER_MODE_DEVICE); @@ -735,10 +701,10 @@ struct BLASEngine { } inline static void gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, double alpha, - const double *A, index_t lda, - const double *B, index_t ldb, - double beta, double *C, index_t ldc) { + int m, int n, int k, double alpha, + const double *A, int lda, + const double *B, int ldb, + double beta, double *C, int ldc) { cublasStatus_t err = cublasDgemm(Stream::GetBlasHandle(stream), GetT(transa), GetT(transb), m, n, k, &alpha, A, lda, B, ldb, &beta, C, ldc); @@ -746,9 +712,9 @@ struct BLASEngine { } inline static void batched_gemm(Stream *stream, bool transa, bool transb, - index_t m, index_t n, index_t k, double alpha, - const double *A, index_t lda, const double *B, index_t ldb, - double beta, double *C, index_t ldc, index_t batch_count, + int m, int n, int k, double alpha, + const double *A, int lda, const double *B, int ldb, + double beta, double *C, int ldc, int batch_count, double **workspace) { #if defined(__CUDACC__) && CUDA_VERSION >= 4010 && CUDA_VERSION < 8000 // Cast DType* to DType** using workspace as a buffer @@ -781,7 +747,7 @@ struct BLASEngine { batch_count); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmStridedBatched fail"; #else - for (index_t i = 0; i < batch_count; ++i) { + for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, beta, C + i * m * n, ldc); @@ -789,47 +755,46 @@ struct BLASEngine { #endif // defined(__CUDACC__) && CUDA_VERSION >= 4010 } inline static void gemv(Stream *stream, - bool trans, index_t m, index_t n, double alpha, - const double *A, index_t lda, - const double *X, index_t incX, - double beta, double *Y, index_t incY) { + bool trans, int m, int n, double alpha, + const double *A, int lda, + const double *X, int incX, + double beta, double *Y, int incY) { cublasStatus_t err = cublasDgemv(Stream::GetBlasHandle(stream), GetT(trans), m, n, &alpha, A, lda, X, incX, &beta, Y, incY); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dgemv fail"; } inline static void batched_gemv(Stream *stream, - bool trans, index_t m, index_t n, - double alpha, const double *A, index_t lda, - const double *X, index_t incX, - double beta, double *Y, index_t incY, index_t batch_count) { - for (index_t i = 0; i < batch_count; ++i) { + bool trans, int m, int n, + double alpha, const double *A, int lda, + const double *X, int incX, + double beta, double *Y, int incY, int batch_count) { + for (int i = 0; i < batch_count; ++i) { gemv(stream, trans, m, n, alpha, A + i * m * n, lda, X + i * (trans ? m : n) * incX, incX, beta, Y + i * (trans ? n : m) * incY, incY); } } inline static void ger(Stream *stream, - index_t m, index_t n, double alpha, - const double *X, index_t incX, - const double *Y, index_t incY, double *A, index_t lda) { + int m, int n, double alpha, + const double *X, int incX, + const double *Y, int incY, double *A, int lda) { cublasStatus_t err = cublasDger(Stream::GetBlasHandle(stream), m, n, &alpha, X, incX, Y, incY, A, lda); CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: Dger fail"; } inline static void batched_ger(Stream *stream, - index_t m, index_t n, double alpha, - const double *X, index_t incX, - const double *Y, index_t incY, double *A, index_t lda, - index_t batch_count) { - for (index_t i = 0; i < batch_count; ++i) { + int m, int n, double alpha, + const double *X, int incX, + const double *Y, int incY, double *A, int lda, int batch_count) { + for (int i = 0; i < batch_count; ++i) { ger(stream, m, n, alpha, X + i * m * incX, incX, Y + i * n * incY, incY, A + i * lda * n, lda); } } inline static void dot(Stream *stream, - index_t n, - const double* X, index_t incX, - const double* Y, index_t incY, + int n, + const double* X, int incX, + const double* Y, int incY, double *ret) { cublasSetPointerMode(Stream::GetBlasHandle(stream), CUBLAS_POINTER_MODE_DEVICE); From bb6a03716ed8c898e6ce149bfb769d8783c2b29b Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 9 Apr 2019 10:04:33 -0700 Subject: [PATCH 10/11] sync change with upstream --- mshadow/dot_engine-inl.h | 118 ++++++++++++++++++++++++--------------- 1 file changed, 74 insertions(+), 44 deletions(-) diff --git a/mshadow/dot_engine-inl.h b/mshadow/dot_engine-inl.h index 5363974f..21816f20 100644 --- a/mshadow/dot_engine-inl.h +++ b/mshadow/dot_engine-inl.h @@ -293,40 +293,45 @@ struct BLASEngine { float beta, float *C, int ldc, int batch_count, float **workspace) { #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) - std::vector p_m(batch_count, m); - std::vector p_n(batch_count, n); - std::vector p_k(batch_count, k); - std::vector p_lda(batch_count, lda); - std::vector p_ldb(batch_count, ldb); - std::vector p_ldc(batch_count, ldc); - std::vector p_alpha(batch_count, alpha); - std::vector p_beta(batch_count, beta); - std::vector pp_A; - std::vector pp_B; - std::vector pp_C; + // since same m/n/k is used for all single gemms, so we put all gemms into one group + const int GROUP_SIZE = 1; + MKL_INT p_m[GROUP_SIZE] = {m}; + MKL_INT p_n[GROUP_SIZE] = {n}; + MKL_INT p_k[GROUP_SIZE] = {k}; + MKL_INT p_lda[GROUP_SIZE] = {lda}; + MKL_INT p_ldb[GROUP_SIZE] = {ldb}; + MKL_INT p_ldc[GROUP_SIZE] = {ldc}; + + float p_alpha[GROUP_SIZE] = {alpha}; + float p_beta[GROUP_SIZE] = {beta}; CBLAS_TRANSPOSE cblas_a_trans = GetT(transa); CBLAS_TRANSPOSE cblas_b_trans = GetT(transb); - std::vector p_group_sizeb(batch_count, batch_count); - std::vector p_transa(batch_count, cblas_a_trans); - std::vector p_transb(batch_count, cblas_b_trans); + MKL_INT p_group_sizeb[GROUP_SIZE] = {batch_count}; + CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans}; + CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans}; + + std::vector pp_A; + std::vector pp_B; + std::vector pp_C; + pp_A.reserve(batch_count); + pp_B.reserve(batch_count); + pp_C.reserve(batch_count); auto m_k = m * k; auto k_n = k * n; auto m_n = m * n; for (int i = 0; i < batch_count; i++) { - pp_A.push_back(A + i * m_k); - pp_B.push_back(B + i * k_n); - pp_C.push_back(C + i * m_n); + pp_A[i] = A + i * m_k; + pp_B[i] = B + i * k_n; + pp_C[i] = C + i * m_n; } - cblas_sgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(), - p_m.data(), p_n.data(), p_k.data(), - p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(), - p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(), - 1, p_group_sizeb.data()); + cblas_sgemm_batch(CblasColMajor, p_transa, p_transb, + p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(), + p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb); #else for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, @@ -400,40 +405,45 @@ struct BLASEngine { double beta, double *C, int ldc, int batch_count, double **workspace) { #if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) - std::vector p_m(batch_count, m); - std::vector p_n(batch_count, n); - std::vector p_k(batch_count, k); - std::vector p_lda(batch_count, lda); - std::vector p_ldb(batch_count, ldb); - std::vector p_ldc(batch_count, ldc); - std::vector p_alpha(batch_count, alpha); - std::vector p_beta(batch_count, beta); - std::vector pp_A; - std::vector pp_B; - std::vector pp_C; + // since same m/n/k is used for all single gemms, so we put all gemms into one group + const int GROUP_SIZE = 1; + MKL_INT p_m[GROUP_SIZE] = {m}; + MKL_INT p_n[GROUP_SIZE] = {n}; + MKL_INT p_k[GROUP_SIZE] = {k}; + MKL_INT p_lda[GROUP_SIZE] = {lda}; + MKL_INT p_ldb[GROUP_SIZE] = {ldb}; + MKL_INT p_ldc[GROUP_SIZE] = {ldc}; + + double p_alpha[GROUP_SIZE] = {alpha}; + double p_beta[GROUP_SIZE] = {beta}; CBLAS_TRANSPOSE cblas_a_trans = GetT(transa); CBLAS_TRANSPOSE cblas_b_trans = GetT(transb); - std::vector p_group_sizeb(batch_count, batch_count); - std::vector p_transa(batch_count, cblas_a_trans); - std::vector p_transb(batch_count, cblas_b_trans); + MKL_INT p_group_sizeb[GROUP_SIZE] = {batch_count}; + CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans}; + CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans}; + + std::vector pp_A; + std::vector pp_B; + std::vector pp_C; + pp_A.reserve(batch_count); + pp_B.reserve(batch_count); + pp_C.reserve(batch_count); auto m_k = m * k; auto k_n = k * n; auto m_n = m * n; for (int i = 0; i < batch_count; i++) { - pp_A.push_back(A + i * m_k); - pp_B.push_back(B + i * k_n); - pp_C.push_back(C + i * m_n); + pp_A[i] = A + i * m_k; + pp_B[i] = B + i * k_n; + pp_C[i] = C + i * m_n; } - cblas_dgemm_batch(CblasColMajor, p_transa.data(), p_transb.data(), - p_m.data(), p_n.data(), p_k.data(), - p_alpha.data(), pp_A.data(), p_lda.data(), pp_B.data(), - p_ldb.data(), p_beta.data(), pp_C.data(), p_ldc.data(), - 1, p_group_sizeb.data()); + cblas_dgemm_batch(CblasColMajor, p_transa, p_transb, + p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(), + p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb); #else for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, @@ -531,6 +541,26 @@ struct BLASEngine { const half::half_t *A, int lda, const half::half_t *B, int ldb, half::half_t beta, half::half_t *C, int ldc, int batch_count, half::half_t **workspace) { +#if defined(__CUDACC__) && CUDA_VERSION >= 9000 + int major = stream->prop.major; + int minor = stream->prop.minor; + // fp16 is not supported before ARCH 53 + if ((major > 5) || (major == 5 && minor >= 3)) { + const __half* A_h = reinterpret_cast(A); + const __half* B_h = reinterpret_cast(B); + __half* alpha_h = reinterpret_cast<__half*>(&alpha); + __half* beta_h = reinterpret_cast<__half*>(&beta); + __half* C_h = reinterpret_cast<__half*>(C); + cublasStatus_t err = cublasHgemmStridedBatched(Stream::GetBlasHandle(stream), + GetT(transa), GetT(transb), m, n, k, alpha_h, + A_h, lda, m * k, + B_h, ldb, k * n, + beta_h, C_h, ldc, m * n, + batch_count); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: HgemmStridedBatched fail"; + return; + } +#endif for (int i = 0; i < batch_count; ++i) { gemm(stream, transa, transb, m, n, k, alpha, A + i * m * k, lda, B + i * k * n, ldb, From 9057bf9b537ca36ace1f56631ec39de06e3bfbf6 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Wed, 10 Apr 2019 18:51:23 -0700 Subject: [PATCH 11/11] remove uint32 --- mshadow/tensor.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/mshadow/tensor.h b/mshadow/tensor.h index 4f690712..0d662621 100755 --- a/mshadow/tensor.h +++ b/mshadow/tensor.h @@ -1072,9 +1072,6 @@ inline void BatchGEMM(Tensor dst, #define MSHADOW_SCALAR_ int32_t #include "./expr_scalar-inl.h" #undef MSHADOW_SCALAR_ -#define MSHADOW_SCALAR_ uint32_t -#include "./expr_scalar-inl.h" -#undef MSHADOW_SCALAR_ #define MSHADOW_SCALAR_ int64_t #include "./expr_scalar-inl.h" #undef MSHADOW_SCALAR_