Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR] Resolve #106 Masking system #111

Merged
merged 34 commits into from
Jan 17, 2024
Merged

[PR] Resolve #106 Masking system #111

merged 34 commits into from
Jan 17, 2024

Conversation

soran-ghaderi
Copy link
Member

This PR integrates the new masking system and now supports lookahead and padding mask using this new method.

It also resolves tests related to these two masking classes.

There remains attention masks such as dilated and other attention masks which will be resolved separately.

Reversioned and updated to 0.0.1

soran-ghaderi and others added 30 commits May 31, 2023 00:51
Global Attention mask
Get q_len and k_len separately
Add support for mixed precision
Implement new PaddingMask class
1. Explicit Query and Key Lengths: Instead of relying on the input dimensions, it provides the option to pass the query and key lengths explicitly as arguments (query_len and key_len). This improves flexibility and removes the need for conditional handling of input dimensions.

Advanced Mask Handling:
- The mask value is multiplied by a scalar constant (mask_value) to create the masking effect. This scalar can be customized to control the strength of the mask.
- The code now supports two types of masks: tf.Tensor and tf.SparseTensor. It handles each type separately to ensure correct masking.
- If the mask is a tf.Tensor, it casts it to the same dtype as the inputs and apply element-wise multiplication with the mask_value.
- If the mask is a tf.SparseTensor, it uses tf.sparse.TensorSparseValue to create a sparse tensor with masked values. This is useful when dealing with large sparse tensors efficiently.

3. Error Handling:
- Added error handling to ensure that required input arguments (query_len for 3D inputs) are provided.
- Added an error for unsupported mask types to enforce type safety.
It is implemented and now supports creating a mask from the inputs based on the values equal to the padding_value
Other minor modifications to the BaseMask and PaddingMask
this needs further testing and modifications
After releasing the first stable version this features can be added
It should be added in the next versions
Move generic.py to sequence masks
Move core.py to masks package
Separate test package for the generic.py
It passes tests with different inputs including padding_mask, valid_lens, and scores with padding_value
The scores are 2d (seq_len, seq_len) and their shapes do not change for different heads
Now it uses the new masking API
the multihead for the lookahead works fine but the multihead is not affected yet.
…ntput

This needs to support valid_lens as well
…ved the padding and lookahead masks to their respective files under the main masks package.

This helps the api cleaner and simpler
All tests are passing
@soran-ghaderi soran-ghaderi added enhancement New feature or request tests Related to tests tensorflow Related to Tensorflow labels Jan 15, 2024
@soran-ghaderi
Copy link
Member Author

resolve #106

@soran-ghaderi soran-ghaderi self-assigned this Jan 15, 2024
@soran-ghaderi soran-ghaderi changed the title Resolve #106 Masking system [PR] Resolve #106 Masking system Jan 15, 2024
@soran-ghaderi soran-ghaderi requested a review from Oreanu January 15, 2024 21:27
Copy link
Member Author

@soran-ghaderi soran-ghaderi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be merged.

@soran-ghaderi soran-ghaderi removed the request for review from Oreanu January 17, 2024 16:26
@soran-ghaderi soran-ghaderi merged commit 80343bb into master Jan 17, 2024
3 checks passed
@soran-ghaderi soran-ghaderi deleted the atomic_masks branch January 17, 2024 16:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request tensorflow Related to Tensorflow tests Related to tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant