Skip to content

Commit

Permalink
Merge pull request #208 from clara-parabricks/ntadimeti/fix_train_dat…
Browse files Browse the repository at this point in the history
…aloader

Fix train dataloader
  • Loading branch information
ntadimeti authored Aug 25, 2020
2 parents e9fa1e9 + 0c7dc74 commit 45cd565
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 4 deletions.
2 changes: 1 addition & 1 deletion atacworks/dl4atac/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _get_generator(self):
rec['input'],
hrecs[layer_key][file_id][local_idx]))
rec['input'] = np.swapaxes(rec['input'], 0, 1)
yield rec
idx = yield rec


class DatasetInfer(DatasetBase):
Expand Down
2 changes: 0 additions & 2 deletions configs/infer_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ regions: "None"
# Data Pre-processing parameters
interval_size: 50000
genome: "hg19"
nonpeak: "None"

# Experiment args
peaks: False
Expand All @@ -34,7 +33,6 @@ pad: 5000
layers: "None"

#Infer args
input_files: "None"
intervals_file: "None"
genome: "None"
reg_rounding: 0
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def get_installation_requirments(file_path):


setup(name='atacworks',
version='0.3.0',
version='0.3.1',
description='NVIDIA genomics python libraries and utiliites',
author='NVIDIA Corporation',
url="/~https://github.com/clara-genomics/AtacWorks",
Expand Down
61 changes: 61 additions & 0 deletions tests/dl4atac/test_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
#
"""Unit tests for atacworks/dl4atac/dataset.py module."""
import os

import numpy as np
import pytest

from atacworks.io import h5io
from atacworks.dl4atac.dataset import DatasetInfer, DatasetTrain


@pytest.mark.cpu
def test_dataset_train(tmpdir):
"""Create a dict of numpy arrays and write it to a file using \
the dict_to_h5 API. Send the file to DatasetTrain API and verify \
that samples are being read as expected."""
input_dict = {"input": np.array([0, 1, 3, 4]),
"label_reg": np.array([1, 2, 10, 1]),
"label_cla": np.array([1, 1, 0, 0])}
h5file = os.path.join(tmpdir, "h5file.h5")
h5io.dict_to_h5(input_dict, h5file)
train_dataset = DatasetTrain(files=[h5file], layers=None)
for idx in range(0, len(train_dataset)):
expected_dict = {"idx": idx,
"input": input_dict["input"][idx],
"label_reg": input_dict["label_reg"][idx],
"label_cla": input_dict["label_cla"][idx]}
result_dict = train_dataset[idx]

assert expected_dict.keys() == result_dict.keys()
for key, value in expected_dict.items():
assert result_dict[key] == value


@pytest.mark.cpu
def test_dataset_infer(tmpdir):
"""Create a dict of numpy arrays and write it to a file using \
the dict_to_h5 API. Send the file to DatasetInfer API and verify \
that samples are being read as expected."""
input_dict = {"input": np.array([0, 1, 3, 4]),
"label_reg": np.array([1, 2, 10, 1]),
"label_cla": np.array([1, 1, 0, 0])}
h5file = os.path.join(tmpdir, "h5file.h5")
h5io.dict_to_h5(input_dict, h5file)
infer_dataset = DatasetInfer(files=[h5file], layers=None)
for idx in range(0, len(infer_dataset)):
expected_dict = {"idx": idx,
"input": input_dict["input"][idx]}
result_dict = infer_dataset[idx]

assert expected_dict.keys() == result_dict.keys()
for key, value in expected_dict.items():
assert result_dict[key] == value

0 comments on commit 45cd565

Please sign in to comment.