Skip to content

Commit

Permalink
Merge pull request #57 from DynamicsAndNeuralSystems/jmoo2880-add-new…
Browse files Browse the repository at this point in the history
…-spi-testing

New benchmarking dataset and dependency updates
  • Loading branch information
joshuabmoore authored Feb 14, 2024
2 parents 9da4258 + 409208a commit 3f9c0b5
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 71 deletions.
1 change: 0 additions & 1 deletion .github/workflows/run_unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install .
pip install pandas==1.3.3 numpy==1.22.0
- name: Run pyspi calculator unit tests
run: |
pytest -v ./tests/test_calc.py
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<p align="center">
<picture>
<source srcset="img/pyspi_logo_dark.png" media="(prefers-color-scheme: dark)">
<source srcset="img/pyspi_logo_darkmode.png" media="(prefers-color-scheme: dark)">
<img src="img/pyspi_logo.png" alt="pyspi logo" height="200"/>
</picture>
</p>
Expand Down
Binary file removed img/pyspi_logo_dark.png
Binary file not shown.
Binary file added img/pyspi_logo_darkmode.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added pyspi/data/cml7.npy
Binary file not shown.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
pytest
scikit-learn==0.24.1
scikit-learn==1.0.1
scipy==1.7.3
numpy>=1.21.1
pandas>=1.3.3
pandas==1.5.0
statsmodels==0.12.1
pyyaml==5.4
tqdm==4.50.2
Expand Down
11 changes: 6 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
# http://www.diveintopython3.net/packaging.html
# https://pypi.python.org/pypi?:action=list_classifiers

with open('README.md') as file:
with open('README.md', 'r', encoding='utf-8') as file:
long_description = file.read()


install_requires = [
'scikit-learn==0.24.1',
'scikit-learn==1.0.1',
'scipy==1.7.3',
'numpy>=1.21.1',
'pandas>=1.3.3',
'pandas==1.5.0',
'statsmodels==0.12.1',
'pyyaml==5.4',
'tqdm==4.50.2',
Expand Down Expand Up @@ -59,9 +59,10 @@
'lib/PhiToolbox/utility/Gauss/logdet.m',
'data/cml.npy',
'data/forex.npy',
'data/standard_normal.npy']},
'data/standard_normal.npy',
'data/cml7.npy']},
include_package_data=True,
version='0.4.1',
version='0.4.2',
description='Library for pairwise analysis of time series data.',
author='Oliver M. Cliff',
author_email='oliver.m.cliff@gmail.com',
Expand Down
Binary file added tests/CML7_benchmark_tables.pkl
Binary file not shown.
Binary file removed tests/calc_standard_normal.pkl
Binary file not shown.
50 changes: 50 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import pytest

@pytest.fixture(scope="session")
def spi_warning_logger(request):
warnings_log = list()

def add_warning(spi, module_name, max_z, num_exceed, num_iteractions):
warnings_log.append((spi, module_name, max_z, num_exceed, num_iteractions))

request.session.spi_warnings = warnings_log
return add_warning

def pytest_sessionfinish(session, exitstatus):
# retrieve the spi warnings from the session object
spi_warnings = getattr(session, 'spi_warnings', [])

# styling
header_line = "=" * 80
content_line = "-" * 80
footer_line = "=" * 80
header = " SPI BENCHMARKING SUMMARY"
footer = f" Session completed with exit status: {exitstatus} "
padded_header = f"{header:^80}"
padded_footer = f"{footer:^80}"

print("\n")
print(header_line)
print(padded_header)
print(header_line)

# print problematic SPIs in table format
if spi_warnings:
print(f"\nDetected {len(spi_warnings)} SPI(s) with outputs exceeding the specified 2 sigma threshold.\n")

# table header
print(f"{'SPI':<25}{'Cat':<10}{'Max ZSc.':>10}{'# Exceed. Pairs':>20}{'Unq. Pairs':>15}")
print(content_line)

# table content
for est, module_name, max_z, num_exceed, num_iteractions in spi_warnings:
# add special character for v.large zscores
error = ""
if max_z > 10:
error = " **"
print(f"{est+error:<25}{module_name:<10}{max_z:>10.4g}{num_exceed:>15}{num_iteractions:>20}")
else:
print("\n\nNo SPIs exceeded the sigma threshold.\n")

print(footer_line)
print(padded_footer)
138 changes: 76 additions & 62 deletions tests/test_SPIs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,93 +3,107 @@
import dill
import pyspi
import numpy as np
from copy import deepcopy
import warnings




############# Fixtures and helper functions #########

def load_benchmark_calcs():
benchmark_calcs = dict()
calcs = ['calc_standard_normal.pkl'] # follow this naming convention -> calc_{name}.pkl
for calc in calcs:
# extract the calculator name from the filename
calc_name = calc[len("calc_"):-len(".pkl")]

# Load the calculator
with open(f"tests/{calc}", "rb") as f:
loaded_calc = dill.load(f)
benchmark_calcs[calc_name] = loaded_calc
def load_benchmark_tables():
"""Function to load the mean and standard deviation tables for each MPI."""
table_fname = 'CML7_benchmark_tables.pkl'
with open(f"tests/{table_fname}", "rb") as f:
loaded_tables = dill.load(f)

return benchmark_calcs

def load_benchmark_datasets():
benchmark_datasets = dict()
dataset_names = ['standard_normal.npy']
for dname in dataset_names:
dataset = np.load(f"pyspi/data/{dname}")
dataset = dataset.T
benchmark_datasets[dname.strip('.npy')] = dataset
return loaded_tables

