Skip to content

Commit

Permalink
Merge pull request optuna#5432 from eukaryo/suppress-numpy-invalid
Browse files Browse the repository at this point in the history
Suppress warnings from `numpy` in hypervolume computation
  • Loading branch information
not522 authored May 21, 2024
2 parents fe400d1 + 875fbc1 commit 3aac18f
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 17 deletions.
6 changes: 4 additions & 2 deletions optuna/_hypervolume/hssp.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ def _solve_hssp_on_unique_loss_vals(
subset_size: int,
reference_point: np.ndarray,
) -> np.ndarray:
assert not np.any(reference_point - rank_i_loss_vals <= 0)
if not np.isfinite(reference_point).all():
return rank_i_indices[:subset_size]
diff_of_loss_vals_and_ref_point = reference_point - rank_i_loss_vals
assert subset_size <= rank_i_indices.size
n_objectives = reference_point.size
contribs = np.prod(reference_point - rank_i_loss_vals, axis=-1)
contribs = np.prod(diff_of_loss_vals_and_ref_point, axis=-1)
selected_indices = np.zeros(subset_size, dtype=int)
selected_vecs = np.empty((subset_size, n_objectives))
indices = np.arange(rank_i_loss_vals.shape[0], dtype=int)
Expand Down
2 changes: 2 additions & 0 deletions optuna/_hypervolume/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def _compute_2d(solution_set: np.ndarray, reference_point: np.ndarray) -> float:
The reference point to compute the hypervolume.
"""
assert solution_set.shape[1] == 2 and reference_point.shape[0] == 2
if not np.isfinite(reference_point).all():
return float("inf")

# Ascending order in the 1st objective, and descending order in the 2nd objective.
sorted_solution_set = solution_set[np.lexsort((-solution_set[:, 1], solution_set[:, 0]))]
Expand Down
2 changes: 2 additions & 0 deletions optuna/_hypervolume/wfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def __init__(self) -> None:
self._reference_point: np.ndarray | None = None

def _compute(self, solution_set: np.ndarray, reference_point: np.ndarray) -> float:
if not np.isfinite(reference_point).all():
return float("inf")
self._reference_point = reference_point.astype(np.float64)
if self._reference_point.shape[0] == 2:
return _compute_2d(solution_set, self._reference_point)
Expand Down
26 changes: 11 additions & 15 deletions tests/hypervolume_tests/test_hssp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import math
from typing import Tuple

import numpy as np
Expand All @@ -11,11 +12,14 @@ def _compute_hssp_truth_and_approx(test_case: np.ndarray, subset_size: int) -> T
r = 1.1 * np.max(test_case, axis=0)
truth = 0.0
for subset in itertools.permutations(test_case, subset_size):
truth = max(truth, optuna._hypervolume.WFG().compute(np.asarray(subset), r))
hv = optuna._hypervolume.WFG().compute(np.asarray(subset), r)
assert not math.isnan(hv)
truth = max(truth, hv)
indices = optuna._hypervolume.hssp._solve_hssp(
test_case, np.arange(len(test_case)), subset_size, r
)
approx = optuna._hypervolume.WFG().compute(test_case[indices], r)
assert not math.isnan(approx)
return truth, approx


Expand All @@ -30,32 +34,24 @@ def test_solve_hssp(dim: int) -> None:
assert approx / truth > 0.6321 # 1 - 1/e


@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_solve_hssp_infinite_loss() -> None:
rng = np.random.RandomState(128)

subset_size = 4
test_case = rng.rand(9, 2)
test_case[-1].fill(float("inf"))
truth, approx = _compute_hssp_truth_and_approx(test_case, subset_size)
assert np.isinf(truth)
assert np.isinf(approx)

test_case = rng.rand(9, 3)
test_case[-1].fill(float("inf"))
truth, approx = _compute_hssp_truth_and_approx(test_case, subset_size)
assert truth == 0
assert np.isnan(approx)

for dim in range(2, 4):
test_case = rng.rand(9, dim)
test_case[-1].fill(float("inf"))
truth, approx = _compute_hssp_truth_and_approx(test_case, subset_size)
assert np.isinf(truth)
assert np.isinf(approx)

test_case = rng.rand(9, dim)
test_case[-1].fill(-float("inf"))
truth, approx = _compute_hssp_truth_and_approx(test_case, subset_size)
assert np.isinf(truth)
assert np.isinf(approx)


@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_solve_hssp_duplicated_infinite_loss() -> None:
test_case = np.array([[np.inf, 0, 0], [np.inf, 0, 0], [0, np.inf, 0], [0, 0, np.inf]])
r = np.full(3, np.inf)
Expand Down

0 comments on commit 3aac18f

Please sign in to comment.