Skip to content

Commit

Permalink
Feat/add tests (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirmeier authored Nov 13, 2022
1 parent f9a5f52 commit d589fd5
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 40 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,30 @@ jobs:
run: |
poetry build
tests:
runs-on: ubuntu-latest
needs:
- precommit
strategy:
matrix:
python-version: [3.9]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install poetry
- name: Build package
run: |
pip install jaxlib jax
poetry install
- name: Run examples
run: |
poetry run pytest
examples:
runs-on: ubuntu-latest
needs:
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,6 @@ cython_debug/

# vscode
.vscode/

# dont att poetry lock to package
poetry.lock
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ repos:
language: python
language_version: python3
types: [python]
files: "(reconcile|examples|tests)"
files: "(reconcile|examples)"

- repo: /~https://github.com/PyCQA/flake8
rev: 5.0.1
Expand Down
30 changes: 15 additions & 15 deletions examples/reconciliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
import numpy as np
import optax
from chex import Array, PRNGKey
from data import sample_hierarchical_timeseries
from jax import numpy as jnp
from jax import random
from jax.config import config

from reconcile.data import sample_hierarchical_timeseries
from reconcile.forecast import Forecaster
from reconcile.grouping import Grouping
from reconcile.probabilistic_reconciliation import ProbabilisticReconciliation

config.update("jax_enable_x64", True)


class GPForecaster(Forecaster):
"""Example implementation of a forecaster"""
Expand All @@ -29,7 +32,7 @@ def data(self):
"""Returns the data"""
return self._ys, self._xs

def fit(self, rng_key: PRNGKey, ys: Array, xs: Array):
def fit(self, rng_key: PRNGKey, ys: Array, xs: Array, niter=2000):
"""Fit a model to each of the time series"""

self._xs = xs
Expand All @@ -42,11 +45,11 @@ def fit(self, rng_key: PRNGKey, ys: Array, xs: Array):
for i in np.arange(p):
x, y = xs[:, [i], :], ys[:, [i], :]
# fit a model for each time series
learned_params, _, D = self._fit_one(rng_key, x, y)
learned_params, _, D = self._fit_one(rng_key, x, y, niter)
# save the learned parameters and the original data
self._models[i] = learned_params, D

def _fit_one(self, rng_key, x, y):
def _fit_one(self, rng_key, x, y, niter):
# here we use GPs to model the time series
D = gpx.Dataset(X=x.reshape(-1, 1), y=y.reshape(-1, 1))
sgpr, q, likelihood = self._model(rng_key, D.n)
Expand All @@ -58,15 +61,15 @@ def _fit_one(self, rng_key, x, y):
objective=negative_elbo,
parameter_state=parameter_state,
optax_optim=optimiser,
n_iters=2000,
n_iters=niter,
)
learned_params, training_history = inference_state.unpack()
return learned_params, training_history, D

