Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update training integration tests #113

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 141 additions & 3 deletions protopipe/pipeline/temp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@
"""

import os
import time
import logging
from pathlib import Path

import requests
from tqdm import tqdm
from urllib.parse import urlparse
import numpy as np
from numpy import nan
from scipy.sparse import lil_matrix, csr_matrix
import astropy.units as u
from astropy.coordinates import SkyCoord
from requests.exceptions import HTTPError

from ctapipe.core import Container, Field
from ctapipe.coordinates import CameraFrame, TelescopeFrame
Expand All @@ -28,6 +32,7 @@

logger = logging.getLogger(__name__)


class HillasParametersTelescopeFrameContainer(Container):
container_prefix = "hillas"

Expand Down Expand Up @@ -213,6 +218,139 @@ def initialize_hillas_planes(
except ImportError:
has_resources = False


def download_file(url, path, auth=None, chunk_size=10240, progress=False):
"""
Download a file. Will write to ``path + '.part'`` while downloading
and rename after successful download to the final name.
Parameters
----------
url: str or url
The URL to download
path: pathlib.Path or str
Where to store the downloaded data.
auth: None or tuple of (username, password) or a request.AuthBase instance.
chunk_size: int
Chunk size for writing the data file, 10 kB by default.
"""
logger.info(f"Downloading {url} to {path}")
name = urlparse(url).path.split("/")[-1]
path = Path(path)

with requests.get(url, stream=True, auth=auth, timeout=5) as r:
# make sure the request is successful
r.raise_for_status()

total = float(r.headers.get("Content-Length", float("inf")))
pbar = tqdm(
total=total,
disable=not progress,
unit="B",
unit_scale=True,
desc=f"Downloading {name}",
)

try:
# open a .part file to avoid creating
# a broken file at the intended location
part_file = path.with_suffix(path.suffix + ".part")

part_file.parent.mkdir(parents=True, exist_ok=True)
with part_file.open("wb") as f:
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
pbar.update(len(chunk))
except: # we really want to catch everythin here
# cleanup part file if something goes wrong
if part_file.is_file():
part_file.unlink()
raise

# when successful, move to intended location
part_file.rename(path)


def get_cache_path(url, cache_name="ctapipe", env_override="CTAPIPE_CACHE"):
if os.getenv(env_override):
base = Path(os.environ["CTAPIPE_CACHE"])
else:
base = Path(os.environ["HOME"]) / ".cache" / cache_name

url = urlparse(url)

path = os.path.join(url.netloc.rstrip("/"), url.path.lstrip("/"))
path = base / path
path.parent.mkdir(parents=True, exist_ok=True)
return path


def download_file_cached(
name,
cache_name="ctapipe",
auth=None,
env_prefix="CTAPIPE_DATA_",
default_url="http://cccta-dataserver.in2p3.fr/data/",
progress=False,
):
"""
Downloads a file from a dataserver and caches the result locally
in ``$HOME/.cache/<cache_name>``.
If the file is found in the cache, no new download is performed.
Parameters
----------
name: str or pathlib.Path
the name of the file, relative to the data server url
cache_name: str
What name to use for the cache directory
env_prefix: str
Prefix for the environt variables used for overriding the URL,
and providing username and password in case authentication is required.
auth: True, None or tuple of (username, password)
Authentication data for the request. Will be passed to ``requests.get``.
If ``True``, read username and password for the request from
the env variables ``env_prefix + 'USER'`` and ``env_prefix + PASSWORD``
default_url: str
The default url from which to download ``name``, can be overriden
by setting the env variable ``env_prefix + URL``
Returns
-------
path: pathlib.Path
the full path to the downloaded data.
"""
logger.debug(f"File {name} is not available in cache, downloading.")

base_url = os.environ.get(env_prefix + "URL", default_url).rstrip("/")
url = base_url + "/" + str(name).lstrip("/")

path = get_cache_path(url, cache_name=cache_name)
part_file = path.with_suffix(path.suffix + ".part")

if part_file.is_file():
logger.warning("Another download for this file is already running, waiting.")
while part_file.is_file():
time.sleep(1)

# if we already dowloaded the file, just use it
if path.is_file():
logger.debug(f"File {name} is available in cache.")
return path

if auth is True:
try:
auth = (
os.environ[env_prefix + "USER"],
os.environ[env_prefix + "PASSWORD"],
)
except KeyError:
raise KeyError(
f'You need to set the env variables "{env_prefix}USER"'
f' and "{env_prefix}PASSWORD" to download test files.'
) from None

download_file(url=url, path=path, auth=auth, progress=progress)
return path


DEFAULT_URL = "http://cccta-dataserver.in2p3.fr/data/ctapipe-extra/v0.3.3/"
def get_dataset_path(filename, url=DEFAULT_URL):
"""
Expand All @@ -230,7 +368,7 @@ def get_dataset_path(filename, url=DEFAULT_URL):
-------
string with full path to the given dataset
"""

searchpath = os.getenv("CTAPIPE_SVC_PATH")

if searchpath:
Expand All @@ -239,7 +377,7 @@ def get_dataset_path(filename, url=DEFAULT_URL):
if filepath:
return filepath

if has_resources:
if has_resources and (url is DEFAULT_URL):
logger.debug(
"Resource '{}' not found in CTAPIPE_SVC_PATH, looking in "
"ctapipe_resources...".format(filename)
Expand Down
14 changes: 8 additions & 6 deletions protopipe/scripts/tests/test_dataTraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@

# TEST FILES

URL = "http://cccta-dataserver.in2p3.fr/data/protopipe/testData/"

# PROD 3b

PROD3B_CTA_NORTH = get_dataset_path("gamma_LaPalma_baseline_20Zd_180Az_prod3b_test.simtel.gz")
PROD3B_CTA_SOUTH = get_dataset_path("gamma_Paranal_baseline_20Zd_180Az_prod3_test.simtel.gz")
PROD3B_CTA_NORTH = get_dataset_path("gamma1.simtel.gz",
url=f"{URL}/prod3_laPalma_baseline_Az180_Az20")
PROD3B_CTA_SOUTH = get_dataset_path("gamma1.simtel.gz",
url=f"{URL}/prod3_Paranal_baseline_Az180_Az20")


@pytest.mark.parametrize("input_file", [PROD3B_CTA_NORTH, PROD3B_CTA_SOUTH])
Expand All @@ -36,7 +40,7 @@ def test_dataTraining_noImages(input_file):
"""

