Skip to content

Commit

Permalink
Cigar tools: fix strip() logic
Browse files Browse the repository at this point in the history
  • Loading branch information
Donaim committed Nov 14, 2023
1 parent 5ab93bd commit ea58060
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 41 deletions.
101 changes: 72 additions & 29 deletions micall/tests/test_cigar_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def test_cigar_hit_ref_cut_add_prop_exhaustive(hit, cut_point):
assert left + right == hit


@pytest.mark.parametrize('hit, expected', [
lstrip_cases = [
(CigarHit('9M', r_st=1, r_ei=9, q_st=1, q_ei=9),
CigarHit('9M', r_st=1, r_ei=9, q_st=1, q_ei=9)),

Expand All @@ -340,39 +340,95 @@ def test_cigar_hit_ref_cut_add_prop_exhaustive(hit, cut_point):
(CigarHit('6D5M', r_st=1, r_ei=11, q_st=1, q_ei=5),
CigarHit('5M', r_st=7, r_ei=11, q_st=1, q_ei=5)),

(CigarHit('4I6D5M', r_st=1, r_ei=11, q_st=1, q_ei=9),
CigarHit('5M', r_st=7, r_ei=11, q_st=5, q_ei=9)),
(CigarHit('6D4I5M', r_st=1, r_ei=11, q_st=1, q_ei=9),
CigarHit('5M', r_st=7, r_ei=11, q_st=5, q_ei=9)),
CigarHit('4I5M', r_st=7, r_ei=11, q_st=1, q_ei=9)),

(CigarHit('3D3D4I5M', r_st=1, r_ei=11, q_st=1, q_ei=9),
CigarHit('4I5M', r_st=7, r_ei=11, q_st=1, q_ei=9)),

(CigarHit('3D2I3D2I5M', r_st=1, r_ei=11, q_st=1, q_ei=9),
CigarHit('4I5M', r_st=7, r_ei=11, q_st=1, q_ei=9)),

(CigarHit('4I6D5M', r_st=1, r_ei=11, q_st=1, q_ei=9),
CigarHit('4I5M', r_st=7, r_ei=11, q_st=1, q_ei=9)),

(CigarHit('', r_st=1, r_ei=0, q_st=1, q_ei=0),
CigarHit('', r_st=1, r_ei=0, q_st=1, q_ei=0)),
])
]

@pytest.mark.parametrize('hit, expected', lstrip_cases)
def test_cigar_hit_lstrip(hit, expected):
assert expected == hit.lstrip_query()


@pytest.mark.parametrize('hit, cut_point', [(x[0], x[1]) for x in cigar_hit_ref_cut_cases
if not isinstance(x[2], Exception)])
def test_cigar_hit_strip_combines_with_connect(hit, cut_point):
left, right = hit.cut_reference(cut_point)
rstrip_cases = [
(CigarHit('9M', r_st=1, r_ei=9, q_st=1, q_ei=9),
CigarHit('9M', r_st=1, r_ei=9, q_st=1, q_ei=9)),

(CigarHit('5M6D', r_st=1, r_ei=11, q_st=1, q_ei=5),
CigarHit('5M', r_st=1, r_ei=5, q_st=1, q_ei=5)),

(CigarHit('6D5M', r_st=1, r_ei=11, q_st=1, q_ei=5),
CigarHit('6D5M', r_st=1, r_ei=11, q_st=1, q_ei=5)),

(CigarHit('5M4I6D', r_st=1, r_ei=11, q_st=1, q_ei=9),
CigarHit('5M4I', r_st=1, r_ei=5, q_st=1, q_ei=9)),

(CigarHit('5M4I3D3D', r_st=1, r_ei=11, q_st=1, q_ei=9),
CigarHit('5M4I', r_st=1, r_ei=5, q_st=1, q_ei=9)),

(CigarHit('5M2I3D2I3D', r_st=1, r_ei=11, q_st=1, q_ei=9),
CigarHit('5M4I', r_st=1, r_ei=5, q_st=1, q_ei=9)),

left = left.rstrip_query()
right = right.lstrip_query()
(CigarHit('5M6D4I', r_st=1, r_ei=11, q_st=1, q_ei=9),
CigarHit('5M4I', r_st=1, r_ei=5, q_st=1, q_ei=9)),

(CigarHit('', r_st=1, r_ei=0, q_st=1, q_ei=0),
CigarHit('', r_st=1, r_ei=0, q_st=1, q_ei=0)),
]

@pytest.mark.parametrize('hit, expected', rstrip_cases)
def test_cigar_hit_rstrip(hit, expected):
assert expected == hit.rstrip_query()


strip_prop_cases_all = [x[0] for x in cigar_hit_ref_cut_cases] \
+ [x[0] for x in lstrip_cases] \
+ [x[0] for x in rstrip_cases]


@pytest.mark.parametrize('hit', strip_prop_cases_all)
def test_cigar_hit_strip_combines_with_connect(hit):
for cut_point in range(hit.r_st - 1, hit.r_ei):
left, right = hit.cut_reference(cut_point + hit.epsilon)

left = left.rstrip_query()
right = right.lstrip_query()

assert left.connect(right).coordinate_mapping == hit.coordinate_mapping

assert left.connect(right).coordinate_mapping == hit.coordinate_mapping

