Skip to content

Commit

Permalink
Patch moonshine (huggingface#35731)
Browse files Browse the repository at this point in the history
* udpate expected logits for T4 runners

* update doc

* correct order of the args for better readability

* remove generate wrap

* convert modular
  • Loading branch information
eustlb authored and elvircrn committed Feb 13, 2025
1 parent f17a8d0 commit 33cbbab
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 74 deletions.
4 changes: 2 additions & 2 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,6 @@
title: MobileBERT
- local: model_doc/modernbert
title: ModernBert
- local: model_doc/moonshine
title: moonshine
- local: model_doc/mpnet
title: MPNet
- local: model_doc/mpt
Expand Down Expand Up @@ -774,6 +772,8 @@
title: Mimi
- local: model_doc/mms
title: MMS
- local: model_doc/moonshine
title: Moonshine
- local: model_doc/moshi
title: Moshi
- local: model_doc/musicgen
Expand Down
41 changes: 12 additions & 29 deletions src/transformers/models/moonshine/modeling_moonshine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,10 +1169,9 @@ def compute_num_masked_span(input_length):
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
attention_mask (`torch.Tensor`)`, *optional*):
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
but it is not used.
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Expand All @@ -1181,9 +1180,6 @@ def compute_num_masked_span(input_length):
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor`)`, *optional*):
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
but it is not used.
decoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
Expand All @@ -1204,11 +1200,10 @@ def compute_num_masked_span(input_length):
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
Expand All @@ -1231,6 +1226,11 @@ def compute_num_masked_span(input_length):
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Expand Down Expand Up @@ -1552,22 +1552,5 @@ def forward(
encoder_attentions=outputs.encoder_attentions,
)

def generate(self, *args, **kwargs):
# TODO: @eustlb do it rather with a custom logits processor
token_limit_factor = 6.5 / 16000.0 # Maximum of 6.5 tokens per second
if kwargs.get("max_new_tokens") is None and kwargs.get("max_length") is None:
if kwargs.get("attention_mask") is not None:
seq_lens = kwargs["attention_mask"].sum(dim=-1)
else:
seq_lens = kwargs["input_values"].shape[-1]
max_length = int(seq_lens.max().item() * token_limit_factor)
logger.warning_once(
f"Based on the input length, Moonshine will generate up to {max_length} tokens (ratio of 6.5 tokens/second). "
"To specify a different length, set either `max_new_tokens` or `max_length`."
)
kwargs["max_length"] = max_length

return super().generate(*args, **kwargs)


__all__ = ["MoonshineModel", "MoonshinePreTrainedModel", "MoonshineForConditionalGeneration"]
41 changes: 12 additions & 29 deletions src/transformers/models/moonshine/modular_moonshine.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,10 +816,9 @@ def forward(
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_values`, the [`AutoFeatureExtractor`] should be used for padding
and conversion into a tensor of type `torch.FloatTensor`.
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
attention_mask (`torch.Tensor`)`, *optional*):
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
but it is not used.
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Expand All @@ -828,9 +827,6 @@ def forward(
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor`)`, *optional*):
Moonshine does not support masking of the `input_values`, this argument is preserved for compatibility,
but it is not used.
decoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
Expand All @@ -851,11 +847,10 @@ def forward(
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
Expand All @@ -878,6 +873,11 @@ def forward(
Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
decoder_position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
Expand Down Expand Up @@ -1109,23 +1109,6 @@ def forward(
encoder_attentions=outputs.encoder_attentions,
)

def generate(self, *args, **kwargs):
# TODO: @eustlb do it rather with a custom logits processor
token_limit_factor = 6.5 / 16000.0 # Maximum of 6.5 tokens per second
if kwargs.get("max_new_tokens") is None and kwargs.get("max_length") is None:
if kwargs.get("attention_mask") is not None:
seq_lens = kwargs["attention_mask"].sum(dim=-1)
else:
seq_lens = kwargs["input_values"].shape[-1]
max_length = int(seq_lens.max().item() * token_limit_factor)
logger.warning_once(
f"Based on the input length, Moonshine will generate up to {max_length} tokens (ratio of 6.5 tokens/second). "
"To specify a different length, set either `max_new_tokens` or `max_length`."
)
kwargs["max_length"] = max_length

return super().generate(*args, **kwargs)


__all__ = [
"MoonshineConfig",
Expand Down
28 changes: 14 additions & 14 deletions tests/models/moonshine/test_modeling_moonshine.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,9 @@ def test_tiny_logits_single(self):

# fmt: off
EXPECTED_LOGITS = torch.tensor([
-9.1107, 4.5538, 6.3902, -6.8141, -7.2459, -7.9077, -7.2842, -7.6045, -8.0387, -7.8354,
-7.3870, -7.2453, -7.7423, -7.3914, -7.3869, -7.6982, -7.6422, -7.0507, -7.3982, -7.2486,
-8.0799, -7.3303, -7.3675, -6.8769, -7.6879, -7.2684, -6.9868, -6.7459, -7.6858, -7.3052,
-9.1106, 4.5542, 6.3892, -6.8139, -7.2456, -7.9074, -7.2839, -7.6043, -8.0384, -7.8351,
-7.3867, -7.2450, -7.7420, -7.3912, -7.3866, -7.6979, -7.6420, -7.0504, -7.3979, -7.2483,
-8.0796, -7.3300, -7.3672, -6.8765, -7.6876, -7.2682, -6.9866, -6.7457, -7.6855, -7.3050,
])
# fmt: on
self.assertTrue(torch.allclose(outputs.logits[0][0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
Expand All @@ -502,9 +502,9 @@ def test_base_logits_single(self):

# fmt: off
EXPECTED_LOGITS = torch.tensor([
-6.7340, 1.9483, 5.2449, -8.0277, -7.9167, -7.8956, -7.9649, -7.9348, -8.1312, -8.0616,
-8.1070, -7.7696, -7.8809, -7.9451, -8.1013, -7.8177, -7.8598, -7.8257, -7.8729, -7.9657,
-7.9310, -8.1024, -7.8698, -7.8231, -8.0752, -7.9764, -7.8127, -8.0536, -7.9492, -7.9289,
-6.7336, 1.9482, 5.2448, -8.0277, -7.9167, -7.8956, -7.9649, -7.9348, -8.1312, -8.0616,
-8.1070, -7.7696, -7.8809, -7.9450, -8.1013, -7.8177, -7.8598, -7.8257, -7.8729, -7.9657,
-7.9310, -8.1024, -7.8699, -7.8231, -8.0752, -7.9764, -7.8127, -8.0536, -7.9492, -7.9290,
])
# fmt: on
self.assertTrue(torch.allclose(outputs.logits[0][0, :30].cpu(), EXPECTED_LOGITS, atol=1e-4))
Expand All @@ -519,10 +519,10 @@ def test_tiny_logits_batch(self):
outputs = model.generate(**inputs, max_new_tokens=1, return_dict_in_generate=True, output_logits=True)
# fmt: off
EXPECTED_LOGITS = torch.tensor([
[-8.0098, 5.0239, 4.5986, -6.8125, -7.1676, -7.8782, -7.2152, -7.5188, -7.9078, -7.7394],
[-4.4394, -1.4429, 6.6715, -6.8927, -7.3748, -7.0967, -6.5255, -7.0255, -7.2583, -7.0007],
[-10.0088, 3.2862, 0.7342, -6.5558, -6.8514, -6.5309, -6.4173, -6.9485, -6.6215, -6.6230],
[-10.8083, 4.0034, -0.0635, -5.0501, -5.3903, -5.4587, -5.2416, -5.4742, -5.2662, -5.3154]
[-8.0109, 5.0241, 4.5979, -6.8125, -7.1675, -7.8783, -7.2152, -7.5188, -7.9077, -7.7394],
[-4.4399, -1.4422, 6.6710, -6.8929, -7.3751, -7.0969, -6.5257, -7.0257, -7.2585, -7.0008],
[-10.0086, 3.2859, 0.7345, -6.5557, -6.8514, -6.5308, -6.4172, -6.9484, -6.6214, -6.6229],
[-10.8078, 4.0030, -0.0633, -5.0505, -5.3906, -5.4590, -5.2420, -5.4746, -5.2665, -5.3158]
])
# fmt: on
self.assertTrue(torch.allclose(outputs.logits[0][:, :10].cpu(), EXPECTED_LOGITS, atol=1e-4))
Expand All @@ -538,10 +538,10 @@ def test_base_logits_batch(self):

# fmt: off
EXPECTED_LOGITS = torch.tensor([
[-7.7288, 1.4636, 5.2273, -7.7310, -7.6249, -7.6009, -7.6786, -7.6438, -7.8450, -7.7546],
[-6.2161, -0.5891, 7.9489, -7.0693, -6.9996, -6.9980, -7.0952, -7.0830, -7.1685, -7.0136],
[-7.3186, 3.1192, 3.8938, -5.7208, -5.8429, -5.7610, -5.9997, -5.8213, -5.8616, -5.8720],
[-9.5488, 1.0147, 4.1174, -5.9972, -6.0616, -6.0331, -6.2105, -6.0320, -6.0791, -6.0875]
[-7.7272, 1.4630, 5.2294, -7.7313, -7.6252, -7.6011, -7.6788, -7.6441, -7.8452, -7.7549],
[-6.2173, -0.5891, 7.9493, -7.0694, -6.9997, -6.9982, -7.0953, -7.0831, -7.1686, -7.0137],
[-7.3184, 3.1192, 3.8937, -5.7206, -5.8428, -5.7609, -5.9996, -5.8212, -5.8615, -5.8719],
[-9.5475, 1.0146, 4.1179, -5.9971, -6.0614, -6.0329, -6.2103, -6.0318, -6.0789, -6.0873]
])

# fmt: on
Expand Down

0 comments on commit 33cbbab

Please sign in to comment.