Skip to content

Commit

Permalink
Merge pull request #834 from MilesCranmer/fix-sklearn-test
Browse files Browse the repository at this point in the history
test: update deprecated sklearn test syntax
  • Loading branch information
MilesCranmer authored Feb 23, 2025
2 parents b03e02b + c904de1 commit 47d19e2
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions pysr/test/test_main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import importlib
import os
import pickle as pkl
Expand All @@ -12,7 +13,13 @@
import numpy as np
import pandas as pd
import sympy # type: ignore
from sklearn.utils.estimator_checks import check_estimator

try:
from sklearn.utils.estimator_checks import estimator_checks_generator
except ImportError:
from sklearn.utils.estimator_checks import check_estimator

estimator_checks_generator = functools.partial(check_estimator, generate_only=True)

from pysr import (
ParametricExpressionSpec,
Expand Down Expand Up @@ -930,9 +937,8 @@ def test_scikit_learn_compatibility(self):
temp_equation_file=True,
) # Return early.

check_generator = check_estimator(model, generate_only=True)
exception_messages = []
for _, check in check_generator:
for _, check in estimator_checks_generator(model):
if check.func.__name__ in {
# We can use complex data, so avoid this check
"check_complex_data",
Expand Down

0 comments on commit 47d19e2

Please sign in to comment.