Skip to content

Commit

Permalink
Fixed numba incompatibility. Force covariance matrix input parameter …
Browse files Browse the repository at this point in the history
…to be of floating data type.
  • Loading branch information
johannvk committed Jan 26, 2025
1 parent 373fd4c commit b689a43
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
10 changes: 9 additions & 1 deletion skchange/costs/multivariate_t_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,10 +1060,14 @@ def _check_fixed_param(
) -> np.ndarray:
"""Check if the fixed mean parameter is valid.
The covariance matrix is checked for positive definiteness,
and forced to a floating point representation for numba compatibility.
Parameters
----------
param : 2-tuple of float or np.ndarray
Fixed mean and covariance matrix for the cost calculation.
Both are converted to float values or float arrays.
X : np.ndarray
Input data.
Expand All @@ -1074,7 +1078,11 @@ def _check_fixed_param(
"""
mean, cov = param
mean = check_mean(mean, X)
cov = check_cov(cov, X)

# Require floating point representation of
# the covariance matrix for numba compatibility:
cov = check_cov(cov, X, force_float=True)

return mean, cov

@property
Expand Down
4 changes: 3 additions & 1 deletion skchange/costs/tests/test_multivariate_t_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ def test_scale_matrix_numba_benchmark(
t_dof,
abs_tol=1.0e-9,
rel_tol=0.0,
max_iter=100,
)

# Time numba version
Expand All @@ -426,6 +427,7 @@ def test_scale_matrix_numba_benchmark(
t_dof,
abs_tol=1.0e-9,
rel_tol=0.0,
max_iter=100,
)
end = perf_counter()
times_njit.append(end - start)
Expand Down Expand Up @@ -820,7 +822,7 @@ def test_iterative_mv_t_dof_estimate_returns_inf_for_high_initial_dof():

def test_multivariate_t_log_likelihood_returns_nan_for_non_pos_def_scale_matrix():
"""Test that log likelihood returns np.nan for non-pos. def. scale matrix."""
non_positive_definite_matrix = np.array([[1, 2], [2, 1]])
non_positive_definite_matrix = np.array([[1, 2], [2, 1]], dtype=np.float64)
centered_samples = np.random.randn(100, 2)
dof = 5.0

Expand Down
9 changes: 8 additions & 1 deletion skchange/costs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def check_var(var: VarType, X: np.ndarray) -> np.ndarray:
return var


def check_cov(cov: CovType, X: np.ndarray) -> np.ndarray:
def check_cov(cov: CovType, X: np.ndarray, force_float: bool = False) -> np.ndarray:
"""Check if the fixed covariance matrix parameter is valid.
Parameters
Expand All @@ -65,6 +65,9 @@ def check_cov(cov: CovType, X: np.ndarray) -> np.ndarray:
Fixed covariance matrix for the cost calculation.
X : np.ndarray
2d input data.
force_float : bool, default=False
If True, force the covariance matrix to be of
floating point data type.
Returns
-------
Expand All @@ -83,4 +86,8 @@ def check_cov(cov: CovType, X: np.ndarray) -> np.ndarray:
)
if not np.all(np.linalg.eigvals(cov) > 0):
raise ValueError("covariance matrix must be positive definite.")

if force_float:
cov = cov.astype(float)

return cov

0 comments on commit b689a43

Please sign in to comment.