Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Open Compressed (#5578)
Browse files Browse the repository at this point in the history
* functionality to accept compressed files as input to predict

* test for lzma format included

* minor logical error

* suggested changes incorporated

* Compression is always auto-detected

* Import the open_compressed() function from Tango

* Changelog

* The canonical extension for lzma is xz

* Updated file name

Co-authored-by: Dbhasin@1 <drishti_b@me.iitr.ac.in>
  • Loading branch information
dirkgr and Dbhasin1 authored Feb 24, 2022
1 parent 5b3352c commit 92e54cc
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 21 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Running the test suite out-of-tree (e.g. after installation) is now possible by pointing the environment variable `ALLENNLP_SRC_DIR` to the sources.
- Silenced a warning that happens when you inappropriately clone a tensor.

### Added

- We can now transparently read compressed input files during prediction.
- LZMA compression is now supported.


## [v2.9.0](/~https://github.com/allenai/allennlp/releases/tag/v2.9.0) - 2022-01-27

### Added
Expand Down
5 changes: 2 additions & 3 deletions allennlp/commands/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from allennlp.commands.subcommand import Subcommand
from allennlp.common import logging as common_logging
from allennlp.common.checks import check_for_gpu, ConfigurationError
from allennlp.common.file_utils import cached_path
from allennlp.common.file_utils import cached_path, open_compressed
from allennlp.common.util import lazy_groups_of
from allennlp.data.dataset_readers import MultiTaskDatasetReader
from allennlp.models.archival import load_archive
Expand Down Expand Up @@ -158,7 +158,6 @@ def __init__(
self._batch_size = batch_size
self._print_to_console = print_to_console
self._dataset_reader = None if not has_dataset_reader else predictor._dataset_reader

self._multitask_head = multitask_head
if self._multitask_head is not None:
if self._dataset_reader is None:
Expand Down Expand Up @@ -210,7 +209,7 @@ def _get_json_data(self) -> Iterator[JsonDict]:
yield self._predictor.load_line(line)
else:
input_file = cached_path(self._input_file)
with open(input_file, "r") as file_input:
with open_compressed(input_file) as file_input:
for line in file_input:
if not line.isspace():
yield self._predictor.load_line(line)
Expand Down
31 changes: 22 additions & 9 deletions allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""
Utilities for working with the local dataset cache.
"""
import bz2
import gzip
import lzma
import weakref
from contextlib import contextmanager
import glob
Expand Down Expand Up @@ -443,21 +446,31 @@ def get_file_extension(path: str, dot=True, lower: bool = True):
return ext.lower() if lower else ext


_SUFFIXES: Dict[Callable, str] = {
open: "",
gzip.open: ".gz",
bz2.open: ".bz2",
lzma.open: ".xz",
}


def open_compressed(
filename: Union[str, PathLike], mode: str = "rt", encoding: Optional[str] = "UTF-8", **kwargs
filename: Union[str, PathLike],
mode: str = "rt",
encoding: Optional[str] = "UTF-8",
**kwargs,
):
if not isinstance(filename, str):
filename = str(filename)
open_fn: Callable = open

if filename.endswith(".gz"):
import gzip

open_fn = gzip.open
elif filename.endswith(".bz2"):
import bz2
open_fn: Callable
filename = str(filename)
for open_fn, suffix in _SUFFIXES.items():
if len(suffix) > 0 and filename.endswith(suffix):
break
else:
open_fn = open

open_fn = bz2.open
return open_fn(cached_path(filename), mode=mode, encoding=encoding, **kwargs)


Expand Down
2 changes: 1 addition & 1 deletion allennlp/modules/token_embedders/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,7 @@ def __init__(
import bz2

package = bz2
elif extension == ".lzma":
elif extension == ".xz":
import lzma

package = lzma
Expand Down
Binary file not shown.
12 changes: 6 additions & 6 deletions tests/common/file_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,16 @@ def test_extract_with_external_symlink(self):
with pytest.raises(ValueError):
cached_path(dangerous_file, extract_archive=True)

def test_open_compressed(self):
@pytest.mark.parametrize("suffix", ["bz2", "gz", "xz"])
def test_open_compressed(self, suffix: str):
uncompressed_file = self.FIXTURES_ROOT / "embeddings/fake_embeddings.5d.txt"
with open_compressed(uncompressed_file) as f:
uncompressed_lines = [line.strip() for line in f]

for suffix in ["bz2", "gz"]:
compressed_file = f"{uncompressed_file}.{suffix}"
with open_compressed(compressed_file) as f:
compressed_lines = [line.strip() for line in f]
assert compressed_lines == uncompressed_lines
compressed_file = f"{uncompressed_file}.{suffix}"
with open_compressed(compressed_file) as f:
compressed_lines = [line.strip() for line in f]
assert compressed_lines == uncompressed_lines

def test_meta_backwards_compatible(self):
url = "http://fake.datastore.com/glove.txt.gz"
Expand Down
2 changes: 1 addition & 1 deletion tests/data/vocabulary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def test_read_pretrained_words(self):

# Reading from a single (compressed) file or a single-file archive
base_path = str(self.FIXTURES_ROOT / "embeddings/fake_embeddings.5d.txt")
for ext in ["", ".gz", ".lzma", ".bz2", ".zip", ".tar.gz"]:
for ext in ["", ".gz", ".xz", ".bz2", ".zip", ".tar.gz"]:
file_path = base_path + ext
words_read = set(_read_pretrained_tokens(file_path))
assert words_read == words, (
Expand Down
2 changes: 1 addition & 1 deletion tests/modules/token_embedders/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def test_embeddings_text_file(self):
assert text == correct_text, "Test failed for file: " + path

# Check for a file contained inside an archive with multiple files
for ext in [".zip", ".tar.gz", ".tar.bz2", ".tar.lzma"]:
for ext in [".zip", ".tar.gz", ".tar.bz2", ".tar.xz"]:
archive_path = str(self.FIXTURES_ROOT / "utf-8_sample/archives/utf-8") + ext
file_uri = format_embeddings_file_uri(archive_path, "folder/utf-8_sample.txt")
with EmbeddingsTextFile(file_uri) as f:
Expand Down

0 comments on commit 92e54cc

Please sign in to comment.