@pytest.mark.parametrize('hit', strip_prop_cases_all)
def test_cigar_hit_strip_combines_with_add(hit):
for cut_point in range(hit.r_st - 1, hit.r_ei):
left, right = hit.cut_reference(cut_point + hit.epsilon)

@pytest.mark.parametrize('hit', [x[0] for x in cigar_hit_ref_cut_cases])
left = left.rstrip_query()
right = right.lstrip_query()

if left.touches(right):
assert left + right == hit


@pytest.mark.parametrize('hit', strip_prop_cases_all)
def test_cigar_hit_strip_never_crashes(hit):
hit.rstrip_query().lstrip_query()
hit.lstrip_query().rstrip_query()
hit.lstrip_query().lstrip_query()
hit.rstrip_query().rstrip_query()


@pytest.mark.parametrize('hit', [x[0] for x in cigar_hit_ref_cut_cases])
@pytest.mark.parametrize('hit', strip_prop_cases_all)
def test_cigar_hit_strip_is_idempotent(hit):
h1 = hit.rstrip_query()
assert h1 == h1.rstrip_query() == h1.rstrip_query().rstrip_query()
Expand All @@ -387,25 +443,12 @@ def test_cigar_hit_strip_is_idempotent(hit):
assert h1 == h1.rstrip_query() == h1.lstrip_query()


@pytest.mark.parametrize('hit', [x[0] for x in cigar_hit_ref_cut_cases])
@pytest.mark.parametrize('hit', strip_prop_cases_all)
def test_cigar_hit_strips_are_commutative(hit):
assert hit.rstrip_query().lstrip_query() \
== hit.lstrip_query().rstrip_query()


@pytest.mark.parametrize('hit, cut_point', [(x[0], x[1]) for x in cigar_hit_ref_cut_cases
if not isinstance(x[2], Exception)
and not 'N' in str(x[0].cigar)])
def test_cigar_hit_strip_combines_with_add(hit, cut_point):
left, right = hit.cut_reference(cut_point)

left = left.rstrip_query()
right = right.lstrip_query()

if left.touches(right):
assert left + right == hit


@pytest.mark.parametrize('hit, cut_point', [(x[0], x[1]) for x in cigar_hit_ref_cut_cases
if not isinstance(x[2], Exception)])
def test_cigar_hit_ref_cut_add_associativity(hit, cut_point):
Expand Down
49 changes: 37 additions & 12 deletions micall/utils/cigar_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,37 @@ def slice_operations(self, start_inclusive, end_noninclusive) -> 'Cigar':
[start_inclusive:end_noninclusive])


def lstrip_query(self) -> 'Cigar':
""" Return a copy of the Cigar with leading (unmatched) query elements removed. """

if self.query_length == 0:
return self

min_q = min(self.coordinate_mapping.query_to_ref.keys(), default=0)
min_op = self.coordinate_mapping.query_to_op[min_q]

ops = [(1, op) for i, (op, ref_pointer, query_pointer)
in enumerate(self.iterate_operations_with_pointers())
if ref_pointer is None or i >= min_op]
return Cigar.coerce(ops)


def rstrip_query(self) -> 'Cigar':
""" Return a copy of the Cigar with trailing (unmatched) query elements removed. """

if self.query_length == 0:
return self

max_q = max(self.coordinate_mapping.query_to_ref.keys(),
default=len(self.coordinate_mapping.query_to_op) - 1)
max_op = self.coordinate_mapping.query_to_op[max_q]

ops = [(1, op) for i, (op, ref_pointer, query_pointer)
in enumerate(self.iterate_operations_with_pointers())
if ref_pointer is None or i <= max_op]
return Cigar.coerce(ops)


@cached_property
def coordinate_mapping(self) -> CoordinateMapping:
"""
Expand Down Expand Up @@ -580,23 +611,17 @@ def cut_reference(self, cut_point: float) -> Tuple['CigarHit', 'CigarHit']:
def lstrip_query(self) -> 'CigarHit':
""" Return a copy of the CigarHit with leading (unmatched) query elements removed. """

if len(self.coordinate_mapping.ref_to_query) == 0:
return self

closest_ref = self.coordinate_mapping.ref_to_query.closest_key(self.r_st - 1)
remainder, stripped = self.cut_reference(closest_ref - self.epsilon)
return stripped
cigar = self.cigar.lstrip_query()
return CigarHit(cigar, r_st=self.r_ei - cigar.ref_length + 1, r_ei=self.r_ei,
q_st=self.q_ei - cigar.query_length + 1, q_ei=self.q_ei)


def rstrip_query(self) -> 'CigarHit':
""" Return a copy of the CigarHit with trailing (unmatched) query elements removed. """

if len(self.coordinate_mapping.ref_to_query) == 0:
return self

closest_ref = self.coordinate_mapping.ref_to_query.closest_key(self.r_ei + 1)
stripped, remainder = self.cut_reference(closest_ref + self.epsilon)
return stripped
cigar = self.cigar.rstrip_query()
return CigarHit(cigar, r_st=self.r_st, r_ei=self.r_st + cigar.ref_length - 1,
q_st=self.q_st, q_ei=self.q_st + cigar.query_length - 1)


@cached_property
Expand Down

0 comments on commit ea58060

Please sign in to comment.