Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support Llama3.1 8B 128K generation on single GPU 80GB #8811

Merged
merged 6 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions paddlenlp/transformers/llama/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@ def __init__(
long_sequence_strategy_name=None,
long_sequence_init_args=None,
use_long_sequence_strategies=False,
use_flash_attention_for_generation=False,
use_last_token_for_generation=False,
immediate_clear_past_key_value=False,
**kwargs,
):
self.vocab_size = vocab_size
Expand Down Expand Up @@ -189,6 +192,9 @@ def __init__(
self.long_sequence_strategy_name = long_sequence_strategy_name
self.long_sequence_init_args = {} if long_sequence_init_args is None else long_sequence_init_args
self.use_long_sequence_strategies = use_long_sequence_strategies
self.use_flash_attention_for_generation = use_flash_attention_for_generation
self.use_last_token_for_generation = use_last_token_for_generation
self.immediate_clear_past_key_value = immediate_clear_past_key_value

super().__init__(
pad_token_id=pad_token_id,
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def fusion_flash_attention(
key_states,
value_states,
attn_mask=attention_mask,
is_causal=attention_mask is None,
is_causal=attention_mask is None and query_states.shape[1] != 1,
)
attn_weights = None

Expand Down
70 changes: 65 additions & 5 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,43 @@
return scale_cos_sin


class Llama3RotaryEmbedding(LlamaRotaryEmbedding):
def __init__(
self,
dim,
max_position_embeddings=8192,
base=500000,
factor=8.0,
low_freq_factor=1.0,
high_freq_factor=4.0,
original_max_position_embeddings=8192,
):
self.factor = factor
self.low_freq_factor = low_freq_factor
self.high_freq_factor = high_freq_factor
self.original_max_position_embeddings = original_max_position_embeddings
super().__init__(dim, max_position_embeddings, base)

Check warning on line 534 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L530-L534

Added lines #L530 - L534 were not covered by tests

def _set_cos_sin_cache(self, seq_len):
low_freq_wavelen = self.original_max_position_embeddings / self.low_freq_factor
high_freq_wavelen = self.original_max_position_embeddings / self.high_freq_factor
new_freqs = []
for freq in self.inv_freq:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
new_freqs.append(freq)
elif wavelen > low_freq_wavelen:
new_freqs.append(freq / self.factor)

Check warning on line 545 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L537-L545

Added lines #L537 - L545 were not covered by tests
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (self.original_max_position_embeddings / wavelen - self.low_freq_factor) / (

Check warning on line 548 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L547-L548

Added lines #L547 - L548 were not covered by tests
self.high_freq_factor - self.low_freq_factor
)
new_freqs.append((1 - smooth) * freq / self.factor + smooth * freq)
self.inv_freq = paddle.to_tensor(new_freqs, dtype=self.inv_freq.dtype)
super()._set_cos_sin_cache(seq_len=seq_len)

Check warning on line 553 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L551-L553

Added lines #L551 - L553 were not covered by tests


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
Expand Down Expand Up @@ -779,7 +816,21 @@
self.config = config

def _init_rope(self):
if self.config.rope_scaling_type is None:
if (
hasattr(self.config, "rope_scaling")
and self.config.rope_scaling is not None
and self.config.rope_scaling.get("rope_type", None) == "llama3"
):
self.rotary_emb = Llama3RotaryEmbedding(

Check warning on line 824 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L824

Added line #L824 was not covered by tests
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rope_theta,
factor=self.config.rope_scaling["factor"],
high_freq_factor=self.config.rope_scaling["high_freq_factor"],
low_freq_factor=self.config.rope_scaling["low_freq_factor"],
original_max_position_embeddings=self.config.rope_scaling["original_max_position_embeddings"],
)
elif self.config.rope_scaling_type is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
Expand Down Expand Up @@ -987,6 +1038,9 @@
# reuse k, v, self_attention
key_states = paddle.concat([past_key_value[0], key_states], axis=1)
value_states = paddle.concat([past_key_value[1], value_states], axis=1)
if self.config.immediate_clear_past_key_value:
past_key_value[0]._clear_data()
past_key_value[1]._clear_data()

Check warning on line 1043 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1042-L1043

Added lines #L1042 - L1043 were not covered by tests

past_key_value = (key_states, value_states) if use_cache else None
if self.kv_indices is not None:
Expand Down Expand Up @@ -1547,8 +1601,11 @@

if self.config.context_parallel_degree > 1 and (attention_mask is not None or self.config.alibi):
raise NotImplementedError("Ring FlashAttention dosen't support attention_mask or alibi")

# embed positions
if attn_mask_startend_row_indices is None and attention_mask is None:
if self.config.use_flash_attention_for_generation:
attention_mask = None

Check warning on line 1607 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1607

Added line #L1607 was not covered by tests
elif attn_mask_startend_row_indices is None and attention_mask is None:
# [bs, seq_len]
attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool)
if attn_mask_startend_row_indices is None and self.config.alibi:
Expand Down Expand Up @@ -1580,7 +1637,7 @@

use_casual_mask = get_use_casual_mask() and not self.config.alibi

if use_casual_mask:
if self.config.use_flash_attention_for_generation or use_casual_mask:
attention_mask = None
elif attn_mask_startend_row_indices is None:
attention_mask = self._prepare_decoder_attention_mask(
Expand All @@ -1590,7 +1647,7 @@
is_casual = False

if attn_mask_startend_row_indices is None and self.config.use_flash_attention and get_env_device() != "gcu":
if use_casual_mask:
if self.config.use_flash_attention_for_generation or use_casual_mask:

Check warning on line 1650 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1650

Added line #L1650 was not covered by tests
is_casual = True
else:
is_casual = is_casual_mask(attention_mask)
Expand Down Expand Up @@ -1654,6 +1711,9 @@
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

if self.config.use_last_token_for_generation:
hidden_states = paddle.unsqueeze(hidden_states[:, -1, :], 1)

Check warning on line 1715 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1715

Added line #L1715 was not covered by tests

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
Expand Down Expand Up @@ -1879,7 +1939,7 @@
position_ids = model_kwargs["position_ids"]
model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1)

if not is_encoder_decoder and "attention_mask" in model_kwargs:
if not is_encoder_decoder and "attention_mask" in model_kwargs and model_kwargs["attention_mask"] is not None:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = paddle.concat(
[attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype)], axis=-1
Expand Down
Loading