@staticmethod
def _model(rng_key, n):
z = random.uniform(rng_key, (20, 1))
prior = gpx.Prior(kernel=gpx.RBF())
prior = gpx.Prior(mean_function=gpx.Constant(), kernel=gpx.RBF())
likelihood = gpx.Gaussian(num_datapoints=n)
posterior = prior * likelihood
q = gpx.CollapsedVariationalGaussian(
Expand Down Expand Up @@ -113,11 +116,8 @@ def predictive_posterior_probability(
chex.assert_equal_shape([ys_test, xs_test])

preds = self.posterior_predictive(rng_key, xs_test)
y_test_pred = jnp.zeros(ys_test.shape[1])
for i, pred in enumerate(preds):
lp = preds[i].log_prob(jnp.squeeze(ys_test[:, i, :]))
y_test_pred.at[i].set(lp)
return jnp.asarray(y_test_pred)
lp = preds.log_prob(ys_test)
return lp


def run():
Expand All @@ -127,10 +127,10 @@ def run():
all_features = jnp.tile(x, [1, all_timeseries.shape[1], 1])

forecaster = GPForecaster()
forecaster.fit(random.PRNGKey(1), all_timeseries, all_features)
forecaster.posterior_predictive(random.PRNGKey(1), all_features)
forecaster.predictive_posterior_probability(
random.PRNGKey(1), all_timeseries, all_features
forecaster.fit(
random.PRNGKey(1),
all_timeseries[:, :, :90],
all_features[:, :, :90],
)

recon = ProbabilisticReconciliation(grouping, forecaster)
Expand Down
9 changes: 4 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ chex = "^0.1.5"
flax = "^0.6.1"
optax = "^0.1.3"
pandas = "^1.5.1"
scipy = "^1.9.3"
statsmodels = "^0.13.2"
numpy = "^1.23.4"

scipy = "^1.9.3"

[tool.poetry.group.dev.dependencies]
pre-commit = "^2.20.0"
Expand All @@ -30,11 +30,10 @@ flake8-pyproject = "^1.1.0.post0"
isort = "^5.10.1"
mypy = "^0.971"
bandit = "^1.7.4"

gpjax = "^0.5.0"
pytest = "^7.2.0"

[tool.poetry.group.examples.dependencies]
statsmodels = "^0.13.2"
matplotlib = "^3.6.1"
gpjax = "^0.5.0"


Expand Down
35 changes: 22 additions & 13 deletions examples/data.py → reconcile/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ def _sample_timeseries(N, D):


def sample_grouped_timeseries():
"""
Sample a grouped timeseries from an ARMA
Returns
-------
Tuple
a tuple where the first element is a matrix of time series measurements
and the second one is a pd.DataFrame of groups
"""

def _group_names():
group_one = [
"VIC:Mel",
Expand Down Expand Up @@ -62,20 +72,19 @@ def _group_names():


def sample_hierarchical_timeseries():
"""
Sample a hierarchical timeseries from an ARMA
Returns
-------
Tuple
a tuple where the first element is a matrix of time series measurements
and the second one is a pd.DataFrame of groups
"""

def _group_names():
hierarchy = [
"A:10:A",
"A:10:B",
"A:10:C",
"A:20:A",
"A:20:B",
"B:30:A",
"B:30:B",
"B:30:C",
"B:40:A",
"B:40:B",
]
hierarchy = ["A:10", "A:20", "B:10", "B:20", "B:30"]

return pd.DataFrame.from_dict({"h1": hierarchy})

return _sample_timeseries(100, 10), _group_names()
return _sample_timeseries(100, 5), _group_names()
29 changes: 24 additions & 5 deletions reconcile/grouping.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import warnings
from itertools import chain

import numpy as np
import pandas as pd
from jax import numpy as jnp
from scipy import sparse


def as_list(maybe_list):
return maybe_list if isinstance(maybe_list, list) else [maybe_list]


class Grouping:
"""
Class that represents a grouping/hierarchy of a grouped or hierarchical
Expand All @@ -24,14 +31,28 @@ def __init__(self, groups: pd.DataFrame):
self._group_names = list(groups.columns)

if len(self._group_names) > 1:
warnings.warn("Grouped timeseries is poorly tested. Use with care!")
gmat = self._gts_create_g_mat()
gmat = self._gts_gmat_as_integer(gmat)
self._labels = None
else:
out_edges_per_level, _, idxs = self._hts_create_nodes()
out_edges_per_level, labels, _ = self._hts_create_nodes()
gmat = self._hts_create_g_mat(out_edges_per_level)
labels = [as_list(labels[key]) for key in sorted(labels.keys())]
self._labels = list(chain(*labels))
self._s_matrix = self._smatrix(gmat)
self._n_all_timeseries = self._s_matrix.shape[0]

def all_timeseries_column_names(self):
return self._labels

def bottom_timeseries_column_names(self):
return self._labels[self.n_upper_timeseries :]

@property
def n_groups(self):
return self._groups.shape[1]

@property
def n_all_timeseries(self):
return self._n_all_timeseries
Expand All @@ -54,10 +75,8 @@ def extract_bottom_timeseries(self, y):
return y[:, self.n_upper_timeseries :, :]

def upper_time_series(self, b):
sub = self._s_matrix[
: (self.n_all_timeseries - self.n_bottom_timeseries), :
].T
return jnp.einsum("ijk,jl->ilk", b, sub.toarray())
y = self.all_timeseries(b)
return y[:, : self.n_upper_timeseries, :]

@staticmethod
def _paste0(a, b):
Expand Down
2 changes: 1 addition & 1 deletion reconcile/probabilistic_reconciliation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

class ProbabilisticReconciliation:
"""
Probabilistic reconcilation of hierarchical/grouped time series class
Probabilistic reconcilation of hierarchical time series class
"""

def __init__(self, grouping: Grouping, forecaster: Forecaster):
Expand Down
Empty file added tests/__init__.py
Empty file.
34 changes: 34 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
from jax import numpy as jnp
from jax import random

from examples.reconciliation import GPForecaster
from reconcile import ProbabilisticReconciliation
from reconcile.data import sample_hierarchical_timeseries
from reconcile.grouping import Grouping


@pytest.fixture()
def grouping():
_, groups = sample_hierarchical_timeseries()
grouping = Grouping(groups)
return grouping


@pytest.fixture()
def reconciliator():
(b, x), groups = sample_hierarchical_timeseries()
grouping = Grouping(groups)
all_timeseries = grouping.all_timeseries(b)
all_features = jnp.tile(x, [1, all_timeseries.shape[1], 1])

forecaster = GPForecaster()
forecaster.fit(
random.PRNGKey(1),
all_timeseries[:, :90, :],
all_features[:, :90, :],
100,
)

recon = ProbabilisticReconciliation(grouping, forecaster)
return (all_timeseries, all_features), recon
18 changes: 18 additions & 0 deletions tests/test_grouping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import chex


def test_grouping_size(grouping):
assert grouping.n_groups == 1


def test_grouping_colnames(grouping):
for e, f in zip(
grouping.all_timeseries_column_names(),
["Total", "A", "B", "A:10", "A:20", "B:10", "B:20", "B:30"],
):
assert e == f


def test_grouping_summing_matrix(grouping):
chex.assert_axis_dimension(grouping.summing_matrix().toarray(), 0, 8)
chex.assert_axis_dimension(grouping.summing_matrix().toarray(), 1, 5)
18 changes: 18 additions & 0 deletions tests/test_reconciliation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import chex
from jax import random


def test_fit_reconciliation(reconciliator):
(_, all_features), recon = reconciliator
fit_recon = recon.fit_reconciled_posterior_predictive(
random.PRNGKey(1), all_features, n_samples=100
)
chex.assert_shape(fit_recon, (100, 5, 100))


def test_sample_reconciliation(reconciliator):
(_, all_features), recon = reconciliator
fit_recon = recon.sample_reconciled_posterior_predictive(
random.PRNGKey(1), all_features, n_warmup=50, n_iter=100
)
chex.assert_shape(fit_recon, (50, 4, 5, 100))

0 comments on commit d589fd5

Please sign in to comment.