diff --git a/pysr/test/__init__.py b/pysr/test/__init__.py index 7774f4c40..21bbd1409 100644 --- a/pysr/test/__init__.py +++ b/pysr/test/__init__.py @@ -2,3 +2,4 @@ from .test_env import runtests as runtests_env from .test_jax import runtests as runtests_jax from .test_torch import runtests as runtests_torch +from .cliTest import runtests as runtests_cli diff --git a/pysr/test/__main__.py b/pysr/test/__main__.py index 7d45c7b92..b0ec3b36d 100644 --- a/pysr/test/__main__.py +++ b/pysr/test/__main__.py @@ -11,7 +11,7 @@ parser.add_argument( "test", nargs="*", - help="Test to run. One or more of 'main', 'env', 'jax', 'torch'.", + help="Test to run. One or more of 'main', 'env', 'jax', 'torch', 'cli'.", ) # Parse args: @@ -25,7 +25,7 @@ # Run tests: for test in tests: - if test in {"main", "env", "jax", "torch"}: + if test in {"main", "env", "jax", "torch", "cli"}: cur_dir = os.path.dirname(os.path.abspath(__file__)) print(f"Running test from {cur_dir}") if test == "main": @@ -36,6 +36,8 @@ runtests_jax() elif test == "torch": runtests_torch() + elif test == "cli": + runtests_cli() else: parser.print_help() raise SystemExit(1) diff --git a/pysr/test/cliTest.py b/pysr/test/cliTest.py index cd87652fb..cbdea162f 100644 --- a/pysr/test/cliTest.py +++ b/pysr/test/cliTest.py @@ -54,5 +54,10 @@ def test_help_on_install(self): self.assertEqual(expected, actual) -if __name__ == '__main__': - unittest.main() +def runtests(): + """Run all tests in cliTest.py.""" + loader = unittest.TestLoader() + suite = unittest.TestSuite() + suite.addTests(loader.loadTestsFromTestCase(TestCli)) + runner = unittest.TextTestRunner() + return runner.run(suite)