Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
experimental: add backoff (#2968)
Browse files Browse the repository at this point in the history
* add backoff

* use for etag

* use backoff everywhere

* pylint
  • Loading branch information
joelgrus authored Jun 18, 2019
1 parent 2a59be3 commit a655ad5
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 33 deletions.
39 changes: 28 additions & 11 deletions allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import botocore
from botocore.exceptions import ClientError
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry # pylint: disable=import-error

from allennlp.common.tqdm import Tqdm

Expand Down Expand Up @@ -78,7 +80,6 @@ def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:

return url, etag


def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str:
"""
Given something that might be a URL (or might be a local path),
Expand Down Expand Up @@ -176,17 +177,32 @@ def s3_get(url: str, temp_file: IO) -> None:
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)

def session_with_backoff() -> requests.Session:
"""
We ran into an issue where http requests to s3 were timing out,
possibly because we were making too many requests too quickly.
This helper function returns a requests session that has retry-with-backoff
built in.
see stackoverflow.com/questions/23267409/how-to-implement-retry-mechanism-into-python-requests-library
"""
session = requests.Session()
retries = Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])
session.mount('http://', HTTPAdapter(max_retries=retries))
session.mount('https://', HTTPAdapter(max_retries=retries))

return session

def http_get(url: str, temp_file: IO) -> None:
req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = Tqdm.tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
with session_with_backoff() as session:
req = session.get(url, stream=True)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = Tqdm.tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()


# TODO(joelgrus): do we want to do checksums or anything like that?
Expand All @@ -204,7 +220,8 @@ def get_from_cache(url: str, cache_dir: str = None) -> str:
if url.startswith("s3://"):
etag = s3_etag(url)
else:
response = requests.head(url, allow_redirects=True)
with session_with_backoff() as session:
response = session.head(url, allow_redirects=True)
if response.status_code != 200:
raise IOError("HEAD request failed for url {} with status code {}"
.format(url, response.status_code))
Expand Down
23 changes: 12 additions & 11 deletions allennlp/predictors/wikitables_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from subprocess import run
from typing import List
import shutil
import requests

from overrides import overrides
import torch

from allennlp.common.file_utils import cached_path
from allennlp.common.file_utils import cached_path, session_with_backoff
from allennlp.common.util import JsonDict, sanitize
from allennlp.data import DatasetReader, Instance
from allennlp.models import Model
Expand Down Expand Up @@ -40,16 +39,18 @@ def __init__(self, model: Model, dataset_reader: DatasetReader) -> None:
# Load auxiliary sempre files during startup for faster logical form execution.
os.makedirs(SEMPRE_DIR, exist_ok=True)
abbreviations_path = os.path.join(SEMPRE_DIR, 'abbreviations.tsv')
if not os.path.exists(abbreviations_path):
result = requests.get(ABBREVIATIONS_FILE)
with open(abbreviations_path, 'wb') as downloaded_file:
downloaded_file.write(result.content)

grammar_path = os.path.join(SEMPRE_DIR, 'grow.grammar')
if not os.path.exists(grammar_path):
result = requests.get(GROW_FILE)
with open(grammar_path, 'wb') as downloaded_file:
downloaded_file.write(result.content)

with session_with_backoff() as session:
if not os.path.exists(abbreviations_path):
result = session.get(ABBREVIATIONS_FILE)
with open(abbreviations_path, 'wb') as downloaded_file:
downloaded_file.write(result.content)

if not os.path.exists(grammar_path):
result = session.get(GROW_FILE)
with open(grammar_path, 'wb') as downloaded_file:
downloaded_file.write(result.content)

@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
Expand Down
23 changes: 12 additions & 11 deletions allennlp/semparse/executors/wikitables_sempre_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
import pathlib
import shutil
import subprocess
import requests

from allennlp.common.file_utils import cached_path
from allennlp.common.file_utils import cached_path, session_with_backoff
from allennlp.common.checks import check_for_java

logger = logging.getLogger(__name__) # pylint: disable=invalid-name
Expand Down Expand Up @@ -74,16 +73,18 @@ def _create_sempre_executor(self) -> None:
# sure we put the files in that location.
os.makedirs(SEMPRE_DIR, exist_ok=True)
abbreviations_path = os.path.join(SEMPRE_DIR, 'abbreviations.tsv')
if not os.path.exists(abbreviations_path):
result = requests.get(ABBREVIATIONS_FILE)
with open(abbreviations_path, 'wb') as downloaded_file:
downloaded_file.write(result.content)

grammar_path = os.path.join(SEMPRE_DIR, 'grow.grammar')
if not os.path.exists(grammar_path):
result = requests.get(GROW_FILE)
with open(grammar_path, 'wb') as downloaded_file:
downloaded_file.write(result.content)

with session_with_backoff() as session:
if not os.path.exists(abbreviations_path):
result = session.get(ABBREVIATIONS_FILE)
with open(abbreviations_path, 'wb') as downloaded_file:
downloaded_file.write(result.content)

if not os.path.exists(grammar_path):
result = session.get(GROW_FILE)
with open(grammar_path, 'wb') as downloaded_file:
downloaded_file.write(result.content)

if not check_for_java():
raise RuntimeError('Java is not installed properly.')
Expand Down

0 comments on commit a655ad5

Please sign in to comment.