# the difference is only the 'site' key as a check for the user
if "Paranal" in str(input_file):
if input_file in [PROD3B_CTA_SOUTH]:
ana_config = resource_filename("protopipe", "scripts/tests/test_config_analysis_south.yaml")
else:
ana_config = resource_filename("protopipe", "scripts/tests/test_config_analysis_north.yaml")
Expand All @@ -45,7 +49,6 @@ def test_dataTraining_noImages(input_file):
f"python {data_training.__file__}\
--config_file {ana_config}\
-o test_training_noImages.h5\
-m 10\
-i {path.dirname(input_file)}\
-f {path.basename(input_file)}"
)
Expand All @@ -69,7 +72,7 @@ def test_dataTraining_withImages(input_file):
"""

# the difference is only the 'site' key as a check for the user
if "Paranal" in str(input_file):
if input_file in [PROD3B_CTA_SOUTH]:
ana_config = resource_filename("protopipe", "scripts/tests/test_config_analysis_south.yaml")
else:
ana_config = resource_filename("protopipe", "scripts/tests/test_config_analysis_north.yaml")
Expand All @@ -78,7 +81,6 @@ def test_dataTraining_withImages(input_file):
f"python {data_training.__file__}\
--config_file {ana_config}\
-o test_training_withImages.h5\
-m 10\
--save_images\
-i {path.dirname(input_file)}\
-f {path.basename(input_file)}"
Expand Down