Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GQA for auto parallel #8234

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,16 +434,40 @@ def forward(
if self.config.rope:
if self.use_fused_rope:
assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states, _ = fused_rotary_position_embedding(
query_states,
key_states,
v=None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)

paddle_version = float(paddle.__version__[:3])
if ((paddle_version != 0.0) and (paddle_version <= 2.6)) and (num_heads != num_key_value_heads):
query_states, _, _ = fused_rotary_position_embedding(
query_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
key_states, _, _ = fused_rotary_position_embedding(
key_states,
None,
None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
else:
query_states, key_states, _ = fused_rotary_position_embedding(
query_states,
key_states,
v=None,
sin=sin,
cos=cos,
position_ids=position_ids,
use_neox_rotary_style=False,
)
else:
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
# hack here, because elementwise infer spmd not support broadcast now
Expand All @@ -463,8 +487,11 @@ def forward(

# TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# paddle version > 2.6 or develop support flash-attn with gqa/mqa
paddle_version = float(paddle.__version__[:3])
if (paddle_version != 0.0) and (paddle_version <= 2.6):
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)

has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient)
if (
Expand Down
Loading