diff --git a/paddle/fluid/framework/data_type.h b/paddle/fluid/framework/data_type.h index 15d45d8386dad..ec8284b825500 100644 --- a/paddle/fluid/framework/data_type.h +++ b/paddle/fluid/framework/data_type.h @@ -89,7 +89,7 @@ struct DataTypeTrait { _ForEachDataTypeHelper_(callback, int, INT32); \ _ForEachDataTypeHelper_(callback, int64_t, INT64); -// It's only for DataParallel in HIP +// It's only for DataParallel in HIP, bf16 not support in HIP. #define _ForEachDataTypeForHIP_(callback) \ _ForEachDataTypeHelper_(callback, float, FP32); \ _ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \ diff --git a/paddle/fluid/imperative/reducer.cc b/paddle/fluid/imperative/reducer.cc index 746df54a7dc11..beddbd5d12008 100644 --- a/paddle/fluid/imperative/reducer.cc +++ b/paddle/fluid/imperative/reducer.cc @@ -49,6 +49,10 @@ void Group::DivNRanks(const platform::DeviceContext &context, int64_t nranks) { VLOG(4) << "before div 2" << *tensor; VLOG(4) << "NDiv for cpu devices : rank = " << nranks; #ifdef PADDLE_WITH_HIP + if (dtype_ == paddle::framework::proto::VarType_Type_BF16) { + PADDLE_THROW(paddle::platform::errors::Fatal( + "Unsupport BF16 in DataParallel for now")); + } framework::VisitDataTypeForHIP( dtype_, DivNRanksForAllReduce( tensor, nranks, context)); diff --git a/paddle/fluid/imperative/reducer.cu b/paddle/fluid/imperative/reducer.cu index 88326d66211cc..05453a61b7e39 100644 --- a/paddle/fluid/imperative/reducer.cu +++ b/paddle/fluid/imperative/reducer.cu @@ -21,6 +21,10 @@ namespace imperative { void Group::DivNRanks(framework::Tensor *tensor, int64_t nranks, const platform::DeviceContext &context) { #ifdef PADDLE_WITH_HIP + if (dtype_ == paddle::framework::proto::VarType_Type_BF16) { + PADDLE_THROW(paddle::platform::errors::Fatal( + "Unsupport BF16 in DataParallel for now")); + } framework::VisitDataTypeForHIP( dtype_, DivNRanksForAllReduce(tensor, nranks, context));