Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
make Instance, Batch, and all field classes "slots" classes (#4313)
Browse files Browse the repository at this point in the history
* make field classes slots classes

* update CHANGELOG

* make Instance a slot class as well

* make Batch a slots class

* add test and fix more Matt's edge case

* add comment

* handle case where sub field is not a slots class

* fix comment

* safely get slots

* clean up

* make more robust
  • Loading branch information
epwalsh authored Jun 4, 2020
1 parent 2b2d141 commit 06bac68
Show file tree
Hide file tree
Showing 18 changed files with 125 additions and 2 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Similar to our caching mechanism, we introduced a lock file to the vocab to avoid race
conditions when saving/loading the vocab from/to the same serialization directory in different processes.
- Changed the `Token` class to a "slots" class, which dramatically reduces the size in memory of `Token` instances.
- Changed the `Token`, `Instance`, and `Batch` classes along with all `Field` classes to "slots" classes. This dramatically reduces the size in memory of instances.
- SimpleTagger will no longer calculate span-based F1 metric when `calculate_span_f1` is `False`.

## [v1.0.0rc5](/~https://github.com/allenai/allennlp/releases/tag/v1.0.0rc5) - 2020-05-26
Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class Batch(Iterable):
in a list.
"""

__slots__ = ["instances"]

def __init__(self, instances: Iterable[Instance]) -> None:
super().__init__()

Expand Down
9 changes: 9 additions & 0 deletions allennlp/data/fields/adjacency_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ class AdjacencyField(Field[torch.Tensor]):
The value to use as padding.
"""

__slots__ = [
"indices",
"labels",
"sequence_field",
"_label_namespace",
"_padding_value",
"_indexed_labels",
]

# It is possible that users want to use this field with a namespace which uses OOV/PAD tokens.
# This warning will be repeated for every instantiation of this class (i.e for every data
# instance), spewing a lot of warnings so this class variable is used to only log a single
Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/fields/array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ class ArrayField(Field[numpy.ndarray]):
for each dimension.
"""

__slots__ = ["array", "padding_value", "dtype"]

def __init__(
self, array: numpy.ndarray, padding_value: int = 0, dtype: numpy.dtype = numpy.float32
) -> None:
Expand Down
16 changes: 15 additions & 1 deletion allennlp/data/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class Field(Generic[DataArray]):
then intelligently batch together instances and pad them into actual tensors.
"""

__slots__ = [] # type: ignore

def count_vocab_items(self, counter: Dict[str, Dict[str, int]]):
"""
If there are strings in this field that need to be converted into integers through a
Expand Down Expand Up @@ -116,7 +118,19 @@ def batch_tensors(self, tensor_list: List[DataArray]) -> DataArray: # type: ign

def __eq__(self, other) -> bool:
if isinstance(self, other.__class__):
return self.__dict__ == other.__dict__
# With the way "slots" classes work, self.__slots__ only gives the slots defined
# by the current class, but not any of its base classes. Therefore to truly
# check for equality we have to check through all of the slots in all of the
# base classes as well.
for class_ in self.__class__.mro():
for attr in getattr(class_, "__slots__", []):
if getattr(self, attr) != getattr(other, attr):
return False
# It's possible that a subclass was not defined as a slots class, in which
# case we'll need to check __dict__.
if hasattr(self, "__dict__"):
return self.__dict__ == other.__dict__
return True
return NotImplemented

def __len__(self):
Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/fields/flag_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class FlagField(Field[Any]):
This will be passed to a `forward` method as a single value of whatever type you pass in.
"""

__slots__ = ["flag_value"]

def __init__(self, flag_value: Any) -> None:
self.flag_value = flag_value

Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/fields/index_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class IndexField(Field[torch.Tensor]):
A field containing the sequence that this `IndexField` is a pointer into.
"""

__slots__ = ["sequence_index", "sequence_field"]

def __init__(self, index: int, sequence_field: SequenceField) -> None:
self.sequence_index = index
self.sequence_field = sequence_field
Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/fields/label_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class LabelField(Field[torch.Tensor]):
step. If this is `False` and your labels are not strings, this throws a `ConfigurationError`.
"""

__slots__ = ["label", "_label_namespace", "_label_id", "_skip_indexing"]

# Most often, you probably don't want to have OOV/PAD tokens with a LabelField, so we warn you
# about it when you pick a namespace that will getting these tokens by default. It is
# possible, however, that you _do_ actually want OOV/PAD tokens with this Field. This class
Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/fields/list_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class ListField(SequenceField[DataArray]):
contained `Field` objects must be of the same type.
"""

__slots__ = ["field_list"]

def __init__(self, field_list: List[Field]) -> None:
field_class_set = {field.__class__ for field in field_list}
assert (
Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/fields/metadata_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class MetadataField(Field[DataArray], Mapping[str, Any]):
this to be a dictionary, but it could be anything you want.
"""

__slots__ = ["metadata"]

def __init__(self, metadata: Any) -> None:
self.metadata = metadata

Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/fields/multilabel_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class MultiLabelField(Field[torch.Tensor]):
"""

__slots__ = ["labels", "_label_namespace", "_label_ids", "_num_labels"]

# It is possible that users want to use this field with a namespace which uses OOV/PAD tokens.
# This warning will be repeated for every instantiation of this class (i.e for every data
# instance), spewing a lot of warnings so this class variable is used to only log a single
Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/fields/namespace_swapping_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class NamespaceSwappingField(Field[torch.Tensor]):
The namespace that the tokens from the source sentence will be mapped to.
"""

__slots__ = ["_source_tokens", "_target_namespace", "_mapping_array"]

def __init__(self, source_tokens: List[Token], target_namespace: str) -> None:
self._source_tokens = source_tokens
self._target_namespace = target_namespace
Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/fields/sequence_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ class SequenceField(Field[DataArray]):
pointing to words in a `TextField`, items in a `ListField`, or something else.
"""

__slots__ = [] # type: ignore

def sequence_length(self) -> int:
"""
How many elements are there in this sequence?
Expand Down
8 changes: 8 additions & 0 deletions allennlp/data/fields/sequence_label_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ class SequenceLabelField(Field[torch.Tensor]):
strings to integers to use (so that "O" as a tag doesn't get the same id as "O" as a word).
"""

__slots__ = [
"labels",
"sequence_field",
"_label_namespace",
"_indexed_labels",
"_skip_indexing",
]

# It is possible that users want to use this field with a namespace which uses OOV/PAD tokens.
# This warning will be repeated for every instantiation of this class (i.e for every data
# instance), spewing a lot of warnings so this class variable is used to only log a single
Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/fields/span_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class SpanField(Field[torch.Tensor]):
A field containing the sequence that this `SpanField` is a span inside.
"""

__slots__ = ["span_start", "span_end", "sequence_field"]

def __init__(self, span_start: int, span_end: int, sequence_field: SequenceField) -> None:
self.span_start = span_start
self.span_end = span_end
Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/fields/text_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class TextField(SequenceField[TextFieldTensors]):
`TokenCharactersIndexer` produces an array of shape (num_tokens, num_characters).
"""

__slots__ = ["tokens", "_token_indexers", "_indexed_tokens"]

def __init__(self, tokens: List[Token], token_indexers: Dict[str, TokenIndexer]) -> None:
self.tokens = tokens
self._token_indexers = token_indexers
Expand Down
2 changes: 2 additions & 0 deletions allennlp/data/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class Instance(Mapping[str, Field]):
The `Field` objects that will be used to produce data arrays for this instance.
"""

__slots__ = ["fields", "indexed"]

def __init__(self, fields: MutableMapping[str, Field]) -> None:
self.fields = fields
self.indexed = False
Expand Down
66 changes: 66 additions & 0 deletions tests/data/fields/field_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from allennlp.data.fields import Field


def test_eq_with_inheritance():
class SubField(Field):

__slots__ = ["a"]

def __init__(self, a):
self.a = a

class SubSubField(SubField):

__slots__ = ["b"]

def __init__(self, a, b):
super().__init__(a)
self.b = b

class SubSubSubField(SubSubField):

__slots__ = ["c"]

def __init__(self, a, b, c):
super().__init__(a, b)
self.c = c

assert SubField(1) == SubField(1)
assert SubField(1) != SubField(2)

assert SubSubField(1, 2) == SubSubField(1, 2)
assert SubSubField(1, 2) != SubSubField(1, 1)
assert SubSubField(1, 2) != SubSubField(2, 2)

assert SubSubSubField(1, 2, 3) == SubSubSubField(1, 2, 3)
assert SubSubSubField(1, 2, 3) != SubSubSubField(0, 2, 3)


def test_eq_with_inheritance_for_non_slots_field():
class SubField(Field):
def __init__(self, a):
self.a = a

assert SubField(1) == SubField(1)
assert SubField(1) != SubField(2)


def test_eq_with_inheritance_for_mixed_field():
class SubField(Field):

__slots__ = ["a"]

def __init__(self, a):
self.a = a

class SubSubField(SubField):
def __init__(self, a, b):
super().__init__(a)
self.b = b

assert SubField(1) == SubField(1)
assert SubField(1) != SubField(2)

assert SubSubField(1, 2) == SubSubField(1, 2)
assert SubSubField(1, 2) != SubSubField(1, 1)
assert SubSubField(1, 2) != SubSubField(2, 2)

0 comments on commit 06bac68

Please sign in to comment.