-
Notifications
You must be signed in to change notification settings - Fork 2.3k
pieces for multitask learning #2369
pieces for multitask learning #2369
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
from allennlp.data.instance import Instance | ||
from allennlp.data.iterators.data_iterator import DataIterator | ||
|
||
@DataIterator.register("homogeneous-batch") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By convention this should be registered as homogeneous_batch
If false, it will do the tensorization anew each iteration. | ||
track_epoch : ``bool``, optional, (default = False) | ||
If true, each instance will get a ``MetadataField`` containing the epoch number. | ||
partition_key : ``str``, optional, (default = "dataset") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit gross, I wonder if it's better to allow setting an "origin" attribute on an Instance
or something. Maybe not for this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it needs to make it into the model, so if you did that you'd have to change all the "batch to tensor" logic to account for that and then make sure it doesn't somehow collide with other tensors and so on
as discussed at happy hour, this includes an
InterleavingDatasetReader
which wraps multiple other dataset readers and interleaves their instances (adding aMetadataField
indicating which dataset each instance came from) and aHomogeneousBatchIterator
, which assumes such aMetadataField
exists and constructs batches that are homogeneous with respect to its value.The only "weird" thing is that the file_path passed to
InterleavingDatasetReader.read()
needs to be a JSON-serialized dict { wrapped_reader_key -> file_path }. We discussed alternative designs like passing in a directory and requiring each wrapped reader to know to look for a specific file under the provided directory; I felt like that seemed a little bit too prescriptive about data layout and harder to configure.I believe that with these pieces most of the multitask things that S2 research wants to do should be relatively easy. (Notably, with the file-path-as-JSON-dict innovation, we can just use the usual
Trainer
😬 )FYI @amandalynne