Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Data V2 #3700

Merged
merged 59 commits into from
Feb 26, 2020
Merged

Data V2 #3700

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
0c42cb9
example for feedback
DeNeutoy Jan 30, 2020
5ffedfc
Merge branch 'master' into data-v2
DeNeutoy Feb 19, 2020
80049f8
remove all existing multiprocessing
DeNeutoy Feb 19, 2020
6f58c2a
sneak torch datasets inside DatasetReader
DeNeutoy Feb 19, 2020
1b3ad9a
lint
DeNeutoy Feb 19, 2020
effc445
trainer_v2, We Love To See It
DeNeutoy Feb 19, 2020
9d44ad6
datasets have index_with now, not iterators
DeNeutoy Feb 19, 2020
7e89ea6
use iter, custom collate function in allennlp wrapper
DeNeutoy Feb 19, 2020
883b6d7
we don't even need the data in the trainer anymore
DeNeutoy Feb 19, 2020
56d022a
all trainer tests passing
DeNeutoy Feb 20, 2020
01e12f5
black
DeNeutoy Feb 20, 2020
5aea291
make find learning rate work
DeNeutoy Feb 20, 2020
f026946
update test fixtures to new config
DeNeutoy Feb 20, 2020
5973b50
get train command tests mostly working
DeNeutoy Feb 20, 2020
a23f47a
lazily construct samplers, index lazy datasets
DeNeutoy Feb 20, 2020
a76ea0a
Merge branch 'master' into data-v2
DeNeutoy Feb 20, 2020
ebf3854
update some fixtures
DeNeutoy Feb 20, 2020
57a67e5
evaluate tests passing
DeNeutoy Feb 20, 2020
7d21ed8
all command tests passing
DeNeutoy Feb 20, 2020
24a500c
lint
DeNeutoy Feb 20, 2020
fb13769
update model test case, common and module tests passing
DeNeutoy Feb 20, 2020
ef5187f
fix test interdependence introduced by #3762
DeNeutoy Feb 21, 2020
b1ea845
more test interdependence
DeNeutoy Feb 21, 2020
0231616
tests tests tests
DeNeutoy Feb 21, 2020
01d76bb
remove unnecessary brackets
DeNeutoy Feb 21, 2020
12b6efb
Merge branch 'master' into data-v2
DeNeutoy Feb 21, 2020
859d3ca
update a chunk of the configs
DeNeutoy Feb 21, 2020
c22dee3
fix archival test, couple more configs
DeNeutoy Feb 21, 2020
fe5b470
rm pointless gan test
DeNeutoy Feb 21, 2020
7533c91
more tests passing
DeNeutoy Feb 21, 2020
ad45659
add current state of from params changes
DeNeutoy Feb 21, 2020
f944840
Revert "add current state of from params changes"
DeNeutoy Feb 21, 2020
3b12a2f
Merge branch 'master' into data-v2
DeNeutoy Feb 21, 2020
be1f58c
updated understanding of Lazy
DeNeutoy Feb 21, 2020
ebdabe0
add discussion of None comparison to Lazy
DeNeutoy Feb 21, 2020
8693739
lint
DeNeutoy Feb 21, 2020
b9b0650
it's a hard doc life
DeNeutoy Feb 21, 2020
88314c7
pull samplers into separate file
DeNeutoy Feb 21, 2020
14296a1
more docs updates
DeNeutoy Feb 22, 2020
8a08899
fold in #3812
DeNeutoy Feb 22, 2020
3520280
remove torch dataset
DeNeutoy Feb 22, 2020
0f1d8a4
add example to lazy
DeNeutoy Feb 22, 2020
93e1e89
rename to collate
DeNeutoy Feb 22, 2020
40dd695
no kwargs
DeNeutoy Feb 23, 2020
da3b1b4
Revert "fold in #3812"
DeNeutoy Feb 23, 2020
801a8f5
don't break up dataset
DeNeutoy Feb 23, 2020
007fd0c
add comment to iterable dataset len
DeNeutoy Feb 23, 2020
d00e1a9
Merge branch 'master' into data-v2
DeNeutoy Feb 23, 2020
c066804
improve docstrings, build dataloader using partial_objects
DeNeutoy Feb 23, 2020
61c7b14
flake
DeNeutoy Feb 23, 2020
2b56b14
give dataloader a default implementation
DeNeutoy Feb 24, 2020
354010a
safer default for DataLoader init
DeNeutoy Feb 24, 2020
568291d
more coherent dir structure
DeNeutoy Feb 24, 2020
a016103
update imports
DeNeutoy Feb 24, 2020
47db16a
Merge branch 'master' into data-v2
DeNeutoy Feb 24, 2020
04fdb70
add a test for the BucketBatchSampler
DeNeutoy Feb 24, 2020
d1d5c4a
split bucket sampler into own file, tests
DeNeutoy Feb 24, 2020
5f0c8db
PR comments
DeNeutoy Feb 26, 2020
6f63a53
Merge branch 'master' into data-v2
DeNeutoy Feb 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions allennlp/commands/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
from allennlp.commands.subcommand import Subcommand
from allennlp.common.util import dump_metrics, prepare_environment
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.iterators import DataIterator
from allennlp.data import DataLoader
from allennlp.models.archival import load_archive
from allennlp.training.util import evaluate

