Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Add "sub_module" argument in PretrainedTransformerMismatchedEmbedder (#…
Browse files Browse the repository at this point in the history
…5580)

* Add "submodule" argument in PretrainedTransformerMismatchedEmbedder

* Update CHANGELOG.md
  • Loading branch information
sythello authored Feb 28, 2022
1 parent 92e54cc commit 9f03803
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- We can now transparently read compressed input files during prediction.
- LZMA compression is now supported.
- Added the argument `sub_module` in `PretrainedTransformerMismatchedEmbedder`


## [v2.9.0](/~https://github.com/allenai/allennlp/releases/tag/v2.9.0) - 2022-01-27
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ class PretrainedTransformerMismatchedEmbedder(TokenEmbedder):
through the transformer model independently, and concatenate the final representations.
Should be set to the same value as the `max_length` option on the
`PretrainedTransformerMismatchedIndexer`.
sub_module: `str`, optional (default = `None`)
The name of a submodule of the transformer to be used as the embedder. Some transformers naturally act
as embedders such as BERT. However, other models consist of encoder and decoder, in which case we just
want to use the encoder.
train_parameters: `bool`, optional (default = `True`)
If this is `True`, the transformer weights get updated during training.
last_layer_only: `bool`, optional (default = `True`)
Expand Down Expand Up @@ -65,6 +69,7 @@ def __init__(
self,
model_name: str,
max_length: int = None,
sub_module: str = None,
train_parameters: bool = True,
last_layer_only: bool = True,
override_weights_file: Optional[str] = None,
Expand All @@ -80,6 +85,7 @@ def __init__(
self._matched_embedder = PretrainedTransformerEmbedder(
model_name,
max_length=max_length,
sub_module=sub_module,
train_parameters=train_parameters,
last_layer_only=last_layer_only,
override_weights_file=override_weights_file,
Expand Down

0 comments on commit 9f03803

Please sign in to comment.