Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Modified ActionSpaceWalker to use DomainLanguage (#3006)
Browse files Browse the repository at this point in the history
* moved action space walker and lf search script from iterative search

* action space walker uses DomainLanguage

* modified script for searching nlvr logical forms

* addressed PR comments

* mypy fix
  • Loading branch information
pdasigi authored Jun 25, 2019
1 parent 9a13ab5 commit 0a26739
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 120 deletions.
100 changes: 65 additions & 35 deletions allennlp/semparse/action_space_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import logging

from allennlp.common.util import START_SYMBOL
from allennlp.semparse.worlds.world import World
from allennlp.semparse.type_declarations import type_declaration as types
from allennlp.semparse.domain_languages.domain_language import DomainLanguage


logger = logging.getLogger(__name__) # pylint: disable=invalid-name
Expand All @@ -19,13 +18,13 @@ class also has some utilities for indexing logical forms to efficiently retrieve
Parameters
----------
world : ``World``
The world from which valid actions will be taken.
world : ``DomainLanguage``
The world (domain language instantiation) from which valid actions will be taken.
max_path_length : ``int``
The maximum path length till which the action space will be explored. Paths longer than this
length will be discarded.
"""
def __init__(self, world: World, max_path_length: int) -> None:
def __init__(self, world: DomainLanguage, max_path_length: int) -> None:
self._world = world
self._max_path_length = max_path_length
self._completed_paths: List[List[str]] = None
Expand All @@ -36,14 +35,12 @@ def _walk(self) -> None:
"""
Walk over action space to collect completed paths of at most ``self._max_path_length`` steps.
"""
actions = self._world.get_nonterminal_productions()
start_productions = actions[START_SYMBOL]
# Buffer of NTs to expand, previous actions
incomplete_paths = [([str(type_)], [f"{START_SYMBOL} -> {type_}"]) for type_ in
self._world.get_valid_starting_types()]

incomplete_paths = [([start_production.split(' -> ')[-1]], [start_production])
for start_production in start_productions]
self._completed_paths = []
actions = self._world.get_valid_actions()
# Keeps track of `MultiMatchNamedBasicTypes` to substitute them with appropriate types.
multi_match_substitutions = self._world.get_multi_match_mapping()
# Overview: We keep track of the buffer of non-terminals to expand, and the action history
# for each incomplete path. At every iteration in the while loop below, we iterate over all
# incomplete paths, expand one non-terminal from the buffer in a depth-first fashion, get
Expand All @@ -60,11 +57,7 @@ def _walk(self) -> None:
# Taking the last non-terminal added to the buffer. We're going depth-first.
nonterminal = nonterminal_buffer.pop()
next_actions = []
if nonterminal in multi_match_substitutions:
for current_nonterminal in [nonterminal] + multi_match_substitutions[nonterminal]:
if current_nonterminal in actions:
next_actions.extend(actions[current_nonterminal])
elif nonterminal not in actions:
if nonterminal not in actions:
# This happens when the nonterminal corresponds to a type that does not exist in
# the context. For example, in the variable free variant of the WikiTables
# world, there are nonterminals for specific column types (like date). Say we
Expand All @@ -81,7 +74,7 @@ def _walk(self) -> None:
# Since we expand the last action added to the buffer, the left child should be
# added after the right child.
for right_side_part in reversed(self._get_right_side_parts(action)):
if types.is_nonterminal(right_side_part):
if self._world.is_nonterminal(right_side_part):
new_nonterminal_buffer.append(right_side_part)
next_paths.append((new_nonterminal_buffer, new_history))
incomplete_paths = []
Expand All @@ -92,7 +85,7 @@ def _walk(self) -> None:
next_path_index = len(self._completed_paths)
for action in path:
for value in self._get_right_side_parts(action):
if not types.is_nonterminal(value):
if not self._world.is_nonterminal(value):
self._terminal_path_index[action].add(next_path_index)
self._completed_paths.append(path)
# We're adding to incomplete_paths for the next iteration, only those paths that are
Expand All @@ -103,42 +96,79 @@ def _walk(self) -> None:
@staticmethod
def _get_right_side_parts(action: str) -> List[str]:
_, right_side = action.split(" -> ")
if "[" in right_side:
if right_side.startswith("["):
right_side_parts = right_side[1:-1].split(", ")
else:
right_side_parts = [right_side]
return right_side_parts

def get_logical_forms_with_agenda(self,
agenda: List[str],
max_num_logical_forms: int = None) -> List[str]:
max_num_logical_forms: int = None,
allow_partial_match: bool = False) -> List[str]:
"""
Parameters
----------
agenda : ``List[str]``
max_num_logical_forms : ``int`` (optional)
allow_partial_match : ``bool`` (optional, defaul=False)
If set, this method will return logical forms which contain not necessarily all the
items on the agenda. The returned list will be sorted by how many items the logical
forms match.
"""
if not agenda:
logger.warning("Agenda is empty! Returning all paths instead.")
return self.get_all_logical_forms(max_num_logical_forms)
if allow_partial_match:
logger.warning("Agenda is empty! Returning all paths instead.")
return self.get_all_logical_forms(max_num_logical_forms)
return []
if self._completed_paths is None:
self._walk()
agenda_path_indices = [self._terminal_path_index[action] for action in agenda]
if all([not path_indices for path_indices in agenda_path_indices]):
logger.warning("""None of the agenda items is in any of the paths found. Returning all
paths.""")
return self.get_all_logical_forms(max_num_logical_forms)
# We omit any agenda items that are not in any of the paths, since they would cause the
# final intersection to be null.
if allow_partial_match:
logger.warning("""Agenda items not in any of the paths found. Returning all paths.""")
return self.get_all_logical_forms(max_num_logical_forms)
return []
# TODO (pradeep): Sort the indices and do intersections in order, so that we can return the
# set with maximal coverage if the full intersection is null.
filtered_path_indices = []

# This list contains for each agenda item the list of indices of paths that contain that agenda item. Note
# that we omit agenda items that are not in any paths to avoid the final intersection being null. So there
# will not be any empty sub-lists in the list below.
filtered_path_indices: List[Set[int]] = []
for agenda_item, path_indices in zip(agenda, agenda_path_indices):
if not path_indices:
logger.warning(f"{agenda_item} is not in any of the paths found! Ignoring it.")
continue
filtered_path_indices.append(path_indices)
return_set = filtered_path_indices[0]
for next_set in filtered_path_indices[1:]:
return_set = return_set.intersection(next_set)
paths = [self._completed_paths[index] for index in return_set]

# This mapping is from a path index to the number of items in the agenda that the path contains.
index_to_num_items: Dict[int, int] = defaultdict(int)
for indices in filtered_path_indices:
for index in indices:
index_to_num_items[index] += 1
if allow_partial_match:
# We group the paths based on how many agenda items they contain, and output them in a sorted order.
num_items_grouped_paths: Dict[int, List[List[str]]] = defaultdict(list)
for index, num_items in index_to_num_items.items():
num_items_grouped_paths[num_items].append(self._completed_paths[index])
paths = []
# Sort by number of agenda items present in the paths.
for num_items, corresponding_paths in sorted(num_items_grouped_paths.items(),
reverse=True):
# Given those paths, sort them by length, so that the first path in ``paths`` will
# be the shortest path with the most agenda items.
paths.extend(sorted(corresponding_paths, key=len))
else:
indices_to_return = []
for index, num_items in index_to_num_items.items():
if num_items == len(filtered_path_indices):
indices_to_return.append(index)
# Sort all the paths by length
paths = sorted([self._completed_paths[index] for index in indices_to_return], key=len)
if max_num_logical_forms is not None:
paths = sorted(paths, key=len)[:max_num_logical_forms]
logical_forms = [self._world.get_logical_form(path) for path in paths]
paths = paths[:max_num_logical_forms]
logical_forms = [self._world.action_sequence_to_logical_form(path) for path in paths]
return logical_forms

def get_all_logical_forms(self,
Expand All @@ -150,5 +180,5 @@ def get_all_logical_forms(self,
if self._length_sorted_paths is None:
self._length_sorted_paths = sorted(self._completed_paths, key=len)
paths = self._length_sorted_paths[:max_num_logical_forms]
logical_forms = [self._world.get_logical_form(path) for path in paths]
logical_forms = [self._world.action_sequence_to_logical_form(path) for path in paths]
return logical_forms
1 change: 0 additions & 1 deletion allennlp/semparse/domain_languages/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from allennlp.semparse.domain_languages.domain_language import (DomainLanguage, START_SYMBOL,
predicate, predicate_with_side_args)
from allennlp.semparse.common.errors import ParsingError, ExecutionError
from allennlp.semparse.domain_languages.nlvr_language import NlvrLanguage
from allennlp.semparse.domain_languages.quarel_language import QuaRelLanguage
from allennlp.semparse.domain_languages.wikitables_language import WikiTablesLanguage
Loading

0 comments on commit 0a26739

Please sign in to comment.