From 9188a562def4ce92fd1498431a9bebe50b2bdeeb Mon Sep 17 00:00:00 2001 From: zhangting2020 Date: Tue, 20 Feb 2024 11:12:05 +0800 Subject: [PATCH 1/7] fa and fuse rope support gqa/mqa --- paddlenlp/transformers/llama/modeling.py | 49 ++++++++++++++++++------ 1 file changed, 38 insertions(+), 11 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 231fcfd44e20..1d8d6bb515ed 100644 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -923,20 +923,44 @@ def forward( position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) if self.use_fused_rope: assert past_key_value is None, "fuse rotary not support cache kv for now" + batch_size, seq_length, num_heads, head_dim = query_states.shape + _, kv_seq_len, num_key_value_heads, _ = key_states.shape cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) if get_env_device() == "npu": query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0] key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0] else: - query_states, key_states, _ = fused_rotary_position_embedding( - query_states, - key_states, - v=None, - sin=sin, - cos=cos, - position_ids=position_ids, - use_neox_rotary_style=False, - ) + # paddle version > 2.6 or develop support q and k/v with different num_heads + paddle_version = float(paddle.__version__[:3]) + if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads): + query_states, _, _ = fused_rotary_position_embedding( + query_states, + None, + None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + key_states, _, _ = fused_rotary_position_embedding( + key_states, + None, + None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + query_states, key_states, _ = fused_rotary_position_embedding( + query_states, + key_states, + v=None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -955,8 +979,11 @@ def forward( # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1 # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + # paddle version > 2.6 or develop support flash-attn with gqa/mqa + paddle_version = float(paddle.__version__[:3]) + if (paddle_version != 0.0) and (paddle_version <= 2.6): + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) if ( From 5f0adfa6a917d060ca7fdde9a98ac02c475004e2 Mon Sep 17 00:00:00 2001 From: zhangting2020 Date: Thu, 22 Feb 2024 10:22:12 +0800 Subject: [PATCH 2/7] add tests for llama with GQA --- llm/llama/tests/test_GQA.py | 108 ++++++++++++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 llm/llama/tests/test_GQA.py diff --git a/llm/llama/tests/test_GQA.py b/llm/llama/tests/test_GQA.py new file mode 100644 index 000000000000..e3870b485e75 --- /dev/null +++ b/llm/llama/tests/test_GQA.py @@ -0,0 +1,108 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle +import paddle.distributed.fleet as fleet +from paddle.distributed.fleet.meta_parallel.pipeline_parallel import PipelineParallel + +from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoModelForCausalLMPipe,AutoTokenizer + + +class TestLlama(unittest.TestCase): + def test_sequence_model(self): + world_size = paddle.distributed.get_world_size() + pp_degree = world_size + tp_degree = 1 + + if world_size > 2: + pp_degree = 2 + assert world_size % pp_degree == 0 + tp_degree = world_size // pp_degree + + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": tp_degree, + "pp_degree": pp_degree, + "sharding_degree": 1, + } + #strategy.pipeline_configs = {"enable_partial_send_recv": False if pp_degree > 1 else True} + fleet.init(is_collective=True, strategy=strategy) + hcg = fleet.get_hybrid_communicate_group() + mp_group = hcg.get_model_parallel_group() + tensor_parallel_rank = mp_group.rank + + if pp_degree > 1: + model_class = AutoModelForCausalLMPipe + else: + model_class = AutoModelForCausalLM + + model_name_or_path = "meta-llama/Llama-2-7b" + + seq_len = 2048 + batch_size = 2 + + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) + config = AutoConfig.from_pretrained(model_name_or_path) + config.seq_length = seq_len + config.num_key_value_heads = 8 # gqa + config.max_position_embeddings = max(config.max_position_embeddings, seq_len) + config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128) + config.use_flash_attention = True + config.use_fused_rope = True + config.use_fused_rms_norm = True + config.fuse_attention_qkv = True + config.recompute_granularity = "full" + config.virtual_pp_degree = 1 + config.use_recompute = False + + config.tensor_parallel_degree = tp_degree + config.tensor_parallel_rank = tensor_parallel_rank + config.tensor_parallel_output = False + config.sequence_parallel = False + + config.fuse_sequence_parallel_allreduce = False + + # hidden_size = 4096 + model = model_class.from_config( + config, + dtype="float16", + ) + + model.eval() + + input_ids = paddle.arange(100, 100 + batch_size * seq_len, dtype="int64").reshape([batch_size, seq_len]) + labels = paddle.arange(101, 101 + batch_size * seq_len, dtype="int64").reshape([batch_size, seq_len]) + + attention_mask = None + if pp_degree > 1: + pp_model = PipelineParallel(layers=model, hcg=hcg, strategy=strategy) + pp_model.accumulate_steps = batch_size # for micro_batch_size * acc_steps == batch_size + ret = pp_model.eval_batch(data=[input_ids, labels], compute_loss=True) + else: + ret = model(input_ids=input_ids, labels=labels, attention_mask=attention_mask) + ret = ret[0] + + print(f"ret mp{tp_degree} pp{pp_degree}", ret.item()) + ret_mp_pp = ret.item() + + + + + +if __name__ == "__main__": + TestLlama().test_sequence_model() \ No newline at end of file From 42e2969d600b70259ebb33b1a601d059862dccb2 Mon Sep 17 00:00:00 2001 From: zhangting2020 Date: Thu, 22 Feb 2024 10:24:21 +0800 Subject: [PATCH 3/7] rename test --- llm/llama/tests/test_GQA.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llm/llama/tests/test_GQA.py b/llm/llama/tests/test_GQA.py index e3870b485e75..7242dd5d1d9f 100644 --- a/llm/llama/tests/test_GQA.py +++ b/llm/llama/tests/test_GQA.py @@ -23,7 +23,7 @@ class TestLlama(unittest.TestCase): - def test_sequence_model(self): + def test_GQA(self): world_size = paddle.distributed.get_world_size() pp_degree = world_size tp_degree = 1 @@ -105,4 +105,4 @@ def test_sequence_model(self): if __name__ == "__main__": - TestLlama().test_sequence_model() \ No newline at end of file + TestLlama().test_GQA() From 6fd2b5088df095d6101767556818b764910af5a1 Mon Sep 17 00:00:00 2001 From: zhangting2020 Date: Thu, 22 Feb 2024 14:57:54 +0800 Subject: [PATCH 4/7] add test --- tests/transformers/llama/test_modeling.py | 29 +++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/transformers/llama/test_modeling.py b/tests/transformers/llama/test_modeling.py index aa24fea3e3bd..4c96b2e415d5 100644 --- a/tests/transformers/llama/test_modeling.py +++ b/tests/transformers/llama/test_modeling.py @@ -269,6 +269,31 @@ def check_model_position_ids(self, config, input_ids, input_mask, *args): else: self.parent.assertTrue((result_position_id[0] == result_no_position_id[0]).all()) + def create_and_check_gqa_model(self, config, input_ids, input_mask, *args): + model = LlamaForCausalLM(config) + config.num_key_value_heads = 8 # gqa + #config.max_position_embeddings = max(config.max_position_embeddings, seq_len) + #config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128) + config.use_flash_attention = True + config.use_fused_rope = True + model = model.from_config( + config, + dtype="bfloat16", + ) + model.eval() + + result = model( + input_ids, + use_cache=True, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + if self.parent.use_labels: + self.parent.assertIsInstance(result[0].item(), float) + self.parent.assertEqual(result[1].shape, [self.batch_size, self.seq_length, self.vocab_size]) + else: + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.vocab_size]) + class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): base_model_class = LlamaModel @@ -318,6 +343,10 @@ def test_llama_lm_head_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_lm_head_model(*config_and_inputs) + def test_llama_gqa_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gqa_model(*config_and_inputs) + class LlamaModelIntegrationTest(ModelTesterPretrainedMixin, unittest.TestCase): base_model_class = LlamaModel From eb6dd7bd97588ec020a18aadbf3c4dee8cd4b6e4 Mon Sep 17 00:00:00 2001 From: zhangting2020 Date: Thu, 22 Feb 2024 14:59:49 +0800 Subject: [PATCH 5/7] remove test --- llm/llama/tests/test_GQA.py | 108 ------------------------------------ 1 file changed, 108 deletions(-) delete mode 100644 llm/llama/tests/test_GQA.py diff --git a/llm/llama/tests/test_GQA.py b/llm/llama/tests/test_GQA.py deleted file mode 100644 index 7242dd5d1d9f..000000000000 --- a/llm/llama/tests/test_GQA.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import numpy as np -import paddle -import paddle.distributed.fleet as fleet -from paddle.distributed.fleet.meta_parallel.pipeline_parallel import PipelineParallel - -from paddlenlp.transformers import AutoConfig, AutoModelForCausalLM, AutoModelForCausalLMPipe,AutoTokenizer - - -class TestLlama(unittest.TestCase): - def test_GQA(self): - world_size = paddle.distributed.get_world_size() - pp_degree = world_size - tp_degree = 1 - - if world_size > 2: - pp_degree = 2 - assert world_size % pp_degree == 0 - tp_degree = world_size // pp_degree - - strategy = fleet.DistributedStrategy() - strategy.hybrid_configs = { - "dp_degree": 1, - "mp_degree": tp_degree, - "pp_degree": pp_degree, - "sharding_degree": 1, - } - #strategy.pipeline_configs = {"enable_partial_send_recv": False if pp_degree > 1 else True} - fleet.init(is_collective=True, strategy=strategy) - hcg = fleet.get_hybrid_communicate_group() - mp_group = hcg.get_model_parallel_group() - tensor_parallel_rank = mp_group.rank - - if pp_degree > 1: - model_class = AutoModelForCausalLMPipe - else: - model_class = AutoModelForCausalLM - - model_name_or_path = "meta-llama/Llama-2-7b" - - seq_len = 2048 - batch_size = 2 - - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) - config = AutoConfig.from_pretrained(model_name_or_path) - config.seq_length = seq_len - config.num_key_value_heads = 8 # gqa - config.max_position_embeddings = max(config.max_position_embeddings, seq_len) - config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128) - config.use_flash_attention = True - config.use_fused_rope = True - config.use_fused_rms_norm = True - config.fuse_attention_qkv = True - config.recompute_granularity = "full" - config.virtual_pp_degree = 1 - config.use_recompute = False - - config.tensor_parallel_degree = tp_degree - config.tensor_parallel_rank = tensor_parallel_rank - config.tensor_parallel_output = False - config.sequence_parallel = False - - config.fuse_sequence_parallel_allreduce = False - - # hidden_size = 4096 - model = model_class.from_config( - config, - dtype="float16", - ) - - model.eval() - - input_ids = paddle.arange(100, 100 + batch_size * seq_len, dtype="int64").reshape([batch_size, seq_len]) - labels = paddle.arange(101, 101 + batch_size * seq_len, dtype="int64").reshape([batch_size, seq_len]) - - attention_mask = None - if pp_degree > 1: - pp_model = PipelineParallel(layers=model, hcg=hcg, strategy=strategy) - pp_model.accumulate_steps = batch_size # for micro_batch_size * acc_steps == batch_size - ret = pp_model.eval_batch(data=[input_ids, labels], compute_loss=True) - else: - ret = model(input_ids=input_ids, labels=labels, attention_mask=attention_mask) - ret = ret[0] - - print(f"ret mp{tp_degree} pp{pp_degree}", ret.item()) - ret_mp_pp = ret.item() - - - - - -if __name__ == "__main__": - TestLlama().test_GQA() From c929040e92214876cb79ba41eac4793a1ce9c5a7 Mon Sep 17 00:00:00 2001 From: zhangting2020 Date: Thu, 22 Feb 2024 15:00:45 +0800 Subject: [PATCH 6/7] remove comments --- tests/transformers/llama/test_modeling.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/transformers/llama/test_modeling.py b/tests/transformers/llama/test_modeling.py index 4c96b2e415d5..c1e96d2fdc4d 100644 --- a/tests/transformers/llama/test_modeling.py +++ b/tests/transformers/llama/test_modeling.py @@ -272,8 +272,6 @@ def check_model_position_ids(self, config, input_ids, input_mask, *args): def create_and_check_gqa_model(self, config, input_ids, input_mask, *args): model = LlamaForCausalLM(config) config.num_key_value_heads = 8 # gqa - #config.max_position_embeddings = max(config.max_position_embeddings, seq_len) - #config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128) config.use_flash_attention = True config.use_fused_rope = True model = model.from_config( From 1596b372448dc3ae282a8e274a944abcb3b1c0d7 Mon Sep 17 00:00:00 2001 From: zhangting2020 Date: Thu, 22 Feb 2024 16:42:35 +0800 Subject: [PATCH 7/7] fix dtype --- tests/transformers/llama/test_modeling.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/transformers/llama/test_modeling.py b/tests/transformers/llama/test_modeling.py index c1e96d2fdc4d..114f4d3ffadb 100644 --- a/tests/transformers/llama/test_modeling.py +++ b/tests/transformers/llama/test_modeling.py @@ -272,12 +272,7 @@ def check_model_position_ids(self, config, input_ids, input_mask, *args): def create_and_check_gqa_model(self, config, input_ids, input_mask, *args): model = LlamaForCausalLM(config) config.num_key_value_heads = 8 # gqa - config.use_flash_attention = True config.use_fused_rope = True - model = model.from_config( - config, - dtype="bfloat16", - ) model.eval() result = model(