Skip to content

Commit

Permalink
[llm]fix bloom tensor parallelism (#7065)
Browse files Browse the repository at this point in the history
* fix bloom tp

* add comment
  • Loading branch information
lugimzzz authored Sep 18, 2023
1 parent 3bd4bc3 commit da02add
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions paddlenlp/transformers/bloom/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def get_input_embeddings(self):
return self.word_embeddings

def _prepare_attn_mask(
self, attention_mask: Tensor, input_shape: Tuple[int, int], past_key_values_length: int, num_heads: int, dtype
self, attention_mask: Tensor, input_shape: Tuple[int, int], past_key_values_length: int, num_heads: int
) -> Tensor:
# create causal mask
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
Expand All @@ -819,8 +819,9 @@ def _prepare_attn_mask(

mask_shape = expanded_attn_mask.shape
expanded_attn_mask = expanded_attn_mask.expand([mask_shape[0], num_heads, mask_shape[2], mask_shape[3]])
zero = paddle.zeros(expanded_attn_mask.shape, dtype=dtype)
neg_inf = paddle.full(expanded_attn_mask.shape, paddle.finfo(dtype).min, dtype=dtype)
# Attention score will be cast to float32 in the following calculation, therefore we set attention_mask dtype as float32
zero = paddle.zeros(expanded_attn_mask.shape, dtype=paddle.float32)
neg_inf = paddle.full(expanded_attn_mask.shape, paddle.finfo(paddle.float32).min, dtype=paddle.float32)
expanded_attn_mask = paddle.where(expanded_attn_mask, zero, neg_inf)
batch_size, num_heads, sq_len, kv_len = expanded_attn_mask.shape
return expanded_attn_mask.reshape([batch_size * num_heads, sq_len, kv_len])
Expand Down Expand Up @@ -929,7 +930,6 @@ def forward(
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
num_heads=block_size,
dtype=hidden_states.dtype,
)
else:
alibi = alibi.reshape([batch_size * self.config.n_head, 1, seq_length_with_past])
Expand All @@ -938,7 +938,6 @@ def forward(
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
num_heads=self.config.n_head,
dtype=hidden_states.dtype,
)

for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
Expand Down Expand Up @@ -1088,7 +1087,7 @@ def __init__(self, config):
self.lm_head = BloomLMHead(config, self.bloom.word_embeddings.weight)
self.criterion = BloomPretrainingCriterion(
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_output=True,
tensor_parallel_output=config.tensor_parallel_output,
)

def get_output_embeddings(self):
Expand Down

0 comments on commit da02add

Please sign in to comment.