Flatten DefaultClassifier interface #2978
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR attempts to flatten the
DefaultClassifier
interface. We remove the distinction between tensor and non-tensor operations, temporarily removing JIT support until we find a more comprehensive solution.Changes:
The
DefaultClassifier
now has anembeddings
argument in the init method, since all current implementations use embeddings. This removes the need for the_inner_embeddings
property which no longer needs to be implemented in the subclasses.All implementing classes now have their embeddings argument called
embeddings
(rather thanword_embeddings
/document_embeddings
etc.). Hopefully this introduces a bit more consistency across model classes.Previously, the
get_prediction_data_points
was called twice in the forward pass: once inforward_loss
, then again in_prepare_tensors
. It is now only called once, directly at the beginning of theforward_loss
.Renames
_embed_prediction_data_point
to_get_embedding_for_data_point
(in most cases, the embedding already exists)Changes
_get_prediction_data_points
to_get_data_points_from_sentence
, changing the scope to extracting points for a sinceSentence
rather than the full batchAdditional changes:
DataPoint
now has__len__