Expand Down Expand Up @@ -173,15 +173,15 @@ def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]:
model.vocab.extend_from_instances(instances=instances)
model.extend_embedder_vocab(embedding_sources)

iterator_params = config.pop("validation_iterator", None)
if iterator_params is None:
iterator_params = config.pop("iterator")
instances.index_with(model.vocab)
data_loader_params = config.pop("validation_data_loader", None)
if data_loader_params is None:
data_loader_params = config.pop("data_loader")
if args.batch_size:
iterator_params["batch_size"] = args.batch_size
iterator = DataIterator.from_params(iterator_params)
iterator.index_with(model.vocab)
data_loader_params["batch_size"] = args.batch_size
data_loader = DataLoader.from_params(dataset=instances, params=data_loader_params)

metrics = evaluate(model, instances, iterator, args.cuda_device, args.batch_weight_key)
metrics = evaluate(model, data_loader, args.cuda_device, args.batch_weight_key)

logger.info("Finished evaluating.")

Expand Down
20 changes: 9 additions & 11 deletions allennlp/commands/find_learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,16 @@
import os
import re
from typing import List, Tuple
import itertools

from overrides import overrides

from allennlp.commands.subcommand import Subcommand
from allennlp.common import Params, Tqdm
from allennlp.common.checks import check_for_gpu, ConfigurationError
from allennlp.common.util import prepare_environment
from allennlp.data import DataIterator, Vocabulary
from allennlp.data import Vocabulary
from allennlp.data import DataLoader
from allennlp.models import Model
from allennlp.training import Trainer, TrainerBase
from allennlp.training.util import create_serialization_dir, datasets_from_params
Expand Down Expand Up @@ -211,11 +213,10 @@ def find_learning_rate_model(
),
)

model = Model.from_params(vocab=vocab, params=params.pop("model"))
iterator = DataIterator.from_params(params.pop("iterator"))
iterator.index_with(vocab)

train_data = all_datasets["train"]
train_data.index_with(vocab)
model = Model.from_params(vocab=vocab, params=params.pop("model"))
data_loader = DataLoader.from_params(dataset=train_data, params=params.pop("data_loader"))

trainer_params = params.pop("trainer")

