Skip to content

Commit

Permalink
Merge pull request #179 from GavinHuttley/develop
Browse files Browse the repository at this point in the history
ENH: migrate to cogent3 new types
  • Loading branch information
GavinHuttley authored Jan 17, 2025
2 parents c0a70af + 3818b4a commit eff3f7a
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 179 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ license = { file = "LICENSE" }
requires-python = ">=3.10,<3.13"
dependencies = ["blosc2",
"click",
"cogent3",
"cogent3>=2024.12.19a2",
"duckdb",
"h5py",
"hdf5plugin",
Expand Down
51 changes: 33 additions & 18 deletions src/ensembl_tui/_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
from collections import defaultdict
from dataclasses import dataclass

import cogent3
import h5py
import numpy
import typing_extensions
from cogent3.app.composable import define_app
from cogent3.core.alignment import Aligned, Alignment
from cogent3.core import new_alignment as c3_align
from cogent3.core.location import _DEFAULT_GAP_DTYPE, IndelMap

from ensembl_tui import _genome as eti_genome
from ensembl_tui import _storage_mixin as eti_storage
from ensembl_tui import _util as eti_util

DNA = cogent3.get_moltype("dna", new_type=True)

_no_gaps = numpy.array([], dtype=_DEFAULT_GAP_DTYPE)

GAP_STORE_SUFFIX = "indels-hdf5_blosc2"
Expand Down Expand Up @@ -285,21 +288,28 @@ def get_distinct(self, field: str) -> list[str]:
def num_records(self) -> int:
return self.conn.sql(f"SELECT COUNT(*) from {self._tables[0]}").fetchone()[0]

def close(self) -> None:
"""closes duckdb and h5py storage"""
if self.gap_store:
self.gap_store.close()
self.conn.close()


def get_alignment(
align_db: AlignDb,
genomes: dict,
genomes: dict[str, eti_genome.Genome],
ref_species: str,
seqid: str,
ref_start: int | None = None,
ref_end: int | None = None,
namer: typing.Callable | None = None,
mask_features: list[str] | None = None,
) -> typing.Iterable[Alignment]:
) -> typing.Iterable[c3_align.Alignment]:
"""yields cogent3 Alignments"""

if ref_species not in genomes:
raise ValueError(f"unknown species {ref_species!r}")
msg = f"unknown species {ref_species!r}"
raise ValueError(msg)

