Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Multilingual parser and Cross-lingual ELMo #2628

Merged
merged 25 commits into from
Jun 12, 2019
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
a4a6728
Multilingual version of biaffine dependency parser with elmo alignment
TalSchuster Mar 15, 2019
02660d6
Clean multilang biaffine
TalSchuster Mar 18, 2019
f62a0f6
Multi-lang dep config example
TalSchuster Mar 18, 2019
2101636
Larger softmax value
TalSchuster Mar 18, 2019
d8ab96f
Merge branch 'master' into multilingual_parser
TalSchuster Mar 29, 2019
7884b86
formating
TalSchuster Mar 30, 2019
c985456
Merge branch 'multilingual_parser' of github.com:TalSchuster/allennlp…
TalSchuster Mar 30, 2019
797489f
Merge branch 'master' into multilingual_parser
TalSchuster Mar 30, 2019
4c3ade3
reorganize multilang dataset reader to work with a pathname
TalSchuster Apr 12, 2019
5c77a40
Merge branch 'multilingual_parser' of github.com:TalSchuster/allennlp…
TalSchuster Apr 12, 2019
8b6094f
factoring biaffine parser to prevent duplicating code
TalSchuster Apr 15, 2019
a374561
multilangTokenEmbedder interface
TalSchuster Apr 15, 2019
3e7ce6d
fix parser factorization and pylint staff
TalSchuster Apr 15, 2019
84de432
more pylint
TalSchuster Apr 15, 2019
c02fac6
inspect embedder and fix params_test
TalSchuster Apr 17, 2019
b31aeba
make mypy happy
TalSchuster Apr 17, 2019
fc664a5
cr comments and doc
TalSchuster Apr 19, 2019
1bd162a
doc
TalSchuster Apr 19, 2019
fc4aac7
fix doc
TalSchuster Apr 19, 2019
c3dc864
Multilingual tests (#4)
TalSchuster Jun 5, 2019
2990045
multilingual embedder test
TalSchuster Jun 11, 2019
7b750d6
Merge branch 'master' into multilingual_parser
TalSchuster Jun 11, 2019
7bcaa8f
Merge branch 'master' into multilingual_parser
matt-gardner Jun 12, 2019
b362437
cr comments
TalSchuster Jun 12, 2019
e09770e
new link
TalSchuster Jun 12, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions allennlp/data/dataset_readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from allennlp.data.dataset_readers.sequence_tagging import SequenceTaggingDatasetReader
from allennlp.data.dataset_readers.snli import SnliReader
from allennlp.data.dataset_readers.universal_dependencies import UniversalDependenciesDatasetReader
from allennlp.data.dataset_readers.universal_dependencies_multilang import UniversalDependenciesMultiLangDatasetReader
from allennlp.data.dataset_readers.stanford_sentiment_tree_bank import (
StanfordSentimentTreeBankDatasetReader)
from allennlp.data.dataset_readers.quora_paraphrase import QuoraParaphraseDatasetReader
Expand Down
193 changes: 193 additions & 0 deletions allennlp/data/dataset_readers/universal_dependencies_multilang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from typing import Dict, Tuple, List
import logging
import itertools
import glob
import os
import numpy as np

from overrides import overrides

from allennlp.common.checks import ConfigurationError
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import Field, TextField, SequenceLabelField, MetadataField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.data.dataset_readers.universal_dependencies import lazy_parse

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


def get_file_paths(pathname: str, languages: List[str]):
"""
Gets a list of all files by the pathname with the given language ids.
Filenames are assumed to have the language identifier followed by a dash
as a prefix (e.g. en-universal.conll).

Parameters
----------
pathname : ``str``, required.
An absolute or relative pathname (can contain shell-style wildcards)
languages : ``List[str]``, required
The language identifiers to use.

Returns
-------
A list of tuples (language id, file path).
"""
paths = []
for file_path in glob.glob(pathname):
base = os.path.splitext(os.path.basename(file_path))[0]
lang_id = base.split('-')[0]
if lang_id in languages:
paths.append((lang_id, file_path))

if not paths:
raise ConfigurationError("No dataset files to read")

return paths


@DatasetReader.register("universal_dependencies_multilang")
class UniversalDependenciesMultiLangDatasetReader(DatasetReader):
"""
Reads multiple files in the conllu Universal Dependencies format.
All files should be in the same directory and the filenames should have
the language identifier followed by a dash as a prefix (e.g. en-universal.conll)
When using the alternate option, the reader alternates randomly between
the files every instances_per_file. The is_first_pass_for_vocab disables
this behaviour for the first pass (could be useful for a single full path
over the dataset in order to generate a vocabulary).

Notice: when using the alternate option, one should also use the ``instances_per_epoch``
option for the iterator. Otherwise, each epoch will loop infinitely.
TalSchuster marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
languages : ``List[str]``, required
The language identifiers to use.
token_indexers : ``Dict[str, TokenIndexer]``, optional (default=``{"tokens": SingleIdTokenIndexer()}``)
The token indexers to be applied to the words TextField.
use_language_specific_pos : ``bool``, optional (default = False)
Whether to use UD POS tags, or to use the language specific POS tags
provided in the conllu format.
alternate : ``bool``, optional (default = True)
Whether to alternate between input files.
is_first_pass_for_vocab : ``bool``, optional (default = True)
Whether the first pass will be for generating the vocab. If true,
the first pass will run over the entire dataset of each file (even if alternate is on).
instances_per_file : ``int``, optional (default = 32)
The amount of consecutive cases to sample from each input file when alternating.
"""
def __init__(self,
languages: List[str],
token_indexers: Dict[str, TokenIndexer] = None,
use_language_specific_pos: bool = False,
lazy: bool = False,
alternate: bool = True,
is_first_pass_for_vocab: bool = True,
instances_per_file: int = 32) -> None:
super().__init__(lazy)
self._languages = languages
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
self._use_language_specific_pos = use_language_specific_pos

self._is_first_pass_for_vocab = is_first_pass_for_vocab
self._alternate = alternate
self._instances_per_file = instances_per_file

self._is_first_pass = True
self._iterators = None

def _read_one_file(self, lang: str, file_path: str):
with open(file_path, 'r') as conllu_file:
logger.info("Reading UD instances for %s language from conllu dataset at: %s", lang, file_path)

for annotation in lazy_parse(conllu_file.read()):
# CoNLLU annotations sometimes add back in words that have been elided
# in the original sentence; we remove these, as we're just predicting
# dependencies for the original sentence.
# We filter by None here as elided words have a non-integer word id,
# and are replaced with None by the conllu python library.
annotation = [x for x in annotation if x["id"] is not None]

heads = [x["head"] for x in annotation]
tags = [x["deprel"] for x in annotation]
words = [x["form"] for x in annotation]
if self._use_language_specific_pos:
pos_tags = [x["xpostag"] for x in annotation]
else:
pos_tags = [x["upostag"] for x in annotation]
yield self.text_to_instance(lang, words, pos_tags, list(zip(tags, heads)))

@overrides
def _read(self, file_path: str):
file_paths = get_file_paths(file_path, self._languages)
if (self._is_first_pass and self._is_first_pass_for_vocab) or (not self._alternate):
iterators = [(lang, iter(self._read_one_file(lang, file_path))) \
TalSchuster marked this conversation as resolved.
Show resolved Hide resolved
for (lang, file_path) in file_paths]
_, iterators = zip(*iterators)
self._is_first_pass = False
for inst in itertools.chain(*iterators):
yield inst

else:
if self._iterators is None:
self._iterators = [(lang, iter(self._read_one_file(lang, file_path))) \
for (lang, file_path) in file_paths]
num_files = len(file_paths)
while True:
ind = np.random.randint(num_files)
lang, lang_iter = self._iterators[ind]
for _ in range(self._instances_per_file):
try:
yield lang_iter.__next__()
except StopIteration:
lang, file_path = file_paths[ind]
lang_iter = iter(self._read_one_file(lang, file_path))
self._iterators[ind] = (lang, lang_iter)
yield lang_iter.__next__()

@overrides
def text_to_instance(self, # type: ignore
lang: str,
words: List[str],
upos_tags: List[str],
dependencies: List[Tuple[str, int]] = None) -> Instance:
# pylint: disable=arguments-differ
"""
Parameters
----------
lang : ``str``, required.
The language identifier.
words : ``List[str]``, required.
The words in the sentence to be encoded.
upos_tags : ``List[str]``, required.
The universal dependencies POS tags for each word.
dependencies ``List[Tuple[str, int]]``, optional (default = None)
A list of (head tag, head index) tuples. Indices are 1 indexed,
meaning an index of 0 corresponds to that word being the root of
the dependency tree.

Returns
-------
An instance containing words, upos tags, dependency head tags and head
indices as fields. The language identifier is stored in the metadata.
"""
fields: Dict[str, Field] = {}

tokens = TextField([Token(w) for w in words], self._token_indexers)
fields["words"] = tokens
fields["pos_tags"] = SequenceLabelField(upos_tags, tokens, label_namespace="pos")
if dependencies is not None:
# We don't want to expand the label namespace with an additional dummy token, so we'll
# always give the 'ROOT_HEAD' token a label of 'root'.
fields["head_tags"] = SequenceLabelField([x[0] for x in dependencies],
tokens,
label_namespace="head_tags")
fields["head_indices"] = SequenceLabelField([int(x[1]) for x in dependencies],
tokens,
label_namespace="head_index_tags")

fields["metadata"] = MetadataField({"words": words, "pos": upos_tags, "lang": lang})
return Instance(fields)
1 change: 1 addition & 0 deletions allennlp/data/iterators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from allennlp.data.iterators.bucket_iterator import BucketIterator
from allennlp.data.iterators.homogeneous_batch_iterator import HomogeneousBatchIterator
from allennlp.data.iterators.multiprocess_iterator import MultiprocessIterator
from allennlp.data.iterators.same_lang_iterator import SameLangIterator
45 changes: 45 additions & 0 deletions allennlp/data/iterators/same_lang_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from collections import deque, defaultdict
from typing import Iterable, Deque
import logging
import random

from allennlp.common.util import lazy_groups_of
from allennlp.data.instance import Instance
from allennlp.data.iterators.data_iterator import DataIterator
from allennlp.data.dataset import Batch

logger = logging.getLogger(__name__) # pylint: disable=invalid-name

def split_by_lang(instance_list):
insts_by_lang = defaultdict(lambda: [])
for inst in instance_list:
inst_lang = inst.fields['metadata'].metadata['lang']
insts_by_lang[inst_lang].append(inst)

return iter(insts_by_lang.values())

@DataIterator.register("same_lang")
class SameLangIterator(DataIterator):
TalSchuster marked this conversation as resolved.
Show resolved Hide resolved
"""

Splits batches into batches containing the same language.
TalSchuster marked this conversation as resolved.
Show resolved Hide resolved
Based on the basic iterator.

It takes the same parameters as :class:`allennlp.data.iterators.DataIterator`
"""
def _create_batches(self, instances: Iterable[Instance], shuffle: bool) -> Iterable[Batch]:
# First break the dataset into memory-sized lists:
for instance_list in self._memory_sized_lists(instances):
if shuffle:
random.shuffle(instance_list)
instance_list = split_by_lang(instance_list)
for same_lang_batch in instance_list:
iterator = iter(same_lang_batch)
excess: Deque[Instance] = deque()
# Then break each memory-sized list into batches.
for batch_instances in lazy_groups_of(iterator, self._batch_size):
for poss_smaller_batches in self._ensure_batch_is_sufficiently_small(batch_instances, excess):
batch = Batch(poss_smaller_batches)
yield batch
if excess:
yield Batch(excess)
1 change: 1 addition & 0 deletions allennlp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from allennlp.models.biattentive_classification_network import BiattentiveClassificationNetwork
from allennlp.models.constituency_parser import SpanConstituencyParser
from allennlp.models.biaffine_dependency_parser import BiaffineDependencyParser
from allennlp.models.biaffine_dependency_parser_multilang import BiaffineDependencyParserMultiLang
from allennlp.models.coreference_resolution.coref import CoreferenceResolver
from allennlp.models.crf_tagger import CrfTagger
from allennlp.models.decomposable_attention import DecomposableAttention
Expand Down
104 changes: 59 additions & 45 deletions allennlp/models/biaffine_dependency_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,64 @@ def forward(self, # type: ignore
raise ConfigurationError("Model uses a POS embedding, but no POS tags were passed.")

mask = get_text_field_mask(words)

predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll = self._parse(
embedded_text_input, mask, head_tags, head_indices)

loss = arc_nll + tag_nll

if head_indices is not None and head_tags is not None:
evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
# We calculate attatchment scores for the whole sentence
# but excluding the symbolic ROOT token at the start,
# which is why we start from the second element in the sequence.
self._attachment_scores(predicted_heads[:, 1:],
predicted_head_tags[:, 1:],
head_indices,
head_tags,
evaluation_mask)

output_dict = {
"heads": predicted_heads,
"head_tags": predicted_head_tags,
"arc_loss": arc_nll,
"tag_loss": tag_nll,
"loss": loss,
"mask": mask,
"words": [meta["words"] for meta in metadata],
"pos": [meta["pos"] for meta in metadata]
}

return output_dict

@overrides
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

head_tags = output_dict.pop("head_tags").cpu().detach().numpy()
heads = output_dict.pop("heads").cpu().detach().numpy()
mask = output_dict.pop("mask")
lengths = get_lengths_from_binary_sequence_mask(mask)
head_tag_labels = []
head_indices = []
for instance_heads, instance_tags, length in zip(heads, head_tags, lengths):
instance_heads = list(instance_heads[1:length])
instance_tags = instance_tags[1:length]
labels = [self.vocab.get_token_from_index(label, "head_tags")
for label in instance_tags]
head_tag_labels.append(labels)
head_indices.append(instance_heads)

output_dict["predicted_dependencies"] = head_tag_labels
output_dict["predicted_heads"] = head_indices
return output_dict

def _parse(self,
embedded_text_input: torch.Tensor,
mask: torch.LongTensor,
head_tags: torch.LongTensor = None,
head_indices: torch.LongTensor = None
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

embedded_text_input = self._input_dropout(embedded_text_input)
encoded_text = self.encoder(embedded_text_input, mask)

Expand Down Expand Up @@ -258,59 +316,15 @@ def forward(self, # type: ignore
head_indices=head_indices,
head_tags=head_tags,
mask=mask)
loss = arc_nll + tag_nll

evaluation_mask = self._get_mask_for_eval(mask[:, 1:], pos_tags)
# We calculate attatchment scores for the whole sentence
# but excluding the symbolic ROOT token at the start,
# which is why we start from the second element in the sequence.
self._attachment_scores(predicted_heads[:, 1:],
predicted_head_tags[:, 1:],
head_indices[:, 1:],
head_tags[:, 1:],
evaluation_mask)
else:
arc_nll, tag_nll = self._construct_loss(head_tag_representation=head_tag_representation,
child_tag_representation=child_tag_representation,
attended_arcs=attended_arcs,
head_indices=predicted_heads.long(),
head_tags=predicted_head_tags.long(),
mask=mask)
loss = arc_nll + tag_nll

output_dict = {
"heads": predicted_heads,
"head_tags": predicted_head_tags,
"arc_loss": arc_nll,
"tag_loss": tag_nll,
"loss": loss,
"mask": mask,
"words": [meta["words"] for meta in metadata],
"pos": [meta["pos"] for meta in metadata]
}

return output_dict

@overrides
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:

head_tags = output_dict.pop("head_tags").cpu().detach().numpy()
heads = output_dict.pop("heads").cpu().detach().numpy()
mask = output_dict.pop("mask")
lengths = get_lengths_from_binary_sequence_mask(mask)
head_tag_labels = []
head_indices = []
for instance_heads, instance_tags, length in zip(heads, head_tags, lengths):
instance_heads = list(instance_heads[1:length])
instance_tags = instance_tags[1:length]
labels = [self.vocab.get_token_from_index(label, "head_tags")
for label in instance_tags]
head_tag_labels.append(labels)
head_indices.append(instance_heads)

output_dict["predicted_dependencies"] = head_tag_labels
output_dict["predicted_heads"] = head_indices
return output_dict
return predicted_heads, predicted_head_tags, mask, arc_nll, tag_nll

def _construct_loss(self,
head_tag_representation: torch.Tensor,
Expand Down
Loading