diff --git a/mshadow/base.h b/mshadow/base.h index 29d83c25..0a696509 100644 --- a/mshadow/base.h +++ b/mshadow/base.h @@ -671,11 +671,46 @@ struct minimum { "floating point types not uint8"; \ break; \ case mshadow::kInt32: \ - LOG(FATAL) << "This operation only support " \ - "floating point types, not int32"; \ - break; \ - default: \ - LOG(FATAL) << "Unknown type enum " << type; \ + LOG(FATAL) << "This operation only support " \ + "floating point types, not int32";\ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + +#define MSHADOW_REAL_TYPE_SWITCH_EX(type$, DType$, DLargeType$, ...) \ + switch (type$) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType$; \ + typedef float DLargeType$; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType$; \ + typedef double DLargeType$; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType$; \ + typedef float DLargeType$; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kUint8: \ + LOG(FATAL) << "This operation only support " \ + "floating point types not uint8"; \ + break; \ + case mshadow::kInt32: \ + LOG(FATAL) << "This operation only support " \ + "floating point types, not int32";\ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type$; \ } #define MSHADOW_LAYOUT_SWITCH(layout, Layout, ...) \ diff --git a/mshadow/dot_engine-inl.h b/mshadow/dot_engine-inl.h index 249f43b5..18b94f7b 100644 --- a/mshadow/dot_engine-inl.h +++ b/mshadow/dot_engine-inl.h @@ -34,6 +34,7 @@ inline void GetBatchedView(DType **dst, DType *src, int num, int stride, } } #ifdef __CUDACC__ +namespace cuda {}; template inline void GetBatchedView(DType **dst, DType *src, int num, int stride, Stream *stream) {