Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
make 'cached_path' work offline (#4253)
Browse files Browse the repository at this point in the history
* make 'cached_path' work offline

* add test case

* Update allennlp/common/file_utils.py

Co-authored-by: Michael Schmitz <michael@schmitztech.com>

* log some more info

* else clause

Co-authored-by: Michael Schmitz <michael@schmitztech.com>
  • Loading branch information
epwalsh and schmmd authored May 18, 2020
1 parent fc81067 commit 7d71398
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 35 deletions.
105 changes: 76 additions & 29 deletions allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@
Utilities for working with the local dataset cache.
"""

import glob
import os
import logging
import shutil
import tempfile
import json
from urllib.parse import urlparse
from pathlib import Path
from typing import Optional, Tuple, Union, IO, Callable, Set
from typing import Optional, Tuple, Union, IO, Callable, Set, List
from hashlib import sha256
from functools import wraps

import boto3
import botocore
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, EndpointConnectionError
import requests
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError
from requests.packages.urllib3.util.retry import Retry

from allennlp.common.tqdm import Tqdm
Expand Down Expand Up @@ -124,7 +126,7 @@ def is_url_or_existing_file(url_or_filename: Union[str, Path, None]) -> bool:
return parsed.scheme in ("http", "https", "s3") or os.path.exists(url_or_filename)


def split_s3_path(url: str) -> Tuple[str, str]:
def _split_s3_path(url: str) -> Tuple[str, str]:
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
Expand All @@ -137,7 +139,7 @@ def split_s3_path(url: str) -> Tuple[str, str]:
return bucket_name, s3_path


def s3_request(func: Callable):
def _s3_request(func: Callable):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
Expand All @@ -156,7 +158,7 @@ def wrapper(url: str, *args, **kwargs):
return wrapper


def get_s3_resource():
def _get_s3_resource():
session = boto3.session.Session()
if session.get_credentials() is None:
# Use unsigned requests.
Expand All @@ -168,24 +170,24 @@ def get_s3_resource():
return s3_resource


@s3_request
def s3_etag(url: str) -> Optional[str]:
@_s3_request
def _s3_etag(url: str) -> Optional[str]:
"""Check ETag on S3 object."""
s3_resource = get_s3_resource()
bucket_name, s3_path = split_s3_path(url)
s3_resource = _get_s3_resource()
bucket_name, s3_path = _split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag


@s3_request
def s3_get(url: str, temp_file: IO) -> None:
@_s3_request
def _s3_get(url: str, temp_file: IO) -> None:
"""Pull a file directly from S3."""
s3_resource = get_s3_resource()
bucket_name, s3_path = split_s3_path(url)
s3_resource = _get_s3_resource()
bucket_name, s3_path = _split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)


def session_with_backoff() -> requests.Session:
def _session_with_backoff() -> requests.Session:
"""
We ran into an issue where http requests to s3 were timing out,
possibly because we were making too many requests too quickly.
Expand All @@ -201,8 +203,18 @@ def session_with_backoff() -> requests.Session:
return session


def http_get(url: str, temp_file: IO) -> None:
with session_with_backoff() as session:
def _http_etag(url: str) -> Optional[str]:
with _session_with_backoff() as session:
response = session.head(url, allow_redirects=True)
if response.status_code != 200:
raise IOError(
"HEAD request failed for url {} with status code {}".format(url, response.status_code)
)
return response.headers.get("ETag")


def _http_get(url: str, temp_file: IO) -> None:
with _session_with_backoff() as session:
req = session.get(url, stream=True)
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
Expand All @@ -214,6 +226,22 @@ def http_get(url: str, temp_file: IO) -> None:
progress.close()


def _find_latest_cached(url: str, cache_dir: str) -> Optional[str]:
filename = url_to_filename(url)
cache_path = os.path.join(cache_dir, filename)
candidates: List[Tuple[str, float]] = []
for path in glob.glob(cache_path + "*"):
if path.endswith(".json"):
continue
mtime = os.path.getmtime(path)
candidates.append((path, mtime))
# Sort candidates by modification time, neweste first.
candidates.sort(key=lambda x: x[1], reverse=True)
if candidates:
return candidates[0][0]
return None


# TODO(joelgrus): do we want to do checksums or anything like that?
def get_from_cache(url: str, cache_dir: str = None) -> str:
"""
Expand All @@ -226,18 +254,37 @@ def get_from_cache(url: str, cache_dir: str = None) -> str:
os.makedirs(cache_dir, exist_ok=True)

# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url)
else:
with session_with_backoff() as session:
response = session.head(url, allow_redirects=True)
if response.status_code != 200:
raise IOError(
"HEAD request failed for url {} with status code {}".format(
url, response.status_code
)
try:
if url.startswith("s3://"):
etag = _s3_etag(url)
else:
etag = _http_etag(url)
except (ConnectionError, EndpointConnectionError):
# We might be offline, in which case we don't want to throw an error
# just yet. Instead, we'll try to use the latest cached version of the
# target resource, if it exists. We'll only throw an exception if we
# haven't cached the resource at all yet.
logger.warning(
"Connection error occured while trying to fetch ETag for %s. "
"Will attempt to use latest cached version of resource",
url,
)
latest_cached = _find_latest_cached(url, cache_dir)
if latest_cached:
logger.info(
"ETag request failed with connection error, using latest cached "
"version of %s: %s",
url,
latest_cached,
)
return latest_cached
else:
logger.error(
"Connection failed while trying to fetch ETag, "
"and no cached version of %s could be found",
url,
)
etag = response.headers.get("ETag")
raise

filename = url_to_filename(url, etag)

Expand All @@ -252,9 +299,9 @@ def get_from_cache(url: str, cache_dir: str = None) -> str:

# GET file object
if url.startswith("s3://"):
s3_get(url, temp_file)
_s3_get(url, temp_file)
else:
http_get(url, temp_file)
_http_get(url, temp_file)

# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
Expand Down
52 changes: 46 additions & 6 deletions allennlp/tests/common/file_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@

import pytest
import responses
from requests.exceptions import ConnectionError

from allennlp.common import file_utils
from allennlp.common.file_utils import (
url_to_filename,
filename_to_url,
get_from_cache,
cached_path,
split_s3_path,
_split_s3_path,
open_compressed,
)
from allennlp.common.testing import AllenNlpTestCase
Expand Down Expand Up @@ -57,6 +59,44 @@ def setup_method(self):
with open(self.glove_file, "rb") as glove:
self.glove_bytes = glove.read()

def test_cached_path_offline(self, monkeypatch):
# Ensures `cached_path` just returns the path to the latest cached version
# of the resource when there's no internet connection.

# First we mock the `_http_etag` method so that it raises a `ConnectionError`,
# like it would if there was no internet connection.
def mocked_http_etag(url: str):
raise ConnectionError

monkeypatch.setattr(file_utils, "_http_etag", mocked_http_etag)

url = "/~https://github.com/allenai/allennlp/blob/master/some-fake-resource"

# We'll create two cached versions of this fake resource using two different etags.
etags = ['W/"3e5885bfcbf4c47bc4ee9e2f6e5ea916"', 'W/"3e5885bfcbf4c47bc4ee9e2f6e5ea918"']
filenames = [os.path.join(self.TEST_DIR, url_to_filename(url, etag)) for etag in etags]
for filename, etag in zip(filenames, etags):
meta_filename = filename + ".json"
with open(filename, "w") as f:
f.write("some random data")
with open(meta_filename, "w") as meta_f:
json.dump({"url": url, "etag": etag}, meta_f)

# The version corresponding to the last etag should be returned, since
# that one has the latest "last modified" time.
assert get_from_cache(url, cache_dir=self.TEST_DIR) == filenames[-1]

# We also want to make sure this works when the latest cached version doesn't
# have a corresponding etag.
filename = os.path.join(self.TEST_DIR, url_to_filename(url))
meta_filename = filename + ".json"
with open(filename, "w") as f:
f.write("some random data")
with open(meta_filename, "w") as meta_f:
json.dump({"url": url, "etag": etag}, meta_f)

assert get_from_cache(url, cache_dir=self.TEST_DIR) == filename

def test_url_to_filename(self):
for url in [
"http://allenai.org",
Expand Down Expand Up @@ -120,14 +160,14 @@ def test_url_to_filename_with_etags_eliminates_quotes(self):

def test_split_s3_path(self):
# Test splitting good urls.
assert split_s3_path("s3://my-bucket/subdir/file.txt") == ("my-bucket", "subdir/file.txt")
assert split_s3_path("s3://my-bucket/file.txt") == ("my-bucket", "file.txt")
assert _split_s3_path("s3://my-bucket/subdir/file.txt") == ("my-bucket", "subdir/file.txt")
assert _split_s3_path("s3://my-bucket/file.txt") == ("my-bucket", "file.txt")

# Test splitting bad urls.
with pytest.raises(ValueError):
split_s3_path("s3://")
split_s3_path("s3://myfile.txt")
split_s3_path("myfile.txt")
_split_s3_path("s3://")
_split_s3_path("s3://myfile.txt")
_split_s3_path("myfile.txt")

@responses.activate
def test_get_from_cache(self):
Expand Down

0 comments on commit 7d71398

Please sign in to comment.