align_records = align_db.get_records_matching(
species=ref_species,
Expand All @@ -322,7 +332,7 @@ def get_alignment(
genome_start = align_record.start
genome_end = align_record.stop
gap_pos, gap_lengths = align_record.gap_data
gaps = IndelMap(
imap = IndelMap(
gap_pos=gap_pos,
gap_lengths=gap_lengths,
parent_length=genome_end - genome_start,
Expand All @@ -346,14 +356,15 @@ def get_alignment(
seq_start = seq_start - genome_start
seq_end = seq_end - genome_start

align_start = gaps.get_align_index(seq_start)
align_end = gaps.get_align_index(seq_end)
align_start = imap.get_align_index(seq_start)
align_end = imap.get_align_index(seq_end)
break
else:
msg = f"no matching alignment record for {ref_species!r}"
raise ValueError(msg)

seqs = {}
gaps = {}
for align_record in block:
record_species = align_record.species
genome = genomes[record_species]
Expand All @@ -362,20 +373,20 @@ def get_alignment(
genome_start = align_record.start
genome_end = align_record.stop
gap_pos, gap_lengths = align_record.gap_data
gaps = IndelMap(
imap = IndelMap(
gap_pos=gap_pos,
gap_lengths=gap_lengths,
parent_length=genome_end - genome_start,
)

# We use the alignment indices derived for the reference sequence
# above
seq_start = gaps.get_seq_index(align_start)
seq_end = gaps.get_seq_index(align_end)
seq_start = imap.get_seq_index(align_start)
seq_end = imap.get_seq_index(align_end)
seq_length = seq_end - seq_start
if align_record.strand == "-":
# if it's neg strand, the alignment start is the genome stop
seq_start = gaps.parent_length - seq_end
seq_start = imap.parent_length - seq_end

s = genome.get_seq(
seqid=align_record.seqid,
Expand All @@ -385,7 +396,7 @@ def get_alignment(
with_annotations=False,
)
# we now trim the gaps for this sequence to the sub-alignment
gaps = gaps[align_start:align_end]
imap = imap[align_start:align_end]

if align_record.strand == "-":
s = s.rc()
Expand All @@ -394,13 +405,17 @@ def get_alignment(
strand_symbol = -1 if align_record.strand == "-" else 1
s.name = f"{s.name}:{strand_symbol}"

aligned = Aligned(gaps, s)
if aligned.name not in seqs:
seqs[aligned.name] = aligned
elif str(aligned) == str(seqs[aligned.name]):
if s.name in seqs:
print(f"duplicated {s.name}")
seqs[s.name] = numpy.array(s)
gaps[s.name] = imap.array

aln = Alignment(list(seqs.values()))
aln_data = c3_align.AlignedSeqsData.from_seqs_and_gaps(
seqs=seqs,
gaps=gaps,
alphabet=DNA.most_degen_alphabet(),
)
aln = c3_align.Alignment(seqs_data=aln_data, moltype=DNA)
aln.annotation_db = genome.annotation_db
if mask_features:
aln = aln.with_masked_annotations(biotypes=mask_features)
Expand Down Expand Up @@ -428,7 +443,7 @@ def __init__(
self._mask_features = mask_features
self._sep = sep

def main(self, segment: eti_genome.genome_segment) -> list[Alignment]:
def main(self, segment: eti_genome.genome_segment) -> list[c3_align.Alignment]:
results = []
for aln in get_alignment(
self._align_db,
Expand Down
60 changes: 0 additions & 60 deletions src/ensembl_tui/_faster_fasta.py

This file was deleted.

81 changes: 37 additions & 44 deletions src/ensembl_tui/_genome.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
import dataclasses
import functools
import pathlib
import re
import sys
import typing
import uuid
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any

import cogent3
import h5py
import numpy
from cogent3 import get_moltype, make_seq
from cogent3.app.composable import define_app
from cogent3.core import new_alphabet
from cogent3.core.annotation import Feature
from cogent3.core.annotation_db import (
OptionalInt,
OptionalStr,
)
from cogent3.core.sequence import Sequence
from cogent3.core.new_sequence import Sequence
from cogent3.parse.fasta import iter_fasta_records
from cogent3.util.table import Table
from numpy.typing import NDArray

Expand All @@ -26,19 +28,16 @@
from ensembl_tui import _species as eti_species
from ensembl_tui import _storage_mixin as eti_storage
from ensembl_tui import _util as eti_util
from ensembl_tui._faster_fasta import quicka_parser

SEQ_STORE_NAME = "genome.seqs-hdf5_blosc2"

_typed_id = re.compile(
r"\b[a-z]+:",
flags=re.IGNORECASE,
) # ensembl stableid's prefixed by the type
_feature_id = re.compile(r"(?<=\bID=)[^;]+")
_exon_id = re.compile(r"(?<=\bexon_id=)[^;]+")
_parent_id = re.compile(r"(?<=\bParent=)[^;]+")
_symbol = re.compile(r"(?<=\bName=)[^;]+")
_description = re.compile(r"(?<=\bdescription=)[^;]+")
DNA = cogent3.get_moltype("dna", new_type=True)
alphabet = DNA.most_degen_alphabet()
bytes_to_array = new_alphabet.bytes_to_array(
chars=alphabet.as_bytes(),
dtype=numpy.uint8,
delete=b" \n\r\t",
)


def _rename(label: str) -> str:
Expand All @@ -47,7 +46,11 @@ def _rename(label: str) -> str:

@define_app
class fasta_to_hdf5: # noqa: N801
def __init__(self, config: eti_config.Config, label_to_name=_rename) -> None:
def __init__(
self,
config: eti_config.Config,
label_to_name: Callable[[str], str] = _rename,
) -> None:
self.config = config
self.label_to_name = label_to_name

Expand All @@ -63,8 +66,11 @@ def main(self, db_name: str) -> bool:

src_dir = src_dir / "fasta"
for path in src_dir.glob("*.fa.gz"):
for label, seq in quicka_parser(path):
seqid = self.label_to_name(label)
for seqid, seq in iter_fasta_records(
path,
converter=bytes_to_array,
label_to_name=self.label_to_name,
):
seq_store.add_record(seq, seqid)
del seq

Expand Down Expand Up @@ -125,43 +131,30 @@ class str2arr: # noqa: N801
"""convert string to array of uint8"""

def __init__(self, moltype: str = "dna", max_length: int | None = None) -> None:
moltype = get_moltype(moltype)
self.canonical = "".join(moltype)
mt = cogent3.get_moltype(moltype, new_type=True)
self.alphabet = mt.most_degen_alphabet()
self.max_length = max_length
extended = "".join(list(moltype.alphabets.degen))
self.translation = b"".maketrans(
extended.encode("utf8"),
"".join(chr(i) for i in range(len(extended))).encode("utf8"),
)

def main(self, data: str) -> numpy.ndarray:
if self.max_length:
data = data[: self.max_length]

b = data.encode("utf8").translate(self.translation)
return numpy.array(memoryview(b), dtype=numpy.uint8)
return self.alphabet.to_indices(data)


@define_app
class arr2str: # noqa: N801
"""convert array of uint8 to str"""

def __init__(self, moltype: str = "dna", max_length: int | None = None) -> None:
moltype = get_moltype(moltype)
self.canonical = "".join(moltype)
mt = cogent3.get_moltype(moltype, new_type=True)
self.alphabet = mt.most_degen_alphabet()
self.max_length = max_length
extended = "".join(list(moltype.alphabets.degen))
self.translation = b"".maketrans(
"".join(chr(i) for i in range(len(extended))).encode("utf8"),
extended.encode("utf8"),
)

def main(self, data: numpy.ndarray) -> str:
if self.max_length:
data = data[: self.max_length]

b = data.tobytes().translate(self.translation)
return bytearray(b).decode("utf8")
return self.alphabet.from_indices(data)


@dataclasses.dataclass
Expand Down Expand Up @@ -326,7 +319,7 @@ def get_seq(
stop: int | None = None,
namer: typing.Callable | None = None,
with_annotations: bool = True,
) -> str:
) -> Sequence:
"""returns annotated sequence
Parameters
Expand All @@ -350,20 +343,20 @@ def get_seq(
-----
Full annotations are bound to the instance.
"""
seq = self._seqs.get_seq_str(seqid=seqid, start=start, stop=stop)
seq = self._seqs.get_seq_arr(seqid=seqid, start=start, stop=stop)
if namer:
name = namer(self.species, seqid, start, stop)
else:
name = f"{self.species}:{seqid}:{start}-{stop}"
# we use seqid to make the sequence here because that identifies the
# parent seq identity, required for querying annotations
try:
seq = make_seq(seq, name=seqid, moltype="dna", annotation_offset=start or 0)
except TypeError:
# older version of cogent3
seq = make_seq(seq, name=seqid, moltype="dna")
seq.annotation_offset = start or 0

seq = cogent3.make_seq(
seq,
name=seqid,
moltype="dna",
annotation_offset=start or 0,
new_type=True,
)
seq.name = name
seq.annotation_db = self.annotation_db if with_annotations else None
return seq
Expand Down
3 changes: 2 additions & 1 deletion src/ensembl_tui/_homology.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,5 +228,6 @@ def main(self, homologs: homolog_group) -> SeqsCollectionType:
return make_unaligned_seqs(
data=seqs,
moltype="dna",
info=dict(source=homologs.source),
info={"source": homologs.source},
new_type=True,
)
Loading

0 comments on commit eff3f7a

Please sign in to comment.