From bdfd8bf0b73fda57a61cac4f32f47bafecb0d670 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 22 Feb 2022 09:59:06 +0000 Subject: [PATCH 1/4] refine bf16 amp-o1 logic --- paddle/fluid/imperative/amp_auto_cast.cc | 30 +++++++++++++++++--- python/paddle/fluid/dygraph/amp/auto_cast.py | 2 +- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index 94c6d0a4d569a1..6e8bfbb4a77610 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -273,8 +273,9 @@ static inline std::shared_ptr CastToBF16( template static inline framework::proto::VarType::Type GetPromoteType( - const std::string& op_type, const NameVarMap& ins) { - auto dst_type = framework::proto::VarType::FP16; + const std::string& op_type, const NameVarMap& ins, + const framework::proto::VarType::Type amp_dtype) { + auto dst_type = amp_dtype; for (const auto& pair : ins) { for (const auto& var : pair.second) { if (GetDataType(var) == framework::proto::VarType::FP32) { @@ -337,7 +338,8 @@ NameVarMap AutoCastInputs(const std::string& op_type, } return new_ins; } else { - auto dst_type = GetPromoteType(op_type, ins); + auto dst_type = + GetPromoteType(op_type, ins, framework::proto::VarType::FP16); // NOTE(zhiqiu): if the op has op fp16 kernel, fall back to fp32. if (dst_type == framework::proto::VarType::FP16 && @@ -435,7 +437,7 @@ NameVarMap AutoCastBF16Inputs(const std::string& op_type, } } return new_ins; - } else { + } else if (AmpOperators::Instance().GetMutableBlockOps()->count(op_type)) { for (auto& pair : new_ins) { VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " << GetDtypeStr(*pair.second.cbegin()) << " to float"; @@ -444,6 +446,26 @@ NameVarMap AutoCastBF16Inputs(const std::string& op_type, } } return new_ins; + } else { + auto dst_type = + GetPromoteType(op_type, ins, framework::proto::VarType::BF16); + // NOTE(zhangbo): if the op has op fp16 kernel, fall back to fp32. + if (dst_type == framework::proto::VarType::BF16 && + AmpOperators::Instance().GetMutableUnsupportedBf16Ops()->count( + op_type)) { + dst_type = framework::proto::VarType::FP32; + } + for (auto& pair : new_ins) { + VLOG(5) << "Op(" << op_type << "): Cast " << pair.first << " from " + << GetDtypeStr(*pair.second.cbegin()) << " to " + << framework::DataTypeToString(dst_type); + for (auto& var : pair.second) { + var = (dst_type == framework::proto::VarType::FP32 + ? CastToFP32(var) + : CastToBF16(var)); + } + } + return new_ins; } return new_ins; } diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 37134764e9d1c8..82c534e00219db 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -75,7 +75,7 @@ 'lookup_table', 'lookup_table_v2', 'scatter', 'scatter_grad' } -BF16_WHITE_LIST = {'conv2d'} +BF16_WHITE_LIST = {'conv2d', 'matmul_v2'} BF16_BLACK_LIST = {' '} From e63836cd7b08e19cf70944e43a3f288c5d7549b9 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Tue, 22 Feb 2022 09:59:30 +0000 Subject: [PATCH 2/4] refine amp GLOG --- paddle/fluid/imperative/tracer.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 03811ac778779c..c832787d989062 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -205,17 +205,19 @@ void Tracer::TraceOp(const std::string& type, const NameVarMap& ins, NameVarMap new_ins = ins; if (amp_level_ == AmpLevel::O1) { - VLOG(5) << "Auto mixed precision run operator: " << type; if (amp_dtype_ == phi::DataType::FLOAT16) { + VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type; new_ins = AutoCastInputs(type, ins); } else if (amp_dtype_ == phi::DataType::BFLOAT16) { + VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type; new_ins = AutoCastBF16Inputs(type, ins); } } else if (amp_level_ == AmpLevel::O2) { - VLOG(5) << "Pure fp16 run operator: " << type; if (amp_dtype_ == phi::DataType::FLOAT16) { + VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type; new_ins = CastPureFp16Inputs(type, ins); } else if (amp_dtype_ == phi::DataType::BFLOAT16) { + VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type; new_ins = CastPureBf16Inputs(type, ins); } } From 18795f9194aaea4e3ef1caa0bb494182ca0144aa Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 23 Feb 2022 06:34:14 +0000 Subject: [PATCH 3/4] refine unittest --- .../test_imperative_auto_mixed_precision.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 306c6b4707e8a3..9c7e9e96f2015c 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -1130,20 +1130,26 @@ class TestBf16(unittest.TestCase): test amp for BF16 ''' - def train(self, enable_amp=True): + def train(self, enable_amp=True, amp_level='O1'): paddle.seed(100) input = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.) conv = paddle.nn.Conv2D(4, 6, (3, 3)) with paddle.amp.auto_cast( - enable=enable_amp, level='O2', dtype='bfloat16'): + enable=enable_amp, level=amp_level, dtype='bfloat16'): output = conv(input) output = output.cast('float32') return output.numpy() def test_bf16(self): out_fp32 = self.train(enable_amp=False) - out_bf16 = self.train(enable_amp=True) - self.assertTrue(np.allclose(out_fp32, out_bf16, rtol=1.e-3, atol=1.e-2)) + out_bf16_O1 = self.train(enable_amp=True, amp_level='O1') + out_bf16_O2 = self.train(enable_amp=True, amp_level='O2') + self.assertTrue( + np.allclose( + out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1)) + self.assertTrue( + np.allclose( + out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1)) if __name__ == '__main__': From 5cf63d1ee2f7521d89cded3f30f7acec4d0d37ec Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Wed, 23 Feb 2022 11:44:10 +0000 Subject: [PATCH 4/4] refine unittest --- .../test_imperative_auto_mixed_precision.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 9c7e9e96f2015c..1dfbe901e1843f 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -1141,15 +1141,18 @@ def train(self, enable_amp=True, amp_level='O1'): return output.numpy() def test_bf16(self): - out_fp32 = self.train(enable_amp=False) - out_bf16_O1 = self.train(enable_amp=True, amp_level='O1') - out_bf16_O2 = self.train(enable_amp=True, amp_level='O2') - self.assertTrue( - np.allclose( - out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1)) - self.assertTrue( - np.allclose( - out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1)) + if fluid.core.is_compiled_with_cuda(): + cudnn_version = paddle.device.get_cudnn_version() + if cudnn_version is not None and cudnn_version >= 8100: + out_fp32 = self.train(enable_amp=False) + out_bf16_O1 = self.train(enable_amp=True, amp_level='O1') + out_bf16_O2 = self.train(enable_amp=True, amp_level='O2') + self.assertTrue( + np.allclose( + out_fp32, out_bf16_O1, rtol=1.e-3, atol=1.e-1)) + self.assertTrue( + np.allclose( + out_fp32, out_bf16_O2, rtol=1.e-3, atol=1.e-1)) if __name__ == '__main__':