From f8fad9fc440d4634c412d7bb1a3bea66a7ce58f5 Mon Sep 17 00:00:00 2001 From: John Giorgi Date: Mon, 19 Jul 2021 17:39:58 -0400 Subject: [PATCH] Provide vocab as param to constraints (#5321) * Provide vocab as param to constraints * Update changelog --- CHANGELOG.md | 1 + allennlp/nn/beam_search.py | 25 ++++++++++++++++++++----- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ef6ce00bfbd..bb8d28ba95f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `TransformerModule._post_load_pretrained_state_dict_hook()` method. Can be used to modify `missing_keys` and `unexpected_keys` after loading a pretrained state dictionary. This is useful when tying weights, for example. - Added an end-to-end test for the Transformer Toolkit. +- Added `vocab` argument to `BeamSearch`, which is passed to each contraint in `constraints` (if provided). ### Fixed diff --git a/allennlp/nn/beam_search.py b/allennlp/nn/beam_search.py index 3d0d3ae38b3..26d129f5708 100644 --- a/allennlp/nn/beam_search.py +++ b/allennlp/nn/beam_search.py @@ -6,8 +6,9 @@ from overrides import overrides import torch -from allennlp.common import Registrable +from allennlp.common import Lazy, Registrable from allennlp.common.checks import ConfigurationError +from allennlp.data import Vocabulary from allennlp.nn.util import min_value_of_dtype @@ -568,6 +569,9 @@ class Constraint(Registrable): """ + def __init__(self, vocab: Optional[Vocabulary] = None) -> None: + self.vocab = vocab + def init_state( self, batch_size: int, @@ -625,8 +629,8 @@ def _update_state( @Constraint.register("repeated-ngram-blocking") class RepeatedNGramBlockingConstraint(Constraint): - def __init__(self, ngram_size: int) -> None: - super().__init__() + def __init__(self, ngram_size: int, **kwargs) -> None: + super().__init__(**kwargs) self.ngram_size = ngram_size @overrides @@ -729,6 +733,15 @@ class BeamSearch(Registrable): constraints: `List[Constraint]`, optional (default = `None`) An optional list of `Constraint`s which should be applied during beam search. If not provided, no constraints will be enforced. + + vocab: `Vocabulary` + If `constraints` is not `None`, then `Vocabulary` will be passed to each constraint + during its initialization. Having access to the vocabulary may be useful for certain + contraints, e.g., to mask out invalid predictions during structured prediction. + + In a typical AllenNLP configuration file, this parameter does not get an entry under the + "model", it gets specified as a top-level parameter, then is passed in to the model + separately. """ default_implementation = "beam_search" @@ -742,7 +755,8 @@ def __init__( sampler: Sampler = None, min_steps: Optional[int] = None, final_sequence_scorer: FinalSequenceScorer = None, - constraints: Optional[List[Constraint]] = None, + constraints: Optional[List[Lazy[Constraint]]] = None, + vocab: Optional[Vocabulary] = None, ) -> None: if not max_steps > 0: raise ValueError("max_steps must be positive") @@ -763,7 +777,8 @@ def __init__( self.sampler = sampler or DeterministicSampler() self.min_steps = min_steps or 0 self.final_sequence_scorer = final_sequence_scorer or SequenceLogProbabilityScorer() - self.constraints = constraints or [] + # Lazily build the constrains with the vocab (if provided). + self.constraints = [constraint.construct(vocab=vocab) for constraint in constraints or []] @staticmethod def _reconstruct_sequences(predictions, backpointers):