diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 231fcfd44e20..1d8d6bb515ed 100644 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -923,20 +923,44 @@ def forward( position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) if self.use_fused_rope: assert past_key_value is None, "fuse rotary not support cache kv for now" + batch_size, seq_length, num_heads, head_dim = query_states.shape + _, kv_seq_len, num_key_value_heads, _ = key_states.shape cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) if get_env_device() == "npu": query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0] key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0] else: - query_states, key_states, _ = fused_rotary_position_embedding( - query_states, - key_states, - v=None, - sin=sin, - cos=cos, - position_ids=position_ids, - use_neox_rotary_style=False, - ) + # paddle version > 2.6 or develop support q and k/v with different num_heads + paddle_version = float(paddle.__version__[:3]) + if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads): + query_states, _, _ = fused_rotary_position_embedding( + query_states, + None, + None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + key_states, _, _ = fused_rotary_position_embedding( + key_states, + None, + None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + query_states, key_states, _ = fused_rotary_position_embedding( + query_states, + key_states, + v=None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) else: cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -955,8 +979,11 @@ def forward( # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1 # repeat k/v heads if n_kv_heads < n_heads - key_states = repeat_kv(key_states, self.num_key_value_groups) - value_states = repeat_kv(value_states, self.num_key_value_groups) + # paddle version > 2.6 or develop support flash-attn with gqa/mqa + paddle_version = float(paddle.__version__[:3]) + if (paddle_version != 0.0) and (paddle_version <= 2.6): + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) if ( diff --git a/tests/transformers/llama/test_modeling.py b/tests/transformers/llama/test_modeling.py index aa24fea3e3bd..114f4d3ffadb 100644 --- a/tests/transformers/llama/test_modeling.py +++ b/tests/transformers/llama/test_modeling.py @@ -269,6 +269,24 @@ def check_model_position_ids(self, config, input_ids, input_mask, *args): else: self.parent.assertTrue((result_position_id[0] == result_no_position_id[0]).all()) + def create_and_check_gqa_model(self, config, input_ids, input_mask, *args): + model = LlamaForCausalLM(config) + config.num_key_value_heads = 8 # gqa + config.use_fused_rope = True + model.eval() + + result = model( + input_ids, + use_cache=True, + labels=input_ids if self.parent.use_labels else None, + return_dict=self.parent.return_dict, + ) + if self.parent.use_labels: + self.parent.assertIsInstance(result[0].item(), float) + self.parent.assertEqual(result[1].shape, [self.batch_size, self.seq_length, self.vocab_size]) + else: + self.parent.assertEqual(result[0].shape, [self.batch_size, self.seq_length, self.vocab_size]) + class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): base_model_class = LlamaModel @@ -318,6 +336,10 @@ def test_llama_lm_head_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_lm_head_model(*config_and_inputs) + def test_llama_gqa_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_gqa_model(*config_and_inputs) + class LlamaModelIntegrationTest(ModelTesterPretrainedMixin, unittest.TestCase): base_model_class = LlamaModel