Skip to content

Commit

Permalink
softmax for fp16 with fp32 accumulator (apache#14098)
Browse files Browse the repository at this point in the history
* softmax for fp16 with fp32 accumulator

* return AType in kernel

* add dtype

* kernel

* grad use in-out only when dtype override

* simplify infer type

* address comments
  • Loading branch information
szha authored and drivanov committed Mar 4, 2019
1 parent adb3ead commit eef3b52
Show file tree
Hide file tree
Showing 4 changed files with 326 additions and 71 deletions.
42 changes: 42 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,48 @@ inline int get_num_threads<cpu>(const int N) {
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MXNET_REAL_ACC_TYPE_SWITCH(type, DType, AType, ...)\
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
typedef double AType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
typedef double AType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
typedef float AType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kUint8: \
LOG(FATAL) << "This operation only support " \
"floating point types not uint8"; \
break; \
case mshadow::kInt8: \
LOG(FATAL) << "This operation only support " \
"floating point types not int8"; \
break; \
case mshadow::kInt32: \
LOG(FATAL) << "This operation only support " \
"floating point types, not int32"; \
break; \
case mshadow::kInt64: \
LOG(FATAL) << "This operation only support " \
"floating point types, not int64"; \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}

/*!
* \brief assign the val to out according
Expand Down
Loading

0 comments on commit eef3b52

Please sign in to comment.