diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 907d1e8305..d5f5acd3a9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,32 +17,10 @@ repos: - id: clang-format exclude: dev-tools|examples verbose: true - - repo: /~https://github.com/asottile/reorder_python_imports - rev: v3.12.0 + - repo: /~https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.5 hooks: - - id: reorder-python-imports - args: [--application-directories=python, - --unclassifiable-application-module=_tskit] - - repo: /~https://github.com/asottile/pyupgrade - rev: v3.15.2 - hooks: - - id: pyupgrade - args: [--py3-plus, --py38-plus] - - repo: /~https://github.com/psf/black - rev: 24.4.2 - hooks: - - id: black - language_version: python3 - - repo: /~https://github.com/pycqa/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - args: [--config=python/.flake8] - additional_dependencies: ["flake8-bugbear==23.9.16", "flake8-builtins==2.1.0"] - - repo: /~https://github.com/asottile/blacken-docs - rev: 1.16.0 - hooks: - - id: blacken-docs - args: [--skip-errors] - additional_dependencies: [black==22.3.0] - language_version: python3 + - id: ruff + args: [ "--fix", "--config", "python/ruff.toml" ] + - id: ruff-format + args: [ "--config", "python/ruff.toml" ] \ No newline at end of file diff --git a/python/.flake8 b/python/.flake8 deleted file mode 100644 index e533d063cc..0000000000 --- a/python/.flake8 +++ /dev/null @@ -1,7 +0,0 @@ -[flake8] -# Based directly on Black's recommendations: -# https://black.readthedocs.io/en/stable/the_black_code_style.html#line-length -max-line-length = 81 -select = A,C,E,F,W,B,B950 -#B305 doesn't like `.next()` that is a key Tree method. -ignore = E203, E501, W503, B305 diff --git a/python/benchmark/run-for-all-releases.py b/python/benchmark/run-for-all-releases.py index 3e64d443b1..2fa59614f8 100644 --- a/python/benchmark/run-for-all-releases.py +++ b/python/benchmark/run-for-all-releases.py @@ -1,9 +1,9 @@ import json import subprocess +from distutils.version import StrictVersion from urllib.request import urlopen import tqdm -from distutils.version import StrictVersion def versions(package_name): diff --git a/python/benchmark/run.py b/python/benchmark/run.py index 1f525236e9..5899391e02 100644 --- a/python/benchmark/run.py +++ b/python/benchmark/run.py @@ -14,9 +14,10 @@ tskit_dir = Path(__file__).parent.parent sys.path.append(str(tskit_dir)) -import tskit # noqa: E402 import msprime # noqa: E402 +import tskit # noqa: E402 + with open("config.yaml") as f: config = yaml.load(f, Loader=yaml.FullLoader) diff --git a/python/lwt_interface/dict_encoding_testlib.py b/python/lwt_interface/dict_encoding_testlib.py index 72acea4dce..b7c7fb9c40 100644 --- a/python/lwt_interface/dict_encoding_testlib.py +++ b/python/lwt_interface/dict_encoding_testlib.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,7 @@ compiled module exporting the LightweightTableCollection class. See the test_example_c_module file for an example. """ + import copy import kastore @@ -98,7 +99,7 @@ def full_ts(): # The ts above is used for the whole test session, but our tests need fresh tables to # modify -@pytest.fixture +@pytest.fixture() def tables(full_ts): return full_ts.dump_tables() @@ -183,9 +184,7 @@ def test_example(self, tables): { "codec": "struct", "type": "object", - "properties": { - table: {"type": "string", "binaryFormat": "50p"} - }, + "properties": {table: {"type": "string", "binaryFormat": "50p"}}, } ) @@ -459,9 +458,7 @@ def verify_optional_column(self, tables, table_len, table_name, col_name): out[table_name][col_name], np.zeros(table_len, dtype=np.int32) - 1 ) - def verify_offset_pair( - self, tables, table_len, table_name, col_name, required=False - ): + def verify_offset_pair(self, tables, table_len, table_name, col_name, required=False): offset_col = col_name + "_offset" if not required: @@ -544,9 +541,7 @@ def test_individuals(self, tables): self.verify_offset_pair( tables, len(tables.individuals), "individuals", "location" ) - self.verify_offset_pair( - tables, len(tables.individuals), "individuals", "parents" - ) + self.verify_offset_pair(tables, len(tables.individuals), "individuals", "parents") self.verify_offset_pair( tables, len(tables.individuals), "individuals", "metadata" ) @@ -578,9 +573,7 @@ def test_migrations(self, tables): self.verify_required_columns( tables, "migrations", ["left", "right", "node", "source", "dest", "time"] ) - self.verify_offset_pair( - tables, len(tables.migrations), "migrations", "metadata" - ) + self.verify_offset_pair(tables, len(tables.migrations), "migrations", "metadata") self.verify_optional_column(tables, len(tables.nodes), "nodes", "individual") self.verify_metadata_schema(tables, "migrations") @@ -674,9 +667,7 @@ def get_refseq(d): assert get_refseq(d).is_null() # All empty strings is the same thing - d["reference_sequence"] = dict( - data="", url="", metadata_schema="", metadata=b"" - ) + d["reference_sequence"] = dict(data="", url="", metadata_schema="", metadata=b"") assert get_refseq(d).is_null() del refseq_dict["metadata_schema"] # handled above diff --git a/python/lwt_interface/setup.py b/python/lwt_interface/setup.py index 30ab1dc65a..9cf1e7a35a 100644 --- a/python/lwt_interface/setup.py +++ b/python/lwt_interface/setup.py @@ -1,17 +1,15 @@ import os.path import platform -from setuptools import Extension -from setuptools import setup +from setuptools import Extension, setup from setuptools.command.build_ext import build_ext - IS_WINDOWS = platform.system() == "Windows" # Obscure magic required to allow numpy be used as a 'setup_requires'. # Based on https://stackoverflow.com/questions/19919905 -class local_build_ext(build_ext): +class local_build_ext(build_ext): # noqa: N801 def finalize_options(self): build_ext.finalize_options(self) import builtins diff --git a/python/ruff.toml b/python/ruff.toml new file mode 100644 index 0000000000..7c49dc641b --- /dev/null +++ b/python/ruff.toml @@ -0,0 +1,12 @@ +line-length = 90 + +[lint] +select = ["E", "F", "B", "W", "I", "N", "UP", "A", "RUF", "PT", "NPY"] +# N803,806,802 Allow capital varnames +# E741 Allow "l" as var name +# PT011 allow pytest raises without match +ignore = ["N803", "N806", "N802", "E741", "PT011", "PT009"] + +[lint.isort] +section-order = ["future", "standard-library", "third-party", "first-party", "local-folder"] +known-first-party = ["tskit", "_tskit"] \ No newline at end of file diff --git a/python/setup.py b/python/setup.py index 44e15b9869..27d91e83a9 100644 --- a/python/setup.py +++ b/python/setup.py @@ -1,17 +1,15 @@ import os.path import platform -from setuptools import Extension -from setuptools import setup +from setuptools import Extension, setup from setuptools.command.build_ext import build_ext - IS_WINDOWS = platform.system() == "Windows" # Obscure magic required to allow numpy be used as a 'setup_requires'. # Based on https://stackoverflow.com/questions/19919905 -class local_build_ext(build_ext): +class local_build_ext(build_ext): # noqa: N801 def finalize_options(self): build_ext.finalize_options(self) import builtins diff --git a/python/tests/__init__.py b/python/tests/__init__.py index f069f04f2e..5bbecbafeb 100644 --- a/python/tests/__init__.py +++ b/python/tests/__init__.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2023 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,6 +22,7 @@ import base64 import tskit + from . import tsutil from .simplify import * # NOQA @@ -195,9 +196,7 @@ def trees(self): pt.left = left pt.right = right # Add in all the sites - pt.site_list = [ - site for site in self._sites if left <= site.position < right - ] + pt.site_list = [site for site in self._sites if left <= site.position < right] yield pt pt.index += 1 pt.index = -1 diff --git a/python/tests/conftest.py b/python/tests/conftest.py index d23c019003..6a06c03647 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -38,9 +38,9 @@ def test_something(self, ts_fixture): Note that fixtures have a "scope" for example `ts_fixture` below is only created once per test session and re-used for subsequent tests. """ + import msprime import pytest -from pytest import fixture from . import tsutil @@ -81,29 +81,29 @@ def pytest_collection_modifyitems(config, items): item.add_marker(skip_slow) -@fixture +@pytest.fixture() def overwrite_viz(request): return request.config.getoption("--overwrite-expected-visualizations") -@fixture +@pytest.fixture() def draw_plotbox(request): return request.config.getoption("--draw-svg-debug-box") -@fixture(scope="session") +@pytest.fixture(scope="session") def simple_degree1_ts_fixture(): return msprime.simulate(10, random_seed=42) -@fixture(scope="session") +@pytest.fixture(scope="session") def simple_degree2_ts_fixture(): ts = msprime.simulate(10, recombination_rate=0.2, random_seed=42) assert ts.num_trees == 2 return ts -@fixture(scope="session") +@pytest.fixture(scope="session") def ts_fixture(): """ A tree sequence with data in all fields @@ -111,7 +111,7 @@ def ts_fixture(): return tsutil.all_fields_ts() -@fixture(scope="session") +@pytest.fixture(scope="session") def ts_fixture_for_simplify(): """ A tree sequence with data in all fields execpt edge metadata and migrations @@ -119,7 +119,7 @@ def ts_fixture_for_simplify(): return tsutil.all_fields_ts(edge_metadata=False, migrations=False) -@fixture(scope="session") +@pytest.fixture(scope="session") def replicate_ts_fixture(): """ A list of tree sequences diff --git a/python/tests/ibd.py b/python/tests/ibd.py index 53e28cc5c0..04d20798b8 100644 --- a/python/tests/ibd.py +++ b/python/tests/ibd.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2022 Tskit Developers +# Copyright (c) 2020-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,6 +22,7 @@ """ Python implementation of the IBD-finding algorithms. """ + import argparse import collections @@ -46,9 +47,7 @@ def __init__(self, left=None, right=None, node=None, next_seg=None): self.next = next_seg def __str__(self): - s = "({}-{}->{}:next={})".format( - self.left, self.right, self.node, repr(self.next) - ) + s = f"({self.left}-{self.right}->{self.node}:next={self.next!r})" return s def __repr__(self): diff --git a/python/tests/simplify.py b/python/tests/simplify.py index 1e62c9d11b..0f45848bad 100644 --- a/python/tests/simplify.py +++ b/python/tests/simplify.py @@ -23,6 +23,7 @@ """ Python implementation of the simplify algorithm. """ + import sys import numpy as np @@ -82,9 +83,7 @@ def __init__(self, left=None, right=None, node=None, next_segment=None): self.next = next_segment def __str__(self): - s = "({}-{}->{}:next={})".format( - self.left, self.right, self.node, repr(self.next) - ) + s = f"({self.left}-{self.right}->{self.node}:next={self.next!r})" return s def __repr__(self): @@ -305,7 +304,8 @@ def merge_labeled_ancestors(self, S, input_id): if is_sample: # Free up the existing ancestry mapping. x = self.A_tail[input_id] - assert x.left == 0 and x.right == self.sequence_length + assert x.left == 0 + assert x.right == self.sequence_length self.A_tail[input_id] = None self.A_head[input_id] = None @@ -333,8 +333,7 @@ def merge_labeled_ancestors(self, S, input_id): # Fill in any gaps in the ancestry for the sample self.add_ancestry(input_id, prev_right, left, output_id) if self.keep_unary or ( - self.keep_unary_in_individuals - and self.ts.node(input_id).individual >= 0 + self.keep_unary_in_individuals and self.ts.node(input_id).individual >= 0 ): ancestry_node = output_id self.add_ancestry(input_id, left, right, ancestry_node) @@ -637,9 +636,7 @@ def process_parent_edges(self, edges): x = self.A_head[edge.child] while x is not None: if x.right > edge.left and edge.right > x.left: - y = Segment( - max(x.left, edge.left), min(x.right, edge.right), x.node - ) + y = Segment(max(x.left, edge.left), min(x.right, edge.right), x.node) S.append(y) x = x.next self.merge_labeled_ancestors(S, parent) @@ -654,7 +651,8 @@ def merge_labeled_ancestors(self, S, input_id): if is_sample: # Free up the existing ancestry mapping. x = self.A_tail[input_id] - assert x.left == 0 and x.right == self.sequence_length + assert x.left == 0 + assert x.right == self.sequence_length self.A_tail[input_id] = None self.A_head[input_id] = None diff --git a/python/tests/test_avl_tree.py b/python/tests/test_avl_tree.py index 999b09b528..c069d35aad 100644 --- a/python/tests/test_avl_tree.py +++ b/python/tests/test_avl_tree.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2021 Tskit Developers +# Copyright (c) 2021-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -28,6 +28,7 @@ Note there is a bug in that Python translation which is missing P.B = 0 at the end of A9. """ + from __future__ import annotations import dataclasses @@ -36,7 +37,6 @@ import numpy as np import pytest - # The nodes of the tree are assumed to contain KEY, LLINK, and RLINK fields. # We also have a new field # diff --git a/python/tests/test_balance_metrics.py b/python/tests/test_balance_metrics.py index dc77f95e6b..90e9cd4c47 100644 --- a/python/tests/test_balance_metrics.py +++ b/python/tests/test_balance_metrics.py @@ -22,6 +22,7 @@ """ Tests for tree balance/imbalance metrics. """ + import math import numpy as np diff --git a/python/tests/test_cli.py b/python/tests/test_cli.py index 67b3890c2e..022410cc7e 100644 --- a/python/tests/test_cli.py +++ b/python/tests/test_cli.py @@ -23,6 +23,7 @@ """ Test cases for the command line interfaces to tskit """ + import io import os import sys @@ -36,10 +37,11 @@ import tskit import tskit.cli as cli + from . import tsutil -class TestException(Exception): +class TestException(Exception): # noqa: N818 __test__ = False """ Custom exception we can throw for testing. @@ -52,8 +54,7 @@ def capture_output(func, *args, **kwargs): tuple (stdout, stderr) as strings. """ buffer_class = io.BytesIO - if sys.version_info[0] == 3: - buffer_class = io.StringIO + buffer_class = io.StringIO stdout = sys.stdout sys.stdout = buffer_class() stderr = sys.stderr @@ -264,12 +265,12 @@ def test_fasta_long_args(self): assert args.wrap == 50 @pytest.mark.parametrize( - "flags,expected", - ( - [[], None], - [["-P", "2"], 2], - [["--ploidy", "5"], 5], - ), + ("flags", "expected"), + [ + ([], None), + (["-P", "2"], 2), + (["--ploidy", "5"], 5), + ], ) def test_vcf_ploidy(self, flags, expected): parser = cli.get_tskit_parser() @@ -280,12 +281,12 @@ def test_vcf_ploidy(self, flags, expected): assert args.ploidy == expected @pytest.mark.parametrize( - "flags,expected", - ( - [[], "1"], - [["-c", "chrX"], "chrX"], - [["--contig-id", "chr20"], "chr20"], - ), + ("flags", "expected"), + [ + ([], "1"), + (["-c", "chrX"], "chrX"), + (["--contig-id", "chr20"], "chr20"), + ], ) def test_vcf_contig_id(self, flags, expected): parser = cli.get_tskit_parser() @@ -296,12 +297,12 @@ def test_vcf_contig_id(self, flags, expected): assert args.contig_id == expected @pytest.mark.parametrize( - "flags,expected", - ( - [[], False], - [["-0"], True], - [["--allow-position-zero"], True], - ), + ("flags", "expected"), + [ + ([], False), + (["-0"], True), + (["--allow-position-zero"], True), + ], ) def test_vcf_allow_position_zero(self, flags, expected): parser = cli.get_tskit_parser() @@ -408,9 +409,7 @@ def setUpClass(cls): cls._tree_sequence = tsutil.insert_random_ploidy_individuals( ts, samples_only=True ) - fd, cls._tree_sequence_file = tempfile.mkstemp( - prefix="tsk_cli", suffix=".trees" - ) + fd, cls._tree_sequence_file = tempfile.mkstemp(prefix="tsk_cli", suffix=".trees") os.close(fd) cls._tree_sequence.dump(cls._tree_sequence_file) diff --git a/python/tests/test_coalrate.py b/python/tests/test_coalrate.py index dcfcfc8d66..76a8bbd9ad 100644 --- a/python/tests/test_coalrate.py +++ b/python/tests/test_coalrate.py @@ -22,6 +22,7 @@ """ Test cases for coalescence rate calculation in tskit. """ + import itertools import msprime @@ -627,9 +628,7 @@ def example_ts(self): def test_oor_windows(self): ts = self.example_ts() with pytest.raises(ValueError, match="must be sequence boundary"): - ts.pair_coalescence_counts( - windows=np.array([0.0, 2.0]) * ts.sequence_length - ) + ts.pair_coalescence_counts(windows=np.array([0.0, 2.0]) * ts.sequence_length) def test_unsorted_windows(self): ts = self.example_ts() @@ -700,16 +699,12 @@ def test_output_dim(self): implm = ts.pair_coalescence_counts(sample_sets=ss, windows=None, indexes=None) assert implm.shape == (ts.num_nodes,) windows = np.linspace(0.0, ts.sequence_length, 2) - implm = ts.pair_coalescence_counts( - sample_sets=ss, windows=windows, indexes=None - ) + implm = ts.pair_coalescence_counts(sample_sets=ss, windows=windows, indexes=None) assert implm.shape == (1, ts.num_nodes) indexes = [(0, 1)] implm = ts.pair_coalescence_counts( sample_sets=ss, windows=windows, indexes=indexes ) assert implm.shape == (1, ts.num_nodes, 1) - implm = ts.pair_coalescence_counts( - sample_sets=ss, windows=None, indexes=indexes - ) + implm = ts.pair_coalescence_counts(sample_sets=ss, windows=None, indexes=indexes) assert implm.shape == (ts.num_nodes, 1) diff --git a/python/tests/test_combinatorics.py b/python/tests/test_combinatorics.py index 97c85af8e8..bc35a347d2 100644 --- a/python/tests/test_combinatorics.py +++ b/python/tests/test_combinatorics.py @@ -1,7 +1,7 @@ # # MIT License # -# Copyright (c) 2020-2023 Tskit Developers +# Copyright (c) 2020-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,6 +23,7 @@ """ Test cases for combinatorial algorithms. """ + import collections import io import itertools @@ -38,8 +39,7 @@ import tskit import tskit.combinatorics as comb from tests import test_stats -from tskit.combinatorics import Rank -from tskit.combinatorics import RankTree +from tskit.combinatorics import Rank, RankTree class TestCombination: @@ -1315,12 +1315,8 @@ def test_mutation_within_eps_parent(self): tables = tskit.Tree.generate_star(3).tree_sequence.dump_tables() site = tables.sites.add_row(position=0.5, ancestral_state="0") branch_length = np.nextafter(1, 0) - tables.mutations.add_row( - site=site, time=branch_length, node=0, derived_state="1" - ) - tables.mutations.add_row( - site=site, time=branch_length, node=1, derived_state="1" - ) + tables.mutations.add_row(site=site, time=branch_length, node=0, derived_state="1") + tables.mutations.add_row(site=site, time=branch_length, node=1, derived_state="1") tree = tables.tree_sequence().first() with pytest.raises( tskit.LibraryError, @@ -1344,7 +1340,7 @@ def test_kwargs(self): split_tree = tree.split_polytomies(random_seed=14, tracked_samples=[0, 1]) assert split_tree.num_tracked_samples() == 2 - @pytest.mark.slow + @pytest.mark.slow() @pytest.mark.parametrize("n", [3, 4, 5]) def test_all_topologies(self, n): N = num_leaf_labelled_binary_trees(n) @@ -1490,7 +1486,7 @@ class TestGenerateRandomBinary(TreeGeneratorTestBase): def method(self, n, **kwargs): return tskit.Tree.generate_random_binary(n, random_seed=53, **kwargs) - @pytest.mark.slow + @pytest.mark.slow() @pytest.mark.parametrize("n", [3, 4, 5]) def test_all_topologies(self, n): N = num_leaf_labelled_binary_trees(n) @@ -1523,7 +1519,7 @@ class TestGenerateComb(TreeGeneratorTestBase): method_name = "generate_comb" # Hard-code in some pre-computed ranks for the comb(n) tree. - @pytest.mark.parametrize(["n", "rank"], [(2, 0), (3, 1), (4, 3), (5, 8), (6, 20)]) + @pytest.mark.parametrize(("n", "rank"), [(2, 0), (3, 1), (4, 3), (5, 8), (6, 20)]) def test_unrank_equal(self, n, rank): for extra_params in [{}, {"span": 2.5}, {"branch_length": 3}]: ts = tskit.Tree.generate_comb(n, **extra_params).tree_sequence diff --git a/python/tests/test_dict_encoding.py b/python/tests/test_dict_encoding.py index 7ef4c70836..a16822cd94 100644 --- a/python/tests/test_dict_encoding.py +++ b/python/tests/test_dict_encoding.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2020 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,13 +23,14 @@ Test cases for the low-level dictionary encoding used to move data around in C. """ + import pathlib import pickle -import _tskit import lwt_interface.dict_encoding_testlib -import tskit +import _tskit +import tskit lwt_interface.dict_encoding_testlib.lwt_module = _tskit # Bring the tests defined in dict_encoding_testlib into the current namespace diff --git a/python/tests/test_divmat.py b/python/tests/test_divmat.py index ea83cc560d..9086c1c21f 100644 --- a/python/tests/test_divmat.py +++ b/python/tests/test_divmat.py @@ -22,6 +22,7 @@ """ Test cases for divergence matrix based pairwise stats """ + import array import collections import functools @@ -212,14 +213,8 @@ def branch_divergence_matrix(ts, sample_sets=None, windows=None, span_normalise= tu = ts.nodes_time[w] - ts.nodes_time[u] tv = ts.nodes_time[w] - ts.nodes_time[v] else: - tu = ( - ts.nodes_time[local_root(tree, u)] - - ts.nodes_time[u] - ) - tv = ( - ts.nodes_time[local_root(tree, v)] - - ts.nodes_time[v] - ) + tu = ts.nodes_time[local_root(tree, u)] - ts.nodes_time[u] + tv = ts.nodes_time[local_root(tree, v)] - ts.nodes_time[v] d = (tu + tv) * span D[i, j, k] += d tree.next() @@ -303,7 +298,7 @@ def stats_api_matrix_method( # contiguous, so that we just look at specific sections of the genome. drop = [] if windows[0] != 0: - windows = [0] + windows + windows = [0, *windows] drop.append(0) if windows[-1] != ts.sequence_length: windows.append(ts.sequence_length) @@ -656,9 +651,7 @@ def test_single_tree_multiroot(self, mode): ) np.testing.assert_array_equal(D1, D2) - @pytest.mark.parametrize( - ["left", "right"], [(0, 10), (1, 3), (3.25, 3.75), (5, 10)] - ) + @pytest.mark.parametrize(("left", "right"), [(0, 10), (1, 3), (3.25, 3.75), (5, 10)]) def test_single_tree_interval(self, left, right): # 2.00┊ 6 ┊ # ┊ ┏━┻━┓ ┊ @@ -667,9 +660,7 @@ def test_single_tree_interval(self, left, right): # 0.00┊ 0 1 2 3 ┊ # 0 1 ts = tskit.Tree.generate_balanced(4, span=10).tree_sequence - D1 = check_divmat( - ts, windows=[left, right], mode="branch", span_normalise=False - ) + D1 = check_divmat(ts, windows=[left, right], mode="branch", span_normalise=False) D2 = np.array( [ [0.0, 2.0, 4.0, 4.0], @@ -725,19 +716,20 @@ def test_all_trees_interval(self, interval, mode, span_normalise): check_divmat(ts, windows=interval, mode=mode, span_normalise=span_normalise) @pytest.mark.parametrize( - ["windows"], + "windows", [ - ([0, 26],), - ([0, 1, 2],), - (list(range(27)),), - ([5, 7, 9, 20],), - ([5.1, 5.2, 5.3, 5.5, 6],), - ([5.1, 5.2, 6.5],), + [0, 26], + [0, 1, 2], + list(range(27)), + [5, 7, 9, 20], + [5.1, 5.2, 5.3, 5.5, 6], + [5.1, 5.2, 6.5], ], ) @pytest.mark.parametrize("mode", DIVMAT_MODES) @pytest.mark.parametrize("span_normalise", [True, False]) def test_all_trees_windows(self, windows, mode, span_normalise): + print(windows) ts = tsutil.all_trees_ts(4) ts = tsutil.insert_branch_sites(ts) assert ts.sequence_length == 26 @@ -768,9 +760,7 @@ def test_small_sims(self, n, seed, mode): random_seed=seed, ) assert ts.num_trees >= 2 - ts = msprime.sim_mutations( - ts, rate=0.1, discrete_genome=False, random_seed=seed - ) + ts = msprime.sim_mutations(ts, rate=0.1, discrete_genome=False, random_seed=seed) assert ts.num_mutations > 1 check_divmat(ts, verbosity=0, mode=mode) @@ -1050,16 +1040,16 @@ def check(self, ts, num_threads, *, windows, samples=None, mode=None): @pytest.mark.parametrize("num_threads", [1, 2, 3, 5, 26, 27]) @pytest.mark.parametrize( - ["windows"], + "windows", [ - ([0, 26],), - ([0, 1, 2],), - (list(range(27)),), - ([5, 7, 9, 20],), - ([5.1, 5.2, 5.3, 5.5, 6],), - ([5.1, 5.2, 6.5],), - ("trees",), - ("sites",), + [0, 26], + [0, 1, 2], + list(range(27)), + [5, 7, 9, 20], + [5.1, 5.2, 5.3, 5.5, 6], + [5.1, 5.2, 6.5], + "trees", + "sites", ], ) @pytest.mark.parametrize("mode", DIVMAT_MODES) @@ -1070,12 +1060,12 @@ def test_all_trees(self, num_threads, windows, mode): @pytest.mark.parametrize("samples", [None, [0, 1]]) @pytest.mark.parametrize( - ["windows"], + "windows", [ - ([0, 26],), - (None,), - ("trees",), - ("sites",), + [0, 26], + None, + "trees", + "sites", ], ) @pytest.mark.parametrize("mode", DIVMAT_MODES) @@ -1085,15 +1075,15 @@ def test_all_trees_samples(self, samples, windows, mode): @pytest.mark.parametrize("num_threads", range(1, 5)) @pytest.mark.parametrize( - ["windows"], + "windows", [ - ([0, 100],), - ([0, 50, 75, 95, 100],), - ([50, 75, 95, 100],), - ([0, 50, 75, 95],), - (list(range(100)),), - ("trees",), - ("sites",), + [0, 100], + [0, 50, 75, 95, 100], + [50, 75, 95, 100], + [0, 50, 75, 95], + list(range(100)), + "trees", + "sites", ], ) @pytest.mark.parametrize("mode", DIVMAT_MODES) @@ -1120,7 +1110,7 @@ class TestChunkByTree: # These are based on what we get from np.array_split, there's nothing # particularly critical about exactly how we portion things up. @pytest.mark.parametrize( - ["num_chunks", "expected"], + ("num_chunks", "expected"), [ (1, [[0, 26]]), (2, [[0, 13], [13, 26]]), @@ -1135,7 +1125,7 @@ def test_all_trees_ts_26(self, num_chunks, expected): np.testing.assert_equal(actual, expected) @pytest.mark.parametrize( - ["num_chunks", "expected"], + ("num_chunks", "expected"), [ (1, [[0, 4]]), (2, [[0, 2], [2, 4]]), @@ -1153,7 +1143,7 @@ def test_all_trees_ts_4(self, num_chunks, expected): @pytest.mark.parametrize("span", [1, 2, 5, 0.3]) @pytest.mark.parametrize( - ["num_chunks", "expected"], + ("num_chunks", "expected"), [ (1, [[0, 4]]), (2, [[0, 2], [2, 4]]), @@ -1198,7 +1188,7 @@ class TestChunkWindows: # These are based on what we get from np.array_split, there's nothing # particularly critical about exactly how we portion things up. @pytest.mark.parametrize( - ["windows", "num_chunks", "expected"], + ("windows", "num_chunks", "expected"), [ ([0, 10], 1, [[0, 10]]), ([0, 10], 2, [[0, 10]]), @@ -1219,7 +1209,7 @@ def test_bad_chunks(self, num_chunks): class TestGroupAlleles: @pytest.mark.parametrize( - ["G", "num_alleles", "A", "offsets"], + ("G", "num_alleles", "A", "offsets"), [ ([0, 1], 2, [0, 1], [0, 1, 2]), ([0, 1], 3, [0, 1], [0, 1, 2, 2]), @@ -1263,7 +1253,7 @@ def test_simple_simulation(self): class TestSampleSetParsing: @pytest.mark.parametrize( - ["arg", "flattened", "sizes"], + ("arg", "flattened", "sizes"), [ ([], [], []), ([1], [1], [1]), diff --git a/python/tests/test_drawing.py b/python/tests/test_drawing.py index d9010b3c0b..24f0c16098 100644 --- a/python/tests/test_drawing.py +++ b/python/tests/test_drawing.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2023 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (C) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -23,6 +23,7 @@ """ Test cases for visualisation in tskit. """ + import collections import io import logging @@ -44,7 +45,6 @@ import tskit from tskit import drawing - IS_WINDOWS = platform.system() == "Windows" @@ -241,7 +241,8 @@ def get_ts_varying_min_times(self, *args, **kwargs): flags[3] = 0 tables.nodes.flags = flags edges = tables.edges - assert edges[3].child == 3 and edges[3].parent == 5 + assert edges[3].child == 3 + assert edges[3].parent == 5 edges[3] = edges[3].replace(left=ts.breakpoints(True)[1]) tables.sort() tables.nodes.flags = flags @@ -251,7 +252,7 @@ def fail(self, *args, **kwargs): """ Required for xmlunittest.XmlTestMixin to work with pytest not unittest """ - pytest.fail(*args, **kwargs) + pytest.fail(*args, **kwargs) # noqa: PT016 def closest_left_node(tree, u): @@ -684,13 +685,13 @@ def test_simple_tree(self): 0 1 2 1 """ ) + # fmt: off tree = ( - # fmt: off " 2 \n" "┏┻┓\n" "0 1" - # fmt: on ) + # fmt: on ts = tskit.load_text(nodes, edges, strict=False) t = next(ts.trees()) drawn = t.draw(format="unicode", order="tree") @@ -698,32 +699,33 @@ def test_simple_tree(self): drawn = t.draw_text() self.verify_text_rendering(drawn, tree) + # fmt: off tree = ( - # fmt: off " 2 \n" "+++\n" "0 1\n" - # fmt: on ) + # fmt: on drawn = t.draw_text(use_ascii=True, order="tree") self.verify_text_rendering(drawn, tree) + # fmt: off tree = ( - # fmt: off " ┏0\n" "2┫ \n" " ┗1\n" - # fmt: on ) + # fmt: on drawn = t.draw_text(orientation="left", order="tree") self.verify_text_rendering(drawn, tree) + + # fmt: off tree = ( - # fmt: off " +0\n" "2+ \n" " +1\n" - # fmt: on ) + # fmt: on drawn = t.draw_text(orientation="left", use_ascii=True, order="tree") self.verify_text_rendering(drawn, tree) @@ -743,25 +745,25 @@ def test_simple_tree_long_label(self): 0 1 2 1 """ ) + # fmt: off tree = ( - # fmt: off "ABCDEF\n" "┏┻┓ \n" "0 1 \n" - # fmt: on ) + # fmt: on ts = tskit.load_text(nodes, edges, strict=False) t = next(ts.trees()) drawn = t.draw_text(node_labels={0: "0", 1: "1", 2: "ABCDEF"}, order="tree") self.verify_text_rendering(drawn, tree) + # fmt: off tree = ( - # fmt: off "0┓ \n" " ┣ABCDEF\n" "1┛ \n" - # fmt: on ) + # fmt: on drawn = t.draw_text( node_labels={0: "0", 1: "1", 2: "ABCDEF"}, orientation="right", order="tree" ) @@ -770,22 +772,22 @@ def test_simple_tree_long_label(self): drawn = t.draw_text( node_labels={0: "ABCDEF", 1: "1", 2: "2"}, orientation="right", order="tree" ) + # fmt: off tree = ( - # fmt: off "ABCDEF┓ \n" " ┣2\n" "1━━━━━┛ \n" - # fmt: on ) + # fmt: on self.verify_text_rendering(drawn, tree) + # fmt: off tree = ( - # fmt: off " ┏0\n" "ABCDEF┫ \n" " ┗1\n" - # fmt: on ) + # fmt: on drawn = t.draw_text( node_labels={0: "0", 1: "1", 2: "ABCDEF"}, orientation="left", order="tree" ) @@ -934,40 +936,40 @@ def test_trident_tree(self): 0 1 3 2 """ ) + # fmt: off tree = ( - # fmt: off " 3 \n" "┏━╋━┓\n" "0 1 2\n" - # fmt: on ) + # fmt: on ts = tskit.load_text(nodes, edges, strict=False) t = next(ts.trees()) drawn = t.draw(format="unicode", order="tree") self.verify_text_rendering(drawn, tree) self.verify_text_rendering(t.draw_text(), tree) + # fmt: off tree = ( - # fmt: off " ┏0\n" " ┃\n" "3╋1\n" " ┃\n" " ┗2\n" - # fmt: on ) + # fmt: on drawn = t.draw_text(orientation="left") self.verify_text_rendering(drawn, tree) + # fmt: off tree = ( - # fmt: off "0┓\n" " ┃\n" "1╋3\n" " ┃\n" "2┛\n" - # fmt: on ) + # fmt: on drawn = t.draw_text(orientation="right") self.verify_text_rendering(drawn, tree) @@ -993,43 +995,43 @@ def test_pitchfork_tree(self): ) ts = tskit.load_text(nodes, edges, strict=False) t = next(ts.trees()) + # fmt: off tree = ( - # fmt: off " 4 \n" "┏━┳┻┳━┓\n" "0 1 2 3\n" - # fmt: on ) + # fmt: on drawn = t.draw(format="unicode", order="tree") self.verify_text_rendering(drawn, tree) self.verify_text_rendering(t.draw_text(), tree) # No labels + # fmt: off tree = ( - # fmt: off " ┃ \n" "┏━┳┻┳━┓\n" "┃ ┃ ┃ ┃\n" - # fmt: on ) + # fmt: on drawn = t.draw(format="unicode", node_labels={}, order="tree") self.verify_text_rendering(drawn, tree) self.verify_text_rendering(t.draw_text(node_labels={}), tree) # Some labels + # fmt: off tree = ( - # fmt: off " ┃ \n" "┏━┳┻┳━┓\n" "0 ┃ ┃ 3\n" - # fmt: on ) + # fmt: on labels = {0: "0", 3: "3"} drawn = t.draw(format="unicode", node_labels=labels, order="tree") self.verify_text_rendering(drawn, tree) self.verify_text_rendering(t.draw_text(node_labels=labels), tree) + # fmt: off tree = ( - # fmt: off " ┏0\n" " ┃\n" " ┣1\n" @@ -1037,13 +1039,13 @@ def test_pitchfork_tree(self): " ┣2\n" " ┃\n" " ┗3\n" - # fmt: on ) + # fmt: on drawn = t.draw_text(orientation="left") self.verify_text_rendering(drawn, tree) + # fmt: off tree = ( - # fmt: off "0┓\n" " ┃\n" "1┫\n" @@ -1051,8 +1053,8 @@ def test_pitchfork_tree(self): "2┫\n" " ┃\n" "3┛\n" - # fmt: on ) + # fmt: on drawn = t.draw_text(orientation="right") self.verify_text_rendering(drawn, tree) @@ -1072,30 +1074,30 @@ def test_stick_tree(self): 0 1 2 1 """ ) + # fmt: off tree = ( - # fmt: off "2\n" "┃\n" "1\n" "┃\n" "0\n" - # fmt: on ) + # fmt: on ts = tskit.load_text(nodes, edges, strict=False) t = next(ts.trees()) drawn = t.draw(format="unicode", order="tree") self.verify_text_rendering(drawn, tree) self.verify_text_rendering(t.draw_text(), tree) + # fmt: off tree = ( - # fmt: off "0\n" "┃\n" "1\n" "┃\n" "2\n" - # fmt: on ) + # fmt: on drawn = t.draw_text(orientation="bottom") self.verify_text_rendering(drawn, tree) @@ -1573,17 +1575,11 @@ def test_draw_defaults(self): svg = t.draw_svg() self.verify_basic_svg(svg) - @pytest.mark.parametrize("y_axis", (True, False)) - @pytest.mark.parametrize("y_label", (True, False)) - @pytest.mark.parametrize( - "time_scale", - ( - "rank", - "time", - ), - ) - @pytest.mark.parametrize("y_ticks", ([], [0, 1], None)) - @pytest.mark.parametrize("y_gridlines", (True, False)) + @pytest.mark.parametrize("y_axis", [True, False]) + @pytest.mark.parametrize("y_label", [True, False]) + @pytest.mark.parametrize("time_scale", ["rank", "time"]) + @pytest.mark.parametrize("y_ticks", [[], [0, 1], None]) + @pytest.mark.parametrize("y_gridlines", [True, False]) def test_draw_svg_y_axis_parameter_combos( self, y_axis, y_label, time_scale, y_ticks, y_gridlines ): @@ -1862,7 +1858,7 @@ def test_min_ts_time(self): # def test_all_edges_colour(self): t = self.get_binary_tree() - colours = {u: "rgb({u},255,{u})".format(u=u) for u in t.nodes() if u != t.root} + colours = {u: f"rgb({u},255,{u})" for u in t.nodes() if u != t.root} svg = t.draw(format="svg", edge_colours=colours) self.verify_basic_svg(svg) for colour in colours.values(): @@ -1918,9 +1914,7 @@ def test_one_mutation_colour(self): def test_all_mutations_colour(self): t = self.get_binary_tree() - colours = { - mut.id: f"rgb({mut.id}, {mut.id}, {mut.id})" for mut in t.mutations() - } + colours = {mut.id: f"rgb({mut.id}, {mut.id}, {mut.id})" for mut in t.mutations()} svg = t.draw(format="svg", mutation_colours=colours) self.verify_basic_svg(svg) for colour in colours.values(): @@ -2347,9 +2341,7 @@ def test_xlim_maintains_tree_ids(self): svg = ts.draw_svg(x_lim=[breaks[1], breaks[4]]) assert "t0" not in svg assert "t4" not in svg - svg = ts.draw_svg( - x_lim=[np.nextafter(breaks[1], 0), np.nextafter(breaks[4], 1)] - ) + svg = ts.draw_svg(x_lim=[np.nextafter(breaks[1], 0), np.nextafter(breaks[4], 1)]) assert "t0" in svg assert "t4" in svg @@ -2422,7 +2414,7 @@ def test_tree_root_branch(self): assert snippet2a.startswith('"M 0 0') assert snippet2b.startswith('"M 0 0') assert "H 0" in snippet1 - assert not ("H 0" in snippet2a) # No root branch + assert "H 0" not in snippet2a # No root branch assert "H 0" in snippet2b def test_debug_box(self): @@ -2475,9 +2467,7 @@ class TestDrawKnownSvg(TestDrawSvgBase): def verify_known_svg(self, svg, filename, save=False, **kwargs): # expected SVG files can be inspected in tests/data/svg/*.svg - svg = xml.dom.minidom.parseString( - svg - ).toprettyxml() # Prettify for easy viewing + svg = xml.dom.minidom.parseString(svg).toprettyxml() # Prettify for easy viewing self.verify_basic_svg(svg, **kwargs) svg_fn = pathlib.Path(__file__).parent / "data" / "svg" / filename if save: diff --git a/python/tests/test_extend_edges.py b/python/tests/test_extend_edges.py index 7ee8c8f471..8764283ca3 100644 --- a/python/tests/test_extend_edges.py +++ b/python/tests/test_extend_edges.py @@ -48,9 +48,7 @@ def _slide_mutation_nodes_up(ts, mutations): mut = 0 for tree in ts.trees(): _, right = tree.interval - while ( - mut < mutations.num_rows and ts.sites_position[mutations.site[mut]] < right - ): + while mut < mutations.num_rows and ts.sites_position[mutations.site[mut]] < right: t = mutations.time[mut] c = mutations.node[mut] p = tree.parent(c) @@ -282,11 +280,7 @@ def verify_extend_edges(self, ts, max_iter=10): else: this_chains = chains[j] for a, b, c in this_chains: - if ( - a in tt.nodes() - and tt.parent(a) == c - and b not in tt.nodes() - ): + if a in tt.nodes() and tt.parent(a) == c and b not in tt.nodes(): # the relationship a <- b <- c should still be in the tree, # although maybe they aren't direct parent-offspring # UNLESS we've got an ambiguous case, where on the opposite @@ -345,9 +339,7 @@ def test_migrations_disallowed(self): tables.populations.add_row() tables.migrations.add_row(0, 1, 0, 0, 1, 0) ts = tables.tree_sequence() - with pytest.raises( - _tskit.LibraryError, match="TSK_ERR_MIGRATIONS_NOT_SUPPORTED" - ): + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_MIGRATIONS_NOT_SUPPORTED"): _ = ts.extend_edges() def test_unknown_times(self): diff --git a/python/tests/test_file_format.py b/python/tests/test_file_format.py index 2de38c487e..3d7d061fde 100644 --- a/python/tests/test_file_format.py +++ b/python/tests/test_file_format.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2023 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (c) 2016-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -23,6 +23,7 @@ """ Test cases for tskit's file format. """ + import json import os import sys @@ -633,9 +634,7 @@ def verify_dump_format(self, ts): assert np.array_equal(tables.metadata, b"".join(store["metadata"])) assert np.array_equal(tables.individuals.flags, store["individuals/flags"]) - assert np.array_equal( - tables.individuals.location, store["individuals/location"] - ) + assert np.array_equal(tables.individuals.location, store["individuals/location"]) assert np.array_equal( tables.individuals.location_offset, store["individuals/location_offset"] ) @@ -643,9 +642,7 @@ def verify_dump_format(self, ts): assert np.array_equal( tables.individuals.parents_offset, store["individuals/parents_offset"] ) - assert np.array_equal( - tables.individuals.metadata, store["individuals/metadata"] - ) + assert np.array_equal(tables.individuals.metadata, store["individuals/metadata"]) assert np.array_equal( tables.individuals.metadata_offset, store["individuals/metadata_offset"] ) @@ -749,9 +746,7 @@ def verify_dump_format(self, ts): store["mutations/metadata_schema"].astype("U") ) - assert np.array_equal( - tables.populations.metadata, store["populations/metadata"] - ) + assert np.array_equal(tables.populations.metadata, store["populations/metadata"]) assert np.array_equal( tables.populations.metadata_offset, store["populations/metadata_offset"] ) @@ -988,9 +983,7 @@ def verify_missing_fields(self, ts): data = dict(all_data) del data[key] kastore.dump(data, self.temp_file) - with pytest.raises( - (exceptions.FileFormatError, exceptions.LibraryError) - ): + with pytest.raises((exceptions.FileFormatError, exceptions.LibraryError)): tskit.load(self.temp_file) def test_missing_fields(self): @@ -1000,9 +993,7 @@ def verify_equal_length_columns(self, ts, table): ts.dump(self.temp_file) with kastore.load(self.temp_file) as store: all_data = dict(store) - table_cols = [ - colname for colname in all_data.keys() if colname.startswith(table) - ] + table_cols = [colname for colname in all_data.keys() if colname.startswith(table)] # Remove all the 'offset' columns for col in list(table_cols): if col.endswith("_offset"): @@ -1293,9 +1284,7 @@ def test_table_collection_load_stream(self, tmp_path, ts_fixture): save_path = tmp_path / "tmp.trees" ts_fixture.dump(save_path) with open(save_path, "rb") as f: - tables_no_refseq = tskit.TableCollection.load( - f, skip_reference_sequence=True - ) + tables_no_refseq = tskit.TableCollection.load(f, skip_reference_sequence=True) tables = ts_fixture.tables assert not tables_no_refseq.equals(tables) assert tables_no_refseq.equals(tables, ignore_reference_sequence=True) diff --git a/python/tests/test_fileobj.py b/python/tests/test_fileobj.py index 1740094462..47e1618c81 100644 --- a/python/tests/test_fileobj.py +++ b/python/tests/test_fileobj.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2023 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,6 +22,7 @@ """ Test cases for loading and dumping different types of files and streams """ + import io import multiprocessing import os @@ -36,17 +37,15 @@ import pytest import tszip -from pytest import fixture import tskit - IS_WINDOWS = platform.system() == "Windows" IS_OSX = platform.system() == "Darwin" class TestPath: - @fixture + @pytest.fixture() def tempfile_name(self): with tempfile.TemporaryDirectory() as tmp_dir: yield f"{tmp_dir}/plain_path" @@ -58,7 +57,7 @@ def test_pathlib(self, ts_fixture, tempfile_name): class TestPathLib: - @fixture + @pytest.fixture() def pathlib_tempfile(self): fd, path = tempfile.mkstemp(prefix="tskit_test_pathlib") os.close(fd) @@ -73,7 +72,7 @@ def test_pathlib(self, ts_fixture, pathlib_tempfile): class TestFileObj: - @fixture + @pytest.fixture() def fileobj(self): with tempfile.TemporaryDirectory() as tmp_dir: with open(f"{tmp_dir}/fileobj", "wb") as f: @@ -100,7 +99,7 @@ def test_fileobj_multi(self, replicate_ts_fixture, fileobj): class TestFileObjRW: - @fixture + @pytest.fixture() def fileobj(self): with tempfile.TemporaryDirectory() as tmp_dir: pathlib.Path(f"{tmp_dir}/fileobj").touch() @@ -127,7 +126,7 @@ def test_fileobj_multi(self, replicate_ts_fixture, fileobj): class TestFD: - @fixture + @pytest.fixture() def fd(self): with tempfile.TemporaryDirectory() as tmp_dir: pathlib.Path(f"{tmp_dir}/fd").touch() @@ -245,7 +244,7 @@ def stream(fifo, ts_list): @pytest.mark.skipif(IS_WINDOWS, reason="No FIFOs on Windows") @pytest.mark.skipif(IS_OSX, reason="FIFO flakey on OS X, issue #1170") class TestFIFO: - @fixture + @pytest.fixture() def fifo(self): temp_dir = tempfile.mkdtemp(prefix="tsk_test_streaming") temp_fifo = os.path.join(temp_dir, "fifo") @@ -288,7 +287,7 @@ def server_process(q): @pytest.mark.skipif(IS_WINDOWS or IS_OSX, reason="Errors on systems without proper fds") class TestSocket: - @fixture + @pytest.fixture() def client_fd(self): # Use a queue to synchronise the startup of the server and the client. q = multiprocessing.Queue() diff --git a/python/tests/test_genotype_matching.py b/python/tests/test_genotype_matching.py index a04e3873a6..2b0809a27a 100644 --- a/python/tests/test_genotype_matching.py +++ b/python/tests/test_genotype_matching.py @@ -300,9 +300,7 @@ def update_tree(self): st.value_list.append( InternalValueTransition( tree_node=edge.parent, - value=st.value_list.copy()[ - T_index[edge.child] - ].value, + value=st.value_list.copy()[T_index[edge.child]].value, ) ) else: @@ -310,7 +308,8 @@ def update_tree(self): while T_index[u] == -1: u = parent[u] assert u != -1 - assert T_index[u] != -1 and T_index[edge.child] != -1 + assert T_index[u] != -1 + assert T_index[edge.child] != -1 if ( T[T_index[u]].value_list == T[T_index[edge.child]].value_list ): # DEV: is this fine? @@ -526,8 +525,7 @@ def process_site( if st1.tree_node != tskit.NULL: for st2 in st1.value_list: st2.value = ( - ((self.rho[site.id] / self.ts.num_samples) ** 2) - * b_last_sum + ((self.rho[site.id] / self.ts.num_samples) ** 2) * b_last_sum + (1 - self.rho[site.id]) * (self.rho[site.id] / self.ts.num_samples) * st2.inner_summation @@ -1168,9 +1166,7 @@ def ls_forward_tree(g, ts, rho, mu, precision=30): def ls_backward_tree(g, ts_mirror, rho, mu, normalisation_factor, precision=30): """Backward matrix computation based on a tree sequence.""" - ba = BackwardAlgorithm( - ts_mirror, rho, mu, normalisation_factor, precision=precision - ) + ba = BackwardAlgorithm(ts_mirror, rho, mu, normalisation_factor, precision=precision) return ba.run_backward(g) @@ -1226,7 +1222,7 @@ def example_genotypes(self, ts): return H, G, genotypes def example_parameters_genotypes(self, ts, seed=42): - np.random.seed(seed) + rng = np.random.default_rng(seed) H, G, genotypes = self.example_genotypes(ts) n = H.shape[1] m = ts.get_num_sites() @@ -1242,8 +1238,8 @@ def example_parameters_genotypes(self, ts, seed=42): yield n, m, G, s, e, r, mu # Mixture of random and extremes - rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] - mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] + rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, rng.random(m)] + mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, rng.random(m) * 0.33] s = genotypes[0] for r, mu in itertools.product(rs, mus): @@ -1258,9 +1254,7 @@ def assertAllClose(self, A, B): # Define a bunch of very small tree-sequences for testing a collection of # parameters on def test_simple_n_10_no_recombination(self): - ts = msprime.simulate( - 10, recombination_rate=0, mutation_rate=0.5, random_seed=42 - ) + ts = msprime.simulate(10, recombination_rate=0, mutation_rate=0.5, random_seed=42) assert ts.num_sites > 3 self.verify(ts) diff --git a/python/tests/test_genotypes.py b/python/tests/test_genotypes.py index 329867b600..29be324d29 100644 --- a/python/tests/test_genotypes.py +++ b/python/tests/test_genotypes.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2023 Tskit Developers +# Copyright (c) 2019-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,6 +22,7 @@ """ Test cases for generating genotypes/haplotypes. """ + import itertools import logging import random @@ -107,9 +108,7 @@ def test_single_tree(self): self.verify(ts) def test_many_trees(self): - ts = msprime.simulate( - 8, recombination_rate=10, mutation_rate=10, random_seed=234 - ) + ts = msprime.simulate(8, recombination_rate=10, mutation_rate=10, random_seed=234) assert ts.num_trees > 1 assert ts.num_sites > 1 self.verify(ts) @@ -668,7 +667,7 @@ def test_simple_case(self, ts_fixture): assert v.alleles == test_variant.alleles @pytest.mark.parametrize( - ["left", "expected"], + ("left", "expected"), [ (None, [0, 1, 2, 3, 4]), (0, [0, 1, 2, 3, 4]), @@ -690,7 +689,7 @@ def test_left(self, left, expected): assert positions == expected @pytest.mark.parametrize( - ["right", "expected"], + ("right", "expected"), [ (None, [0, 1, 2, 3, 4]), (5, [0, 1, 2, 3, 4]), @@ -913,9 +912,7 @@ def test_missing_data(self): h = list(ts.haplotypes(isolated_as_missing=False, impute_missing_data=True)) assert h == ["A", "A"] with pytest.warns(FutureWarning): - h = list( - ts.haplotypes(isolated_as_missing=False, impute_missing_data=False) - ) + h = list(ts.haplotypes(isolated_as_missing=False, impute_missing_data=False)) assert h == ["A", "A"] def test_restrict_samples(self): @@ -967,9 +964,7 @@ def test_simple_01_trailing_alleles(self): alleles = ("0", "1", "2", "xxxxx") G2 = ts.genotype_matrix(alleles=alleles) assert np.array_equal(G1, G2) - for v1, v2 in itertools.zip_longest( - ts.variants(), ts.variants(alleles=alleles) - ): + for v1, v2 in itertools.zip_longest(ts.variants(), ts.variants(alleles=alleles)): assert v2.alleles == alleles assert v1.site == v2.site assert np.array_equal(v1.genotypes, v2.genotypes) @@ -981,9 +976,7 @@ def test_simple_01_leading_alleles(self): alleles = ("A", "B", "C", "0", "1") G2 = ts.genotype_matrix(alleles=alleles) assert np.array_equal(G1 + 3, G2) - for v1, v2 in itertools.zip_longest( - ts.variants(), ts.variants(alleles=alleles) - ): + for v1, v2 in itertools.zip_longest(ts.variants(), ts.variants(alleles=alleles)): assert v2.alleles == alleles assert v1.site == v2.site assert np.array_equal(v1.genotypes + 3, v2.genotypes) @@ -997,9 +990,7 @@ def test_simple_01_duplicate_alleles(self): index = np.where(G1 == 1) G1[index] = 2 assert np.array_equal(G1, G2) - for v1, v2 in itertools.zip_longest( - ts.variants(), ts.variants(alleles=alleles) - ): + for v1, v2 in itertools.zip_longest(ts.variants(), ts.variants(alleles=alleles)): assert v2.alleles == alleles assert v1.site == v2.site g = np.array(v1.genotypes) @@ -1015,9 +1006,7 @@ def test_simple_acgt(self): assert ts.num_sites > 2 alleles = tskit.ALLELES_ACGT G = ts.genotype_matrix(alleles=alleles) - for v1, v2 in itertools.zip_longest( - ts.variants(), ts.variants(alleles=alleles) - ): + for v1, v2 in itertools.zip_longest(ts.variants(), ts.variants(alleles=alleles)): assert v2.alleles == alleles assert v1.site == v2.site h1 = "".join(v1.alleles[g] for g in v1.genotypes) @@ -1091,9 +1080,7 @@ class TestUserAllelesRoundTrip: """ def verify(self, ts, alleles): - for v1, v2 in itertools.zip_longest( - ts.variants(), ts.variants(alleles=alleles) - ): + for v1, v2 in itertools.zip_longest(ts.variants(), ts.variants(alleles=alleles)): h1 = [v1.alleles[g] for g in v1.genotypes] h2 = [v2.alleles[g] for g in v2.genotypes] assert h1 == h2 @@ -1534,9 +1521,7 @@ def test_alignments_fails(self): @pytest.mark.skip("Missing data in alignments: #1896") def test_alignments_impute_missing(self): ref = "N" * 10 - A = list( - self.ts().alignments(reference_sequence=ref, isolated_as_missing=False) - ) + A = list(self.ts().alignments(reference_sequence=ref, isolated_as_missing=False)) assert A[0] == "NNGNNNNNNT" assert A[1] == "NNANNNNNNC" assert A[2] == "NNANNNNNNC" @@ -1570,9 +1555,7 @@ def test_alignments_reference_sequence(self): @pytest.mark.skip("Missing data in alignments: #1896") def test_alignments_reference_sequence_missing_data_char(self): ref = "0123456789" - A = list( - self.ts().alignments(reference_sequence=ref, missing_data_character="Q") - ) + A = list(self.ts().alignments(reference_sequence=ref, missing_data_character="Q")) assert A[0] == "01G345678T" assert A[1] == "01A345678C" assert A[2] == "01A345678C" @@ -1658,9 +1641,7 @@ def test_nexus_reference_sequence(self): END; """ ) - assert expected == self.ts().as_nexus( - reference_sequence=ref, include_trees=False - ) + assert expected == self.ts().as_nexus(reference_sequence=ref, include_trees=False) @pytest.mark.skip("Missing data in alignments: #1896") def test_nexus_reference_sequence_missing_data_char(self): @@ -1816,13 +1797,13 @@ def test_reference_sequence_length_mismatch(self, ref): with pytest.raises(ValueError, match="shorter than"): list(ts.alignments(reference_sequence=ref)) - @pytest.mark.parametrize("ref", ["À", "┃", "α"]) + @pytest.mark.parametrize("ref", ["À", "┃", "α"]) # noqa: RUF001 def test_non_ascii_references(self, ref): ts = self.simplest_ts() with pytest.raises(UnicodeEncodeError): list(ts.alignments(reference_sequence=ref)) - @pytest.mark.parametrize("ref", ["À", "┃", "α"]) + @pytest.mark.parametrize("ref", ["À", "┃", "α"]) # noqa: RUF001 def test_non_ascii_embedded_references(self, ref): tables = tskit.TableCollection(1) tables.nodes.add_row(flags=1, time=0) @@ -1831,7 +1812,7 @@ def test_non_ascii_embedded_references(self, ref): with pytest.raises(UnicodeEncodeError): list(ts.alignments()) - @pytest.mark.parametrize("missing_data_char", ["À", "┃", "α"]) + @pytest.mark.parametrize("missing_data_char", ["À", "┃", "α"]) # noqa: RUF001 def test_non_ascii_missing_data_char(self, missing_data_char): ts = self.simplest_ts() with pytest.raises(UnicodeEncodeError): @@ -1899,7 +1880,7 @@ def test_reference_sequence(self, ts): # Tests for allele_remap # @pytest.mark.parametrize( - "alleles_from, alleles_to, allele_map", + ("alleles_from", "alleles_to", "allele_map"), [ # Case 1: alleles_to is longer than alleles_from. ( @@ -1957,8 +1938,8 @@ def test_reference_sequence(self, ts): ), # Case 10: Lists contain unicode characters. ( - ["\u1F1E8", "\u1F1EC"], - ["\u1F1EC", "\u1F1E8", "\u1F1E6", "\u1F1F3"], + ["\u1f1e8", "\u1f1eC"], + ["\u1f1eC", "\u1f1e8", "\u1f1e6", "\u1f1f3"], np.array([1, 0], dtype="uint32"), ), ], @@ -1987,14 +1968,14 @@ def test_not_decoded(self, ts_fixture): variant = tskit.Variant(ts_fixture) assert variant.index == tskit.NULL with pytest.raises(ValueError, match="not yet been decoded"): - variant.site + variant.site # noqa: B018 assert variant.alleles == () with pytest.raises(ValueError, match="not yet been decoded"): assert variant.genotypes assert not variant.has_missing_data assert variant.num_alleles == 0 with pytest.raises(ValueError, match="not yet been decoded"): - variant.position + variant.position # noqa: B018 assert np.array_equal(variant.samples, np.array(ts_fixture.samples())) def test_variant_decode(self, ts_fixture): @@ -2040,9 +2021,7 @@ def test_variant_simple_frequencies(self): tables.sites.add_row(position=0.6, ancestral_state="AS1") tables.mutations.add_row(site=0, derived_state="DS0_0", node=0) tables.mutations.add_row(site=0, derived_state="DS0_3", node=3) - tables.mutations.add_row( - site=1, derived_state="DS1", node=simple_tree.parent(0) - ) + tables.mutations.add_row(site=1, derived_state="DS1", node=simple_tree.parent(0)) ts = tables.tree_sequence() variant_0 = next(ts.variants()) freqs = variant_0.frequencies() @@ -2166,9 +2145,7 @@ def test_variant_str(self): ╟─+┼─+╢ ║Isolated as missing\s*│\s*True║ ╚═+╧═+╝ - """[ - 1: - ] + """[1:] ), str(v), ) @@ -2219,7 +2196,8 @@ def test_variant_html_repr_no_site(self): def test_variant_repr(self, ts_fixture): v = next(ts_fixture.variants()) str_rep = repr(v) - assert len(str_rep) > 0 and len(str_rep) < 10000 + assert len(str_rep) > 0 + assert len(str_rep) < 10000 assert re.search(r"\AVariant", str_rep) assert re.search(rf"\'site\': Site\(id={v.site.id}", str_rep) assert re.search(rf"position={v.position}", str_rep) diff --git a/python/tests/test_haplotype_matching.py b/python/tests/test_haplotype_matching.py index dcc1d684fb..ee8a18e65e 100644 --- a/python/tests/test_haplotype_matching.py +++ b/python/tests/test_haplotype_matching.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2023 Tskit Developers +# Copyright (c) 2019-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,6 +22,7 @@ """ Python implementation of the Li and Stephens forwards and backwards algorithms. """ + import warnings import lshmm as ls @@ -162,9 +163,7 @@ def compute(u, parent_state): optimal_set[u, value_count == max_value_count] = 1 optimal_set = np.zeros((tree.tree_sequence.num_nodes, len(values)), dtype=int) - t_node_time = [ - -1 if st.tree_node == -1 else tree.time(st.tree_node) for st in T - ] + t_node_time = [-1 if st.tree_node == -1 else tree.time(st.tree_node) for st in T] order = np.argsort(t_node_time) for j in order: st = T[j] @@ -258,9 +257,7 @@ def update_tree(self, direction=tskit.FORWARD): u = parent[u] assert u != -1 T_index[edge.child] = len(T) - T.append( - ValueTransition(tree_node=edge.child, value=T[T_index[u]].value) - ) + T.append(ValueTransition(tree_node=edge.child, value=T[T_index[u]].value)) parent[edge.child] = -1 for j in range( @@ -284,7 +281,8 @@ def update_tree(self, direction=tskit.FORWARD): while T_index[u] == -1: u = parent[u] assert u != -1 - assert T_index[u] != -1 and T_index[edge.child] != -1 + assert T_index[u] != -1 + assert T_index[edge.child] != -1 if T[T_index[u]].value == T[T_index[edge.child]].value: st = T[T_index[edge.child]] # Mark the lower ValueTransition as unused. @@ -333,9 +331,7 @@ def update_probabilities(self, site, haplotype_state): while allelic_state[v] == -1: v = tree.parent(v) assert v != -1 - match = ( - haplotype_state == MISSING or haplotype_state == allelic_state[v] - ) + match = haplotype_state == MISSING or haplotype_state == allelic_state[v] # Note that the node u is used only by Viterbi st.value = self.compute_next_probability(site.id, st.value, match, u) @@ -772,7 +768,6 @@ def example_haplotypes(self, ts): def example_parameters_haplotypes(self, ts, seed=42): """Returns an iterator over combinations of haplotype, recombination and mutation rates.""" - np.random.seed(seed) H, haplotypes = self.example_haplotypes(ts) n = H.shape[1] m = ts.get_num_sites() @@ -789,8 +784,9 @@ def example_parameters_haplotypes(self, ts, seed=42): # We'll be refactoring all this to use pytest anyway, so let's not # worry too much about coverage for now. # # Mixture of random and extremes - # rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, np.random.rand(m)] - # mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, np.random.rand(m) * 0.33] + # rng = np.random.default_rng(seed) + # rs = [np.zeros(m) + 0.999, np.zeros(m) + 1e-6, rng.random(m)] + # mus = [np.zeros(m) + 0.33, np.zeros(m) + 1e-6, rng.random(m) * 0.33] # import itertools # for s, r, mu in itertools.product(haplotypes, rs, mus): @@ -804,9 +800,7 @@ def assertAllClose(self, A, B): # Define a bunch of very small tree-sequences for testing a collection # of parameters on def test_simple_n_10_no_recombination(self): - ts = msprime.simulate( - 10, recombination_rate=0, mutation_rate=0.5, random_seed=42 - ) + ts = msprime.simulate(10, recombination_rate=0, mutation_rate=0.5, random_seed=42) assert ts.num_sites > 3 self.verify(ts) diff --git a/python/tests/test_highlevel.py b/python/tests/test_highlevel.py index 4a9337dcb4..e8e39a29d1 100644 --- a/python/tests/test_highlevel.py +++ b/python/tests/test_highlevel.py @@ -23,6 +23,7 @@ """ Test cases for the high level interface to tskit. """ + import collections import dataclasses import decimal @@ -627,8 +628,8 @@ def f(node=None, order=None): assert f(q, order="inorder") == [0, 7, 1, q, 2, 8, 3, 4, 9, 5, 6] assert f(q, order="levelorder") == [q, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6] assert f(q, order="breadthfirst") == [q, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6] - assert f(q, order="timeasc") == list(range(10)) + [q] - assert f(q, order="timedesc") == [q] + list(reversed(range(10))) + assert f(q, order="timeasc") == [*list(range(10)), q] + assert f(q, order="timedesc") == [q, *list(reversed(range(10)))] assert f(q, order="minlex_postorder") == [0, 1, 7, 2, 3, 8, 4, 5, 6, 9, q] assert f(9, order="preorder") == [9, 4, 5, 6] @@ -728,7 +729,7 @@ def f(node=None, order=None): assert minlex == [0, 2, 10, 9, 11, 5, 13, 1, 8, 14, 3, 7, 12, 4, 6] @pytest.mark.parametrize( - ["order", "expected"], + ("order", "expected"), [ ("preorder", [[9, 6, 2, 3, 7, 4, 5, 0, 1], [10, 4, 8, 5, 0, 1, 6, 2, 3]]), ("inorder", [[2, 6, 3, 9, 4, 7, 0, 5, 1], [4, 10, 0, 5, 1, 8, 2, 6, 3]]), @@ -897,7 +898,7 @@ class TestMRCA: # ┃ ┏┻┓ # 0 1 2 - @pytest.mark.parametrize("args, expected", [((2, 1), 3), ((0, 1, 2), 4)]) + @pytest.mark.parametrize(("args", "expected"), [((2, 1), 3), ((0, 1, 2), 4)]) def test_two_or_more_args(self, args, expected): assert self.t.mrca(*args) == expected assert self.t.tmrca(*args) == self.t.tree_sequence.nodes_time[expected] @@ -1137,10 +1138,10 @@ def verify_trees(self, ts): if len(roots) == 0: assert st1.root == tskit.NULL elif len(roots) == 1: - assert st1.root == list(roots)[0] + assert st1.root == next(iter(roots)) else: with pytest.raises(ValueError): - st1.root + st1.root # noqa: B018 assert st2 == st1 assert not (st2 != st1) left, right = st1.get_interval() @@ -1489,7 +1490,7 @@ def verify_pairwise_diversity(self, ts): assert pi1 >= 0.0 assert not math.isnan(pi1) - @pytest.mark.slow + @pytest.mark.slow() @pytest.mark.parametrize("ts", get_example_tree_sequences()) def test_pairwise_diversity(self, ts): self.verify_pairwise_diversity(ts) @@ -1705,7 +1706,7 @@ def test_first_last(self, ts): for kwargs in [{}, {"tracked_samples": ts.samples()}]: t1 = ts.first(**kwargs) t2 = next(ts.trees()) - assert not (t1 is t2) + assert t1 is not t2 assert t1.parent_dict == t2.parent_dict assert t1.index == 0 if "tracked_samples" in kwargs: @@ -1715,7 +1716,7 @@ def test_first_last(self, ts): t1 = ts.last(**kwargs) t2 = next(reversed(ts.trees())) - assert not (t1 is t2) + assert t1 is not t2 assert t1.parent_dict == t2.parent_dict assert t1.index == ts.num_trees - 1 if "tracked_samples" in kwargs: @@ -2067,7 +2068,7 @@ def test_reversed_trees(self, ts): def test_at_index(self, ts): for kwargs in [{}, {"tracked_samples": ts.samples()}]: tree_list = ts.aslist(**kwargs) - for index in list(range(ts.num_trees)) + [-1]: + for index in [*list(range(ts.num_trees)), -1]: t1 = tree_list[index] t2 = ts.at_index(index, **kwargs) assert t1 == t2 @@ -2141,9 +2142,7 @@ def test_load_tables(self, ts): tables.drop_index() # Tables not in tc not rebuilt as per default, so error - with pytest.raises( - _tskit.LibraryError, match="Table collection must be indexed" - ): + with pytest.raises(_tskit.LibraryError, match="Table collection must be indexed"): assert tskit.TreeSequence.load_tables(tables).dump_tables().has_index() # Tables not in tc, but rebuilt @@ -2282,7 +2281,7 @@ def modify(ts, func): assert t1.equals(t2, ignore_ts_metadata=True, ignore_provenance=True) t1 = modify(t1, lambda tc: tc.provenances.clear()) - t2 = modify(t2, lambda tc: setattr(tc, "metadata", t1.metadata)) # noqa: B010 + t2 = modify(t2, lambda tc: setattr(tc, "metadata", t1.metadata)) assert t1.equals(t2) assert t2.equals(t1) @@ -2394,9 +2393,7 @@ def test_individuals_time_errors(self): for j in range(2): t.nodes.add_row(time=j, individual=0) ts = t.tree_sequence() - with pytest.raises( - _tskit.LibraryError, match="TSK_ERR_INDIVIDUAL_TIME_MISMATCH" - ): + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_INDIVIDUAL_TIME_MISMATCH"): _ = ts.individuals_time @pytest.mark.parametrize("n", [1, 10]) @@ -2719,13 +2716,9 @@ def verify_simplify_topology(self, ts, sample): def verify_simplify_equality(self, ts, sample): for filter_sites in [False, True]: - s1, node_map1 = ts.simplify( - sample, map_nodes=True, filter_sites=filter_sites - ) + s1, node_map1 = ts.simplify(sample, map_nodes=True, filter_sites=filter_sites) t1 = s1.dump_tables() - s2, node_map2 = simplify_tree_sequence( - ts, sample, filter_sites=filter_sites - ) + s2, node_map2 = simplify_tree_sequence(ts, sample, filter_sites=filter_sites) t2 = s2.dump_tables() assert s1.num_samples == len(sample) assert s2.num_samples == len(sample) @@ -2777,7 +2770,7 @@ def test_simplify_provenance(self, ts): # TODO this test needs to be broken up into discrete bits, so that we can # test them independently. A way of getting a random-ish subset of samples # from the pytest param would be useful. - @pytest.mark.slow + @pytest.mark.slow() @pytest.mark.parametrize("ts", get_example_tree_sequences()) def test_simplify(self, ts): # Can't simplify edges with metadata @@ -3047,7 +3040,7 @@ def test_trees_params(self): class TestTreeSequenceMetadata: - metadata_tables = [ + metadata_tables = ( "node", "edge", "site", @@ -3055,7 +3048,7 @@ class TestTreeSequenceMetadata: "migration", "individual", "population", - ] + ) metadata_schema = tskit.MetadataSchema( { "codec": "json", @@ -3126,27 +3119,23 @@ def test_table_metadata_schemas(self): assert repr(getattr(tables, f"{table}s").metadata_schema) == repr(schema) for other_table in self.metadata_tables: if other_table != table: - assert ( - repr(getattr(tables, f"{other_table}s").metadata_schema) == "" - ) + assert repr(getattr(tables, f"{other_table}s").metadata_schema) == "" # Check via tree-sequence API new_ts = tskit.TreeSequence.load_tables(tables) assert repr(getattr(new_ts.table_metadata_schemas, table)) == repr(schema) for other_table in self.metadata_tables: if other_table != table: - assert ( - repr(getattr(new_ts.table_metadata_schemas, other_table)) == "" - ) + assert repr(getattr(new_ts.table_metadata_schemas, other_table)) == "" # Can't set schema via this API with pytest.raises(AttributeError): new_ts.table_metadata_schemas = {} - # or modify the schema tuple return object - with pytest.raises(dataclasses.exceptions.FrozenInstanceError): - setattr( - new_ts.table_metadata_schemas, - table, - tskit.MetadataSchema({"codec": "json"}), - ) + # or modify the schema tuple return object + with pytest.raises(dataclasses.FrozenInstanceError): + setattr( + new_ts.table_metadata_schemas, + table, + tskit.MetadataSchema({"codec": "json"}), + ) def test_table_metadata_round_trip_via_row_getters(self): # A tree sequence with all entities @@ -3378,9 +3367,7 @@ def convert(v): else: assert repr(mutation.metadata) == splits[5] - def verify_individuals_format( - self, ts, individuals_file, precision, base64_metadata - ): + def verify_individuals_format(self, ts, individuals_file, precision, base64_metadata): """ Verifies that the individuals we output have the correct form. """ @@ -3408,9 +3395,7 @@ def convert(v): else: assert repr(individual.metadata) == splits[4] - def verify_populations_format( - self, ts, populations_file, precision, base64_metadata - ): + def verify_populations_format(self, ts, populations_file, precision, base64_metadata): """ Verifies that the populations we output have the correct form. """ @@ -3730,9 +3715,7 @@ def test_str(self, ts_fixture): ╟───────────────────┼─────────╢ ║Total Branch Length│[0-9\. ]*║ ╚═══════════════════╧═════════╝ - """[ - 1: - ] + """[1:] ), str(t), ) @@ -4037,7 +4020,7 @@ def test_deprecated_api_warnings(self): # Deprecated and will be removed t1 = self.get_tree() with pytest.warns(FutureWarning, match="Tree.tree_sequence.num_nodes"): - t1.num_nodes + t1.num_nodes # noqa: B018 def test_seek_index(self): ts = msprime.simulate(10, recombination_rate=3, length=5, random_seed=42) @@ -4433,7 +4416,7 @@ def test_multiroot_tree(self): assert 7 == t.virtual_root assert t.siblings(7) == tuple() - @pytest.mark.parametrize("flag,expected", [(0, ()), (1, (2,))]) + @pytest.mark.parametrize(("flag", "expected"), [(0, ()), (1, (2,))]) def test_isolated_node(self, flag, expected): tables = tskit.Tree.generate_balanced(2, arity=2).tree_sequence.dump_tables() tables.nodes.add_row(flags=flag) # Add node 3 @@ -5160,8 +5143,8 @@ def test_position_continuous_coordinates(self, position): @pytest.mark.parametrize("position", [0, 2.999999999, 5.000000001, 9]) def test_position_not_found(self, position): + ts = self.get_example_ts_discrete_coordinates() with pytest.raises(ValueError, match=r"There is no site at position"): - ts = self.get_example_ts_discrete_coordinates() ts.site(position=position) @pytest.mark.parametrize( @@ -5177,36 +5160,36 @@ def test_position_good_type(self, position): ts.site(position=position) def test_position_not_scalar(self): + ts = self.get_example_ts_discrete_coordinates() with pytest.raises( ValueError, match="Position must be provided as a scalar value." ): - ts = self.get_example_ts_discrete_coordinates() ts.site(position=[1, 4, 8]) @pytest.mark.parametrize("position", [-1, 10, 11]) def test_position_out_of_bounds(self, position): + ts = self.get_example_ts_discrete_coordinates() with pytest.raises( ValueError, match="Position is beyond the coordinates defined by sequence length.", ): - ts = self.get_example_ts_discrete_coordinates() ts.site(position=position) def test_query_position_siteless_ts(self): + ts = self.get_example_ts_without_sites() with pytest.raises(ValueError, match=r"There is no site at position"): - ts = self.get_example_ts_without_sites() ts.site(position=1) def test_site_id_and_position_are_none(self): + ts = self.get_example_ts_discrete_coordinates() with pytest.raises(TypeError, match="Site id or position must be provided."): - ts = self.get_example_ts_discrete_coordinates() ts.site(None, position=None) def test_site_id_and_position_are_specified(self): + ts = self.get_example_ts_discrete_coordinates() with pytest.raises( TypeError, match="Only one of site id or position needs to be provided." ): - ts = self.get_example_ts_discrete_coordinates() ts.site(0, position=3) @@ -5247,7 +5230,7 @@ def test_number_types(self, t): # 0.00┊ 0 1 2 3 4 5 6 7 8 ┊ # 0 1 @pytest.mark.parametrize( - ["t", "expected"], + ("t", "expected"), [ (-0.00001, 0), (0, 9), @@ -5272,7 +5255,7 @@ def test_balanced_ternary(self, t, expected): # 0.00┊ 0 1 2 3 4 5 6 7 8 ┊ # 0 1 @pytest.mark.parametrize( - ["t", "expected"], + ("t", "expected"), [ (-0.00001, 0), (0, 9), @@ -5308,14 +5291,13 @@ def test_multiroot_different_times(self, t, expected): # 0.00┊ 3 4 ┊ # 0 1 @pytest.mark.parametrize( - ["t", "expected"], + ("t", "expected"), [ (-0.00001, 0), (0, 2), (1, 2), (2, 2), (3, 2), - (3, 2), (4, 0), ], ) @@ -5331,7 +5313,7 @@ def test_comb_different_leaf_times(self, t, expected): assert tree.num_lineages(t) == expected @pytest.mark.parametrize( - ["t", "expected"], + ("t", "expected"), [ (-0.00001, 0), (0, 0), diff --git a/python/tests/test_ibd.py b/python/tests/test_ibd.py index b0a653d09e..749591c39c 100644 --- a/python/tests/test_ibd.py +++ b/python/tests/test_ibd.py @@ -750,7 +750,7 @@ def test_length(self): (1, 3): [tskit.IdentitySegment(0.3, 1, 4)], (2, 3): [tskit.IdentitySegment(0.3, 1, 5)], } - (ibd_segs, true_segs) + assert_ibd_equal(ibd_segs, true_segs) def test_input_within(self): ibd_segs = ibd_segments(self.ts(), within=[0, 1, 2]) @@ -1122,12 +1122,12 @@ def test_list_semantics(self): def test_str(self): result = self.example_ts.ibd_segments(store_segments=True) - seglist = list(result.values())[0] + seglist = next(iter(result.values())) assert str(seglist).startswith("IdentitySegmentList") def test_repr(self): result = self.example_ts.ibd_segments(store_segments=True) - seglist = list(result.values())[0] + seglist = next(iter(result.values())) assert repr(seglist).startswith("IdentitySegmentList([IdentitySegment") def test_eq_semantics(self): diff --git a/python/tests/test_intervals.py b/python/tests/test_intervals.py index e23f221e62..8d99144f8c 100644 --- a/python/tests/test_intervals.py +++ b/python/tests/test_intervals.py @@ -24,6 +24,7 @@ """ Test cases for the intervals module. """ + import decimal import fractions import gzip @@ -113,14 +114,14 @@ def test_read_only(self): class TestGetRateAllKnown: - examples = [ + examples = ( tskit.RateMap(position=[0, 1], rate=[0]), tskit.RateMap(position=[0, 1], rate=[0.1]), tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]), tskit.RateMap(position=[0, 1, 2], rate=[0, 0.2]), tskit.RateMap(position=[0, 1, 2], rate=[0.1, 1e-6]), tskit.RateMap(position=range(100), rate=range(99)), - ] + ) @pytest.mark.parametrize("rate_map", examples) def test_get_rate_mid(self, rate_map): @@ -145,7 +146,7 @@ def test_get_rate_right(self, rate_map): class TestOperations: - examples = [ + examples = ( tskit.RateMap.uniform(sequence_length=1, rate=0), tskit.RateMap.uniform(sequence_length=1, rate=0.1), tskit.RateMap(position=[0, 1, 2], rate=[0.1, 0.2]), @@ -156,7 +157,7 @@ class TestOperations: tskit.RateMap(position=[0, 1, 2], rate=[np.nan, 0]), tskit.RateMap(position=[0, 1, 2], rate=[0, np.nan]), tskit.RateMap(position=[0, 1, 2, 3], rate=[0, np.nan, 1]), - ] + ) @pytest.mark.parametrize("rate_map", examples) def test_num_intervals(self, rate_map): @@ -639,8 +640,8 @@ def test_slice_with_floats(self): ) b = a.slice(left=50 * np.pi) assert a.sequence_length == b.sequence_length - assert_array_equal([0, 50 * np.pi] + list(a.position[1:]), b.position) - assert_array_equal([np.nan] + list(a.rate), b.rate) + assert_array_equal([0, 50 * np.pi, *list(a.position[1:])], b.position) + assert_array_equal([np.nan, *list(a.rate)], b.rate) def test_slice_trim_left(self): a = tskit.RateMap(position=[0, 100, 200, 300, 400], rate=[1, 2, 3, 4]) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index ee8db8c009..6775005d1e 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -22,20 +22,13 @@ """ Test cases for two-locus statistics """ + import contextlib import io from copy import deepcopy from dataclasses import dataclass -from itertools import combinations -from itertools import combinations_with_replacement -from itertools import permutations -from itertools import product -from typing import Any -from typing import Callable -from typing import Dict -from typing import Generator -from typing import List -from typing import Tuple +from itertools import combinations, combinations_with_replacement, permutations, product +from typing import Any, Callable, Dict, Generator, List, Tuple import msprime import numpy as np @@ -577,9 +570,7 @@ def two_site_count_stat( :returns: 3D array of results, dimensions (sample_sets, row_sites, col_sites). """ params = {"sample_set_sizes": sample_set_sizes} - result = np.zeros( - (num_sample_sets, len(row_sites), len(col_sites)), dtype=np.float64 - ) + result = np.zeros((num_sample_sets, len(row_sites), len(col_sites)), dtype=np.float64) sites, row_idx, col_idx = get_site_row_col_indices(row_sites, col_sites) num_alleles, site_offsets, allele_samples = get_mutation_samples( @@ -643,9 +634,7 @@ def two_branch_count_stat( :returns: 3D array of results, dimensions (sample_sets, row_sites, col_sites). """ params = {"sample_set_sizes": sample_set_sizes} - result = np.zeros( - (num_sample_sets, len(row_sites), len(col_sites)), dtype=np.float64 - ) + result = np.zeros((num_sample_sets, len(row_sites), len(col_sites)), dtype=np.float64) # TODO: get_pos_row_col_indices? # sites, row_idx, col_idx = get_site_row_col_indices(row_sites, col_sites) @@ -1153,9 +1142,9 @@ def get_paper_ex_ts(): # fmt:off # true r2 values for the tree sequence from the tskit paper PAPER_EX_TRUTH_MATRIX = np.array( - [[1.0, 0.11111111, 0.11111111], # noqa: E241 - [0.11111111, 1.0, 1.0], # noqa: E241 - [0.11111111, 1.0, 1.0]] # noqa: E241 + [[1.0, 0.11111111, 0.11111111], + [0.11111111, 1.0, 1.0], + [0.11111111, 1.0, 1.0]] ) # fmt:on @@ -1201,9 +1190,7 @@ def test_subset_sites(partition): ld_matrix(ts, sites=partition), PAPER_EX_TRUTH_MATRIX[a[0] : a[-1] + 1, b[0] : b[-1] + 1], ) - np.testing.assert_equal( - ld_matrix(ts, sites=partition), ts.ld_matrix(sites=partition) - ) + np.testing.assert_equal(ld_matrix(ts, sites=partition), ts.ld_matrix(sites=partition)) @pytest.mark.parametrize("sites", [[0, 1, 2], [1, 2], [0, 1], [0], [1]]) @@ -1562,7 +1549,7 @@ def tmrca(tr, x, y): if y in set(tr.samples(r)): y_mrca = r if x_mrca == -1 or y_mrca == -1: - raise ValueError + raise ValueError from e return (tr.time(x_mrca) + tr.time(y_mrca)) / 2 @@ -1656,7 +1643,7 @@ def combine(samples): (i, j, samples[k], samples[l]) for i, j in combinations(samples, 2) for k in range(len(samples)) - for l in range(k + 1, len(samples)) # noqa: E741 + for l in range(k + 1, len(samples)) if i != samples[k] and j != samples[k] and samples[l] != i and samples[l] != j ] return ij, ijk, ijkl @@ -1705,7 +1692,7 @@ def naive_matrix(ts, stat_func, sample_set=None): ], ) @pytest.mark.parametrize( - "stat,stat_func", + ("stat", "stat_func"), zip( ["d2_unbiased", "dz_unbiased", "pi2_unbiased"], [compute_D2, compute_Dz, compute_pi2], @@ -1742,9 +1729,9 @@ def get_test_branch_sample_set_test_cases(): ] -@pytest.mark.parametrize("ts,sample_set", get_test_branch_sample_set_test_cases()) +@pytest.mark.parametrize(("ts", "sample_set"), get_test_branch_sample_set_test_cases()) @pytest.mark.parametrize( - "stat,stat_func", + ("stat", "stat_func"), zip( ["d2_unbiased", "dz_unbiased", "pi2_unbiased"], [compute_D2, compute_Dz, compute_pi2], diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 54b19b1d6f..09b0e35c0d 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -23,6 +23,7 @@ """ Test cases for the low level C interface to tskit. """ + import collections import gc import inspect @@ -38,7 +39,6 @@ import _tskit import tskit - NON_UTF8_STRING = "\ud861\udd37" @@ -151,7 +151,7 @@ def verify_iterator(self, iterator): class MetadataTestMixin: - metadata_tables = [ + metadata_tables = ( "node", "edge", "site", @@ -159,7 +159,7 @@ class MetadataTestMixin: "migration", "individual", "population", - ] + ) class TestTableCollection(LowLevelTestCase): @@ -509,9 +509,9 @@ def test_uninitialised(self): with pytest.raises(SystemError): result.print_state() with pytest.raises(SystemError): - result.num_segments + result.num_segments # noqa: B018 with pytest.raises(SystemError): - result.total_span + result.total_span # noqa: B018 with pytest.raises(SystemError): result.get_keys() @@ -530,7 +530,7 @@ def test_store_pairs(self): with pytest.raises(_tskit.IdentityPairsNotStoredError): result.get_keys() with pytest.raises(_tskit.IdentityPairsNotStoredError): - result.num_pairs + result.num_pairs # noqa: B018 with pytest.raises(_tskit.IdentityPairsNotStoredError): result.get(0, 1) @@ -543,11 +543,11 @@ def test_store_pairs(self): assert seglist.num_segments == 1 assert seglist.total_span == 1 with pytest.raises(_tskit.IdentitySegmentsNotStoredError): - seglist.node + seglist.node # noqa: B018 with pytest.raises(_tskit.IdentitySegmentsNotStoredError): - seglist.left + seglist.left # noqa: B018 with pytest.raises(_tskit.IdentitySegmentsNotStoredError): - seglist.right + seglist.right # noqa: B018 def test_within_all_pairs(self): ts = msprime.simulate(10, random_seed=1) @@ -737,7 +737,7 @@ def test_table_extend(self, table_name, ts_fixture): @pytest.mark.parametrize("table_name", tskit.TABLE_NAMES) @pytest.mark.parametrize( - ["row_indexes", "expected_rows"], + ("row_indexes", "expected_rows"), [ ([0], [0]), ([4] * 1000, [4] * 1000), @@ -750,9 +750,7 @@ def test_table_extend(self, table_name, ts_fixture): (range(2, -1, -1), [2, 1, 0]), ], ) - def test_table_extend_types( - self, ts_fixture, table_name, row_indexes, expected_rows - ): + def test_table_extend_types(self, ts_fixture, table_name, row_indexes, expected_rows): table = getattr(ts_fixture.tables, table_name) assert len(table) >= 5 ll_table = table.ll_table @@ -806,13 +804,11 @@ def test_mutation_table_keep_rows_ref_error(self): def test_individual_table_keep_rows_ref_error(self): table = _tskit.IndividualTable() table.add_row(parents=[2]) - with pytest.raises( - _tskit.LibraryError, match="TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS" - ): + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_INDIVIDUAL_OUT_OF_BOUNDS"): table.keep_rows([True]) @pytest.mark.parametrize( - ["table_name", "column_name"], + ("table_name", "column_name"), [ (t, c) for t in tskit.TABLE_NAMES @@ -1098,9 +1094,7 @@ def test_index(self): modify_indexes = tc.indexes modify_indexes["edge_insertion_order"] = np.arange(42, 42 + 18, dtype=np.int32) - modify_indexes["edge_removal_order"] = np.arange( - 4242, 4242 + 18, dtype=np.int32 - ) + modify_indexes["edge_removal_order"] = np.arange(4242, 4242 + 18, dtype=np.int32) tc.indexes = modify_indexes assert np.array_equal( tc.indexes["edge_insertion_order"], np.arange(42, 42 + 18, dtype=np.int32) @@ -1148,9 +1142,7 @@ def test_bad_indexes(self): ): tc.indexes = d - tc = msprime.simulate( - 10, recombination_rate=10, random_seed=42 - ).tables._ll_tables + tc = msprime.simulate(10, recombination_rate=10, random_seed=42).tables._ll_tables modify_indexes = tc.indexes shape = modify_indexes["edge_insertion_order"].shape modify_indexes["edge_insertion_order"] = np.zeros(shape, dtype=np.int32) @@ -1176,7 +1168,7 @@ class TestTreeSequence(LowLevelTestCase, MetadataTestMixin): Tests for the low-level interface for the TreeSequence. """ - ARRAY_NAMES = [ + ARRAY_NAMES = ( "individuals_flags", "nodes_time", "nodes_flags", @@ -1199,7 +1191,7 @@ class TestTreeSequence(LowLevelTestCase, MetadataTestMixin): "migrations_time", "indexes_edge_insertion_order", "indexes_edge_removal_order", - ] + ) def setUp(self): fd, self.temp_file = tempfile.mkstemp(prefix="msp_ll_ts_") @@ -1507,9 +1499,7 @@ def test_extend_edges_bad_args(self): with pytest.raises(_tskit.LibraryError, match="positive"): ts1.extend_edges(-1) tsm = self.get_example_migration_tree_sequence() - with pytest.raises( - _tskit.LibraryError, match="TSK_ERR_MIGRATIONS_NOT_SUPPORTED" - ): + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_MIGRATIONS_NOT_SUPPORTED"): tsm.extend_edges(1) @pytest.mark.parametrize( @@ -1546,75 +1536,65 @@ def test_ld_matrix(self, stat_method_name): assert a.shape == (10, 10, 1) # CPython API errors + bad_sample_sets = np.array([], dtype=np.int32) with pytest.raises(ValueError, match="Sum of sample_set_sizes"): - bad_sample_sets = np.array([], dtype=np.int32) stat_method(sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode) + bad_sample_sets = np.array(ts.get_samples(), dtype=np.uint32) with pytest.raises(TypeError, match="cast array data"): - bad_sample_sets = np.array(ts.get_samples(), dtype=np.uint32) stat_method(sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode) with pytest.raises(ValueError, match="Unrecognised stats mode"): stat_method(sample_set_sizes, sample_sets, row_sites, col_sites, "bla") with pytest.raises(TypeError, match="at most"): - stat_method( - sample_set_sizes, sample_sets, row_sites, col_sites, mode, "abc" - ) + stat_method(sample_set_sizes, sample_sets, row_sites, col_sites, mode, "abc") + bad_sites = ["abadsite", 0, 3, 2] with pytest.raises(ValueError, match="invalid literal"): - bad_sites = ["abadsite", 0, 3, 2] stat_method(sample_set_sizes, sample_sets, bad_sites, col_sites, mode) + bad_sites = [None, 0, 3, 2] with pytest.raises(TypeError): - bad_sites = [None, 0, 3, 2] stat_method(sample_set_sizes, sample_sets, bad_sites, col_sites, mode) + bad_sites = [{}, 0, 3, 2] with pytest.raises(TypeError): - bad_sites = [{}, 0, 3, 2] stat_method(sample_set_sizes, sample_sets, bad_sites, col_sites, mode) + bad_sites = np.array([0, 1, 2], dtype=np.uint32) with pytest.raises(TypeError, match="Cannot cast array data"): - bad_sites = np.array([0, 1, 2], dtype=np.uint32) stat_method(sample_set_sizes, sample_sets, bad_sites, col_sites, mode) + bad_sites = ["abadsite", 0, 3, 2] with pytest.raises(ValueError, match="invalid literal"): - bad_sites = ["abadsite", 0, 3, 2] stat_method(sample_set_sizes, sample_sets, row_sites, bad_sites, mode) + bad_sites = [None, 0, 3, 2] + bad_sites = [None, 0, 3, 2] with pytest.raises(TypeError): - bad_sites = [None, 0, 3, 2] stat_method(sample_set_sizes, sample_sets, row_sites, bad_sites, mode) + bad_sites = [{}, 0, 3, 2] with pytest.raises(TypeError): - bad_sites = [{}, 0, 3, 2] stat_method(sample_set_sizes, sample_sets, row_sites, bad_sites, mode) + bad_sites = np.array([0, 1, 2], dtype=np.uint32) with pytest.raises(TypeError, match="Cannot cast array data"): - bad_sites = np.array([0, 1, 2], dtype=np.uint32) stat_method(sample_set_sizes, sample_sets, row_sites, bad_sites, mode) + # C API errors + bad_sites = np.array([1, 0, 2], dtype=np.int32) with pytest.raises(tskit.LibraryError, match="TSK_ERR_UNSORTED_SITES"): - bad_sites = np.array([1, 0, 2], dtype=np.int32) stat_method(sample_set_sizes, sample_sets, bad_sites, col_sites, mode) + bad_sites = np.array([1, 0, 2], dtype=np.int32) with pytest.raises(tskit.LibraryError, match="TSK_ERR_UNSORTED_SITES"): - bad_sites = np.array([1, 0, 2], dtype=np.int32) stat_method(sample_set_sizes, sample_sets, row_sites, bad_sites, mode) - with pytest.raises( - _tskit.LibraryError, match="TSK_ERR_INSUFFICIENT_SAMPLE_SETS" - ): - bad_sample_sets = np.array([], dtype=np.int32) - bad_sample_set_sizes = np.array([], dtype=np.uint32) - stat_method( - bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode - ) + bad_sample_sets = np.array([], dtype=np.int32) + bad_sample_set_sizes = np.array([], dtype=np.uint32) + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_INSUFFICIENT_SAMPLE_SETS"): + stat_method(bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode) + bad_sample_sets = np.array([], dtype=np.int32) + bad_sample_set_sizes = np.array([0], dtype=np.uint32) with pytest.raises(_tskit.LibraryError, match="TSK_ERR_EMPTY_SAMPLE_SET"): - bad_sample_sets = np.array([], dtype=np.int32) - bad_sample_set_sizes = np.array([0], dtype=np.uint32) - stat_method( - bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode - ) + stat_method(bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode) + bad_sample_sets = np.array([1000], dtype=np.int32) + bad_sample_set_sizes = np.array([1], dtype=np.uint32) with pytest.raises(_tskit.LibraryError, match="TSK_ERR_NODE_OUT_OF_BOUNDS"): - bad_sample_sets = np.array([1000], dtype=np.int32) - bad_sample_set_sizes = np.array([1], dtype=np.uint32) - stat_method( - bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode - ) + stat_method(bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode) + bad_sample_sets = np.array([2, 2], dtype=np.int32) + bad_sample_set_sizes = np.array([2], dtype=np.uint32) with pytest.raises(_tskit.LibraryError, match="TSK_ERR_DUPLICATE_SAMPLE"): - bad_sample_sets = np.array([2, 2], dtype=np.int32) - bad_sample_set_sizes = np.array([2], dtype=np.uint32) - stat_method( - bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode - ) + stat_method(bad_sample_set_sizes, bad_sample_sets, row_sites, col_sites, mode) with pytest.raises(_tskit.LibraryError, match="TSK_ERR_UNSUPPORTED_STAT_MODE"): stat_method(sample_set_sizes, sample_sets, row_sites, col_sites, "branch") @@ -1720,9 +1700,7 @@ def test_load_tables_build_indexes(self): assert tables4.has_index() def test_clear_table(self, ts_fixture): - tables = _tskit.TableCollection( - sequence_length=ts_fixture.get_sequence_length() - ) + tables = _tskit.TableCollection(sequence_length=ts_fixture.get_sequence_length()) ts_fixture.ll_tree_sequence.dump_tables(tables) tables.clear() data_tables = [t for t in tskit.TABLE_NAMES if t != "provenances"] @@ -2182,11 +2160,11 @@ def test_output_dims(self): jafs = ts.allele_frequency_spectrum( s, samples, windows, mode=mode, polarised=True ) - assert jafs.shape == tuple([len(windows) - 1] + list(s + 1)) + assert jafs.shape == tuple([len(windows) - 1, *list(s + 1)]) jafs = ts.allele_frequency_spectrum( s, samples, windows, mode=mode, polarised=False ) - assert jafs.shape == tuple([len(windows) - 1] + list(s + 1)) + assert jafs.shape == tuple([len(windows) - 1, *list(s + 1)]) def test_node_mode_not_supported(self): ts = self.get_example_tree_sequence() @@ -2245,9 +2223,7 @@ def test_output_dims(self): assert div.shape == (1, N, 1) div = method([2, 2, n - 4], samples, [[0, 1], [1, 2]], windows, mode=mode) assert div.shape == (1, N, 2) - div = method( - [2, 2, n - 4], samples, [[0, 1], [1, 2], [0, 1]], windows, mode=mode - ) + div = method([2, 2, n - 4], samples, [[0, 1], [1, 2], [0, 1]], windows, mode=mode) assert div.shape == (1, N, 3) def test_set_index_errors(self): @@ -2653,7 +2629,10 @@ def f(x): for bad_array in [[1, 1], range(10)]: with pytest.raises(ValueError): ts.general_stat( - W, lambda x: bad_array, 1, ts.get_breakpoints() # noqa:B023 + W, + lambda x, bad_array=bad_array: bad_array, + 1, + ts.get_breakpoints(), ) with pytest.raises(ValueError): ts.general_stat(W, lambda x: [1], 2, ts.get_breakpoints()) @@ -2662,7 +2641,10 @@ def f(x): for bad_array in [["sdf"], 0, "w4", None]: with pytest.raises(ValueError): ts.general_stat( - W, lambda x: bad_array, 1, ts.get_breakpoints() # noqa:B023 + W, + lambda x, bad_array=bad_array: bad_array, + 1, + ts.get_breakpoints(), ) @@ -2829,9 +2811,7 @@ def test_copy(self, isolated_as_missing, samples, alleles): alleles = variant.alleles site_id = variant.site_id variant.decode(1) - with pytest.raises( - tskit.LibraryError, match="Can't decode a copy of a variant" - ): + with pytest.raises(tskit.LibraryError, match="Can't decode a copy of a variant"): variant2.decode(1) assert site_id == variant2.site_id assert alleles == variant2.alleles @@ -3119,7 +3099,7 @@ class TestTree(LowLevelTestCase): Tests on the low-level tree interface. """ - ARRAY_NAMES = [ + ARRAY_NAMES = ( "parent", "left_child", "right_child", @@ -3127,7 +3107,7 @@ class TestTree(LowLevelTestCase): "right_sib", "num_children", "edge", - ] + ) def test_options(self): ts = self.get_example_tree_sequence() diff --git a/python/tests/test_metadata.py b/python/tests/test_metadata.py index d57e6ea9d3..df7bc38b72 100644 --- a/python/tests/test_metadata.py +++ b/python/tests/test_metadata.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (c) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -23,6 +23,7 @@ """ Tests for metadata handling. """ + import collections import io import json @@ -88,8 +89,7 @@ def test_pickle(self): tables = ts.dump_tables() # For each node, we create some Python metadata that can be pickled metadata = [ - {"one": j, "two": 2 * j, "three": list(range(j))} - for j in range(ts.num_nodes) + {"one": j, "two": 2 * j, "three": list(range(j))} for j in range(ts.num_nodes) ] encoded, offset = tskit.pack_bytes(list(map(pickle.dumps, metadata))) tables.nodes.set_columns( @@ -211,9 +211,7 @@ def test_nodes(self): 2 0 1 !@#$%^&*() """ ) - n = tskit.parse_nodes( - nodes, strict=False, encoding="utf8", base64_metadata=False - ) + n = tskit.parse_nodes(nodes, strict=False, encoding="utf8", base64_metadata=False) expected = ["abc", "XYZ+", "!@#$%^&*()"] for a, b in zip(expected, n): assert a.encode("utf8") == b.metadata @@ -227,9 +225,7 @@ def test_sites(self): 0.8 G !@#$%^&*() """ ) - s = tskit.parse_sites( - sites, strict=False, encoding="utf8", base64_metadata=False - ) + s = tskit.parse_sites(sites, strict=False, encoding="utf8", base64_metadata=False) expected = ["abc", "XYZ+", "!@#$%^&*()"] for a, b in zip(expected, s): assert a.encode("utf8") == b.metadata @@ -265,7 +261,8 @@ def test_populations(self): assert a.encode("utf8") == b.metadata @pytest.mark.parametrize( - "base64_metadata,expected", [(True, ["pop", "gen"]), (False, ["cG9w", "Z2Vu"])] + ("base64_metadata", "expected"), + [(True, ["pop", "gen"]), (False, ["cG9w", "Z2Vu"])], ) def test_migrations(self, base64_metadata, expected): migrations = io.StringIO( @@ -522,9 +519,7 @@ def test_simple_default(self): ms = tskit.MetadataSchema(schema) assert ms.decode_row(b"") == {"number": 5} assert ms.decode_row(ms.validate_and_encode_row({})) == {"number": 5} - assert ms.decode_row(ms.validate_and_encode_row({"number": 42})) == { - "number": 42 - } + assert ms.decode_row(ms.validate_and_encode_row({"number": 42})) == {"number": 42} def test_nested_default_error(self): schema = { diff --git a/python/tests/test_ms.py b/python/tests/test_ms.py index f98be6d876..620b9862df 100644 --- a/python/tests/test_ms.py +++ b/python/tests/test_ms.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2022 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,7 @@ is not used but an iterator over tree_sequences if the num_replicates argument is used. """ + import collections import itertools import os diff --git a/python/tests/test_parsimony.py b/python/tests/test_parsimony.py index 80cec45b96..f1f8658eba 100644 --- a/python/tests/test_parsimony.py +++ b/python/tests/test_parsimony.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2022 Tskit Developers +# Copyright (c) 2019-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,6 +22,7 @@ """ Tests for the tree parsimony methods. """ + import dataclasses import io import itertools @@ -34,7 +35,6 @@ import tests.tsutil as tsutil import tskit - INF = np.inf @@ -373,9 +373,7 @@ def test_felsenstein_example_reconstruct(self): [[0, 2.5, 1, 2.5], [2.5, 0, 2.5, 1], [1, 2.5, 0, 2.5], [2.5, 1, 2.5, 0]] ) S = sankoff_score(tree, genotypes, cost_matrix) - ancestral_state, transitions = reconstruct_states( - tree, genotypes, S, cost_matrix - ) + ancestral_state, transitions = reconstruct_states(tree, genotypes, S, cost_matrix) assert {2: 1, 4: 2, 0: 1} == transitions assert 0 == ancestral_state @@ -384,9 +382,7 @@ def verify_infinite_sites(self, ts): assert ts.num_sites > 5 tree = ts.first() for variant in ts.variants(): - ancestral_state, transitions = sankoff_map_mutations( - tree, variant.genotypes - ) + ancestral_state, transitions = sankoff_map_mutations(tree, variant.genotypes) assert len(transitions) == 1 assert ancestral_state == 0 assert transitions[variant.site.mutations[0].node] == 1 @@ -576,6 +572,7 @@ class TestParsimonyRoundTrip(TestParsimonyBase): def verify(self, ts): G = ts.genotype_matrix(isolated_as_missing=False) alleles = [v.alleles for v in ts.variants()] + rng = np.random.default_rng(42) for randomize_ancestral_states in [False, True]: tables = ts.dump_tables() tables.sites.clear() @@ -587,7 +584,7 @@ def verify(self, ts): num_alleles = len(alleles[site.id]) if alleles[site.id][-1] is None: num_alleles -= 1 - fixed_anc_state = np.random.randint(num_alleles) + fixed_anc_state = rng.integers(num_alleles) ancestral_state, mutations = self.do_map_mutations( tree, G[site.id], @@ -838,9 +835,7 @@ def test_one_missing(self, n): for j in range(n): genotypes = np.zeros(n, dtype=np.int8) - 1 genotypes[j] = 0 - ancestral_state, transitions = self.do_map_mutations( - tree, genotypes, alleles - ) + ancestral_state, transitions = self.do_map_mutations(tree, genotypes, alleles) assert ancestral_state == "0" assert len(transitions) == 0 @@ -852,9 +847,7 @@ def test_one_missing_balanced(self, arity): for j in range(n): genotypes = np.zeros(n, dtype=np.int8) - 1 genotypes[j] = 0 - ancestral_state, transitions = self.do_map_mutations( - tree, genotypes, alleles - ) + ancestral_state, transitions = self.do_map_mutations(tree, genotypes, alleles) assert ancestral_state == "0" assert len(transitions) == 0 diff --git a/python/tests/test_phylo_formats.py b/python/tests/test_phylo_formats.py index 2125f506e9..9bf67d15a8 100644 --- a/python/tests/test_phylo_formats.py +++ b/python/tests/test_phylo_formats.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2021 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (c) 2016-2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -23,6 +23,7 @@ """ Tests for phylogenetics export functions, newick, nexus, FASTA etc. """ + import functools import io import textwrap @@ -47,9 +48,7 @@ @functools.lru_cache(maxsize=100) def alignment_example(sequence_length, include_reference=True): - ts = msprime.sim_ancestry( - samples=5, sequence_length=sequence_length, random_seed=123 - ) + ts = msprime.sim_ancestry(samples=5, sequence_length=sequence_length, random_seed=123) ts = msprime.sim_mutations(ts, rate=0.1, random_seed=1234) tables = ts.dump_tables() if include_reference: @@ -299,9 +298,7 @@ def test_nexus_no_trees(self): END; """ ) - assert expected == self.ts().as_nexus( - reference_sequence=ref, include_trees=False - ) + assert expected == self.ts().as_nexus(reference_sequence=ref, include_trees=False) def test_nexus_no_alignments(self): expected = textwrap.dedent( @@ -501,9 +498,7 @@ def test_as_newick_default(self): def test_c_and_py_output_equal(self): t = self.tree() - assert t.as_newick() == t.as_newick( - node_labels={u: f"n{u}" for u in t.samples()} - ) + assert t.as_newick() == t.as_newick(node_labels={u: f"n{u}" for u in t.samples()}) def test_as_newick_precision_3(self): s = "(n0:0.667,(n1:0.333,n2:0.333):0.333);" @@ -938,7 +933,7 @@ def test_nexus_defaults(self): TREE t0^2 = [&R] (n0:3.25000000000000000,(n1:2.00000000000000000,n2:2.00000000000000000):1.25000000000000000); TREE t2^10 = [&R] (n1:2.00000000000000000,(n0:1.00000000000000000,n2:1.00000000000000000):1.00000000000000000); END; - """ # noqa: B950 + """ # noqa: E501 ) assert ts.as_nexus() == expected @@ -1009,7 +1004,7 @@ def test_nexus_defaults(self): TREE t0.00000000000000000^2.50000000000000000 = [&R] (n0:3,(n1:2,n2:2):1); TREE t2.50000000000000000^10.00000000000000000 = [&R] (n1:2,(n0:1,n2:1):1); END; - """ # noqa: B950 + """ ) assert ts.as_nexus() == expected @@ -1389,8 +1384,7 @@ def assert_missing_data_encoded(self, d): ) else: assert ( - a[j].state_denomination - == dendropy.StateAlphabet.AMBIGUOUS_STATE + a[j].state_denomination == dendropy.StateAlphabet.AMBIGUOUS_STATE ) def test_fasta(self): diff --git a/python/tests/test_provenance.py b/python/tests/test_provenance.py index 0f7c662523..eeaf6bcd53 100644 --- a/python/tests/test_provenance.py +++ b/python/tests/test_provenance.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2020 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (C) 2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -23,6 +23,7 @@ """ Tests for the provenance information attached to tree sequences. """ + import json import os import platform diff --git a/python/tests/test_reference_sequence.py b/python/tests/test_reference_sequence.py index 99e849fd87..dca9a59741 100644 --- a/python/tests/test_reference_sequence.py +++ b/python/tests/test_reference_sequence.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2021-2022 Tskit Developers +# Copyright (c) 2021-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,6 +22,7 @@ """ Tests for reference sequence support. """ + import pytest import tskit @@ -81,9 +82,7 @@ def test_asdict_reference_no_metadata(self): def test_asdict_reference_metadata(self): tables = tskit.TableCollection(1) - tables.reference_sequence.metadata_schema = ( - tskit.MetadataSchema.permissive_json() - ) + tables.reference_sequence.metadata_schema = tskit.MetadataSchema.permissive_json() tables.reference_sequence.metadata = {"a": "ABCDEF"} d = tables.asdict()["reference_sequence"] assert d["data"] == "" @@ -113,9 +112,7 @@ def test_fromdict_reference_url(self): def test_fromdict_reference_metadata(self): tables = tskit.TableCollection(1) - tables.reference_sequence.metadata_schema = ( - tskit.MetadataSchema.permissive_json() - ) + tables.reference_sequence.metadata_schema = tskit.MetadataSchema.permissive_json() tables.reference_sequence.metadata = {"a": "ABCDEF"} tables = tskit.TableCollection.fromdict(tables.asdict()) assert tables.has_reference_sequence() @@ -133,9 +130,7 @@ def test_fromdict_no_reference(self): def test_fromdict_all_values_empty(self): d = tskit.TableCollection(1).asdict() - d["reference_sequence"] = dict( - data="", url="", metadata_schema="", metadata=b"" - ) + d["reference_sequence"] = dict(data="", url="", metadata_schema="", metadata=b"") tables = tskit.TableCollection.fromdict(d) assert not tables.has_reference_sequence() @@ -275,9 +270,7 @@ def test_write_metadata_schema_fails(self): tables.reference_sequence.data = "abc" ts = tables.tree_sequence() with pytest.raises(AttributeError, match="read-only"): - ts.reference_sequence.metadata_schema = ( - tskit.MetadataSchema.permissive_json() - ) + ts.reference_sequence.metadata_schema = tskit.MetadataSchema.permissive_json() def test_write_object_fails(self, ts_fixture): tables = tskit.TableCollection(1) diff --git a/python/tests/test_stats.py b/python/tests/test_stats.py index bfe2d270cf..0b6ca1e329 100644 --- a/python/tests/test_stats.py +++ b/python/tests/test_stats.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2021 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (C) 2016 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -23,6 +23,7 @@ """ Test cases for stats calculations in tskit. """ + import contextlib import io @@ -197,7 +198,7 @@ def ts(self): tables.mutations.add_row(site=2, node=0, derived_state="G") return tables.tree_sequence() - @pytest.mark.parametrize(["a", "b", "expected"], [(0, 0, 1), (0, 1, 1), (0, 2, 1)]) + @pytest.mark.parametrize(("a", "b", "expected"), [(0, 0, 1), (0, 1, 1), (0, 2, 1)]) def test_r2(self, a, b, expected): ts = self.ts() A = get_r2_matrix(ts) @@ -612,8 +613,8 @@ def set_partitions(collection): first = collection[0] for smaller in set_partitions(collection[1:]): for n, subset in enumerate(smaller): - yield smaller[:n] + [[first] + subset] + smaller[n + 1 :] - yield [[first]] + smaller + yield smaller[:n] + [[first, *subset]] + smaller[n + 1 :] + yield [[first], *smaller] def naive_mean_descendants(ts, reference_sets): @@ -1358,7 +1359,6 @@ def test_span_normalise(self): ts = self.get_two_tree_ts() sample_sets = [[0, 1], [2]] focal = [0] - np.random.seed(5) windows = ts.sequence_length * np.array([0.2, 0.4, 0.6, 0.8, 1]) windows.sort() windows[0] = 0.0 diff --git a/python/tests/test_table_transforms.py b/python/tests/test_table_transforms.py index 35285fc4dc..7478487976 100644 --- a/python/tests/test_table_transforms.py +++ b/python/tests/test_table_transforms.py @@ -22,6 +22,7 @@ """ Test cases for table transformation operations like trim(), decapitate, etc. """ + import decimal import fractions import io diff --git a/python/tests/test_tables.py b/python/tests/test_tables.py index 51852a2c79..4f77fa2691 100644 --- a/python/tests/test_tables.py +++ b/python/tests/test_tables.py @@ -24,6 +24,7 @@ Test cases for the low-level tables used to transfer information between simulations and the tree sequence. """ + import dataclasses import io import json @@ -107,14 +108,10 @@ def make_transposed_input_data(self, num_rows): if len(data) == num_rows else ( bytes( - data[ - cols[f"{col}_offset"][j] : cols[f"{col}_offset"][j + 1] - ] + data[cols[f"{col}_offset"][j] : cols[f"{col}_offset"][j + 1]] ) if "metadata" in col - else data[ - cols[f"{col}_offset"][j] : cols[f"{col}_offset"][j + 1] - ] + else data[cols[f"{col}_offset"][j] : cols[f"{col}_offset"][j + 1]] ) ) for col, data in cols.items() @@ -123,7 +120,7 @@ def make_transposed_input_data(self, num_rows): for j in range(num_rows) ] - @pytest.fixture + @pytest.fixture() def test_rows(self, scope="session"): test_rows = self.make_transposed_input_data(10) # Annoyingly we have to tweak some types as once added to a row and then put in @@ -134,14 +131,14 @@ def test_rows(self, scope="session"): test_rows[n][col] = bytes(test_rows[n][col]).decode("ascii") return test_rows - @pytest.fixture + @pytest.fixture() def table(self, test_rows): table = self.table_class() for row in test_rows: table.add_row(**row) return table - @pytest.fixture + @pytest.fixture() def table_5row(self, test_rows): table_5row = self.table_class() for row in test_rows[:5]: @@ -571,13 +568,9 @@ def test_truncate(self): for list_col, offset_col in self.ragged_list_columns: offset = getattr(table, offset_col.name) assert offset.shape == (num_rows + 1,) - assert np.array_equal( - input_data[offset_col.name][: num_rows + 1], offset - ) + assert np.array_equal(input_data[offset_col.name][: num_rows + 1], offset) list_data = getattr(table, list_col.name) - assert np.array_equal( - list_data, input_data[list_col.name][: offset[-1]] - ) + assert np.array_equal(list_data, input_data[list_col.name][: offset[-1]]) used.add(offset_col.name) used.add(list_col.name) for name, data in input_data.items(): @@ -914,9 +907,7 @@ def test_random_metadata(self): input_data["metadata"] = metadata input_data["metadata_offset"] = metadata_offset table.set_columns(**input_data) - unpacked_metadatas = tskit.unpack_bytes( - table.metadata, table.metadata_offset - ) + unpacked_metadatas = tskit.unpack_bytes(table.metadata, table.metadata_offset) assert metadatas == unpacked_metadatas def test_drop_metadata(self): @@ -1101,8 +1092,7 @@ def test_round_trip_set_columns(self): del input_data["metadata_offset"] metadata_column = [self.metadata_example_data() for _ in range(num_rows)] encoded_metadata_column = [ - table.metadata_schema.validate_and_encode_row(r) - for r in metadata_column + table.metadata_schema.validate_and_encode_row(r) for r in metadata_column ] packed_metadata, metadata_offset = tskit.util.pack_bytes( encoded_metadata_column @@ -1235,9 +1225,7 @@ def verify_metadata_vector(self, table, key, dtype, default_value=9999): # does this more elegantly has_default = default_value != 9999 if has_default: - md_vec = table.metadata_vector( - key, default_value=default_value, dtype=dtype - ) + md_vec = table.metadata_vector(key, default_value=default_value, dtype=dtype) else: md_vec = table.metadata_vector(key, dtype=dtype) assert isinstance(md_vec, np.ndarray) @@ -1254,7 +1242,7 @@ def verify_metadata_vector(self, table, key, dtype, default_value=9999): else: md = default_value break - assert np.all(np.cast[dtype](md) == x) + assert np.all(np.asarray(md, dtype=dtype) == x) def test_metadata_vector_errors(self): table = self.table_class() @@ -1319,9 +1307,7 @@ def test_metadata_vector_nodefault(self): assert np.all(np.equal(md_vec, [d["abc"] for d in metadata_list])) # now automated ones for dtype in [None, "int", "float", "object"]: - self.verify_metadata_vector( - table, key="abc", dtype=dtype, default_value=9999 - ) + self.verify_metadata_vector(table, key="abc", dtype=dtype, default_value=9999) self.verify_metadata_vector( table, key=["abc"], dtype=dtype, default_value=9999 ) @@ -1361,9 +1347,7 @@ def test_metadata_vector(self): # now some automated ones for dtype in [None, "int", "float", "object"]: self.verify_metadata_vector(table, key="abc", dtype=dtype, default_value=-1) - self.verify_metadata_vector( - table, key=["abc"], dtype=dtype, default_value=-1 - ) + self.verify_metadata_vector(table, key=["abc"], dtype=dtype, default_value=-1) self.verify_metadata_vector(table, key=["x"], dtype=dtype, default_value=-1) self.verify_metadata_vector( table, key=["b", "c"], dtype=dtype, default_value=-1 @@ -1677,24 +1661,22 @@ def test_bad_indexes(self, table): class TestIndividualTable(*common_tests): - columns = [UInt32Column("flags")] - ragged_list_columns = [ + columns = (UInt32Column("flags"),) + ragged_list_columns = ( (DoubleColumn("location"), UInt32Column("location_offset")), (Int32Column("parents"), UInt32Column("parents_offset")), (CharColumn("metadata"), UInt32Column("metadata_offset")), - ] - string_colnames = [] - binary_colnames = ["metadata"] - input_parameters = [("max_rows_increment", 0)] - equal_len_columns = [["flags"]] + ) + string_colnames = () + binary_colnames = ("metadata",) + input_parameters = (("max_rows_increment", 0),) + equal_len_columns = (("flags",),) table_class = tskit.IndividualTable def test_simple_example(self): t = tskit.IndividualTable() t.add_row(flags=0, location=[], parents=[], metadata=b"123") - t.add_row( - flags=1, location=(0, 1, 2, 3), parents=(4, 5, 6, 7), metadata=b"\xf0" - ) + t.add_row(flags=1, location=(0, 1, 2, 3), parents=(4, 5, 6, 7), metadata=b"\xf0") s = str(t) assert len(s) > 0 assert len(t) == 2 @@ -1821,17 +1803,17 @@ def test_keep_rows_data(self): class TestNodeTable(*common_tests): - columns = [ + columns = ( UInt32Column("flags"), DoubleColumn("time"), Int32Column("individual"), Int32Column("population"), - ] - ragged_list_columns = [(CharColumn("metadata"), UInt32Column("metadata_offset"))] - string_colnames = [] - binary_colnames = ["metadata"] - input_parameters = [("max_rows_increment", 0)] - equal_len_columns = [["time", "flags", "population"]] + ) + ragged_list_columns = ((CharColumn("metadata"), UInt32Column("metadata_offset")),) + string_colnames = () + binary_colnames = ("metadata",) + input_parameters = (("max_rows_increment", 0),) + equal_len_columns = (("time", "flags", "population"),) table_class = tskit.NodeTable def test_simple_example(self): @@ -1901,17 +1883,17 @@ def test_add_row_bad_data(self): class TestEdgeTable(*common_tests): - columns = [ + columns = ( DoubleColumn("left"), DoubleColumn("right"), Int32Column("parent"), Int32Column("child"), - ] - equal_len_columns = [["left", "right", "parent", "child"]] - string_colnames = [] - binary_colnames = ["metadata"] - ragged_list_columns = [(CharColumn("metadata"), UInt32Column("metadata_offset"))] - input_parameters = [("max_rows_increment", 0)] + ) + equal_len_columns = (("left", "right", "parent", "child"),) + string_colnames = () + binary_colnames = ("metadata",) + ragged_list_columns = ((CharColumn("metadata"), UInt32Column("metadata_offset")),) + input_parameters = (("max_rows_increment", 0),) table_class = tskit.EdgeTable def test_simple_example(self): @@ -1948,15 +1930,15 @@ def test_add_row_bad_data(self): class TestSiteTable(*common_tests): - columns = [DoubleColumn("position")] - ragged_list_columns = [ + columns = (DoubleColumn("position"),) + ragged_list_columns = ( (CharColumn("ancestral_state"), UInt32Column("ancestral_state_offset")), (CharColumn("metadata"), UInt32Column("metadata_offset")), - ] - equal_len_columns = [["position"]] - string_colnames = ["ancestral_state"] - binary_colnames = ["metadata"] - input_parameters = [("max_rows_increment", 0)] + ) + equal_len_columns = (("position",),) + string_colnames = ("ancestral_state",) + binary_colnames = ("metadata",) + input_parameters = (("max_rows_increment", 0),) table_class = tskit.SiteTable def test_simple_example(self): @@ -1994,29 +1976,27 @@ def test_packset_ancestral_state(self): table = self.table_class() table.set_columns(**input_data) ancestral_states = [tsutil.random_strings(10) for _ in range(num_rows)] - ancestral_state, ancestral_state_offset = tskit.pack_strings( - ancestral_states - ) + ancestral_state, ancestral_state_offset = tskit.pack_strings(ancestral_states) table.packset_ancestral_state(ancestral_states) assert np.array_equal(table.ancestral_state, ancestral_state) assert np.array_equal(table.ancestral_state_offset, ancestral_state_offset) class TestMutationTable(*common_tests): - columns = [ + columns = ( Int32Column("site"), Int32Column("node"), DoubleColumn("time"), Int32Column("parent"), - ] - ragged_list_columns = [ + ) + ragged_list_columns = ( (CharColumn("derived_state"), UInt32Column("derived_state_offset")), (CharColumn("metadata"), UInt32Column("metadata_offset")), - ] - equal_len_columns = [["site", "node", "time"]] - string_colnames = ["derived_state"] - binary_colnames = ["metadata"] - input_parameters = [("max_rows_increment", 0)] + ) + equal_len_columns = (("site", "node", "time"),) + string_colnames = ("derived_state",) + binary_colnames = ("metadata",) + input_parameters = (("max_rows_increment", 0),) table_class = tskit.MutationTable def test_simple_example(self): @@ -2097,19 +2077,19 @@ def test_keep_rows_data(self): class TestMigrationTable(*common_tests): - columns = [ + columns = ( DoubleColumn("left"), DoubleColumn("right"), Int32Column("node"), Int32Column("source"), Int32Column("dest"), DoubleColumn("time"), - ] - ragged_list_columns = [(CharColumn("metadata"), UInt32Column("metadata_offset"))] - string_colnames = [] - binary_colnames = ["metadata"] - input_parameters = [("max_rows_increment", 0)] - equal_len_columns = [["left", "right", "node", "source", "dest", "time"]] + ) + ragged_list_columns = ((CharColumn("metadata"), UInt32Column("metadata_offset")),) + string_colnames = () + binary_colnames = ("metadata",) + input_parameters = (("max_rows_increment", 0),) + equal_len_columns = (("left", "right", "node", "source", "dest", "time"),) table_class = tskit.MigrationTable def test_simple_example(self): @@ -2148,15 +2128,15 @@ def test_add_row_bad_data(self): class TestProvenanceTable(CommonTestsMixin, AssertEqualsMixin): - columns = [] - ragged_list_columns = [ + columns = () + ragged_list_columns = ( (CharColumn("timestamp"), UInt32Column("timestamp_offset")), (CharColumn("record"), UInt32Column("record_offset")), - ] - equal_len_columns = [[]] - string_colnames = ["record", "timestamp"] - binary_colnames = [] - input_parameters = [("max_rows_increment", 0)] + ) + equal_len_columns = ((),) + string_colnames = ("record", "timestamp") + binary_colnames = () + input_parameters = (("max_rows_increment", 0),) table_class = tskit.ProvenanceTable def test_simple_example(self): @@ -2200,12 +2180,12 @@ def test_packset_record(self): class TestPopulationTable(*common_tests): metadata_mandatory = True - columns = [] - ragged_list_columns = [(CharColumn("metadata"), UInt32Column("metadata_offset"))] - equal_len_columns = [[]] - string_colnames = [] - binary_colnames = ["metadata"] - input_parameters = [("max_rows_increment", 0)] + columns = () + ragged_list_columns = ((CharColumn("metadata"), UInt32Column("metadata_offset")),) + equal_len_columns = ((),) + string_colnames = () + binary_colnames = ("metadata",) + input_parameters = (("max_rows_increment", 0),) table_class = tskit.PopulationTable def test_simple_example(self): @@ -2238,9 +2218,7 @@ class TestTableCollectionIndexes: def test_index(self): i = np.arange(20) r = np.arange(20)[::-1] - index = tskit.TableCollectionIndexes( - edge_insertion_order=i, edge_removal_order=r - ) + index = tskit.TableCollectionIndexes(edge_insertion_order=i, edge_removal_order=r) assert np.array_equal(index.edge_insertion_order, i) assert np.array_equal(index.edge_removal_order, r) d = index.asdict() @@ -2876,9 +2854,7 @@ def test_younger_than_node_below(self): def test_older_than_node_above(self): ts = msprime.simulate(5, mutation_rate=1, random_seed=42) tables = ts.dump_tables() - tables.mutations.time = ( - np.ones(len(tables.mutations.time), dtype=np.float64) * 42 - ) + tables.mutations.time = np.ones(len(tables.mutations.time), dtype=np.float64) * 42 with pytest.raises( _tskit.LibraryError, match="A mutation's time must be < the parent node of the edge on which it" @@ -2890,9 +2866,7 @@ def test_older_than_parent_node(self): ts = msprime.simulate( 10, random_seed=42, mutation_rate=0.0, recombination_rate=1.0 ) - ts = tsutil.jukes_cantor( - ts, num_sites=10, mu=1, multiple_per_node=False, seed=42 - ) + ts = tsutil.jukes_cantor(ts, num_sites=10, mu=1, multiple_per_node=False, seed=42) tables = ts.dump_tables() assert sum(tables.mutations.parent != -1) != 0 # Make all times the node time @@ -2912,9 +2886,7 @@ def test_older_than_parent_mutation(self): ts = msprime.simulate( 10, random_seed=42, mutation_rate=0.0, recombination_rate=1.0 ) - ts = tsutil.jukes_cantor( - ts, num_sites=10, mu=1, multiple_per_node=False, seed=42 - ) + ts = tsutil.jukes_cantor(ts, num_sites=10, mu=1, multiple_per_node=False, seed=42) tables = ts.dump_tables() tables.compute_mutation_times() assert sum(tables.mutations.parent != -1) != 0 @@ -2973,9 +2945,7 @@ def test_mixed_known_and_unknown(self): ts = msprime.simulate( 10, random_seed=42, mutation_rate=0.0, recombination_rate=1.0 ) - ts = tsutil.jukes_cantor( - ts, num_sites=10, mu=1, multiple_per_node=False, seed=42 - ) + ts = tsutil.jukes_cantor(ts, num_sites=10, mu=1, multiple_per_node=False, seed=42) tables = ts.dump_tables() tables.compute_mutation_times() tables.sort() @@ -3904,7 +3874,7 @@ def check_concordance(d1, d2): for k1, v1 in d1.items(): v2 = d2[k1] assert type(v1) is type(v2) - if type(v1) is dict: + if type(v1) is dict: # noqa: E721 assert set(v1.keys()) == set(v2.keys()) for sk1, sv1 in v1.items(): sv2 = v2[sk1] @@ -4013,12 +3983,8 @@ def test_equals_migration_metadata(self, ts_fixture): t1 = ts_fixture.dump_tables() t2 = t1.copy() t1.assert_equals(t2) - t1.migrations.add_row( - 0, 1, source=0, dest=1, node=0, time=0, metadata={"a": "a"} - ) - t2.migrations.add_row( - 0, 1, source=0, dest=1, node=0, time=0, metadata={"a": "b"} - ) + t1.migrations.add_row(0, 1, source=0, dest=1, node=0, time=0, metadata={"a": "a"}) + t2.migrations.add_row(0, 1, source=0, dest=1, node=0, time=0, metadata={"a": "b"}) assert not t1.migrations.equals(t2.migrations) assert not t1.equals(t2) assert t1.migrations.equals(t2.migrations, ignore_metadata=True) @@ -4058,11 +4024,11 @@ def test_equals_population_metadata(self, ts_fixture): class TestTableCollectionAssertEquals: - @pytest.fixture + @pytest.fixture() def t1(self, ts_fixture): return ts_fixture.dump_tables() - @pytest.fixture + @pytest.fixture() def t2(self, ts_fixture): return ts_fixture.dump_tables() @@ -4104,9 +4070,7 @@ def test_metadata(self, t1, t2): t2.metadata = {"foo": "bar"} with pytest.raises( AssertionError, - match=re.escape( - "Metadata differs: self=Test metadata other={'foo': 'bar'}" - ), + match=re.escape("Metadata differs: self=Test metadata other={'foo': 'bar'}"), ): t1.assert_equals(t2) t1.assert_equals(t2, ignore_metadata=True) @@ -4829,9 +4793,7 @@ def get_msprime_example(self, sample_size, T, seed): random_seed=seed, ) ts = tsutil.add_random_metadata(ts, seed) - ts = tsutil.insert_random_ploidy_individuals( - ts, max_ploidy=1, samples_only=True - ) + ts = tsutil.insert_random_ploidy_individuals(ts, max_ploidy=1, samples_only=True) return ts def get_wf_example(self, N, T, seed): @@ -4841,9 +4803,7 @@ def get_wf_example(self, N, T, seed): ts = ts.simplify() ts = tsutil.jukes_cantor(ts, 1, 10, seed=seed) ts = tsutil.add_random_metadata(ts, seed) - ts = tsutil.insert_random_ploidy_individuals( - ts, max_ploidy=2, samples_only=True - ) + ts = tsutil.insert_random_ploidy_individuals(ts, max_ploidy=2, samples_only=True) return ts def split_example(self, ts, T): @@ -4978,13 +4938,9 @@ def verify_union_consistency(self, tables, other, node_mapping): assert np.sum(i2.parents == tskit.NULL) == np.sum( iu.parents == tskit.NULL ) - md2 = [ - ts2.individual(p).metadata for p in i2.parents if p != tskit.NULL - ] + md2 = [ts2.individual(p).metadata for p in i2.parents if p != tskit.NULL] md2u = [indivs21[md] for md in md2] - mdu = [ - tsu.individual(p).metadata for p in iu.parents if p != tskit.NULL - ] + mdu = [tsu.individual(p).metadata for p in iu.parents if p != tskit.NULL] assert set(md2u) == set(mdu) else: # the individual *should* be there, but by a different name @@ -5072,9 +5028,7 @@ def test_provenance(self): ) tables_copy = tables.copy() tables.union(other, node_mapping) - uni_other_dict = json.loads(tables.provenances[-1].record)["parameters"][ - "other" - ] + uni_other_dict = json.loads(tables.provenances[-1].record)["parameters"]["other"] recovered_prov_table = tskit.ProvenanceTable() assert len(uni_other_dict["timestamp"]) == len(uni_other_dict["record"]) for timestamp, record in zip( diff --git a/python/tests/test_text_formats.py b/python/tests/test_text_formats.py index 92874af6fd..480217a381 100644 --- a/python/tests/test_text_formats.py +++ b/python/tests/test_text_formats.py @@ -22,6 +22,7 @@ """ Test cases for converting fam file to tskit """ + import dataclasses import tempfile from dataclasses import asdict diff --git a/python/tests/test_threads.py b/python/tests/test_threads.py index ad96cb3d52..a0397e053d 100644 --- a/python/tests/test_threads.py +++ b/python/tests/test_threads.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2021 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (c) 2016-2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -23,6 +23,7 @@ """ Test cases for threading enabled aspects of the API. """ + import platform import threading @@ -58,9 +59,7 @@ class TestLdCalculatorReplicates: num_test_sites = 25 def get_tree_sequence(self): - ts = msprime.simulate( - 20, mutation_rate=10, recombination_rate=10, random_seed=8 - ) + ts = msprime.simulate(20, mutation_rate=10, recombination_rate=10, random_seed=8) return tsutil.subsample_sites(ts, self.num_test_sites) def test_get_r2_multiple_instances(self): @@ -142,9 +141,7 @@ def worker(thread_index, results): # Temporarily skipping these on Windows and OSX See # /~https://github.com/tskit-dev/tskit/issues/344 # /~https://github.com/tskit-dev/tskit/issues/1041 -@pytest.mark.skipif( - IS_WINDOWS or IS_OSX, reason="Can't test thread support on Windows." -) +@pytest.mark.skipif(IS_WINDOWS or IS_OSX, reason="Can't test thread support on Windows.") class TestTables: """ Tests to ensure that attempts to access tables in threads correctly @@ -153,9 +150,7 @@ class TestTables: def get_tables(self): # TODO include migrations here. - ts = msprime.simulate( - 100, mutation_rate=10, recombination_rate=10, random_seed=8 - ) + ts = msprime.simulate(100, mutation_rate=10, recombination_rate=10, random_seed=8) return ts.tables def run_multiple_writers(self, writer, num_writers=32): diff --git a/python/tests/test_topology.py b/python/tests/test_topology.py index d564ec0590..1a27a1bef5 100644 --- a/python/tests/test_topology.py +++ b/python/tests/test_topology.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2023 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (c) 2016-2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -23,6 +23,7 @@ """ Test cases for the supported topological variations and operations. """ + import functools import io import itertools @@ -2282,7 +2283,7 @@ def test_single_tree(self): trees = [t.parent_dict for t in ts_redundant.trees()] assert len(trees) == 2 assert trees[0] == trees[1] - assert [t.parent_dict for t in ts.trees()][0] == trees[0] + assert next(t.parent_dict for t in ts.trees()) == trees[0] def test_many_trees(self): ts = msprime.simulate(20, recombination_rate=5, random_seed=self.random_seed) @@ -2297,8 +2298,7 @@ def test_many_trees(self): comparisons = 0 for t in ts.trees(): while ( - redundant_t is not None - and redundant_t.interval.right <= t.interval.right + redundant_t is not None and redundant_t.interval.right <= t.interval.right ): assert t.parent_dict == redundant_t.parent_dict comparisons += 1 @@ -5352,9 +5352,7 @@ def test_many_mutations_over_single_sample_ancestral_state(self): 0 0 0 0 """ ) - ts = tskit.load_text( - nodes, edges, sites=sites, mutations=mutations, strict=False - ) + ts = tskit.load_text(nodes, edges, sites=sites, mutations=mutations, strict=False) assert ts.sample_size == 1 assert ts.num_trees == 1 assert ts.num_sites == 1 @@ -5393,9 +5391,7 @@ def test_many_mutations_over_single_sample_derived_state(self): 0 0 1 1 """ ) - ts = tskit.load_text( - nodes, edges, sites=sites, mutations=mutations, strict=False - ) + ts = tskit.load_text(nodes, edges, sites=sites, mutations=mutations, strict=False) assert ts.sample_size == 1 assert ts.num_trees == 1 assert ts.num_sites == 1 @@ -5811,9 +5807,7 @@ def verify_keep_input_roots(self, ts, samples): assert left <= position < right new_site = new_sites[position] # We assume the metadata contains a unique key for each mutation. - new_mutations = { - mut.metadata: mut for mut in new_site.mutations - } + new_mutations = {mut.metadata: mut for mut in new_site.mutations} # Just make sure the metadata is actually unique. assert len(new_mutations) == len(new_site.mutations) input_site = input_sites[position] @@ -5912,10 +5906,8 @@ def verify_nodes_unchanged(self, ts_in, resample_size=None, **kwargs): if resample_size is None: samples = None else: - np.random.seed(42) - samples = np.sort( - np.random.choice(ts_in.num_nodes, resample_size, replace=False) - ) + rng = np.random.default_rng(42) + samples = np.sort(rng.choice(ts_in.num_nodes, resample_size, replace=False)) for ts in (ts_in, self.reverse_node_indexes(ts_in)): filtered, n_map = do_simplify( @@ -5933,9 +5925,7 @@ def verify_nodes_unchanged(self, ts_in, resample_size=None, **kwargs): # Check that edges are identical to the normal simplify(), # with the normal "simplify" having altered IDs - simplified, node_map = ts.simplify( - samples=samples, map_nodes=True, **kwargs - ) + simplified, node_map = ts.simplify(samples=samples, map_nodes=True, **kwargs) simplified_edges = {e for e in simplified.tables.edges} filtered_edges = { e.replace(parent=node_map[e.parent], child=node_map[e.child]) @@ -6005,9 +5995,7 @@ def test_multiroot(self, resample_size): @pytest.mark.parametrize("resample_size", [None, 10]) def test_with_metadata(self, ts_fixture_for_simplify, resample_size): assert ts_fixture_for_simplify.num_nodes > 10 - self.verify_nodes_unchanged( - ts_fixture_for_simplify, resample_size=resample_size - ) + self.verify_nodes_unchanged(ts_fixture_for_simplify, resample_size=resample_size) @pytest.mark.parametrize("resample_size", [None, 7]) def test_complex_ts_with_unary(self, resample_size): @@ -6666,9 +6654,7 @@ class TestMutationTime: def verify_times(self, ts): tables = ts.tables # Clear out the existing mutations as they come from msprime - tables.mutations.time = np.full( - tables.mutations.time.shape, -1, dtype=np.float64 - ) + tables.mutations.time = np.full(tables.mutations.time.shape, -1, dtype=np.float64) assert np.all(tables.mutations.time == -1) # Compute times with C method and dumb python method tables.compute_mutation_times() @@ -6730,9 +6716,7 @@ def test_example(self): tables = ts.tables python_time = tsutil.compute_mutation_times(ts) assert np.allclose(python_time, tables.mutations.time, rtol=1e-15, atol=1e-15) - tables.mutations.time = np.full( - tables.mutations.time.shape, -1, dtype=np.float64 - ) + tables.mutations.time = np.full(tables.mutations.time.shape, -1, dtype=np.float64) assert np.all(tables.mutations.time == -1) tables.compute_mutation_times() assert np.allclose(python_time, tables.mutations.time, rtol=1e-15, atol=1e-15) @@ -6969,9 +6953,7 @@ def verify(self, ts): # The python implementation here doesn't maintain roots np.testing.assert_array_equal(tree1.parent, tree2.parent_array[:-1]) np.testing.assert_array_equal(tree1.left_child, tree2.left_child_array[:-1]) - np.testing.assert_array_equal( - tree1.right_child, tree2.right_child_array[:-1] - ) + np.testing.assert_array_equal(tree1.right_child, tree2.right_child_array[:-1]) assert right == ts.sequence_length @@ -7049,9 +7031,7 @@ def verify(self, ts): assert tree_py.left_child[-1] == tree_lib.left_root np.testing.assert_array_equal(tree_py.parent, tree_lib.parent_array) np.testing.assert_array_equal(tree_py.left_child, tree_lib.left_child_array) - np.testing.assert_array_equal( - tree_py.right_child, tree_lib.right_child_array - ) + np.testing.assert_array_equal(tree_py.right_child, tree_lib.right_child_array) np.testing.assert_array_equal(tree_py.left_sib, tree_lib.left_sib_array) np.testing.assert_array_equal(tree_py.right_sib, tree_lib.right_sib_array) np.testing.assert_array_equal( @@ -7066,9 +7046,7 @@ def verify(self, ts): # except for the extra node and the details of the sib arrays. np.testing.assert_array_equal(tree_py.parent[:-1], tree_leg.parent) np.testing.assert_array_equal(tree_py.left_child[:-1], tree_leg.left_child) - np.testing.assert_array_equal( - tree_py.right_child[:-1], tree_leg.right_child - ) + np.testing.assert_array_equal(tree_py.right_child[:-1], tree_leg.right_child) # The sib arrays are identical except for root nodes. for u in range(ts.num_nodes): if u not in roots: @@ -7346,9 +7324,7 @@ def squash_edges(ts): last_e = edges[0] for e in edges[1:]: condition = ( - e.parent != last_e.parent - or e.child != last_e.child - or e.left != last_e.right + e.parent != last_e.parent or e.child != last_e.child or e.left != last_e.right ) if condition: squashed.append(last_e) @@ -7549,9 +7525,7 @@ def test_simple_recombination(self): self.verify(ts) def test_large_recombination(self): - ts = msprime.simulate( - 25, random_seed=12, recombination_rate=5, mutation_rate=15 - ) + ts = msprime.simulate(25, random_seed=12, recombination_rate=5, mutation_rate=15) self.verify(ts) def test_no_recombination(self): @@ -7611,8 +7585,8 @@ def verify(self, a): a = np.array(a) start, end = a[0], a[-1] # Check random values. - np.random.seed(43) - for v in np.random.uniform(start, end, 10): + rng = np.random.default_rng(43) + for v in rng.uniform(start, end, 10): assert search_sorted(a, v) == np.searchsorted(a, v) # Check equal values. for v in a: @@ -7630,27 +7604,27 @@ def test_negative_range(self): self.verify(-1 * np.arange(j)[::-1]) def test_random_unit_interval(self): - np.random.seed(143) + rng = np.random.default_rng(143) for size in range(1, 100): - a = np.random.random(size=size) + a = rng.random(size=size) a.sort() self.verify(a) def test_random_interval(self): - np.random.seed(143) + rng = np.random.default_rng(143) for _ in range(10): - interval = np.random.random(2) * 10 + interval = rng.random(2) * 10 interval.sort() - a = np.random.uniform(*interval, size=100) + a = rng.uniform(*interval, size=100) a.sort() self.verify(a) def test_random_negative(self): - np.random.seed(143) + rng = np.random.default_rng(143) for _ in range(10): - interval = np.random.random(2) * 5 + interval = rng.random(2) * 5 interval.sort() - a = -1 * np.random.uniform(*interval, size=100) + a = -1 * rng.uniform(*interval, size=100) a.sort() self.verify(a) @@ -7912,9 +7886,7 @@ def example_intervals(self, tables): yield [(0.25 * L, 0.5 * L)] yield [(0.25 * L, 0.5 * L), (0.75 * L, 0.8 * L)] - def do_keep_intervals( - self, tables, intervals, simplify=True, record_provenance=True - ): + def do_keep_intervals(self, tables, intervals, simplify=True, record_provenance=True): t1 = tables.copy() simple_keep_intervals(t1, intervals, simplify, record_provenance) t2 = tables.copy() @@ -7986,9 +7958,7 @@ def test_hundred_intervals(self): self.do_keep_intervals(tables, intervals, simplify, rec_prov) def test_regular_intervals(self): - ts = msprime.simulate( - 3, random_seed=1234, recombination_rate=2, mutation_rate=2 - ) + ts = msprime.simulate(3, random_seed=1234, recombination_rate=2, mutation_rate=2) tables = ts.tables eps = 0.0125 for num_intervals in range(2, 10): @@ -8132,9 +8102,7 @@ def test_tables_single_tree_delete_middle(self): def test_ts_single_tree_keep_middle(self): ts = msprime.simulate(10, random_seed=2) ts_keep = ts.keep_intervals([[0.25, 0.5]], record_provenance=False) - ts_delete = ts.delete_intervals( - [[0, 0.25], [0.5, 1.0]], record_provenance=False - ) + ts_delete = ts.delete_intervals([[0, 0.25], [0.5, 1.0]], record_provenance=False) assert ts_keep == ts_delete def test_ts_single_tree_delete_middle(self): @@ -8312,23 +8280,17 @@ def test_ltrim_single_tree_tiny_left(self): self.verify_ltrim(ts, ts.ltrim()) def test_ltrim_many_trees(self): - ts = msprime.simulate( - 10, recombination_rate=10, mutation_rate=12, random_seed=2 - ) + ts = msprime.simulate(10, recombination_rate=10, mutation_rate=12, random_seed=2) ts = self.clear_left_mutate(ts, 0.5, 10) self.verify_ltrim(ts, ts.ltrim()) def test_ltrim_many_trees_left_min(self): - ts = msprime.simulate( - 10, recombination_rate=10, mutation_rate=12, random_seed=2 - ) + ts = msprime.simulate(10, recombination_rate=10, mutation_rate=12, random_seed=2) ts = self.clear_left_mutate(ts, sys.float_info.min, 10) self.verify_ltrim(ts, ts.ltrim()) def test_ltrim_many_trees_left_epsilon(self): - ts = msprime.simulate( - 10, recombination_rate=10, mutation_rate=12, random_seed=2 - ) + ts = msprime.simulate(10, recombination_rate=10, mutation_rate=12, random_seed=2) ts = self.clear_left_mutate(ts, sys.float_info.epsilon, 0) self.verify_ltrim(ts, ts.ltrim()) @@ -8374,23 +8336,17 @@ def test_rtrim_single_tree_tiny_left(self): self.verify_rtrim(ts, ts.rtrim()) def test_rtrim_many_trees(self): - ts = msprime.simulate( - 10, recombination_rate=10, mutation_rate=12, random_seed=2 - ) + ts = msprime.simulate(10, recombination_rate=10, mutation_rate=12, random_seed=2) ts = self.clear_right_mutate(ts, 0.5, 10) self.verify_rtrim(ts, ts.rtrim()) def test_rtrim_many_trees_left_min(self): - ts = msprime.simulate( - 10, recombination_rate=10, mutation_rate=12, random_seed=2 - ) + ts = msprime.simulate(10, recombination_rate=10, mutation_rate=12, random_seed=2) ts = self.clear_right_mutate(ts, sys.float_info.min, 10) self.verify_rtrim(ts, ts.rtrim()) def test_rtrim_many_trees_left_epsilon(self): - ts = msprime.simulate( - 10, recombination_rate=10, mutation_rate=12, random_seed=2 - ) + ts = msprime.simulate(10, recombination_rate=10, mutation_rate=12, random_seed=2) ts = self.clear_right_mutate(ts, sys.float_info.epsilon, 0) self.verify_rtrim(ts, ts.rtrim()) @@ -8406,9 +8362,7 @@ def test_rtrim_multiple_mutations(self): self.assertAlmostEqual(trimmed_ts.sequence_length, 0.5) assert trimmed_ts.num_sites == 2 assert trimmed_ts.num_mutations == 5 # We should have deleted 4 - assert ( - np.max(trimmed_ts.tables.edges.right) == trimmed_ts.tables.sequence_length - ) + assert np.max(trimmed_ts.tables.edges.right) == trimmed_ts.tables.sequence_length self.verify_rtrim(ts, trimmed_ts) def test_rtrim_migrations(self): @@ -8425,9 +8379,7 @@ def test_trim_multiple_mutations(self): assert trimmed_ts.num_mutations == 3 assert trimmed_ts.num_sites == 1 assert np.min(trimmed_ts.tables.edges.left) == 0 - assert ( - np.max(trimmed_ts.tables.edges.right) == trimmed_ts.tables.sequence_length - ) + assert np.max(trimmed_ts.tables.edges.right) == trimmed_ts.tables.sequence_length def test_trims_no_effect(self): # Deleting from middle should have no effect on any trim function diff --git a/python/tests/test_tree_positioning.py b/python/tests/test_tree_positioning.py index 39f0b5ccab..2ba394e26d 100644 --- a/python/tests/test_tree_positioning.py +++ b/python/tests/test_tree_positioning.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2023 Tskit Developers +# Copyright (c) 2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -23,6 +23,7 @@ Tests for tree iterator schemes. Mostly used to develop the incremental iterator infrastructure. """ + import msprime import numpy as np import pytest @@ -58,7 +59,7 @@ def assert_equal(self, other): assert self.tree_pos.index == other.tree_pos.index assert self.tree_pos.interval == other.tree_pos.interval - def next(self): # NOQA: A003 + def next(self): valid = self.tree_pos.next() if valid: for j in range(self.tree_pos.out_range.start, self.tree_pos.out_range.stop): @@ -81,9 +82,7 @@ def prev(self): e = self.tree_pos.out_range.order[j] c = self.ts.edges_child[e] self.parent[c] = -1 - for j in range( - self.tree_pos.in_range.start, self.tree_pos.in_range.stop, -1 - ): + for j in range(self.tree_pos.in_range.start, self.tree_pos.in_range.stop, -1): e = self.tree_pos.in_range.order[j] c = self.ts.edges_child[e] p = self.ts.edges_parent[e] diff --git a/python/tests/test_tree_stats.py b/python/tests/test_tree_stats.py index a272c4f8cf..ef38fa8a5a 100644 --- a/python/tests/test_tree_stats.py +++ b/python/tests/test_tree_stats.py @@ -23,6 +23,7 @@ """ Test cases for generalized statistic computation. """ + import collections import contextlib import functools @@ -40,8 +41,6 @@ import tskit import tskit.exceptions as exceptions -np.random.seed(5) - def cached_np(func): """ @@ -68,18 +67,19 @@ def subset_combos(*args, p=0.5, min_tests=3): # random set is run each time. Ensures that at least min_tests are run. # Uncomment this line to run all tests (takes about an hour): # p = 1.0 + rng = np.random.default_rng(42) num_tests = 0 skipped_tests = [] # total_tests = 0 for x in itertools.product(*args): # total_tests = total_tests + 1 - if np.random.uniform() < p: + if rng.uniform() < p: num_tests += num_tests + 1 yield x elif len(skipped_tests) < min_tests: skipped_tests.append(x) - elif np.random.uniform() < 0.1: - skipped_tests[np.random.randint(min_tests)] = x + elif rng.uniform() < 0.1: + skipped_tests[rng.integers(min_tests)] = x while num_tests < min_tests: yield skipped_tests.pop() num_tests = num_tests + 1 @@ -165,8 +165,7 @@ def naive_branch_general_stat( s = sum(tree.branch_length(u) * f(x[u]) for u in tree.nodes()) else: s = sum( - tree.branch_length(u) * (f(x[u]) + f(total - x[u])) - for u in tree.nodes() + tree.branch_length(u) * (f(x[u]) + f(total - x[u])) for u in tree.nodes() ) sigma[tree.index] = s * tree.span if isinstance(windows, str) and windows == "trees": @@ -288,9 +287,7 @@ def windowed_sitewise_stat(ts, sigma, windows, span_normalise=True): return A -def naive_site_general_stat( - ts, W, f, windows=None, polarised=False, span_normalise=True -): +def naive_site_general_stat(ts, W, f, windows=None, polarised=False, span_normalise=True): n, K = W.shape # Hack to determine M M = len(f(W[0])) @@ -359,9 +356,7 @@ def site_general_stat( while site_index < len(sites) and sites.position[site_index] < right: assert left <= sites.position[site_index] ancestral_state = sites[site_index].ancestral_state - allele_state = collections.defaultdict( - functools.partial(np.zeros, state_dim) - ) + allele_state = collections.defaultdict(functools.partial(np.zeros, state_dim)) allele_state[ancestral_state][:] = total_weight while ( mutation_index < len(mutations) @@ -399,9 +394,7 @@ def site_general_stat( ############################## -def naive_node_general_stat( - ts, W, f, windows=None, polarised=False, span_normalise=True -): +def naive_node_general_stat(ts, W, f, windows=None, polarised=False, span_normalise=True): windows = ts.parse_windows(windows) n, K = W.shape M = f(W[0]).shape[0] @@ -480,9 +473,7 @@ def node_summary(u): w_right = windows[window_index + 1] # Flush the contribution of all nodes to the current window. for u in range(ts.num_nodes): - result[window_index, u] += (w_right - last_update[u]) * current_values[ - u - ] + result[window_index, u] += (w_right - last_update[u]) * current_values[u] last_update[u] = w_right window_index += 1 @@ -607,7 +598,7 @@ def test_wright_fisher_unsimplified(self): ts = tables.tree_sequence() self.verify(ts) - @pytest.mark.slow + @pytest.mark.slow() def test_wright_fisher_initial_generation(self): tables = wf.wf_sim( 6, 5, seed=3, deep_history=True, initial_generation_samples=True, num_loci=2 @@ -892,7 +883,8 @@ def example_sample_sets(ts, min_size=1): number of sample sets returned in each example must be at least min_size """ samples = ts.samples() - np.random.shuffle(samples) + rng = np.random.default_rng(42) + rng.shuffle(samples) splits = np.array_split(samples, min_size) yield splits yield [[s] for s in samples] @@ -953,16 +945,16 @@ def example_weights(self, ts, min_size=1): """ Generate a series of example weights from the specfied tree sequence. """ - np.random.seed(46) + rng = np.random.default_rng(46) for k in [min_size, min_size + 1, min_size + 10]: W = 1.0 + np.zeros((ts.num_samples, k)) W[0, :] = 2.0 yield W for j in range(k): - W[:, j] = np.random.exponential(1, ts.num_samples) + W[:, j] = rng.exponential(1, ts.num_samples) yield W for j in range(k): - W[:, j] = np.random.normal(0, 1, ts.num_samples) + W[:, j] = rng.normal(0, 1, ts.num_samples) yield W def transform_weights(self, W): @@ -996,9 +988,7 @@ def wrapped_summary_func(x): ts, gW, wrapped_summary_func, windows, mode=self.mode, span_normalise=sn ) sigma3 = ts_method(W, windows=windows, mode=self.mode, span_normalise=sn) - sigma4 = definition( - ts, W, windows=windows, mode=self.mode, span_normalise=sn - ) + sigma4 = definition(ts, W, windows=windows, mode=self.mode, span_normalise=sn) assert sigma1.shape == sigma2.shape assert sigma1.shape == sigma3.shape @@ -1071,9 +1061,7 @@ def wrapped_summary_func(x): M = len(wrapped_summary_func(W[0])) sigma1 = ts.general_stat(W, wrapped_summary_func, M, windows, mode=self.mode) sigma2 = general_stat(ts, W, wrapped_summary_func, windows, mode=self.mode) - sigma3 = ts_method( - sample_sets, indexes=indexes, windows=windows, mode=self.mode - ) + sigma3 = ts_method(sample_sets, indexes=indexes, windows=windows, mode=self.mode) sigma4 = definition( ts, sample_sets, indexes=indexes, windows=windows, mode=self.mode ) @@ -2004,8 +1992,7 @@ def wrapped_summary_func(x): / denom ) sigma2 = ( - general_stat(ts, W, wrapped_summary_func, windows, mode=self.mode) - / denom + general_stat(ts, W, wrapped_summary_func, windows, mode=self.mode) / denom ) sigma3 = ts_method( sample_sets, @@ -2255,10 +2242,7 @@ def verify_weighted_stat(self, ts, W, indexes, windows): def f(x): mx = np.sum(x) / n return np.array( - [ - (x[i] - W_sum[i] * mx) * (x[j] - W_sum[j] * mx) / 2 - for i, j in indexes - ] + [(x[i] - W_sum[i] * mx) * (x[j] - W_sum[j] * mx) / 2 for i, j in indexes] ) self.verify_definition( @@ -3110,8 +3094,7 @@ def node_f3(ts, sample_sets, indexes, windows=None, span_normalise=True): + (tA - nA) * (tA - nA - 1) * nB * nC ) SS[u] -= ( - nA * nC * (tA - nA) * (tB - nB) - + (tA - nA) * (tC - nC) * nA * nB + nA * nC * (tA - nA) * (tB - nB) + (tA - nA) * (tC - nC) * nA * nB ) S += SS * (min(end, t1.interval.right) - max(begin, t1.interval.left)) with suppress_division_by_zero_warning(): @@ -3297,12 +3280,10 @@ def node_f4(ts, sample_sets, indexes, windows=None, span_normalise=True): nD = t4.num_tracked_samples(u) # ac|bd - ad|bc SS[u] += ( - nA * nC * (tB - nB) * (tD - nD) - + (tA - nA) * (tC - nC) * nB * nD + nA * nC * (tB - nB) * (tD - nD) + (tA - nA) * (tC - nC) * nB * nD ) SS[u] -= ( - nA * nD * (tB - nB) * (tC - nC) - + (tA - nA) * (tD - nD) * nB * nC + nA * nD * (tB - nB) * (tC - nC) + (tA - nA) * (tD - nD) * nB * nC ) S += SS * (min(end, t1.interval.right) - max(begin, t1.interval.left)) with suppress_division_by_zero_warning(): @@ -3409,9 +3390,7 @@ class TestFold: def test_examples(self): A = np.arange(12) - Af = np.array( - [11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] - ) + Af = np.array([11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) assert np.all(foldit(A) == Af) @@ -3472,7 +3451,7 @@ def naive_site_allele_frequency_spectrum( windows = ts.parse_windows(windows) num_windows = len(windows) - 1 out_dim = [1 + len(sample_set) for sample_set in sample_sets] - out = np.zeros([num_windows] + out_dim) + out = np.zeros([num_windows, *out_dim]) G = ts.genotype_matrix(isolated_as_missing=False) samples = ts.samples() # Indexes of the samples within the sample sets into the samples array. @@ -3527,14 +3506,12 @@ def naive_branch_allele_frequency_spectrum( windows = ts.parse_windows(windows) num_windows = len(windows) - 1 out_dim = [1 + len(sample_set) for sample_set in sample_sets] - out = np.zeros([num_windows] + out_dim) + out = np.zeros([num_windows, *out_dim]) for j in range(num_windows): begin = windows[j] end = windows[j + 1] S = np.zeros(out_dim) - trees = [ - next(ts.trees(tracked_samples=sample_set)) for sample_set in sample_sets - ] + trees = [next(ts.trees(tracked_samples=sample_set)) for sample_set in sample_sets] t = trees[0] while True: tr_len = min(end, t.interval.right) - max(begin, t.interval.left) @@ -3590,7 +3567,7 @@ def branch_allele_frequency_spectrum( out_dim = [1 + len(sample_set) for sample_set in sample_sets] time = ts.tables.nodes.time - result = np.zeros([num_windows] + out_dim) + result = np.zeros([num_windows, *out_dim]) # Number of nodes in sample_set j ancestral to each node u. count = np.zeros((ts.num_nodes, num_sample_sets + 1), dtype=np.uint32) for j in range(num_sample_sets): @@ -3610,7 +3587,7 @@ def update_result(window_index, u, right): c = count[u, :num_sample_sets] if not polarised: c = fold(c, out_dim) - index = tuple([window_index] + list(c)) + index = tuple([window_index, *list(c)]) result[index] += x last_update[u] = right @@ -3673,9 +3650,9 @@ def site_allele_frequency_spectrum( num_windows = windows.shape[0] - 1 out_dim = [1 + len(sample_set) for sample_set in sample_sets] - result = np.zeros([num_windows] + out_dim) + result = np.zeros([num_windows, *out_dim]) # Add an extra sample set to count across all samples - sample_sets = list(sample_sets) + [ts.samples()] + sample_sets = [*list(sample_sets), ts.samples()] # Number of nodes in sample_set j ancestral to each node u. count = np.zeros((ts.num_nodes, len(sample_sets)), dtype=np.uint32) for j in range(len(sample_sets)): @@ -3786,9 +3763,7 @@ def verify_single_sample_set(self, ts): self.assertArrayEqual(a1, a2) for windows in [None, (0, L), (0, L / 2, L)]: a1 = ts.allele_frequency_spectrum(mode=self.mode, windows=windows) - a2 = ts.allele_frequency_spectrum( - [samples], mode=self.mode, windows=windows - ) + a2 = ts.allele_frequency_spectrum([samples], mode=self.mode, windows=windows) self.assertArrayEqual(a1, a2) for polarised in [True, False]: a1 = ts.allele_frequency_spectrum(mode=self.mode, polarised=polarised) @@ -3810,9 +3785,7 @@ def verify_sample_sets(self, ts, sample_sets, windows): # print(ts.draw_text()) # print("sample_sets = ", sample_sets) windows = ts.parse_windows(windows) - for span_normalise, polarised in itertools.product( - [True, False], [True, False] - ): + for span_normalise, polarised in itertools.product([True, False], [True, False]): sfs1 = naive_allele_frequency_spectrum( ts, sample_sets, @@ -3975,7 +3948,7 @@ def get_example_ts(self): def test_duplicate_samples(self): ts = self.get_example_ts() - for bad_set in [[1, 1], [1, 2, 1], list(range(10)) + [9]]: + for bad_set in [[1, 1], [1, 2, 1], [*list(range(10)), 9]]: with pytest.raises(exceptions.LibraryError): ts.diversity([bad_set]) with pytest.raises(exceptions.LibraryError): @@ -4007,10 +3980,10 @@ def test_non_samples(self): ts.sample_count_stat([[ts.num_samples]], self.identity_f(ts), 1) def test_span_normalise(self): - np.random.seed(92) + rng = np.random.default_rng(92) ts = self.get_example_ts() sample_sets = [[0, 1], [2, 3, 4], [5, 6]] - windows = ts.sequence_length * np.random.uniform(size=10) + windows = ts.sequence_length * rng.uniform(size=10) windows.sort() windows[0] = 0.0 windows[-1] = ts.sequence_length @@ -4693,14 +4666,14 @@ def test_errors(self): ts.trait_correlation(W[1:, :]) -@pytest.mark.slow +@pytest.mark.slow() class TestBranchTraitCovariance( TestTraitCovariance, TopologyExamplesMixin, TraitCovarianceMixin ): mode = "branch" -@pytest.mark.slow +@pytest.mark.slow() class TestNodeTraitCovariance( TestTraitCovariance, TopologyExamplesMixin, TraitCovarianceMixin ): @@ -4871,9 +4844,7 @@ def f(x): else: return x[:-1] * 0.0 - self.verify_definition( - ts, W, windows, f, ts.trait_correlation, trait_correlation - ) + self.verify_definition(ts, W, windows, f, ts.trait_correlation, trait_correlation) def test_errors(self): ts = self.get_example_ts() @@ -4915,14 +4886,14 @@ def test_normalisation(self): self.verify_standardising(ts, trait_correlation, ts.trait_correlation) -@pytest.mark.slow +@pytest.mark.slow() class TestBranchTraitCorrelation( TestTraitCorrelation, TopologyExamplesMixin, TraitCorrelationMixin ): mode = "branch" -@pytest.mark.slow +@pytest.mark.slow() class TestNodeTraitCorrelation( TestTraitCorrelation, TopologyExamplesMixin, TraitCorrelationMixin ): @@ -5109,7 +5080,7 @@ def get_example_ts(self): return ts def example_covariates(self, ts): - np.random.seed(999) + rng = np.random.default_rng(999) N = ts.num_samples for k in [1, 2, 5]: k = min(k, ts.num_samples) @@ -5117,7 +5088,7 @@ def example_covariates(self, ts): Z[1, :] = np.arange(k, 2 * k) yield Z for j in range(k): - Z[:, j] = np.random.normal(0, 1, N) + Z[:, j] = rng.normal(0, 1, N) yield Z def transform_weights(self, W, Z): @@ -5249,14 +5220,14 @@ def test_deprecation(self): ts.trait_regression(W, Z=Z, mode=self.mode) -@pytest.mark.slow +@pytest.mark.slow() class TestBranchTraitLinearModel( TestTraitLinearModel, TopologyExamplesMixin, TraitLinearModelMixin ): mode = "branch" -@pytest.mark.slow +@pytest.mark.slow() class TestNodeTraitLinearModel( TestTraitLinearModel, TopologyExamplesMixin, TraitLinearModelMixin ): @@ -5290,7 +5261,7 @@ def compare_sfs(self, ts, tree_fn, sample_sets, tsc_fn): for sample_set in sample_sets: windows = [ k * ts.sequence_length / 20 - for k in [0] + sorted(self.rng.sample(range(1, 20), 4)) + [20] + for k in [0, *sorted(self.rng.sample(range(1, 20), 4)), 20] ] win_args = [ {"begin": windows[i], "end": windows[i + 1]} @@ -5375,9 +5346,7 @@ def setUp(self): def get_ts(self): for N in [12, 15, 20]: - yield msprime.simulate( - N, random_seed=self.random_seed, recombination_rate=10 - ) + yield msprime.simulate(N, random_seed=self.random_seed, recombination_rate=10) @pytest.mark.skip(reason="Skipping SFS.") def test_sfs_interface(self): @@ -5636,9 +5605,7 @@ def f(x): def f(x): return np.array( [ - float( - ((x[0] == 1) and (x[1] == 0)) or ((x[0] == 0) and (x[1] == 2)) - ) + float(((x[0] == 1) and (x[1] == 0)) or ((x[0] == 0) and (x[1] == 2))) / 2.0 ] ) @@ -5741,9 +5708,7 @@ def f(x): ) self.assertArrayAlmostEqual(py_r, ts_r) self.assertArrayAlmostEqual(true_cov, py_r * (geno_var[:, np.newaxis] ** 2)) - self.assertArrayAlmostEqual( - true_cor, ts_r * geno_var[:, np.newaxis] / trait_var - ) + self.assertArrayAlmostEqual(true_cor, ts_r * geno_var[:, np.newaxis] / trait_var) def test_case_odds_and_ends(self): # Tests having (a) the first site after the first window, and @@ -5891,9 +5856,7 @@ def test_case_four_taxa(self): self.assertAlmostEqual(0.0, f4(ts, A, [(0, 1, 2, 3)], mode=mode)[0]) self.assertAlmostEqual(0.0, ts.f4(A, mode=mode)) A = [[0, 2], [1, 3]] - self.assertAlmostEqual( - branch_true_f2_02_13, f2(ts, A, [(0, 1)], mode=mode)[0][0] - ) + self.assertAlmostEqual(branch_true_f2_02_13, f2(ts, A, [(0, 1)], mode=mode)[0][0]) self.assertAlmostEqual(branch_true_f2_02_13, ts.f2(A, mode=mode)) # diversity @@ -6112,9 +6075,7 @@ def f(x): def f(x): return np.array( [ - float( - ((x[0] == 1) and (x[1] == 0)) or ((x[0] == 0) and (x[1] == 2)) - ) + float(((x[0] == 1) and (x[1] == 0)) or ((x[0] == 0) and (x[1] == 2))) / 2.0 ] ) @@ -6445,9 +6406,7 @@ def verify_three_way_stat_windows(self, ts, method): self.assertArrayEqual(x, y) mode = "node" - x = method( - [A, B, C], indexes=[[0, 1, 2], [0, 2, 1]], windows=windows, mode=mode - ) + x = method([A, B, C], indexes=[[0, 1, 2], [0, 2, 1]], windows=windows, mode=mode) # Three windows, N nodes and 2 triples assert x.shape == (3, N, 2) diff --git a/python/tests/test_util.py b/python/tests/test_util.py index eaed4d07e3..4a744e76b6 100644 --- a/python/tests/test_util.py +++ b/python/tests/test_util.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2023 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,6 +22,7 @@ """ Tests for functions in util.py """ + import collections import itertools import math @@ -112,7 +113,7 @@ class TestNumpyArrayCasting: Tests that the safe_np_int_cast() function works. """ - dtypes_to_test = [np.int32, np.uint32, np.int8, np.uint8] + dtypes_to_test = (np.int32, np.uint32, np.int8, np.uint8) def test_basic_arrays(self): # Simple array @@ -354,7 +355,7 @@ def test_regular_cases(self): @pytest.mark.parametrize( - "value, expected", + ("value", "expected"), [ (0, "0 Bytes"), (1, "1 Byte"), @@ -373,7 +374,7 @@ def test_naturalsize(value, expected): @pytest.mark.parametrize( - "obj, expected", + ("obj", "expected"), [ (0, "Test:0"), ( diff --git a/python/tests/test_utilities.py b/python/tests/test_utilities.py index ea107db22f..cf0753d895 100644 --- a/python/tests/test_utilities.py +++ b/python/tests/test_utilities.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2019-2022 Tskit Developers +# Copyright (c) 2019-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,6 +22,7 @@ """ Tests for the various testing utilities. """ + import msprime import numpy as np import pytest @@ -59,10 +60,7 @@ def test_silent_mutations(self): ts = tsutil.jukes_cantor(ts, 5, 2, seed=2) num_silent = 0 for m in ts.mutations(): - if ( - m.parent != -1 - and ts.mutation(m.parent).derived_state == m.derived_state - ): + if m.parent != -1 and ts.mutation(m.parent).derived_state == m.derived_state: num_silent += 1 assert num_silent > 20 @@ -116,9 +114,7 @@ def test_n_5_mutations(self): def test_n_many_mutations(self): for n in range(10, 15): for num_mutations in range(0, n - 1): - ts = tsutil.caterpillar_tree( - n, num_sites=1, num_mutations=num_mutations - ) + ts = tsutil.caterpillar_tree(n, num_sites=1, num_mutations=num_mutations) self.verify(ts, n) assert ts.num_sites == 1 assert ts.num_mutations == num_mutations diff --git a/python/tests/test_vcf.py b/python/tests/test_vcf.py index a1ce300702..7b5a221cf8 100644 --- a/python/tests/test_vcf.py +++ b/python/tests/test_vcf.py @@ -23,6 +23,7 @@ """ Test cases for VCF output in tskit. """ + import contextlib import io import math @@ -62,6 +63,7 @@ def ts_to_pysam(ts, *args, **kwargs): def example_individuals(ts, ploidy=1): + rng = np.random.default_rng(42) if ts.num_individuals == 0: yield None, ts.num_samples / ploidy else: @@ -70,7 +72,7 @@ def example_individuals(ts, ploidy=1): if ts.num_individuals > 3: n = ts.num_individuals - 2 yield list(range(n)), n - yield 2 + np.random.choice(np.arange(n), n, replace=False), n + yield 2 + rng.choice(np.arange(n), n, replace=False), n def legacy_write_vcf(tree_sequence, output, ploidy, contig_id): @@ -189,9 +191,7 @@ class ExamplesMixin: def test_simple_infinite_sites_random_ploidy(self): ts = msprime.simulate(10, mutation_rate=1, random_seed=2) - ts = tsutil.insert_random_ploidy_individuals( - ts, min_ploidy=1, samples_only=True - ) + ts = tsutil.insert_random_ploidy_individuals(ts, min_ploidy=1, samples_only=True) assert ts.num_sites > 2 self.verify(ts) @@ -211,9 +211,7 @@ def test_simple_infinite_sites_ploidy_2_reversed_samples(self): def test_simple_jukes_cantor_random_ploidy(self): ts = msprime.simulate(10, random_seed=2) ts = tsutil.jukes_cantor(ts, num_sites=10, mu=1, seed=2) - ts = tsutil.insert_random_ploidy_individuals( - ts, min_ploidy=1, samples_only=True - ) + ts = tsutil.insert_random_ploidy_individuals(ts, min_ploidy=1, samples_only=True) self.verify(ts) def test_single_tree_multichar_mutations(self): @@ -639,8 +637,7 @@ def test_no_masks_triploid(self): 1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0|0|1""" expected = textwrap.dedent(s) assert ( - drop_header(self.ts().as_vcf(ploidy=3, allow_position_zero=True)) - == expected + drop_header(self.ts().as_vcf(ploidy=3, allow_position_zero=True)) == expected ) def test_site_0_masked(self): @@ -933,8 +930,7 @@ def test_individual_0(self): 1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0|1""" expected = textwrap.dedent(s) assert ( - drop_header(ts.as_vcf(individuals=[0], allow_position_zero=True)) - == expected + drop_header(ts.as_vcf(individuals=[0], allow_position_zero=True)) == expected ) def test_individual_1(self): @@ -947,8 +943,7 @@ def test_individual_1(self): 1\t6\t3\t0\t1\t.\tPASS\t.\tGT\t0""" expected = textwrap.dedent(s) assert ( - drop_header(ts.as_vcf(individuals=[1], allow_position_zero=True)) - == expected + drop_header(ts.as_vcf(individuals=[1], allow_position_zero=True)) == expected ) def test_reversed(self): diff --git a/python/tests/test_version.py b/python/tests/test_version.py index 19cd2bc5dc..c04c59e207 100644 --- a/python/tests/test_version.py +++ b/python/tests/test_version.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2020-2022 Tskit Developers +# Copyright (c) 2020-2024 Tskit Developers # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -22,6 +22,7 @@ """ Test python package versioning """ + from packaging.version import Version from tskit import _version diff --git a/python/tests/test_wright_fisher.py b/python/tests/test_wright_fisher.py index 94245c1fb6..c182867258 100644 --- a/python/tests/test_wright_fisher.py +++ b/python/tests/test_wright_fisher.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2021 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (C) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -23,6 +23,7 @@ """ Test various functions using messy tables output by a forwards-time simulator. """ + import itertools import random @@ -403,9 +404,7 @@ def test_with_mutations(self): tables = ts.tables assert tables.sites.num_rows > 0 assert tables.mutations.num_rows > 0 - samples = np.where(tables.nodes.flags == tskit.NODE_IS_SAMPLE)[0].astype( - np.int32 - ) + samples = np.where(tables.nodes.flags == tskit.NODE_IS_SAMPLE)[0].astype(np.int32) tables.sort() tables.simplify(samples) assert tables.nodes.num_rows > 0 @@ -585,10 +584,10 @@ def test_simplify_tables(self, ts, nsamples): @pytest.mark.parametrize("ts", wf_sims) @pytest.mark.parametrize("nsamples", [2, 5]) def test_simplify_keep_unary(self, ts, nsamples): - np.random.seed(123) + rng = np.random.default_rng(123) ts = tsutil.mark_metadata(ts, "nodes") sub_samples = random.sample(list(ts.samples()), min(nsamples, ts.num_samples)) - random_nodes = np.random.choice(ts.num_nodes, ts.num_nodes // 2) + random_nodes = rng.choice(ts.num_nodes, ts.num_nodes // 2, replace=False) ts = tsutil.insert_individuals(ts, random_nodes) ts = tsutil.mark_metadata(ts, "individuals") diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index d2f98c4775..7d7eb99341 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2023 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (C) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -23,6 +23,7 @@ """ A collection of utilities to edit and construct tree sequences. """ + import collections import dataclasses import functools @@ -399,13 +400,11 @@ def add_random_metadata(ts, seed=1, max_length=10): to the nodes, sites and mutations. """ tables = ts.dump_tables() - np.random.seed(seed) + rng = np.random.default_rng(seed) - length = np.random.randint(0, max_length, ts.num_nodes) + length = rng.integers(0, max_length, ts.num_nodes) offset = np.cumsum(np.hstack(([0], length)), dtype=np.uint32) - # Older versions of numpy didn't have a dtype argument for randint, so - # must use astype instead. - metadata = np.random.randint(-127, 127, offset[-1]).astype(np.int8) + metadata = rng.integers(-127, 127, offset[-1], dtype=np.int8) nodes = tables.nodes nodes.set_columns( flags=nodes.flags, @@ -416,9 +415,9 @@ def add_random_metadata(ts, seed=1, max_length=10): individual=nodes.individual, ) - length = np.random.randint(0, max_length, ts.num_sites) + length = rng.integers(0, max_length, ts.num_sites) offset = np.cumsum(np.hstack(([0], length)), dtype=np.uint32) - metadata = np.random.randint(-127, 127, offset[-1]).astype(np.int8) + metadata = rng.integers(-127, 127, offset[-1], dtype=np.int8) sites = tables.sites sites.set_columns( position=sites.position, @@ -428,9 +427,9 @@ def add_random_metadata(ts, seed=1, max_length=10): metadata=metadata, ) - length = np.random.randint(0, max_length, ts.num_mutations) + length = rng.integers(0, max_length, ts.num_mutations) offset = np.cumsum(np.hstack(([0], length)), dtype=np.uint32) - metadata = np.random.randint(-127, 127, offset[-1]).astype(np.int8) + metadata = rng.integers(-127, 127, offset[-1], dtype=np.int8) mutations = tables.mutations mutations.set_columns( site=mutations.site, @@ -443,9 +442,9 @@ def add_random_metadata(ts, seed=1, max_length=10): metadata=metadata, ) - length = np.random.randint(0, max_length, ts.num_individuals) + length = rng.integers(0, max_length, ts.num_individuals) offset = np.cumsum(np.hstack(([0], length)), dtype=np.uint32) - metadata = np.random.randint(-127, 127, offset[-1]).astype(np.int8) + metadata = rng.integers(-127, 127, offset[-1], dtype=np.int8) individuals = tables.individuals individuals.set_columns( flags=individuals.flags, @@ -457,12 +456,11 @@ def add_random_metadata(ts, seed=1, max_length=10): metadata=metadata, ) - length = np.random.randint(0, max_length, ts.num_populations) + length = rng.integers(0, max_length, ts.num_populations) offset = np.cumsum(np.hstack(([0], length)), dtype=np.uint32) - metadata = np.random.randint(-127, 127, offset[-1]).astype(np.int8) + metadata = rng.integers(-127, 127, offset[-1], dtype=np.int8) populations = tables.populations populations.set_columns(metadata_offset=offset, metadata=metadata) - add_provenance(tables.provenances, "add_random_metadata") ts = tables.tree_sequence() return ts @@ -829,9 +827,7 @@ def compute_mutation_times(ts): end_time = nodes[edges[edge_idx].parent].time duration = end_time - start_time for i, mut_idx in enumerate(edge_mutations): - times[mut_idx] = end_time - ( - duration * ((i + 1) / (len(edge_mutations) + 1)) - ) + times[mut_idx] = end_time - (duration * ((i + 1) / (len(edge_mutations) + 1))) # Mutations not on a edge (i.e. above a root) get given their node's time for i in range(len(mutations)): @@ -1762,7 +1758,7 @@ def set_null(self): self.interval.left = 0 self.interval.right = 0 - def next(self): # NOQA: A003 + def next(self): M = self.ts.num_edges breakpoints = self.ts.breakpoints(as_array=True) left_coords = self.ts.edges_left @@ -1857,7 +1853,8 @@ def prev(self): def seek_forward(self, index): # NOTE this is still in development and not fully tested. - assert index >= self.index and index < self.ts.num_trees + assert index >= self.index + assert index < self.ts.num_trees M = self.ts.num_edges breakpoints = self.ts.breakpoints(as_array=True) left_coords = self.ts.edges_left diff --git a/python/tskit/cli.py b/python/tskit/cli.py index b20bd260d9..221c265b33 100644 --- a/python/tskit/cli.py +++ b/python/tskit/cli.py @@ -24,6 +24,7 @@ """ Command line utilities for tskit. """ + import argparse import json import os @@ -122,9 +123,8 @@ def run_provenances(args): for provenance in tree_sequence.provenances(): d = json.loads(provenance.record) print( - "id={}, timestamp={}, record={}".format( - provenance.id, provenance.timestamp, json.dumps(d, indent=4) - ) + f"id={provenance.id}, timestamp={provenance.timestamp}, " + f"record={json.dumps(d, indent=4)}" ) else: tree_sequence.dump_text(provenances=sys.stdout) @@ -183,9 +183,7 @@ def get_tskit_parser(): ) parser.set_defaults(runner=run_trees) - parser = subparsers.add_parser( - "upgrade", help="Upgrade legacy tree sequence files." - ) + parser = subparsers.add_parser("upgrade", help="Upgrade legacy tree sequence files.") parser.add_argument( "source", help="The source tskit tree sequence file in legacy format" ) diff --git a/python/tskit/combinatorics.py b/python/tskit/combinatorics.py index 880ec73675..660c385fdc 100644 --- a/python/tskit/combinatorics.py +++ b/python/tskit/combinatorics.py @@ -24,6 +24,7 @@ Module for ranking and unranking trees. Trees are considered only leaf-labelled and unordered, so order of children does not influence equality. """ + import collections import functools import heapq @@ -268,9 +269,7 @@ def generate_balanced( raise ValueError("The arity must be at least 2") root = TreeNode.balanced_tree(range(num_leaves), arity) - tables = root.as_tables( - num_leaves=num_leaves, span=span, branch_length=branch_length - ) + tables = root.as_tables(num_leaves=num_leaves, span=span, branch_length=branch_length) if record_provenance: # TODO replace with a version of /~https://github.com/tskit-dev/tskit/pull/243 @@ -297,9 +296,7 @@ def generate_random_binary( rng = random.Random(random_seed) root = TreeNode.random_binary_tree(range(num_leaves), rng) - tables = root.as_tables( - num_leaves=num_leaves, span=span, branch_length=branch_length - ) + tables = root.as_tables(num_leaves=num_leaves, span=span, branch_length=branch_length) if record_provenance: # TODO replace with a version of /~https://github.com/tskit-dev/tskit/pull/243 @@ -1035,11 +1032,11 @@ def label_tree_group(trees, labels): k = first.num_leaves min_label = labels[0] for first_other_labels in itertools.combinations(labels[1:], k - 1): - first_labels = [min_label] + list(first_other_labels) + first_labels = [min_label, *list(first_other_labels)] rest_labels = set_minus(labels, first_labels) for labeled_first in RankTree.all_labellings(first, first_labels): for labeled_rest in RankTree.label_tree_group(rest, rest_labels): - yield [labeled_first] + labeled_rest + yield [labeled_first, *labeled_rest] def _newick(self): if self.is_leaf(): @@ -1320,9 +1317,7 @@ def group_label_ranks(rank, child_group, labels): num_t_labellings = rank_tree.num_labellings() rest_trees = child_group[i + 1 :] num_rest_assignments = num_assignments_in_group(rest_trees) - num_rest_labellings = num_rest_assignments * ( - num_t_labellings ** len(rest_trees) - ) + num_rest_labellings = num_rest_assignments * (num_t_labellings ** len(rest_trees)) num_labellings_per_label_comb = num_t_labellings * num_rest_labellings comb_rank = rank // num_labellings_per_label_comb @@ -1331,7 +1326,7 @@ def group_label_ranks(rank, child_group, labels): rank %= num_rest_labellings min_label = labels[0] - t_labels = [min_label] + Combination.unrank(comb_rank, labels[1:], k - 1) + t_labels = [min_label, *Combination.unrank(comb_rank, labels[1:], k - 1)] labels = set_minus(labels, t_labels) child_labels.append(t_labels) diff --git a/python/tskit/drawing.py b/python/tskit/drawing.py index e7d412e98c..937f816fb3 100644 --- a/python/tskit/drawing.py +++ b/python/tskit/drawing.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2023 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (c) 2015-2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -23,6 +23,7 @@ """ Module responsible for visualisations. """ + import collections import itertools import math @@ -30,17 +31,14 @@ import operator import warnings from dataclasses import dataclass -from typing import List -from typing import Mapping -from typing import Union +from typing import List, Mapping, Union import numpy as np import svgwrite import tskit import tskit.util as util -from _tskit import NODE_IS_SAMPLE -from _tskit import NULL +from _tskit import NODE_IS_SAMPLE, NULL LEFT = "left" RIGHT = "right" @@ -57,6 +55,7 @@ @dataclass class Offsets: "Used when x_lim set, and displayed ts has been cut down by keep_intervals" + tree: int = 0 site: int = 0 mutation: int = 0 @@ -65,6 +64,7 @@ class Offsets: @dataclass(frozen=True) class Timescaling: "Class used to transform the time axis" + max_time: float min_time: float plot_min: float @@ -132,9 +132,7 @@ def check_min_time(min_time, allow_numeric=True): if allow_numeric: is_numeric = isinstance(min_time, numbers.Real) if min_time not in ["tree", "ts"] and not is_numeric: - raise ValueError( - "min_time must be a numeric value or one of 'tree' or 'ts'" - ) + raise ValueError("min_time must be a numeric value or one of 'tree' or 'ts'") else: if min_time not in ["tree", "ts"]: raise ValueError("min_time must be 'tree' or 'ts'") @@ -156,9 +154,7 @@ def check_format(format): # noqa A002 supported_formats = ["svg", "ascii", "unicode"] if fmt not in supported_formats: raise ValueError( - "Unknown format '{}'. Supported formats are {}".format( - format, supported_formats - ) + f"Unknown format '{format}'. Supported formats are {supported_formats}" ) return fmt @@ -212,7 +208,7 @@ def check_x_lim(x_lim, max_x): if x_lim[0] is not None and x_lim[1] is not None and x_lim[0] >= x_lim[1]: raise ValueError("x_lim[0] must be less than x_lim[1]") except TypeError: - raise TypeError("x_lim parameters must be numeric") + raise TypeError("x_lim parameters must be numeric") from None return x_lim @@ -322,9 +318,9 @@ def clip_ts(ts, x_min, x_max, max_num_trees=None): num_start_trees = max_num_trees // 2 + (1 if max_num_trees % 2 else 0) num_end_trees = max_num_trees // 2 assert num_start_trees + num_end_trees == max_num_trees - tree_status[ - (first_tree + num_start_trees) : (last_tree - num_end_trees + 1) - ] = (OMIT | OMIT_MIDDLE) + tree_status[(first_tree + num_start_trees) : (last_tree - num_end_trees + 1)] = ( + OMIT | OMIT_MIDDLE + ) return ts, tree_status, offsets @@ -373,9 +369,7 @@ def edge_and_sample_nodes(ts, omit_regions=None): for left, right in use_regions: used_edges = edges[np.logical_and(edges.left >= left, edges.right < right)] ids = np.concatenate((ids, used_edges.child, used_edges.parent)) - return np.unique( - np.concatenate((ids, np.where(ts.nodes_flags & NODE_IS_SAMPLE)[0])) - ) + return np.unique(np.concatenate((ids, np.where(ts.nodes_flags & NODE_IS_SAMPLE)[0]))) def draw_tree( @@ -755,9 +749,7 @@ def set_spacing(self, top=0, left=0, bottom=0, right=0): left = self.y_axis_offset # Override user-provided, so y-axis is at x=0 self.plotbox.set_padding(top, left, bottom, right) if self.debug_box: - self.root_groups["debug"] = self.dwg_base.add( - self.drawing.g(class_="debug") - ) + self.root_groups["debug"] = self.dwg_base.add(self.drawing.g(class_="debug")) self.plotbox.draw(self.drawing, self.root_groups["debug"]) def get_axes(self): @@ -846,9 +838,7 @@ def draw_x_axis( transform=f"translate({rnd(x)} {y})", ) ) - site.add( - dwg.line((0, 0), (0, rnd(-tick_length_upper)), class_="sym") - ) + site.add(dwg.line((0, 0), (0, rnd(-tick_length_upper)), class_="sym")) for i, m in enumerate(reversed(mutations)): mutation_class = f"mut m{m.id + self.offsets.mutation}" if m.id in self.mutations_outside_tree: @@ -908,9 +898,7 @@ def draw_y_axis( ) if gridlines: tick.add( - dwg.line( - (0, 0), (rnd(self.plotbox.right - x), 0), class_="grid" - ) + dwg.line((0, 0), (rnd(self.plotbox.right - x), 0), class_="grid") ) tick.add(dwg.line((0, 0), (rnd(-tick_length_left), 0))) self.add_text_in_group( @@ -1384,9 +1372,7 @@ def __init__( mutation_nodes = mut_t.node[focal_mutations] mutation_positions = ts.tables.sites.position[mut_t.site][focal_mutations] mutation_ids = np.arange(ts.num_mutations, dtype=int)[focal_mutations] - for m_id, node, pos in zip( - mutation_ids, mutation_nodes, mutation_positions - ): + for m_id, node, pos in zip(mutation_ids, mutation_nodes, mutation_positions): curr_edge = node_edges[node] if curr_edge >= 0: if ( @@ -1403,9 +1389,7 @@ def __init__( self.right_extent = max(self.right_extent, pos) if self.right_extent != tree.interval.right: # Use nextafter so extent of plotting incorporates the mutation - self.right_extent = np.nextafter( - self.right_extent, self.right_extent + 1 - ) + self.right_extent = np.nextafter(self.right_extent, self.right_extent + 1) # attributes for symbols half_symbol_size = f"{rnd(symbol_size / 2):g}" symbol_size = f"{rnd(symbol_size):g}" @@ -1439,9 +1423,10 @@ def __init__( m = mutation.id + self.offsets.mutation # We need to offset the mutation symbol so that it's centred self.mutation_attrs[m] = { - "d": "M -{0},-{0} l {1},{1} M -{0},{0} l {1},-{1}".format( - half_symbol_size, symbol_size - ) + "d": f"M -{half_symbol_size},-{half_symbol_size} " + f"l {symbol_size},{symbol_size} " + f"M -{half_symbol_size},{half_symbol_size} " + f"l {symbol_size},-{symbol_size}" } if mutation_attrs is not None and m in mutation_attrs: self.mutation_attrs[m].update(mutation_attrs[m]) @@ -1830,9 +1815,7 @@ def __init__( if position_label_format is None: position_scale_labels = create_tick_labels(tick_labels) else: - position_scale_labels = [ - position_label_format.format(x) for x in tick_labels - ] + position_scale_labels = [position_label_format.format(x) for x in tick_labels] time = ts.tables.nodes.time time_scale_labels = [ diff --git a/python/tskit/exceptions.py b/python/tskit/exceptions.py index ed0e7d0791..81b2a3b070 100644 --- a/python/tskit/exceptions.py +++ b/python/tskit/exceptions.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2021 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (c) 2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -23,13 +23,16 @@ """ Exceptions defined in tskit. """ -from _tskit import FileFormatError # noqa: F401 -from _tskit import IdentityPairsNotStoredError # noqa: F401 -from _tskit import IdentitySegmentsNotStoredError # noqa: F401 -from _tskit import LibraryError # noqa: F401 -from _tskit import TskitException # noqa: F401 -from _tskit import VersionTooNewError # noqa: F401 -from _tskit import VersionTooOldError # noqa: F401 + +from _tskit import ( + FileFormatError, # noqa: F401 + IdentityPairsNotStoredError, # noqa: F401 + IdentitySegmentsNotStoredError, # noqa: F401 + LibraryError, # noqa: F401 + TskitException, + VersionTooNewError, # noqa: F401 + VersionTooOldError, # noqa: F401 +) class DuplicatePositionsError(TskitException): diff --git a/python/tskit/formats.py b/python/tskit/formats.py index ac466ce1a8..957f735a16 100644 --- a/python/tskit/formats.py +++ b/python/tskit/formats.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2023 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (c) 2016-2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -24,6 +24,7 @@ Module responsible for converting tree sequence files from older formats. """ + import datetime import json import logging @@ -231,7 +232,7 @@ def get_h5py(): raise ImportError( "Legacy formats require h5py. Install via `pip install h5py`" " or `conda install h5py`" - ) + ) from None return h5py @@ -326,7 +327,7 @@ def _dump_legacy_hdf5_v3(tree_sequence, root): trees = root.create_group("trees") # Get the breakpoints from the records. left = [cr.left for cr in tree_sequence.records()] - breakpoints = np.unique(left + [tree_sequence.sequence_length]) + breakpoints = np.unique([*left, tree_sequence.sequence_length]) trees.create_dataset( "breakpoints", (len(breakpoints),), data=breakpoints, dtype=float ) @@ -349,12 +350,8 @@ def _dump_legacy_hdf5_v3(tree_sequence, root): records_group.create_dataset("left", (length,), data=left, dtype="u4") records_group.create_dataset("right", (length,), data=right, dtype="u4") records_group.create_dataset("node", (length,), data=node, dtype="u4") - records_group.create_dataset( - "num_children", (length,), data=num_children, dtype="u4" - ) - records_group.create_dataset( - "children", (len(children),), data=children, dtype="u4" - ) + records_group.create_dataset("num_children", (length,), data=num_children, dtype="u4") + records_group.create_dataset("children", (len(children),), data=children, dtype="u4") indexes_group = trees.create_group("indexes") left_index = sorted(range(length), key=lambda j: (left[j], time[j])) @@ -362,9 +359,7 @@ def _dump_legacy_hdf5_v3(tree_sequence, root): indexes_group.create_dataset( "insertion_order", (length,), data=left_index, dtype="u4" ) - indexes_group.create_dataset( - "removal_order", (length,), data=right_index, dtype="u4" - ) + indexes_group.create_dataset("removal_order", (length,), data=right_index, dtype="u4") nodes_group = trees.create_group("nodes") population = np.zeros(tree_sequence.num_nodes, dtype="u4") @@ -455,9 +450,7 @@ def _dump_legacy_hdf5_v10(tree_sequence, root): _add_dataset(mutations, "node", tables.mutations.node) _add_dataset(mutations, "parent", tables.mutations.parent) _add_dataset(mutations, "derived_state", tables.mutations.derived_state) - _add_dataset( - mutations, "derived_state_offset", tables.mutations.derived_state_offset - ) + _add_dataset(mutations, "derived_state_offset", tables.mutations.derived_state_offset) _add_dataset(mutations, "metadata", tables.mutations.metadata) _add_dataset(mutations, "metadata_offset", tables.mutations.metadata_offset) diff --git a/python/tskit/genotypes.py b/python/tskit/genotypes.py index 239e135777..026102ec6e 100644 --- a/python/tskit/genotypes.py +++ b/python/tskit/genotypes.py @@ -1,7 +1,7 @@ # # MIT License # -# Copyright (c) 2018-2023 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (c) 2015-2018 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -291,9 +291,7 @@ def frequencies(self, remove_missing=None) -> dict[str, float]: if remove_missing: total -= self.num_missing if total == 0: - logging.warning( - "No non-missing samples at this site, frequencies undefined" - ) + logging.warning("No non-missing samples at this site, frequencies undefined") return { allele: count / total if total > 0 else np.nan for allele, count in self.counts().items() @@ -349,7 +347,7 @@ def __repr__(self): "has_missing_data": self.has_missing_data, "isolated_as_missing": self.isolated_as_missing, } - return f"Variant({repr(d)})" + return f"Variant({d!r})" # diff --git a/python/tskit/intervals.py b/python/tskit/intervals.py index 0c78c50b5b..1e0072c655 100644 --- a/python/tskit/intervals.py +++ b/python/tskit/intervals.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2023 Tskit Developers +# Copyright (c) 2024 Tskit Developers # Copyright (C) 2020-2021 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -24,6 +24,7 @@ """ Utilities for working with intervals and interval maps. """ + from __future__ import annotations import collections.abc @@ -355,9 +356,7 @@ def _text_header_and_rows(self, limit=None): return headers, rows def __str__(self): - header, rows = self._text_header_and_rows( - limit=tskit._print_options["max_lines"] - ) + header, rows = self._text_header_and_rows(limit=tskit._print_options["max_lines"]) table = util.unicode_table( rows=rows, header=header, @@ -366,13 +365,11 @@ def __str__(self): return table def _repr_html_(self): - header, rows = self._text_header_and_rows( - limit=tskit._print_options["max_lines"] - ) + header, rows = self._text_header_and_rows(limit=tskit._print_options["max_lines"]) return util.html_table(rows, header=header) def __repr__(self): - return f"RateMap(position={repr(self.position)}, rate={repr(self.rate)})" + return f"RateMap(position={self.position!r}, rate={self.rate!r})" # # Methods for building rate maps. @@ -386,7 +383,7 @@ def copy(self) -> RateMap: # no need for copying. return RateMap(position=self.position, rate=self.rate) - def slice(self, left=None, right=None, *, trim=False) -> RateMap: # noqa: A003 + def slice(self, left=None, right=None, *, trim=False) -> RateMap: """ Returns a subset of this rate map in the specified interval. diff --git a/python/tskit/metadata.py b/python/tskit/metadata.py index f9e8c0c6c7..348cb6f8e1 100644 --- a/python/tskit/metadata.py +++ b/python/tskit/metadata.py @@ -22,6 +22,7 @@ """ Classes for metadata decoding, encoding and validation """ + from __future__ import annotations import abc @@ -34,8 +35,7 @@ import struct import types from itertools import islice -from typing import Any -from typing import Mapping +from typing import Any, Mapping import jsonschema @@ -46,9 +46,9 @@ def replace_root_refs(obj): - if type(obj) is list: + if type(obj) is list: # noqa: E721 return [replace_root_refs(j) for j in obj] - elif type(obj) is dict: + elif type(obj) is dict: # noqa: E721 ret = {k: replace_root_refs(v) for k, v in obj.items()} if ret.get("$ref") == "#": ret["$ref"] = "#/definitions/root" @@ -87,11 +87,11 @@ def __init__(self, schema: Mapping[str, Any]) -> None: raise NotImplementedError # pragma: no cover @classmethod - def modify_schema(self, schema: Mapping) -> Mapping: + def modify_schema(cls, schema: Mapping) -> Mapping: return schema @classmethod - def is_schema_trivial(self, schema: Mapping) -> bool: + def is_schema_trivial(cls, schema: Mapping) -> bool: return False @abc.abstractmethod @@ -121,9 +121,9 @@ def register_metadata_codec( class JSONCodec(AbstractMetadataCodec): - def default_validator(validator, types, instance, schema): + def default_validator(self, types, instance, schema): # For json codec defaults must be at the top level - if validator.is_type(instance, "object"): + if self.is_type(instance, "object"): for v in instance.get("properties", {}).values(): for v2 in v.get("properties", {}).values(): if "default" in v2: @@ -137,7 +137,7 @@ def default_validator(validator, types, instance, schema): ) @classmethod - def is_schema_trivial(self, schema: Mapping) -> bool: + def is_schema_trivial(cls, schema: Mapping) -> bool: return len(schema.get("properties", {})) == 0 def __init__(self, schema: Mapping[str, Any]) -> None: @@ -159,7 +159,7 @@ def encode(self, obj: Any) -> bytes: except TypeError as e: raise exceptions.MetadataEncodingError( f"Could not encode metadata of type {str(e).split()[3]}" - ) + ) from None def decode(self, encoded: bytes) -> Any: if len(encoded) == 0: @@ -227,9 +227,7 @@ def binary_format_validator(validator, types, instance, schema): def required_validator(validator, required, instance, schema): # Do the normal validation try: - yield from jsonschema._validators.required( - validator, required, instance, schema - ) + yield from jsonschema._validators.required(validator, required, instance, schema) except AttributeError: # Needed since jsonschema==4.19.1 yield from jsonschema._keywords.required(validator, required, instance, schema) @@ -404,8 +402,7 @@ def decode_object_or_null(buffer): else: buffer = iter(buffer) return { - key: sub_decoder(buffer) - for key, sub_decoder in sub_decoders.items() + key: sub_decoder(buffer) for key, sub_decoder in sub_decoders.items() } return decode_object_or_null @@ -417,9 +414,9 @@ def make_string_decode(cls, sub_schema): encoding = sub_schema.get("stringEncoding", "utf-8") null_terminated = sub_schema.get("nullTerminated", False) if not null_terminated: - return lambda buffer: struct.unpack(f, bytes(islice(buffer, size)))[ - 0 - ].decode(encoding) + return lambda buffer: struct.unpack(f, bytes(islice(buffer, size)))[0].decode( + encoding + ) else: def decode_string(buffer): @@ -484,7 +481,7 @@ def array_encode_with_length(array): raise ValueError( "Couldn't pack array size - it is likely too long" " for the specified arrayLengthFormat" - ) + ) from None return packed_length + b"".join(element_encoder(ele) for ele in array) return array_encode_with_length @@ -558,9 +555,9 @@ def modify_schema(cls, schema: Mapping) -> Mapping: # we add it here, sadly we can't do this in the metaschema as "default" isn't # used by the validator. def enforce_fixed_properties(obj): - if type(obj) is list: + if type(obj) is list: # noqa: E721 return [enforce_fixed_properties(j) for j in obj] - elif type(obj) is dict: + elif type(obj) is dict: # noqa: E721 ret = {k: enforce_fixed_properties(v) for k, v in obj.items()} if "object" in ret.get("type", []): if ret.get("additional_properties"): @@ -641,8 +638,8 @@ def __init__(self, schema: Mapping[str, Any] | None) -> None: except KeyError: raise exceptions.MetadataSchemaValidationError( f"Unrecognised metadata codec '{schema['codec']}'. " - f"Valid options are {str(list(codec_registry.keys()))}." - ) + f"Valid options are {list(codec_registry.keys())!s}." + ) from None # Codecs can modify the schema, for example to set defaults as the validator # does not. schema = codec_cls.modify_schema(schema) @@ -749,7 +746,9 @@ def parse_metadata_schema(encoded_schema: str) -> MetadataSchema: encoded_schema, object_pairs_hook=collections.OrderedDict ) except json.decoder.JSONDecodeError: - raise ValueError(f"Metadata schema is not JSON, found {encoded_schema}") + raise ValueError( + f"Metadata schema is not JSON, found {encoded_schema}" + ) from None return MetadataSchema(decoded) @@ -784,9 +783,7 @@ def _lazy_decode(cls): # Intercept the init to record the decoder def new_init(self, *args, metadata_decoder=None, **kwargs): - __builtins__object__setattr__( - self, "_metadata_decoder", metadata_decoder - ) + __builtins__object__setattr__(self, "_metadata_decoder", metadata_decoder) wrapped_init(self, *args, **kwargs) cls.__init__ = new_init diff --git a/python/tskit/provenance.py b/python/tskit/provenance.py index bc88e29f1a..937a467451 100644 --- a/python/tskit/provenance.py +++ b/python/tskit/provenance.py @@ -1,6 +1,6 @@ # MIT License # -# Copyright (c) 2018-2023 Tskit Developers +# Copyright (c) 2018-2024 Tskit Developers # Copyright (c) 2016-2017 University of Oxford # # Permission is hereby granted, free of charge, to any person obtaining a copy @@ -24,6 +24,7 @@ Common provenance methods used to determine the state and versions of various dependencies and the OS. """ + import json import os.path import platform @@ -32,6 +33,7 @@ import _tskit import tskit.exceptions as exceptions + from . import _version __version__ = _version.tskit_version diff --git a/python/tskit/stats.py b/python/tskit/stats.py index f7212e819a..5b2508fc92 100644 --- a/python/tskit/stats.py +++ b/python/tskit/stats.py @@ -22,6 +22,7 @@ """ Module responsible for computing various statistics on tree sequences. """ + import sys import threading @@ -52,9 +53,7 @@ class LdCalculator: def __init__(self, tree_sequence): self._tree_sequence = tree_sequence - self._ll_ld_calculator = _tskit.LdCalculator( - tree_sequence.get_ll_tree_sequence() - ) + self._ll_ld_calculator = _tskit.LdCalculator(tree_sequence.get_ll_tree_sequence()) # To protect low-level C code, only one method may execute on the # low-level objects at one time. self._instance_lock = threading.Lock() diff --git a/python/tskit/tables.py b/python/tskit/tables.py index 0aade18c66..5089988368 100644 --- a/python/tskit/tables.py +++ b/python/tskit/tables.py @@ -24,6 +24,7 @@ """ Tree sequence IO via the tables API. """ + import collections.abc import dataclasses import datetime @@ -33,9 +34,7 @@ from collections.abc import Mapping from dataclasses import dataclass from functools import reduce -from typing import Dict -from typing import Optional -from typing import Union +from typing import Dict, Optional, Union import numpy as np @@ -260,9 +259,7 @@ def __eq__(self, other): and self.metadata == other.metadata and ( self.time == other.time - or ( - util.is_unknown_time(self.time) and util.is_unknown_time(other.time) - ) + or (util.is_unknown_time(self.time) and util.is_unknown_time(other.time)) ) ) @@ -346,7 +343,7 @@ class BaseTable: """ # The list of columns in the table. Must be set by subclasses. - column_names = [] + column_names = tuple() def __init__(self, ll_table, row_class): self.ll_table = ll_table @@ -691,8 +688,7 @@ def _repr_html_(self): def _columns_all_integer(self, *colnames): # For displaying floating point values without loads of decimal places return all( - np.all(getattr(self, col) == np.floor(getattr(self, col))) - for col in colnames + np.all(getattr(self, col) == np.floor(getattr(self, col))) for col in colnames ) @@ -843,7 +839,7 @@ class IndividualTable(MetadataTable): :vartype metadata_schema: tskit.MetadataSchema """ - column_names = [ + column_names = ( "flags", "location", "location_offset", @@ -851,7 +847,7 @@ class IndividualTable(MetadataTable): "parents_offset", "metadata", "metadata_offset", - ] + ) def __init__(self, max_rows_increment=0, ll_table=None): if ll_table is None: @@ -870,13 +866,9 @@ def _text_header_and_rows(self, limit=None): location_str = ", ".join(map(str, row.location)) parents_str = ", ".join(map(str, row.parents)) rows.append( - "{}\t{}\t{}\t{}\t{}".format( - j, - row.flags, - location_str, - parents_str, - util.render_metadata(row.metadata), - ).split("\t") + f"{j}\t{row.flags}\t{location_str}\t{parents_str}\t{util.render_metadata(row.metadata)}".split( + "\t" + ) ) return headers, rows @@ -1117,14 +1109,14 @@ class NodeTable(MetadataTable): :vartype metadata_schema: tskit.MetadataSchema """ - column_names = [ + column_names = ( "time", "flags", "population", "individual", "metadata", "metadata_offset", - ] + ) def __init__(self, max_rows_increment=0, ll_table=None): if ll_table is None: @@ -1311,14 +1303,14 @@ class EdgeTable(MetadataTable): :vartype metadata_schema: tskit.MetadataSchema """ - column_names = [ + column_names = ( "left", "right", "parent", "child", "metadata", "metadata_offset", - ] + ) def __init__(self, max_rows_increment=0, ll_table=None): if ll_table is None: @@ -1524,7 +1516,7 @@ class MigrationTable(MetadataTable): :vartype metadata_schema: tskit.MetadataSchema """ - column_names = [ + column_names = ( "left", "right", "node", @@ -1533,7 +1525,7 @@ class MigrationTable(MetadataTable): "time", "metadata", "metadata_offset", - ] + ) def __init__(self, max_rows_increment=0, ll_table=None): if ll_table is None: @@ -1740,13 +1732,13 @@ class SiteTable(MetadataTable): :vartype metadata_schema: tskit.MetadataSchema """ - column_names = [ + column_names = ( "position", "ancestral_state", "ancestral_state_offset", "metadata", "metadata_offset", - ] + ) def __init__(self, max_rows_increment=0, ll_table=None): if ll_table is None: @@ -1953,7 +1945,7 @@ class MutationTable(MetadataTable): :vartype metadata_schema: tskit.MetadataSchema """ - column_names = [ + column_names = ( "site", "node", "time", @@ -1962,7 +1954,7 @@ class MutationTable(MetadataTable): "parent", "metadata", "metadata_offset", - ] + ) def __init__(self, max_rows_increment=0, ll_table=None): if ll_table is None: @@ -2225,7 +2217,7 @@ class PopulationTable(MetadataTable): :vartype metadata_schema: tskit.MetadataSchema """ - column_names = ["metadata", "metadata_offset"] + column_names = ("metadata", "metadata_offset") def __init__(self, max_rows_increment=0, ll_table=None): if ll_table is None: @@ -2340,7 +2332,7 @@ class ProvenanceTable(BaseTable): :vartype timestamp_offset: numpy.ndarray, dtype=np.uint32 """ - column_names = ["record", "record_offset", "timestamp", "timestamp_offset"] + column_names = ("record", "record_offset", "timestamp", "timestamp_offset") def __init__(self, max_rows_increment=0, ll_table=None): if ll_table is None: @@ -2361,9 +2353,7 @@ def equals(self, other, ignore_timestamps=False): ret = False if type(other) is type(self): ret = bool( - self.ll_table.equals( - other.ll_table, ignore_timestamps=ignore_timestamps - ) + self.ll_table.equals(other.ll_table, ignore_timestamps=ignore_timestamps) ) return ret @@ -2610,7 +2600,7 @@ def __str__(self): ) def __repr__(self): - return f"IdentitySegmentList({repr(list(self))})" + return f"IdentitySegmentList({list(self)!r})" def __eq__(self, other): if not isinstance(other, IdentitySegmentList): @@ -2787,7 +2777,7 @@ def clear(self): # FIXME This is a shortcut, we want to put the values in explicitly # here to get more control over how they are displayed. def __repr__(self): - return f"ReferenceSequence({repr(self.asdict())})" + return f"ReferenceSequence({self.asdict()!r})" @property def data(self) -> str: @@ -3254,8 +3244,7 @@ def assert_equals( if self.time_units != other.time_units: raise AssertionError( - f"Time units differs: self={self.time_units} " - f"other={other.time_units}" + f"Time units differs: self={self.time_units} " f"other={other.time_units}" ) if self.sequence_length != other.sequence_length: @@ -3322,7 +3311,7 @@ def __setstate__(self, state): self._ll_tables.fromdict(state) @classmethod - def fromdict(self, tables_dict): + def fromdict(cls, tables_dict): ll_tc = _tskit.TableCollection() ll_tc.fromdict(tables_dict) return TableCollection(ll_tables=ll_tc) @@ -3667,9 +3656,7 @@ def canonicalise(self, remove_unreferenced=None): :param bool remove_unreferenced: Whether to remove unreferenced sites, individuals, and populations (default=True). """ - remove_unreferenced = ( - True if remove_unreferenced is None else remove_unreferenced - ) + remove_unreferenced = True if remove_unreferenced is None else remove_unreferenced self._ll_tables.canonicalise(remove_unreferenced=remove_unreferenced) # TODO add provenance @@ -3841,9 +3828,7 @@ def keep_intervals(self, intervals, simplify=True, record_provenance=True): self.sites.position >= s, self.sites.position < e ) keep_sites = np.logical_or(keep_sites, curr_keep_sites) - keep_edges = np.logical_not( - np.logical_or(edges.right <= s, edges.left >= e) - ) + keep_edges = np.logical_not(np.logical_or(edges.right <= s, edges.left >= e)) metadata, metadata_offset = keep_with_offset( keep_edges, edges.metadata, edges.metadata_offset ) @@ -4093,12 +4078,8 @@ def subset( that are not referred to by any retained entries in the tables should be removed (default: True). See the description for details. """ - reorder_populations = ( - True if reorder_populations is None else reorder_populations - ) - remove_unreferenced = ( - True if remove_unreferenced is None else remove_unreferenced - ) + reorder_populations = True if reorder_populations is None else reorder_populations + remove_unreferenced = True if remove_unreferenced is None else remove_unreferenced nodes = util.safe_np_int_cast(nodes, np.int32) self._ll_tables.subset( nodes, diff --git a/python/tskit/text_formats.py b/python/tskit/text_formats.py index 52a42483ec..78f5536a38 100644 --- a/python/tskit/text_formats.py +++ b/python/tskit/text_formats.py @@ -22,6 +22,7 @@ """ Module responsible for working with text format data. """ + import base64 import numpy as np @@ -73,7 +74,7 @@ def parse_fam(fam_file): ) for plink_fid, plink_iid, pat, mat, sex in individuals: sex = int(sex) - if not (sex in range(3)): + if sex not in range(3): raise ValueError( "Sex must be one of the following: 0 (unknown), 1 (male), 2 (female)" ) @@ -382,13 +383,11 @@ def dump_text( location = ",".join(map(str, individual.location)) parents = ",".join(map(str, individual.parents)) row = ( - "{id}\t" "{flags}\t" "{location}\t" "{parents}\t" "{metadata}" - ).format( - id=individual.id, - flags=individual.flags, - location=location, - parents=parents, - metadata=metadata, + f"{individual.id}\t" + f"{individual.flags}\t" + f"{location}\t" + f"{parents}\t" + f"{metadata}" ) print(row, file=individuals) @@ -396,7 +395,7 @@ def dump_text( print("id", "metadata", sep="\t", file=populations) for population in ts.populations(): metadata = text_metadata(base64_metadata, encoding, population) - row = ("{id}\t" "{metadata}").format(id=population.id, metadata=metadata) + row = f"{population.id}\t" f"{metadata}" print(row, file=populations) if migrations is not None: @@ -414,31 +413,21 @@ def dump_text( for migration in ts.migrations(): metadata = text_metadata(base64_metadata, encoding, migration) row = ( - "{left}\t" - "{right}\t" - "{node}\t" - "{source}\t" - "{dest}\t" - "{time}\t" - "{metadata}\t" - ).format( - left=migration.left, - right=migration.right, - node=migration.node, - source=migration.source, - dest=migration.dest, - time=migration.time, - metadata=metadata, + f"{migration.left}\t" + f"{migration.right}\t" + f"{migration.node}\t" + f"{migration.source}\t" + f"{migration.dest}\t" + f"{migration.time}\t" + f"{metadata}\t" ) print(row, file=migrations) if provenances is not None: print("id", "timestamp", "record", sep="\t", file=provenances) for provenance in ts.provenances(): - row = ("{id}\t" "{timestamp}\t" "{record}\t").format( - id=provenance.id, - timestamp=provenance.timestamp, - record=provenance.record, + row = ( + f"{provenance.id}\t" f"{provenance.timestamp}\t" f"{provenance.record}\t" ) print(row, file=provenances) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index b452ebf844..e22606cc2a 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -23,6 +23,7 @@ """ Module responsible for managing trees and tree sequences. """ + from __future__ import annotations import base64 @@ -36,8 +37,7 @@ import numbers import warnings from dataclasses import dataclass -from typing import Any -from typing import NamedTuple +from typing import Any, NamedTuple import numpy as np @@ -50,9 +50,7 @@ import tskit.text_formats as text_formats import tskit.util as util import tskit.vcf as vcf -from tskit import NODE_IS_SAMPLE -from tskit import NULL -from tskit import UNKNOWN_TIME +from tskit import NODE_IS_SAMPLE, NULL, UNKNOWN_TIME LEGACY_MS_LABELS = "legacy_ms" @@ -131,7 +129,7 @@ class Individual(util.Dataclass): "metadata", "_tree_sequence", ] - id: int # noqa A003 + id: int # A003 """ The integer ID of this individual. Varies from 0 to :attr:`TreeSequence.num_individuals` - 1.""" @@ -205,7 +203,7 @@ class Node(util.Dataclass): """ __slots__ = ["id", "flags", "time", "population", "individual", "metadata"] - id: int # noqa A003 + id: int # A003 """ The integer ID of this node. Varies from 0 to :attr:`TreeSequence.num_nodes` - 1. """ @@ -277,7 +275,7 @@ class Edge(util.Dataclass): The :ref:`metadata ` for this edge, decoded if a schema applies. """ - id: int # noqa A003 + id: int # A003 """ The integer ID of this edge. Varies from 0 to :attr:`TreeSequence.num_edges` - 1. @@ -324,7 +322,7 @@ class Site(util.Dataclass): """ __slots__ = ["id", "position", "ancestral_state", "mutations", "metadata"] - id: int # noqa A003 + id: int # A003 """ The integer ID of this site. Varies from 0 to :attr:`TreeSequence.num_sites` - 1. """ @@ -397,7 +395,7 @@ class Mutation(util.Dataclass): "time", "edge", ] - id: int # noqa A003 + id: int # A003 """ The integer ID of this mutation. Varies from 0 to :attr:`TreeSequence.num_mutations` - 1. @@ -479,9 +477,7 @@ def __eq__(self, other): and self.metadata == other.metadata and ( self.time == other.time - or ( - util.is_unknown_time(self.time) and util.is_unknown_time(other.time) - ) + or (util.is_unknown_time(self.time) and util.is_unknown_time(other.time)) ) ) @@ -530,7 +526,7 @@ class Migration(util.Dataclass): The :ref:`metadata ` for this migration, decoded if a schema applies. """ - id: int # noqa A003 + id: int # A003 """ The integer ID of this mutation. Varies from 0 to :attr:`TreeSequence.num_mutations` - 1. @@ -548,7 +544,7 @@ class Population(util.Dataclass): """ __slots__ = ["id", "metadata"] - id: int # noqa A003 + id: int # A003 """ The integer ID of this population. Varies from 0 to :attr:`TreeSequence.num_populations` - 1. @@ -587,7 +583,7 @@ class Provenance(util.Dataclass): """ __slots__ = ["id", "timestamp", "record"] - id: int # noqa A003 + id: int # A003 timestamp: str """ The time that this entry was made @@ -657,8 +653,7 @@ def __init__( options = 0 if sample_counts is not None: warnings.warn( - "The sample_counts option is not supported since 0.2.4 " - "and is ignored", + "The sample_counts option is not supported since 0.2.4 " "and is ignored", RuntimeWarning, stacklevel=4, ) @@ -758,7 +753,7 @@ def last(self): """ self._ll_tree.last() - def next(self): # noqa A002 + def next(self): # A002 """ Seeks to the next tree in the sequence. If the tree is in the initial null state we seek to the first tree (equivalent to calling :meth:`~Tree.first`). @@ -1624,9 +1619,7 @@ def is_root(self, u) -> bool: :param int u: The node of interest. :return: ``True`` if u is a root. """ - return ( - self.num_samples(u) >= self.root_threshold and self.parent(u) == tskit.NULL - ) + return self.num_samples(u) >= self.root_threshold and self.parent(u) == tskit.NULL def get_index(self): # Deprecated alias for self.index @@ -2534,7 +2527,7 @@ def nodes(self, root=None, order="preorder"): try: iterator = methods[order] except KeyError: - raise ValueError(f"Traversal ordering '{order}' not supported") + raise ValueError(f"Traversal ordering '{order}' not supported") from None root = -1 if root is None else root return iterator(root) @@ -3077,7 +3070,7 @@ def split_polytomies( return a :class:`Tree` created with ``sample_lists=True``. :return: A new tree with polytomies split into random bifurcations. :rtype: tskit.Tree - """ + """ # noqa: RUF002 return combinatorics.split_polytomies( self, epsilon=epsilon, @@ -3259,7 +3252,7 @@ def generate_random_binary( :return: A random binary tree. Its corresponding :class:`TreeSequence` is available via the :attr:`.tree_sequence` attribute. :rtype: Tree - """ + """ # noqa: RUF002 return combinatorics.generate_random_binary( num_leaves, span=span, @@ -4666,10 +4659,7 @@ def edgesets(self): children[edge.parent].add(edge.child) # Update the active edgesets for edge in itertools.chain(edges_out, edges_in): - if ( - len(children[edge.parent]) > 0 - and edge.parent not in active_edgesets - ): + if len(children[edge.parent]) > 0 and edge.parent not in active_edgesets: active_edgesets[edge.parent] = Edgeset(left, right, edge.parent, []) for parent in active_edgesets.keys(): @@ -5124,25 +5114,20 @@ def _haplotypes_array( if allele is not None: if len(allele) != 1: raise TypeError( - "Multi-letter allele or deletion detected at site {}".format( - var.site.id - ) + f"Multi-letter allele or deletion detected at" + f"site {var.site.id}" ) try: ascii_allele = allele.encode("ascii") except UnicodeEncodeError: raise TypeError( - "Non-ascii character in allele at site {}".format( - var.site.id - ) - ) + f"Non-ascii character in allele at site {var.site.id}" + ) from None allele_int8 = ord(ascii_allele) if allele_int8 == missing_int8: raise ValueError( - "The missing data character '{}' clashes with an " - "existing allele at site {}".format( - missing_data_character, var.site.id - ) + f"The missing data character '{missing_data_character}" + f"clashes with an existing allele at site {var.site.id}" ) alleles[i] = allele_int8 H[:, var.site.id - start_site] = alleles[var.genotypes] @@ -5592,9 +5577,7 @@ def alignments( missing_data_character=missing_data_character, samples=samples, ) - site_pos = self.sites_position.astype(np.int64)[ - first_site_id : last_site_id + 1 - ] + site_pos = self.sites_position.astype(np.int64)[first_site_id : last_site_id + 1] for h in H: a[site_pos - interval.left] = h yield a.tobytes().decode("ascii") @@ -7470,7 +7453,7 @@ def f(x): window (defaults to True). :param bool strict: Whether to check that f(0) and f(total weight) are zero. :return: A ndarray with shape equal to (num windows, num statistics). - """ # noqa: B950 + """ # noqa: E501 # helper function for common case where weights are indicators of sample sets for U in sample_sets: if len(U) != len(set(U)): @@ -7657,8 +7640,7 @@ def __k_way_sample_set_stat( drop_based_on_index = True if len(sample_sets) != k: raise ValueError( - "Must specify indexes if there are not exactly {} sample " - "sets.".format(k) + f"Must specify indexes if there are not exactly {k} sample " "sets." ) indexes = np.arange(k, dtype=np.int32) drop_dimension = False @@ -7668,8 +7650,7 @@ def __k_way_sample_set_stat( drop_dimension = True if len(indexes.shape) != 2 or indexes.shape[1] != k: raise ValueError( - "Indexes must be convertable to a 2D numpy array with {} " - "columns".format(k) + f"Indexes must be convertable to a 2D numpy array with {k} " "columns" ) stat = self.__run_windowed_stat( windows, @@ -7702,8 +7683,7 @@ def __k_way_weighted_stat( if indexes is None: if W.shape[1] != k: raise ValueError( - "Must specify indexes if there are not exactly {} columns " - "in W.".format(k) + f"Must specify indexes if there are not exactly {k} columns " "in W." ) indexes = np.arange(k, dtype=np.int32) drop_dimension = False @@ -7713,8 +7693,7 @@ def __k_way_weighted_stat( drop_dimension = True if len(indexes.shape) != 2 or indexes.shape[1] != k: raise ValueError( - "Indexes must be convertable to a 2D numpy array with {} " - "columns".format(k) + f"Indexes must be convertable to a 2D numpy array with {k} " "columns" ) stat = self.__run_windowed_stat( windows, @@ -7733,9 +7712,7 @@ def __k_way_weighted_stat( # Statistics definitions ############################################ - def diversity( - self, sample_sets=None, windows=None, mode="site", span_normalise=True - ): + def diversity(self, sample_sets=None, windows=None, mode="site", span_normalise=True): """ Computes mean genetic diversity (also known as "pi") in each of the sets of nodes from ``sample_sets``. The statistic is also known as @@ -8265,9 +8242,7 @@ def genetic_relatedness_weighted( :return: A ndarray with shape equal to (num windows, num statistics). """ if len(W) != self.num_samples: - raise ValueError( - "First trait dimension must be equal to number of samples." - ) + raise ValueError("First trait dimension must be equal to number of samples.") return self.__k_way_weighted_stat( self._ll_tree_sequence.genetic_relatedness_weighted, 2, @@ -8332,9 +8307,7 @@ def trait_covariance(self, W, windows=None, mode="site", span_normalise=True): If windows=None and W is a single column, a numpy scalar is returned. """ if W.shape[0] != self.num_samples: - raise ValueError( - "First trait dimension must be equal to number of samples." - ) + raise ValueError("First trait dimension must be equal to number of samples.") return self.__run_windowed_stat( windows, self._ll_tree_sequence.trait_covariance, @@ -8398,9 +8371,7 @@ def trait_correlation(self, W, windows=None, mode="site", span_normalise=True): If windows=None and W is a single column, a numpy scalar is returned. """ if W.shape[0] != self.num_samples: - raise ValueError( - "First trait dimension must be equal to number of samples." - ) + raise ValueError("First trait dimension must be equal to number of samples.") sds = np.std(W, axis=0) if np.any(sds == 0): raise ValueError( @@ -8486,9 +8457,7 @@ def trait_linear_model( If windows=None and W is a single column, a numpy scalar is returned. """ if W.shape[0] != self.num_samples: - raise ValueError( - "First trait dimension must be equal to number of samples." - ) + raise ValueError("First trait dimension must be equal to number of samples.") if Z is None: Z = np.ones((self.num_samples, 1)) else: @@ -8717,11 +8686,7 @@ def tjd_func(sample_set_sizes, flattened, **kwargs): g = np.array([np.sum(1 / np.arange(1, nn) ** 2) for nn in n]) with np.errstate(invalid="ignore", divide="ignore"): a = (n + 1) / (3 * (n - 1) * h) - 1 / h**2 - b = ( - 2 * (n**2 + n + 3) / (9 * n * (n - 1)) - - (n + 2) / (h * n) - + g / h**2 - ) + b = 2 * (n**2 + n + 3) / (9 * n * (n - 1)) - (n + 2) / (h * n) + g / h**2 D = (T - S / h) / np.sqrt(a * S + (b / (h**2 + g)) * S * (S - 1)) return D @@ -8794,9 +8759,7 @@ def fst_func(sample_set_sizes, flattened, indexes, **kwargs): fst.shape = divergences.shape for i, (u, v) in enumerate(indexes): denom = ( - diversities[:, :, u] - + diversities[:, :, v] - + 2 * divergences[:, :, i] + diversities[:, :, u] + diversities[:, :, v] + 2 * divergences[:, :, i] ) with np.errstate(divide="ignore", invalid="ignore"): fst[:, :, i] -= ( @@ -9221,9 +9184,7 @@ def count_topologies(self, sample_sets=None): internal samples. """ if sample_sets is None: - sample_sets = [ - self.samples(population=pop.id) for pop in self.populations() - ] + sample_sets = [self.samples(population=pop.id) for pop in self.populations()] yield from combinatorics.treeseq_count_topologies(self, sample_sets) @@ -9544,7 +9505,7 @@ def ld_matrix(self, sample_sets=None, sites=None, mode="site", stat="r2"): except KeyError: raise ValueError( f"Unknown two-locus statistic '{stat}', we support: {list(stats.keys())}" - ) + ) from None return self.__two_locus_sample_set_stat( two_locus_stat, diff --git a/python/tskit/util.py b/python/tskit/util.py index 7dbb8a138d..5ff1dd0320 100644 --- a/python/tskit/util.py +++ b/python/tskit/util.py @@ -22,6 +22,7 @@ """ Module responsible for various utility functions used in other modules. """ + import dataclasses import io import itertools @@ -115,10 +116,10 @@ def safe_np_int_cast(int_array, dtype, copy=False): except TypeError: if int_array.dtype == np.dtype("O"): # this occurs e.g. if we're passed a list of lists of different lengths - raise TypeError("Cannot convert to a rectangular array.") + raise TypeError("Cannot convert to a rectangular array.") from None bounds = np.iinfo(dtype) if np.any(int_array < bounds.min) or np.any(int_array > bounds.max): - raise OverflowError(f"Cannot convert safely to {dtype} type") + raise OverflowError(f"Cannot convert safely to {dtype} type") from None if int_array.dtype.kind == "i" and np.dtype(dtype).kind == "u": # Allow casting from int to unsigned int, since we have checked bounds casting = "unsafe" @@ -336,7 +337,7 @@ def obj_to_collapsed_html(d, name=None, open_depth=0): opened = "open" if open_depth > 0 else "" open_depth -= 1 name = str(name) + ":" if name is not None else "" - if type(d) is dict: + if isinstance(d, dict): return f"""
{name} @@ -347,7 +348,7 @@ def obj_to_collapsed_html(d, name=None, open_depth=0):
""" - elif type(d) is list: + elif isinstance(d, list): return f"""
{name} @@ -401,7 +402,7 @@ def unicode_table( :rtype: str """ if header is not None: - all_rows = [header] + rows + all_rows = [header, *rows] else: all_rows = rows widths = [ @@ -554,7 +555,7 @@ def tree_sequence_html(ts):
- """ # noqa: B950 + """ # noqa: E501 def tree_html(tree): @@ -565,7 +566,8 @@ def tree_html(tree): .tskit-table tbody tr td {{padding: 0.5em 0.5em;}} .tskit-table tbody tr td:first-of-type {{text-align: left;}} .tskit-details-label {{vertical-align: top; padding-right:5px;}} - .tskit-table-set {{display: inline-flex;flex-wrap: wrap;margin: -12px 0 0 -12px;width: calc(100% + 12px);}} + .tskit-table-set {{display: inline-flex;flex-wrap: wrap; + margin: -12px 0 0 -12px;width: calc(100% + 12px);}} .tskit-table-set-table {{margin: 12px 0 0 12px;}} details {{display: inline-block;}} summary {{cursor: pointer; outline: 0; display: list-item;}} @@ -594,7 +596,7 @@ def tree_html(tree): - """ # noqa: B950 + """ # noqa: E501 def variant_html(variant): @@ -614,7 +616,8 @@ def variant_html(variant): .tskit-table tbody tr td {{padding: 0.5em 0.5em;}} .tskit-table tbody tr td:first-of-type {{text-align: left;}} .tskit-details-label {{vertical-align: top; padding-right:5px;}} - .tskit-table-set {{display: inline-flex;flex-wrap: wrap;margin: -12px 0 0 -12px;width: calc(100% + 12px);}} + .tskit-table-set {{display: inline-flex;flex-wrap: wrap; + margin: -12px 0 0 -12px;width: calc(100% + 12px);}} .tskit-table-set-table {{margin: 12px 0 0 12px;}} details {{display: inline-block;}} summary {{cursor: pointer; outline: 0; display: list-item;}} @@ -625,13 +628,16 @@ def variant_html(variant): - - {class_type} + + + {class_type} + - """ # noqa: B950 + """ html_body_tail = """ @@ -642,8 +648,6 @@ def variant_html(variant): """ try: - variant.site - site_id = variant.site.id site_position = variant.site.position num_samples = len(variant.samples) @@ -683,7 +687,7 @@ def variant_html(variant): return ( html_body_head + f""" - Error{str(err)} + Error{err!s} """ + html_body_tail ) @@ -772,7 +776,7 @@ def raise_known_file_format_errors(open_file, existing_exception): header = open_file.read(4) except io.UnsupportedOperation: # If we can't seek, we can't sniff the file. - raise existing_exception + raise existing_exception from None if header == b"\x89HDF": raise tskit.FileFormatError( "The specified file appears to be in HDF5 format. This file " diff --git a/python/tskit/vcf.py b/python/tskit/vcf.py index dcbea01cb6..53adb42ebe 100644 --- a/python/tskit/vcf.py +++ b/python/tskit/vcf.py @@ -23,9 +23,11 @@ """ Convert tree sequences to VCF. """ + import numpy as np import tskit + from . import provenance @@ -87,9 +89,7 @@ def __init__( position_transform(tree_sequence.tables.sites.position), dtype=int ) if self.transformed_positions.shape != (tree_sequence.num_sites,): - raise ValueError( - "Position transform must return an array of the same length" - ) + raise ValueError("Position transform must return an array of the same length") self.contig_length = max( 1, int(position_transform([tree_sequence.sequence_length])[0]) ) @@ -193,12 +193,8 @@ def __write_header(self, output): print("##fileformat=VCFv4.2", file=output) print(f"##source=tskit {provenance.__version__}", file=output) print('##FILTER=', file=output) - print( - f"##contig=", file=output - ) - print( - '##FORMAT=', file=output - ) + print(f"##contig=", file=output) + print('##FORMAT=', file=output) vcf_samples = "\t".join(self.individual_names) print( "#CHROM",