return benchmark_datasets
def load_benchmark_dataset():
dataset_fname = 'cml7.npy'
dataset = np.load(f"pyspi/data/{dataset_fname}").T
return dataset

def compute_new_tables():
"""Compute new tables using the same benchmark dataset(s)."""
benchmark_datasets = load_benchmark_datasets()
# Compute new tables on the benchmark datasets
new_calcs = dict()

calc_base = Calculator() # create base calculator object

for dataset in benchmark_datasets.keys():
calc = deepcopy(calc_base) # make a copy of the base calculator
calc.load_dataset(dataset=benchmark_datasets[dataset])
calc.compute()
new_calcs[dataset] = calc
benchmark_dataset = load_benchmark_dataset()
# Compute new tables on the benchmark dataset
np.random.seed(42)
calc = Calculator(dataset=benchmark_dataset)
calc.compute()
table_dict = dict()
for spi in calc.spis:
table_dict[spi] = calc.table[spi]

return new_calcs
return table_dict

def generate_SPI_test_params():
"""Generate combinations of calculator, dataset and SPI for the fixture."""
benchmark_calcs = load_benchmark_calcs()
new_calcs = compute_new_tables()
"""Function to generate combinations of benchmark table,
new table for each MPI"""
benchmark_tables = load_benchmark_tables()
new_tables = compute_new_tables()
params = []
for calc_name, benchmark_calc in benchmark_calcs.items():
spi_dict = benchmark_calc.spis
for spi_est in spi_dict.keys():
params.append((calc_name, spi_est, benchmark_calc.table[spi_est], new_calcs[calc_name].table[spi_est]))
calc = Calculator()
spis = list(calc.spis.keys())
spi_ob = list(calc.spis.values())
for spi_est, spi_ob in zip(spis, spi_ob):
params.append((spi_est, spi_ob, benchmark_tables[spi_est], new_tables[spi_est].to_numpy()))

return params

params = generate_SPI_test_params()
def pytest_generate_tests(metafunc):
"""Create a hook to generate parameter combinations for parameterised test"""
if "calc_name" in metafunc.fixturenames:
metafunc.parametrize("calc_name,est,mpi_benchmark,mpi_new", params)
if "est" in metafunc.fixturenames:
metafunc.parametrize("est, est_ob, mpi_benchmark,mpi_new", params)


def test_mpi(calc_name, est, mpi_benchmark, mpi_new):
def test_mpi(est, est_ob, mpi_benchmark, mpi_new, spi_warning_logger):
"""Run the benchmarking tests."""

"""First check to see if any SPIs are 'broken', as would be the case if
the benchmark table contains values for certain SPIs whereas the new table for the same
SPI does not (NaN). Also, if all values are NaNs for one SPI and not for the same SPI in the
newly computed table. """
zscore_threshold = 2 # 2 sigma

mismatched_nans = (mpi_benchmark.isna() != mpi_new.isna())
assert not mismatched_nans.any().any(), f"SPI: {est} | Dataset: {calc_name}. Mismatched NaNs."
# separate the the mean and std. dev tables for the benchmark
mean_table = mpi_benchmark['mean']
std_table = mpi_benchmark['std']

# check that the shapes are equal
assert mpi_benchmark.shape == mpi_new.shape, f"SPI: {est}| Dataset: {calc_name}. Different table shapes. "
# check std stable for zeros and impute with smallest non-zero value
min_nonzero_std = np.nanmin(std_table[std_table > 0])
std_table[std_table == 0] = min_nonzero_std


# Now quantify the difference between tables (if a diff exists)
epsilon = np.finfo(float).eps
# check that the shapes are equal
assert mean_table.shape == mpi_new.shape, f"SPI: {est}| Different table shapes. "

# convert NaNs to zeros before proeceeding - this will take care of diagonal and any null outputs
mpi_new = np.nan_to_num(mpi_new)
mpi_mean = np.nan_to_num(mean_table)

# check if matrix is symmetric (undirected SPI) for num exceed correction
isSymmetric = "undirected" in est_ob.labels

# get the module name for easy reference
module_name = est_ob.__module__.split(".")[-1]

if (mpi_new == mpi_mean).all() == False:
# tables are not equivalent, quantify the difference by z-scoring.
diff = abs(mpi_new - mpi_mean)
zscores = diff/std_table
idxs_greater_than_thresh = np.argwhere(zscores > zscore_threshold)
if len(idxs_greater_than_thresh) > 0:
sigs = list()
for idx in idxs_greater_than_thresh:
sigs.append(zscores[idx[0], idx[1]])
# get the max
max_z = max(sigs)
# number of interactions
num_iteractions = (mpi_new.shape[0] * mpi_new.shape[1]) - mpi_new.shape[0]
# count exceedances
num_exceed = len(sigs)
if isSymmetric:
# number of unique exceedences is half
num_exceed /= 2
num_iteractions /= 2

spi_warning_logger(est, module_name, max_z, int(num_exceed), int(num_iteractions))

if not mpi_benchmark.equals(mpi_new):
diff = abs(mpi_benchmark - mpi_new)
max_diff = diff.max().max()
if max_diff > epsilon:
warnings.warn(f"SPI: {est} | Dataset: {calc_name} | Max difference: {max_diff}")




0 comments on commit 3f9c0b5

Please sign in to comment.