Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Remove separate start type prediction in state machines (#3030)
Browse files Browse the repository at this point in the history
* language change and no separate first action prediction

* no separate start state prediction for erm

* added metadata back

* remove separate start type prediction

* update remaining semantic parsers

* retrain fixtures and update predictor tests
  • Loading branch information
pdasigi authored Jul 3, 2019
1 parent c2c4b64 commit 9e52e0f
Show file tree
Hide file tree
Showing 22 changed files with 34 additions and 162 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import logging
from typing import Dict, List
from typing import Dict, List, Any
import os
import gzip
import tarfile
Expand Down Expand Up @@ -208,7 +208,7 @@ def text_to_instance(self, # type: ignore
# pylint: disable=arguments-differ
tokenized_question = self._tokenizer.tokenize(question.lower())
question_field = TextField(tokenized_question, self._question_token_indexers)
# TODO(pradeep): We'll need a better way to input CoreNLP processed lines.
metadata: Dict[str, Any] = {"question_tokens": [x.text for x in tokenized_question]}
table_context = TableQuestionContext.read_from_lines(table_lines, tokenized_question)
target_values_field = MetadataField(target_values)
world = WikiTablesLanguage(table_context)
Expand All @@ -230,6 +230,7 @@ def text_to_instance(self, # type: ignore
action_field = ListField(production_rule_fields)

fields = {'question': question_field,
'metadata': MetadataField(metadata),
'table': table_field,
'world': world_field,
'actions': action_field,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def __init__(self,
self._transition_function = LinkingTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
action_embedding_dim=action_embedding_dim,
input_attention=input_attention,
predict_start_type_separately=False,
add_action_bias=self._add_action_bias,
dropout=dropout,
num_layers=self._decoder_num_layers)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,7 @@ def __init__(self,
self._decoder_step = CoverageTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
action_embedding_dim=action_embedding_dim,
input_attention=attention,
num_start_types=1,
activation=Activation.by_name('tanh')(),
predict_start_type_separately=False,
add_action_bias=False,
dropout=dropout)
self._checklist_cost_weight = checklist_cost_weight
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def __init__(self,
self._decoder_step = BasicTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
action_embedding_dim=action_embedding_dim,
input_attention=attention,
num_start_types=1,
activation=Activation.by_name('tanh')(),
predict_start_type_separately=False,
add_action_bias=False,
dropout=dropout)
self._decoder_beam_search = decoder_beam_search
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def __init__(self,
# Note: there's only one non-trivial entity type in QuaRel for now, so most of the
# entity_type stuff is irrelevant
self._num_entity_types = 4 # TODO(mattg): get this in a more principled way somehow?
self._num_start_types = 1 # Hardcoded until we feed lf syntax into the model
self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim)
self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)

Expand Down Expand Up @@ -171,8 +170,6 @@ def __init__(self,
self._decoder_step = LinkingTransitionFunction(encoder_output_dim=self._encoder_output_dim,
action_embedding_dim=action_embedding_dim,
input_attention=attention,
num_start_types=self._num_start_types,
predict_start_type_separately=False,
add_action_bias=self._add_action_bias,
mixture_feedforward=mixture_feedforward,
dropout=dropout)
Expand Down
1 change: 0 additions & 1 deletion allennlp/models/semantic_parsing/text2sql_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def __init__(self,
self._transition_function = BasicTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
action_embedding_dim=action_embedding_dim,
input_attention=input_attention,
predict_start_type_separately=False,
add_action_bias=self._add_action_bias,
dropout=dropout)
initializer(self)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ def __init__(self,
self._decoder_step = LinkingCoverageTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
action_embedding_dim=action_embedding_dim,
input_attention=attention,
num_start_types=self._num_start_types,
predict_start_type_separately=True,
add_action_bias=self._add_action_bias,
mixture_feedforward=mixture_feedforward,
dropout=dropout)
Expand Down Expand Up @@ -193,7 +191,8 @@ def forward(self, # type: ignore
world: List[WikiTablesLanguage],
actions: List[List[ProductionRule]],
agenda: torch.LongTensor,
target_values: List[List[str]] = None) -> Dict[str, torch.Tensor]:
target_values: List[List[str]] = None,
metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
Parameters
Expand All @@ -220,6 +219,8 @@ def forward(self, # type: ignore
target_values : ``List[List[str]]``, optional (default = None)
For each instance, a list of target values taken from the example lisp string. We pass
this list to the evaluator along with logical forms to compute denotation accuracy.
metadata : ``List[Dict[str, Any]]``, optional (default = None)
Metadata containing the original tokenized question within a 'question_tokens' field.
"""
batch_size = list(question.values())[0].size(0)
# Each instance's agenda is of size (agenda_size, 1)
Expand Down Expand Up @@ -299,7 +300,6 @@ def forward(self, # type: ignore
in_agenda_ratio = sum(actions_in_agenda) / len(actions_in_agenda)
self._agenda_coverage(in_agenda_ratio)

metadata = None
self._compute_validation_outputs(actions,
best_final_states,
world,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,6 @@ def __init__(self,
self._decoder_step = LinkingTransitionFunction(encoder_output_dim=self._encoder.get_output_dim(),
action_embedding_dim=action_embedding_dim,
input_attention=attention,
num_start_types=self._num_start_types,
predict_start_type_separately=True,
add_action_bias=self._add_action_bias,
mixture_feedforward=mixture_feedforward,
dropout=dropout)
Expand All @@ -123,7 +121,8 @@ def forward(self, # type: ignore
world: List[WikiTablesLanguage],
actions: List[List[ProductionRuleArray]],
target_values: List[List[str]] = None,
target_action_sequences: torch.LongTensor = None) -> Dict[str, torch.Tensor]:
target_action_sequences: torch.LongTensor = None,
metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
# pylint: disable=arguments-differ
"""
In this method we encode the table entities, link them to words in the question, then
Expand Down Expand Up @@ -156,6 +155,8 @@ def forward(self, # type: ignore
A list of possibly valid action sequences, where each action is an index into the list
of possible actions. This tensor has shape ``(batch_size, num_action_sequences,
sequence_length)``.
metadata : ``List[Dict[str, Any]]``, optional (default = None)
Metadata containing the original tokenized question within a 'question_tokens' field.
"""
outputs: Dict[str, Any] = {}
rnn_state, grammar_state = self._get_initial_rnn_and_grammar_state(question,
Expand Down Expand Up @@ -212,7 +213,6 @@ def forward(self, # type: ignore
sequence_in_targets = self._action_history_match(best_action_indices, targets)
self._action_sequence_accuracy(sequence_in_targets)

metadata = None
self._compute_validation_outputs(actions,
best_final_states,
world,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def __init__(self,
"entity word average embedding dim", "question embedding dim")

self._num_entity_types = 5 # TODO(mattg): get this in a more principled way somehow?
self._num_start_types = 3 # TODO(mattg): get this in a more principled way somehow?
self._embedding_dim = question_embedder.get_output_dim()
self._entity_type_encoder_embedding = Embedding(self._num_entity_types, self._embedding_dim)
self._entity_type_decoder_embedding = Embedding(self._num_entity_types, action_embedding_dim)
Expand Down Expand Up @@ -685,7 +684,6 @@ def _compute_validation_outputs(self,

if metadata is not None:
outputs["question_tokens"] = [x["question_tokens"] for x in metadata]
outputs["original_table"] = [x["original_table"] for x in metadata]

@overrides
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
Expand Down
29 changes: 18 additions & 11 deletions allennlp/semparse/domain_languages/wikitables_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Unfortunately, mypy doesn't like this very much, so we have to "type: ignore" a bunch of things.
# But it makes for a nicer induced grammar, so it's worth it.
from numbers import Number
from typing import Dict, List, NamedTuple, Set, Tuple, Any
from typing import Dict, List, NamedTuple, Set, Type, Tuple, Any
import logging
import re

Expand Down Expand Up @@ -54,13 +54,7 @@ class WikiTablesLanguage(DomainLanguage):
the language using ``add_predicate`` if, e.g., there is a column with dates in it.
"""
def __init__(self, table_context: TableQuestionContext) -> None:
# TODO (pradeep): We do not want the start types to be a static set. We want it to depend on the table
# context instead, and include start types only when columns of a given type are present in the table.
# Currently, allowing all start types lets the parser produce action sequences like ["@start -> Date"],
# with just one action, when there are no dates in the table. These will obviously just be parsing errors.
# However, changing this set to contain just the required start types is not starightforward because that
# messes the start action prediction in the parser.
super().__init__(start_types={Date, Number, List[str]})
super().__init__(start_types=self._get_start_types_in_context(table_context))
self.table_context = table_context
self.table_data = [Row(row) for row in table_context.table_data]

Expand Down Expand Up @@ -145,6 +139,16 @@ def __init__(self, table_context: TableQuestionContext) -> None:
for name, types in self._function_types.items():
self.terminal_productions[name] = "%s -> %s" % (types[0], name)

def _get_start_types_in_context(self, table_context: TableQuestionContext) -> Set[Type]:
start_types: Set[Type] = set()
if "string" in table_context.column_types:
start_types.add(List[str])
if "number" in table_context.column_types or "num2" in table_context.column_types:
start_types.add(Number)
if "date" in table_context.column_types:
start_types.add(Date)
return start_types

def get_agenda(self,
conservative: bool = False):
"""
Expand Down Expand Up @@ -346,8 +350,8 @@ def evaluate_logical_form(self, logical_form: str, target_list: List[str]) -> bo
"""
try:
denotation = self.execute(logical_form)
except ExecutionError:
logger.warning(f'Failed to execute: {logical_form}')
except ExecutionError as error:
logger.warning(f'Failed to execute: {logical_form}. Error: {error}')
return False
return self.evaluate_denotation(denotation, target_list)

Expand Down Expand Up @@ -434,7 +438,10 @@ def same_as(self, rows: List[Row], column: Column) -> List[Row]:
return return_list
cell_value = rows[0].values[column.name]
for table_row in self.table_data:
if table_row.values[column.name] == cell_value:
new_cell_value = table_row.values[column.name]
if new_cell_value is None or not isinstance(new_cell_value, type(cell_value)):
continue
if new_cell_value == cell_value:
return_list.append(table_row)
return return_list

Expand Down
Loading

0 comments on commit 9e52e0f

Please sign in to comment.