diff --git a/h5pyd/_hl/dataset.py b/h5pyd/_hl/dataset.py index 7782b65..332a80f 100644 --- a/h5pyd/_hl/dataset.py +++ b/h5pyd/_hl/dataset.py @@ -18,6 +18,9 @@ import time import base64 import numpy +import os +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import as_completed from .base import HLObject, jsonToArray, bytesToArray, arrayToBytes from .base import Empty, guess_dtype @@ -1742,3 +1745,147 @@ def toTuple(self, data): return tuple(self.toTuple(x) for x in data) else: return data + + +class MultiManager(): + """ + high-level object to support slicing operations + that map to H5Dread_multi/H5Dwrite_multi + """ + # Avoid overtaxing HSDS + max_workers = 16 + + def __init__(self, datasets=None): + if (datasets is None) or (len(datasets) == 0): + raise ValueError("MultiManager requires non-empty list of datasets") + self.datasets = datasets + + def read_dset_tl(self, args): + """ + Thread-local method to read from a single dataset + """ + dset = args[0] + idx = args[1] + try: + read_args = args[2] + except Exception as e: + raise e + return (idx, dset[read_args]) + + def write_dset_tl(self, args): + """ + Thread-local method to write to a single dataset + """ + dset = args[0] + idx = args[1] + write_args = args[2] + write_vals = args[3] + try: + dset[write_args] = write_vals + except Exception as e: + raise e + return + + def __getitem__(self, args): + """ + Read the same slice from each of the datasets + managed by this MultiManager. + """ + # Spread requests out evenly among all available SNs + + # TODO: This should eventually be handled at the config/HTTPConn level + try: + num_endpoints = int(os.environ["SN_CORES"]) + port_range = os.environ["SN_PORT_RANGE"] + ports = port_range.split('-') + + if len(ports) != 2: + raise ValueError("Malformed SN_PORT_RANGE") + + low_port = int(ports[0]) + high_port = int(ports[1]) + + except Exception as e: + msg = f"{e}: Defaulting Number of SN_COREs to 1" + self.log.warning(msg) + num_endpoints = 1 + + if (num_endpoints > 1): + next_port = low_port + port_len = len(ports[0]) + + for i, dset in enumerate(self.datasets): + endpt = dset.id.http_conn._endpoint + endpt = endpt[:len(endpt) - port_len] + str(next_port) + dset.id.http_conn._endpoint = endpt + next_port += 1 + + if next_port > high_port: + next_port = low_port + + # TODO: Handle the case where some or all datasets share an HTTPConn object + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + read_futures = [executor.submit(self.read_dset_tl, (self.datasets[i], i, args)) for i in range(len(self.datasets))] + ret_data = [None] * len(self.datasets) + + for future in as_completed(read_futures): + try: + result = future.result() + idx = result[0] + dset_data = result[1] + ret_data[idx] = dset_data + except Exception as exc: + executor.shutdown(wait=False) + raise ValueError(f"Error during multi-read: {exc}") + return ret_data + + def __setitem__(self, args, vals): + """ + Write to the provided slice of each dataset + managed by this MultiManager. + """ + # TODO: This should eventually be handled at the config/HTTPConn level + try: + num_endpoints = int(os.environ["SN_CORES"]) + port_range = os.environ["SN_PORT_RANGE"] + ports = port_range.split('-') + + if len(ports) != 2: + raise ValueError("Malformed SN_PORT_RANGE") + + low_port = int(ports[0]) + high_port = int(ports[1]) + + if (high_port - low_port) != num_endpoints - 1: + raise ValueError("Malformed port range specification; must be sequential ports") + + except Exception as e: + print(f"{e}: Defaulting Number of SNs to 1") + num_endpoints = 1 + + # TODO: Handle the case where some or all datasets share an HTTPConn object + # For now, assume each connection is distinct + if (num_endpoints > 1): + next_port = low_port + port_len = len(ports[0]) + + for i, dset in enumerate(self.datasets): + endpt = dset.id.http_conn._endpoint + endpt = endpt[:len(endpt) - port_len] + str(next_port) + dset.id.http_conn._endpoint = endpt + next_port += 1 + + if next_port > high_port: + next_port = low_port + + with ThreadPoolExecutor(max_workers=self.max_workers) as executor: + write_futures = [executor.submit(self.write_dset_tl, (self.datasets[i], i, args, vals[i])) for i in range(len(self.datasets))] + + for future in as_completed(write_futures): + try: + future.result() + except Exception as exc: + executor.shutdown(wait=False) + raise ValueError(f"Error during multi-write: {exc}") + return diff --git a/h5pyd/_hl/httpconn.py b/h5pyd/_hl/httpconn.py index 8219547..4f2a73e 100644 --- a/h5pyd/_hl/httpconn.py +++ b/h5pyd/_hl/httpconn.py @@ -753,11 +753,11 @@ def session(self): s.mount( "http://", - HTTPAdapter(max_retries=retry), + HTTPAdapter(max_retries=retry, pool_connections=16, pool_maxsize=16), ) s.mount( "https://", - HTTPAdapter(max_retries=retry), + HTTPAdapter(max_retries=retry, pool_connections=16, pool_maxsize=16), ) self._s = s else: diff --git a/test/hl/multi_benchmark.py b/test/hl/multi_benchmark.py new file mode 100644 index 0000000..44c210e --- /dev/null +++ b/test/hl/multi_benchmark.py @@ -0,0 +1,287 @@ +import numpy as np +import time + +from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import as_completed +import subprocess +import re + +from h5pyd._hl.dataset import MultiManager +import h5pyd as h5py + +# Flag to stop resource usage collection thread after a benchmark finishes +stop_stat_collection = False + + +def write_datasets_multi(datasets, num_iters=1): + mm = MultiManager(datasets) + data = np.reshape(np.arange(np.prod(datasets[0].shape)), datasets[0].shape) + + start = time.time() + for i in range(num_iters): + mm[...] = [data] * len(datasets) + end = time.time() + avg_time = (end - start) / num_iters + + return avg_time + + +def write_datasets_serial(datasets, num_iters=1): + data = np.reshape(np.arange(np.prod(datasets[0].shape)), datasets[0].shape) + + start = time.time() + for i in range(num_iters): + for d in datasets: + d[...] = data + end = time.time() + avg_time = (end - start) / num_iters + + return avg_time + + +def read_datasets_multi(datasets, num_iters=1): + mm = MultiManager(datasets) + + start = time.time() + for i in range(num_iters): + out = mm[...] + if out is None: + raise ValueError("Read failed!") + + end = time.time() + avg_time = (end - start) / num_iters + + return avg_time + + +def read_datasets_serial(datasets, num_iters=1): + start = time.time() + for i in range(num_iters): + for d in datasets: + out = d[...] + if out is None: + raise ValueError("Read failed!") + + end = time.time() + avg_time = (end - start) / num_iters + + return avg_time + + +def read_datasets_multi_selections(datasets, num_iters=1): + shape = datasets[0].shape + rank = len(shape) + mm = MultiManager(datasets=datasets) + + start = time.time() + for i in range(num_iters): + # Generate random selection + sel = np.random.randint(0, shape[0], size=rank * 2) + out = mm[sel[0]:sel[1], sel[2]:sel[3], sel[4]:sel[5]] + if out is None: + raise ValueError("Read failed!") + end = time.time() + avg_time = (end - start) / num_iters + + return avg_time + + +def read_datasets_serial_selections(datasets, num_iters=1): + shape = datasets[0].shape + rank = len(shape) + + start = time.time() + for i in range(num_iters): + # Generate random selection + sel = np.random.randint(0, shape[0], size=rank * 2) + for d in datasets: + out = d[sel[0]:sel[1], sel[2]:sel[3], sel[4]:sel[5]] + if out is None: + raise ValueError("Read failed!") + end = time.time() + avg_time = (end - start) / num_iters + + return avg_time + + +def write_datasets_multi_selections(datasets, num_iters=1): + shape = datasets[0].shape + rank = len(shape) + data_in = np.reshape(np.arange(np.prod(shape)), shape) + + mm = MultiManager(datasets=datasets) + + start = time.time() + for i in range(num_iters): + # Generate random selection + sel = np.random.randint(0, shape[0], size=rank * 2) + write_data = data_in[sel[0]:sel[1], sel[2]:sel[3], sel[4]:sel[5]] + mm[sel[0]:sel[1], sel[2]:sel[3], sel[4]:sel[5]] = [write_data] * count + end = time.time() + avg_time = (end - start) / num_iters + + return avg_time + + +def write_datasets_serial_selections(datasets, num_iters=1): + shape = datasets[0].shape + rank = len(shape) + data_in = np.reshape(np.arange(np.prod(shape)), shape) + + start = time.time() + for i in range(num_iters): + # Generate random selection + sel = np.random.randint(0, shape[0], size=rank * 2) + write_data = data_in[sel[0]:sel[1], sel[2]:sel[3], sel[4]:sel[5]] + + for d in datasets: + d[sel[0]:sel[1], sel[2]:sel[3], sel[4]:sel[5]] = write_data + end = time.time() + avg_time = (end - start) / num_iters + + return avg_time + + +def test_thread_error(f): + dset1 = f.create_dataset("d1", data=np.arange(100), shape=(100,), dtype=np.int32) + dset2 = f.create_dataset("d2", data=np.reshape(np.arange(100), (10, 10)), shape=(10, 10), dtype=np.int32) + mm = MultiManager([dset1, dset2]) + out = mm[0:15, 0:15] # Only valid for dset 2 + print(out) + return out + + +def get_docker_stats(test_name): + global stop_stat_collection + sn_stat_instances = 0 + dn_stat_instances = 0 + sn_count = 0 + dn_count = 0 + + if test_name in stats: + raise ValueError(f"Test name conflict on name \"{test_name}\"") + + test_stats = {"time": 0.0, "dn_cpu": 0.0, "dn_mem": 0.0, "sn_cpu": 0.0, "sn_mem": 0.0} + + while True: + if stop_stat_collection: + stop_stat_collection = False + return test_stats + + stats_out = subprocess.check_output(['docker', 'stats', '--no-stream']) + + lines = stats_out.splitlines() + + # Count SNs and DNs on first stat check + if sn_count == 0: + for line in lines[1:]: + line = line.decode('utf-8') + # Replace all substrings of whitespace with single space + line = re.sub(" +", " ", line) + words = line.split(' ') + container_name = words[1] + + if "_dn_" in container_name: + dn_count += 1 + elif "_sn_" in container_name: + sn_count += 1 + + for line in lines[1:]: + line = line.decode('utf-8') + # Replace all substrings of whitespace with single space + line = re.sub(" +", " ", line) + words = line.split(' ') + + container_name = words[1] + cpu_percent = float((words[2])[:-1]) + mem_percent = float((words[6])[:-1]) + + # Update average usage values + if "_dn_" in container_name: + dn_stat_instances += 1 + ratio = (dn_stat_instances - 1) / dn_stat_instances + test_stats["dn_cpu"] = (test_stats["dn_cpu"] * ratio) + cpu_percent / dn_stat_instances + test_stats["dn_mem"] = (test_stats["dn_mem"] * ratio) + mem_percent / dn_stat_instances + elif "_sn_" in container_name: + sn_stat_instances += 1 + ratio = (sn_stat_instances - 1) / sn_stat_instances + test_stats["sn_cpu"] = (test_stats["sn_cpu"] * ratio) + cpu_percent / sn_stat_instances + test_stats["sn_mem"] = (test_stats["sn_mem"] * ratio) + mem_percent / sn_stat_instances + else: + # Ignore other docker containers + pass + + # Query docker for stats once per second + time.sleep(1) + + +def run_benchmark(test_name, test_func, stats, datasets, num_iters): + global stop_stat_collection + # For each section, execute docker resource usage readout at simultaneously on a second thread + with ThreadPoolExecutor(max_workers=2) as executor: + futures = [] + futures.append(executor.submit(test_func, datasets, num_iters)) + futures.append(executor.submit(get_docker_stats, test_name)) + time_elapsed = 0.0 + + for f in as_completed(futures): + try: + ret = f.result() + if isinstance(ret, float): + # Benchmark returned; terminate docker stats computation + time_elapsed = ret + stop_stat_collection = True + elif isinstance(ret, dict): + # Stat collection returned + stats[test_name] = ret + stats[test_name]["time"] = time_elapsed + + except Exception as exc: + executor.shutdown(wait=False) + raise ValueError(f"Error during benchmark threading for {test_name}: {exc}") + + +if __name__ == '__main__': + print("Executing multi read/write benchmark") + shape = (100, 100, 100) + count = 64 + num_iters = 50 + dt = np.int32 + stats = {} + + fs = [h5py.File("/home/test_user1/h5pyd_multi_bm_" + str(i), mode='w') for i in range(count)] + data_in = np.zeros(shape, dtype=dt) + datasets = [f.create_dataset("data", shape, dtype=dt, data=data_in) for f in fs] + + print(f"Created {count} datasets, each with {np.prod(shape)} elements") + print(f"Benchmarks will be repeated {num_iters} times") + + print("Testing with multiple HTTP Connections...") + + run_benchmark("Read Multi (Multiple HttpConn)", read_datasets_multi, stats, datasets, num_iters) + run_benchmark("Read Serial (Multiple HttpConn)", read_datasets_serial, stats, datasets, num_iters) + + run_benchmark("Write Multi (Multiple HttpConn)", write_datasets_multi, stats, datasets, num_iters) + run_benchmark("Write Serial (Multiple HttpConn)", write_datasets_serial, stats, datasets, num_iters) + + print("Testing with shared HTTP connection...") + + f = h5py.File("/home/test_user1/h5pyd_multi_bm_shared", mode='w') + datasets = [f.create_dataset("data" + str(i), data=data_in, dtype=dt) for i in range(count)] + + run_benchmark("Read Multi (Shared HttpConn)", read_datasets_multi, stats, datasets, num_iters) + run_benchmark("Read Serial (Shared HttpConn)", read_datasets_serial, stats, datasets, num_iters) + + run_benchmark("Write Multi (Shared HttpConn)", write_datasets_multi, stats, datasets, num_iters) + run_benchmark("Write Serial (Shared HttpConn)", write_datasets_serial, stats, datasets, num_iters) + + # Display results + for test_name in stats: + time_elapsed = stats[test_name]["time"] + dn_cpu = stats[test_name]["dn_cpu"] + dn_mem = stats[test_name]["dn_mem"] + sn_cpu = stats[test_name]["sn_cpu"] + sn_mem = stats[test_name]["sn_mem"] + + print(f"{test_name} - Time: {(time_elapsed):6.4f}, DN CPU%: {(dn_cpu):6.4f},\ + DN MEM%: {(dn_mem):6.4f}, SN CPU%: {(sn_cpu):6.4f}, SN MEM%: {(sn_mem):6.4f}") diff --git a/test/hl/test_dataset.py b/test/hl/test_dataset.py index 5190db1..f1435c8 100644 --- a/test/hl/test_dataset.py +++ b/test/hl/test_dataset.py @@ -26,6 +26,7 @@ import warnings from common import ut, TestCase +from h5pyd._hl.dataset import MultiManager import config if config.get("use_h5py"): @@ -1906,6 +1907,414 @@ def test_basetype_commutative(self,): assert (val == dset) == (dset == val) assert (val != dset) == (dset != val) +class TestMultiManager(BaseDataset): + def test_multi_read_scalar_dataspaces(self): + """ + Test reading from multiple datasets with scalar dataspaces + """ + shape = () + count = 3 + dt = np.int32 + + # Create datasets + data_in = np.array(1, dtype=dt) + datasets = [] + + for i in range(count): + dset = self.f.create_dataset("data" + str(i), shape, + dtype=dt, data=(data_in + i)) + datasets.append(dset) + + mm = MultiManager(datasets) + + # Select via empty tuple + data_out = mm[()] + + self.assertEqual(len(data_out), count) + + for i in range(count): + np.testing.assert_array_equal(data_out[i], data_in + i) + + # Select via Ellipsis + data_out = mm[...] + + self.assertEqual(len(data_out), count) + + for i in range(count): + np.testing.assert_array_equal(data_out[i], data_in + i) + + def test_multi_read_non_scalar_dataspaces(self): + """ + Test reading from multiple datasets with non-scalar dataspaces + """ + shape = (10, 10, 10) + count = 3 + dt = np.int32 + + # Create datasets + data_in = np.reshape(np.arange(np.prod(shape)), shape) + datasets = [] + + for i in range(count): + dset = self.f.create_dataset("data" + str(i), shape, + dtype=dt, data=(data_in + i)) + datasets.append(dset) + + mm = MultiManager(datasets) + data_out = mm[...] + + self.assertEqual(len(data_out), count) + + for i in range(count): + np.testing.assert_array_equal(data_out[i], data_in + i) + + # Partial Read + data_out = mm[:, :, 0] + + self.assertEqual(len(data_out), count) + + for i in range(count): + np.testing.assert_array_equal(data_out[i], (data_in + i)[:, :, 0]) + + def test_multi_read_mixed_dataspaces(self): + """ + Test reading from multiple datasets with scalar and + non-scalar dataspaces + """ + scalar_shape = () + shape = (10, 10, 10) + count = 3 + dt = np.int32 + + # Create datasets + data_scalar_in = np.array(1) + data_nonscalar_in = np.reshape(np.arange(np.prod(shape)), shape) + data_in = [data_scalar_in, data_nonscalar_in, + data_nonscalar_in, data_nonscalar_in] + datasets = [] + + for i in range(count): + if i == 0: + dset = self.f.create_dataset("data" + str(0), scalar_shape, + dtype=dt, data=data_scalar_in) + else: + dset = self.f.create_dataset("data" + str(i), shape, + dtype=dt, data=(data_nonscalar_in + i)) + datasets.append(dset) + + # Set up MultiManager for read + mm = MultiManager(datasets=datasets) + + # Select via empty tuple + data_out = mm[()] + + self.assertEqual(len(data_out), count) + + for i in range(count): + if i == 0: + np.testing.assert_array_equal(data_out[i], data_in[i]) + else: + np.testing.assert_array_equal(data_out[i], data_in[i] + i) + + # Select via Ellipsis + data_out = mm[...] + + self.assertEqual(len(data_out), count) + + for i in range(count): + if i == 0: + np.testing.assert_array_equal(data_out[i], data_in[i]) + else: + np.testing.assert_array_equal(data_out[i], data_in[i] + i) + + def test_multi_read_mixed_types(self): + """ + Test reading from multiple datasets with different types + """ + shape = (10, 10, 10) + count = 4 + dts = [np.int32, np.int64, np.float64, np.dtype("S10")] + + # Create datasets + data_in = np.reshape(np.arange(np.prod(shape)), shape) + data_in_fixed_str = np.full(shape, "abcdefghij", dtype=dts[3]) + datasets = [] + + for i in range(count): + if i < 3: + dset = self.f.create_dataset("data" + str(i), shape, + dtype=dts[i], data=(data_in + i)) + else: + dset = self.f.create_dataset("data" + str(i), shape, + dtype=dts[i], data=data_in_fixed_str) + + datasets.append(dset) + + # Set up MultiManager for read + mm = MultiManager(datasets=datasets) + + # Perform read + data_out = mm[...] + + self.assertEqual(len(data_out), count) + + for i in range(count): + if i < 3: + np.testing.assert_array_equal(data_out[i], np.array(data_in + i, dtype=dts[i])) + else: + np.testing.assert_array_equal(data_out[i], data_in_fixed_str) + + self.assertEqual(data_out[i].dtype, dts[i]) + + def test_multi_read_vlen_str(self): + """ + Test reading from multiple datasets with a vlen string type + """ + shape = (10, 10, 10) + count = 3 + dt = h5py.string_dtype(encoding='utf-8') + data_in = np.full(shape, "abcdefghij", dt) + datasets = [] + + for i in range(count): + dset = self.f.create_dataset("data" + str(i), shape=shape, + data=data_in, dtype=dt) + datasets.append(dset) + + mm = MultiManager(datasets=datasets) + out = mm[...] + + self.assertEqual(len(out), count) + + for i in range(count): + self.assertEqual(out[i].dtype, dt) + out[i] = np.reshape(out[i], newshape=np.prod(shape)) + out[i] = np.reshape(np.array([s.decode() for s in out[i]], dtype=dt), + newshape=shape) + np.testing.assert_array_equal(out[i], data_in) + + def test_multi_read_mixed_shapes(self): + """ + Test reading a selection from multiple datasets with different shapes + """ + shapes = [(150), (10, 15), (5, 5, 6)] + count = 3 + dt = np.int32 + data = np.arange(150, dtype=dt) + data_in = [np.reshape(data, newshape=s) for s in shapes] + datasets = [] + sel_idx = 2 + + for i in range(count): + dset = self.f.create_dataset("data" + str(i), shape=shapes[i], + dtype=dt, data=data_in[i]) + datasets.append(dset) + + mm = MultiManager(datasets=datasets) + # Perform multi read with selection + out = mm[sel_idx] + + # Verify + for i in range(count): + np.testing.assert_array_equal(out[i], data_in[i][sel_idx]) + + def test_multi_write_scalar_dataspaces(self): + """ + Test writing to multiple scalar datasets + """ + shape = () + count = 3 + dt = np.int32 + + # Create datasets + zeros = np.zeros(shape, dtype=dt) + data_in = [] + datasets = [] + + for i in range(count): + dset = self.f.create_dataset("data" + str(i), shape, + dtype=dt, data=zeros) + datasets.append(dset) + + data_in.append(np.array([i])) + + mm = MultiManager(datasets) + # Perform write + mm[...] = data_in + + # Read back and check + for i in range(count): + data_out = self.f["data" + str(i)][...] + np.testing.assert_array_equal(data_out, data_in[i]) + + def test_multi_write_non_scalar_dataspaces(self): + """ + Test writing to multiple non-scalar datasets + """ + shape = (10, 10, 10) + count = 3 + dt = np.int32 + + # Create datasets + zeros = np.zeros(shape, dtype=dt) + data_in = [] + datasets = [] + + for i in range(count): + dset = self.f.create_dataset("data" + str(i), shape, + dtype=dt, data=zeros) + datasets.append(dset) + + d_in = np.array(np.reshape(np.arange(np.prod(shape)), shape) + i, dtype=dt) + data_in.append(d_in) + + mm = MultiManager(datasets) + # Perform write + mm[...] = data_in + + # Read back and check + for i in range(count): + data_out = np.array(self.f["data" + str(i)][...], dtype=dt) + np.testing.assert_array_equal(data_out, data_in[i]) + + def test_multi_write_mixed_dataspaces(self): + """ + Test writing to multiple scalar and non-scalar datasets + """ + scalar_shape = () + shape = (10, 10, 10) + count = 3 + dt = np.int32 + + # Create datasets + data_in = [] + data_scalar_in = np.array(1, dtype=dt) + data_nonscalar_in = np.array(np.reshape(np.arange(np.prod(shape)), shape), dtype=dt) + datasets = [] + + for i in range(count): + if i == 0: + dset = self.f.create_dataset("data" + str(0), scalar_shape, + dtype=dt, data=np.array(0, dtype=dt)) + data_in.append(data_scalar_in) + else: + dset = self.f.create_dataset("data" + str(i), shape, + dtype=dt, data=np.zeros(shape)) + data_in.append(data_nonscalar_in) + datasets.append(dset) + + # Set up MultiManager for write + mm = MultiManager(datasets=datasets) + + # Select via empty tuple + mm[()] = data_in + + for i in range(count): + data_out = self.f["data" + str(i)][...] + np.testing.assert_array_equal(data_out, data_in[i]) + + # Reset datasets + for i in range(count): + if i == 0: + zeros = np.array([0]) + else: + zeros = np.zeros(shape) + self.f["data" + str(i)][...] = zeros + + # Select via Ellipsis + mm[...] = data_in + + for i in range(count): + data_out = self.f["data" + str(i)][...] + + if i == 0: + np.testing.assert_array_equal(data_out, data_in[i]) + else: + np.testing.assert_array_equal(data_out, data_in[i]) + + def test_multi_write_vlen_str(self): + """ + Test writing to multiple datasets with a vlen string type + """ + shape = (10, 10, 10) + count = 3 + dt = h5py.string_dtype(encoding='utf-8') + data_initial_vlen = np.full(shape, "aaaabbbbcc", dtype=dt) + data_in_vlen = np.full(shape, "abcdefghij", dtype=dt) + datasets = [] + + for i in range(count): + dset = self.f.create_dataset("data" + str(i), shape=shape, + data=data_initial_vlen, dtype=dt) + datasets.append(dset) + + mm = MultiManager(datasets=datasets) + # Perform write + mm[...] = [data_in_vlen, data_in_vlen, data_in_vlen] + + # Verify + for i in range(count): + out = self.f["data" + str(i)][...] + self.assertEqual(out.dtype, dt) + + out = np.reshape(out, newshape=np.prod(shape)) + out = np.reshape(np.array([s.decode() for s in out], dtype=dt), + newshape=shape) + np.testing.assert_array_equal(out, data_in_vlen) + + def test_multi_write_mixed_shapes(self): + """ + Test writing to a selection in multiple datasets with different shapes + """ + shapes = [(50, 5), (15, 10), (20, 15)] + count = 3 + dt = np.int32 + data_in = 99 + datasets = [] + sel_idx = 2 + + for i in range(count): + dset = self.f.create_dataset("data" + str(i), shape=shapes[i], + dtype=dt, data=np.zeros(shapes[i], dtype=dt)) + datasets.append(dset) + + mm = MultiManager(datasets=datasets) + # Perform multi write with selection + mm[sel_idx, sel_idx] = [data_in, data_in + 1, data_in + 2] + + # Verify + for i in range(count): + out = self.f["data" + str(i)][...] + np.testing.assert_array_equal(out[sel_idx, sel_idx], data_in + i) + + """ + TBD - Field selection in h5pyd seems to work slightly different than in h5py + def test_multi_write_field_selection(self): + Testing writing to a field selection on multiple datasets + dt = np.dtype([('a', np.float32), ('b', np.int32), ('c', np.float32)]) + shape = (100,) + data = np.ones(shape, dtype=dt) + count = 3 + datasets = [] + + for i in range(count): + dset = self.f.create_dataset("data" + str(i), shape=shape, + data=np.zeros(shape, dtype=dt), + dtype=dt) + datasets.append(dset) + + # Perform write to field 'b' + mm = MultiManager(datasets=datasets) + mm[..., 'b'] = [data['b'], data['b'], data['b']] + + for i in range(count): + out = np.array(self.f["data" + str(i)], dtype=dt) + np.testing.assert_array_equal(out['a'], np.zeros(shape, dtype=dt['a'])) + np.testing.assert_array_equal(out['b'], data['b']) + np.testing.assert_array_equal(out['c'], np.zeros(shape, dtype=dt['c'])) + """ + + if __name__ == '__main__': loglevel = logging.ERROR logging.basicConfig(format='%(asctime)s %(message)s', level=loglevel)