Skip to content

Commit

Permalink
Merge pull request optuna#5630 from kAIto47802/fix-plot-contour
Browse files Browse the repository at this point in the history
Fix the error caused in `plot_contour()` with an impossible pair of variables
  • Loading branch information
not522 authored Aug 20, 2024
2 parents 953ddf4 + 76e6c64 commit 1f38f73
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
13 changes: 10 additions & 3 deletions optuna/visualization/_contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any
from typing import Callable
from typing import NamedTuple
import warnings

import numpy as np

Expand Down Expand Up @@ -208,6 +209,15 @@ def _get_contour_subplot(
x_indices = info.xaxis.indices
y_indices = info.yaxis.indices

if len(x_indices) < 2 or len(y_indices) < 2:
return go.Contour(), go.Scatter(), go.Scatter()
if len(info.z_values) == 0:
warnings.warn(
f"Contour plot will not be displayed because `{info.xaxis.name}` and "
f"`{info.yaxis.name}` cannot co-exist in `trial.params`."
)
return go.Contour(), go.Scatter(), go.Scatter()

feasible = _PlotValues([], [])
infeasible = _PlotValues([], [])

Expand All @@ -227,9 +237,6 @@ def _get_contour_subplot(

z_values[xys[:, 1], xys[:, 0]] = zs

if len(x_indices) < 2 or len(y_indices) < 2:
return go.Contour(), go.Scatter(), go.Scatter()

contour = go.Contour(
x=x_indices,
y=y_indices,
Expand Down
4 changes: 0 additions & 4 deletions tests/visualization_tests/test_visualizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,6 @@ def test_visualizations_with_single_objectives(
study = optuna.create_study(sampler=optuna.samplers.RandomSampler())
study.optimize(objective_func, n_trials=20)

# TODO(c-bata): Fix a bug to remove `pytest.xfail`.
if plot_func is plot_contour and objective_func is objective_single_dynamic_with_categorical:
pytest.xfail("There is a bug that IndexError is raised in plot_contour")

# TODO(c-bata): Fix a bug to remove `pytest.xfail`.
if plot_func is matplotlib_plot_rank and objective_func is objective_single_none_categorical:
pytest.xfail("There is a bug that TypeError is raised in matplotlib.plot_rank")
Expand Down

0 comments on commit 1f38f73

Please sign in to comment.