Expand All @@ -230,11 +231,8 @@ def find_learning_rate_model(
trainer: Trainer = TrainerBase.from_params( # type: ignore
model=model,
serialization_dir=serialization_dir,
iterator=iterator,
train_data=train_data,
validation_data=None,
data_loader=data_loader,
params=trainer_params,
validation_iterator=None,
)

logger.info(
Expand Down Expand Up @@ -292,8 +290,8 @@ def search_learning_rate(

trainer.model.train()

train_generator = trainer.iterator(trainer.train_data, shuffle=trainer.shuffle)
train_generator_tqdm = Tqdm.tqdm(train_generator, total=num_batches)
infinite_generator = itertools.cycle(trainer.data_loader)
train_generator_tqdm = Tqdm.tqdm(infinite_generator, total=num_batches)

learning_rates = []
losses = []
Expand Down
71 changes: 42 additions & 29 deletions allennlp/commands/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import argparse
import logging
import os
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, List, Optional

import torch
import torch.distributed as dist
Expand All @@ -53,7 +53,8 @@
from allennlp.common.checks import check_for_gpu, ConfigurationError
from allennlp.common import util as common_util
from allennlp.common.plugins import import_plugins
from allennlp.data import DataIterator, DatasetReader, Instance, Vocabulary
from allennlp.data import DatasetReader, Vocabulary
from allennlp.data import DataLoader
from allennlp.models.archival import archive_model, CONFIG_NAME
from allennlp.models.model import _DEFAULT_WEIGHTS, Model
from allennlp.training.trainer_base import TrainerBase
Expand Down Expand Up @@ -296,7 +297,7 @@ def train_model(
)

# Creating `Vocabulary` objects from workers could be problematic since
# the data iterators in each worker will yield only `rank` specific
# the data loaders in each worker will yield only `rank` specific
# instances. Hence it is safe to construct the vocabulary and write it
# to disk before initializing the distributed context. The workers will
# load the vocabulary from the path specified.
Expand Down Expand Up @@ -504,36 +505,33 @@ def __init__(
serialization_dir: str,
model: Model,
trainer: TrainerBase,
evaluation_dataset: Iterable[Instance] = None,
evaluation_iterator: DataIterator = None,
evaluation_data_loader: DataLoader = None,
evaluate_on_test: bool = False,
batch_weight_key: str = "",
) -> None:
self.serialization_dir = serialization_dir
self.model = model
self.trainer = trainer
self.evaluation_dataset = evaluation_dataset
self.evaluation_iterator = evaluation_iterator
self.evaluation_data_loader = evaluation_data_loader
self.evaluate_on_test = evaluate_on_test
self.batch_weight_key = batch_weight_key

def run(self) -> Dict[str, Any]:
return self.trainer.train()

def finish(self, metrics: Dict[str, Any]):
if self.evaluation_dataset and self.evaluate_on_test:
if self.evaluation_data_loader and self.evaluate_on_test:
logger.info("The model will be evaluated using the best epoch weights.")
test_metrics = training_util.evaluate(
self.model,
self.evaluation_dataset,
self.evaluation_iterator,
self.evaluation_data_loader,
cuda_device=self.trainer.cuda_device,
batch_weight_key=self.batch_weight_key,
)

for key, value in test_metrics.items():
metrics["test_" + key] = value
elif self.evaluation_dataset:
elif self.evaluation_data_loader:
logger.info(
"To evaluate on the test set after training, pass the "
"'evaluate_on_test' flag, or use the 'allennlp evaluate' command."
Expand All @@ -551,13 +549,13 @@ def from_partial_objects(
dataset_reader: DatasetReader,
train_data_path: str,
model: Lazy[Model],
iterator: DataIterator,
data_loader: Lazy[DataLoader],
trainer: Lazy[TrainerBase],
vocabulary: Lazy[Vocabulary] = None,
datasets_for_vocab_creation: List[str] = None,
validation_dataset_reader: DatasetReader = None,
validation_data_path: str = None,
validation_iterator: DataIterator = None,
validation_data_loader: Lazy[DataLoader] = None,
test_data_path: str = None,
evaluate_on_test: bool = False,
) -> "TrainModel":
Expand Down Expand Up @@ -595,9 +593,9 @@ def from_partial_objects(
model: `Lazy[Model]`
The model that we will train. This is lazy because it depends on the `Vocabulary`;
after constructing the vocabulary we call `model.construct(vocab=vocabulary)`.
iterator: `DataIterator`
The iterator we use to batch instances from the dataset reader at training and (by
default) validation time.
data_loader: `Lazy[DataLoader]`
The data_loader we use to batch instances from the dataset reader at training and (by
default) validation time. This is lazy because it takes a dataset in it's constructor.
trainer: `Lazy[TrainerBase]`
The `Trainer` that actually implements the training loop. This is a lazy object because
it depends on the model that's going to be trained.
Expand All @@ -614,9 +612,9 @@ def from_partial_objects(
`dataset_reader`.
validation_data_path: `str`, optional (default=None)
If given, we will use this data for computing validation metrics and early stopping.
validation_iterator: `DataIterator`, optional (default=None)
If given, we will use this iterator for batching and scheduling instances for the
validation data, instead of `iterator`.
validation_data_loader: `Lazy[DataLoader]`, optional (default=None)
If given, the data_loader we use to batch instances from the dataset reader at
validation and test time. This is lazy because it takes a dataset in it's constructor.
test_data_path: `str`, optional (default=None)
If given, we will use this as test data. This makes it available for vocab creation by
default, but nothing else.
Expand Down Expand Up @@ -658,27 +656,42 @@ def from_partial_objects(
vocabulary_path = os.path.join(serialization_dir, "vocabulary")
vocabulary_.save_to_files(vocabulary_path)

iterator.index_with(model_.vocab)
validation_iterator = validation_iterator or iterator
validation_iterator.index_with(model_.vocab) # it is ok to call this twice
for dataset in datasets.values():
dataset.index_with(model_.vocab)

data_loader_ = data_loader.construct(dataset=datasets["train"])
validation_data = datasets.get("validation")
if validation_data is not None:
# Because of the way Lazy[T] works, we can't check it's existence
# _before_ we've tried to construct it. It returns None if it is not
# present, so we try to construct it first, and then afterward back off
# to the data_loader configuration used for training if it returns None.
validation_data_loader_ = validation_data_loader.construct(dataset=validation_data)
if validation_data_loader_ is None:
validation_data_loader_ = data_loader.construct(dataset=validation_data)
else:
validation_data_loader_ = None

test_data = datasets.get("test")
if test_data is not None:
test_data_loader = validation_data_loader.construct(dataset=test_data)
if test_data_loader is None:
test_data_loader = data_loader.construct(dataset=test_data)
else:
test_data_loader = None

# We don't need to pass serialization_dir and local_rank here, because they will have been
# passed through the trainer by from_params already, because they were keyword arguments to
# construct this class in the first place.
trainer_ = trainer.construct(
model=model_,
iterator=iterator,
train_data=datasets["train"],
validation_iterator=validation_iterator,
validation_data=datasets.get("validation"),
model=model_, data_loader=data_loader_, validation_data_loader=validation_data_loader_,
)

return cls(
serialization_dir=serialization_dir,
model=model_,
trainer=trainer_,
evaluation_dataset=datasets.get("test"),
evaluation_iterator=validation_iterator,
evaluation_data_loader=test_data_loader,
evaluate_on_test=evaluate_on_test,
batch_weight_key=batch_weight_key,
)
Expand Down
21 changes: 19 additions & 2 deletions allennlp/common/lazy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Generic, TypeVar
from typing import Callable, Generic, TypeVar, Optional

T = TypeVar("T")

Expand All @@ -20,10 +20,27 @@ class Lazy(Generic[T]):
The actual implementation here is incredibly simple; the logic that handles the lazy
construction is actually found in `FromParams`, where we have a special case for a `Lazy` type
annotation.

!!! Warning
The way this class is used in from_params means that optional constructor arguments CANNOT
be compared to `None` _before_ it is constructed. See the example below for correct usage.

```
@classmethod
def my_constructor(cls, some_object: Lazy[MyObject] = None) -> MyClass:
...
# WRONG! some_object will never be None at this point, it will be
# a Lazy[] that returns None
obj = some_object or MyObjectDefault()
# CORRECT:
obj = some_object.construct(kwarg=kwarg) or MyObjectDefault()
...
```

"""

def __init__(self, constructor: Callable[..., T]):
self._constructor = constructor

def construct(self, **kwargs) -> T:
def construct(self, **kwargs) -> Optional[T]:
return self._constructor(**kwargs)
37 changes: 21 additions & 16 deletions allennlp/common/testing/model_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from allennlp.commands.train import train_model_from_file
from allennlp.common import Params
from allennlp.common.testing.test_case import AllenNlpTestCase
from allennlp.data import DataIterator, DatasetReader, Vocabulary
from allennlp.data import DatasetReader, Vocabulary
from allennlp.data import DataLoader
from allennlp.data.batch import Batch
from allennlp.models import load_archive, Model

Expand All @@ -25,7 +26,7 @@ def set_up_model(self, param_file, dataset_file):

reader = DatasetReader.from_params(params["dataset_reader"])
# The dataset reader might be lazy, but a lazy list here breaks some of our tests.
instances = list(reader.read(str(dataset_file)))
instances = reader.read(str(dataset_file))
# Use parameters for vocabulary if they are present in the config file, so that choices like
# "non_padded_namespaces", "min_count" etc. can be set if needed.
if "vocabulary" in params:
Expand All @@ -35,11 +36,12 @@ def set_up_model(self, param_file, dataset_file):
vocab = Vocabulary.from_instances(instances)
self.vocab = vocab
self.instances = instances
self.instances.index_with(vocab)
self.model = Model.from_params(vocab=self.vocab, params=params["model"])

# TODO(joelgrus) get rid of these
# (a lot of the model tests use them, so they'll have to be changed)
self.dataset = Batch(self.instances)
self.dataset = Batch(list(self.instances))
self.dataset.index_instances(self.vocab)

def ensure_model_can_train_save_and_load(
Expand Down Expand Up @@ -93,24 +95,27 @@ def ensure_model_can_train_save_and_load(
params = Params.from_file(param_file, params_overrides=overrides)
reader = DatasetReader.from_params(params["dataset_reader"])

# Need to duplicate params because Iterator.from_params will consume.
iterator_params = params["iterator"]
iterator_params2 = Params(copy.deepcopy(iterator_params.as_dict()))

iterator = DataIterator.from_params(iterator_params)
iterator2 = DataIterator.from_params(iterator_params2)

# We'll check that even if we index the dataset with each model separately, we still get
# the same result out.
print("Reading with original model")
model_dataset = reader.read(params["validation_data_path"])
iterator.index_with(model.vocab)
model_batch = next(iterator(model_dataset, shuffle=False))
model_dataset.index_with(model.vocab)

print("Reading with loaded model")
loaded_dataset = reader.read(params["validation_data_path"])
iterator2.index_with(loaded_model.vocab)
loaded_batch = next(iterator2(loaded_dataset, shuffle=False))
loaded_dataset.index_with(loaded_model.vocab)

# Need to duplicate params because DataLoader.from_params will consume.
data_loader_params = params["data_loader"]
data_loader_params["shuffle"] = False
data_loader_params2 = Params(copy.deepcopy(data_loader_params.as_dict()))

data_loader = DataLoader.from_params(dataset=model_dataset, params=data_loader_params)
data_loader2 = DataLoader.from_params(dataset=loaded_dataset, params=data_loader_params2)

# We'll check that even if we index the dataset with each model separately, we still get
# the same result out.
model_batch = next(iter(data_loader))

loaded_batch = next(iter(data_loader2))

# Check gradients are None for non-trainable parameters and check that
# trainable parameters receive some gradient if they are trainable.
Expand Down
1 change: 1 addition & 0 deletions allennlp/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from allennlp.data.dataloader import DataLoader, allennlp_collate
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields.field import DataArray, Field
from allennlp.data.fields.text_field import TextFieldTensors
Expand Down
Loading