Skip to content

Commit

Permalink
feat: allow to match text on OCRResult
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 committed Apr 3, 2023
1 parent 30bc347 commit 73ae0e6
Show file tree
Hide file tree
Showing 2 changed files with 215 additions and 6 deletions.
154 changes: 153 additions & 1 deletion robotoff/prediction/ocr/dataclass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
import itertools
import math
import operator
import re
Expand Down Expand Up @@ -224,6 +225,24 @@ def get_languages(self) -> Optional[dict[str, int]]:

return None

def match(
self,
pattern: str,
preprocess_func: Optional[Callable[[str], str]] = None,
strip_characters: Optional[str] = None,
) -> Optional[list[list["Word"]]]:
"""Find the words in the image that match the pattern words in the
same order.
Return None if full text annotations are not available.
See `Paragraph.match` for more details.
"""
if self.full_text_annotation:
return self.full_text_annotation.match(
pattern, preprocess_func, strip_characters
)
return None


def get_text(
content: Union[OCRResult, str],
Expand Down Expand Up @@ -308,6 +327,24 @@ def detect_orientation(self) -> OrientationResult:
count = Counter(word_orientations)
return OrientationResult(count)

def match(
self,
pattern: str,
preprocess_func: Optional[Callable[[str], str]] = None,
strip_characters: Optional[str] = None,
) -> list[list["Word"]]:
"""Find the words in the image that match the pattern words in the
same order.
See `Paragraph.match` for more details.
"""
return list(
itertools.chain.from_iterable(
page.match(pattern, preprocess_func, strip_characters)
for page in self.pages
)
)


class TextAnnotationPage:
"""Detected page from OCR."""
Expand Down Expand Up @@ -335,6 +372,24 @@ def detect_words_orientation(self) -> list[ImageOrientation]:

return word_orientations

def match(
self,
pattern: str,
preprocess_func: Optional[Callable[[str], str]] = None,
strip_characters: Optional[str] = None,
) -> list[list["Word"]]:
"""Find the words in the page that match the pattern words in the
same order.
See `Paragraph.match` for more details.
"""
return list(
itertools.chain.from_iterable(
block.match(pattern, preprocess_func, strip_characters)
for block in self.blocks
)
)


class Block:
"""Logical element on the page."""
Expand Down Expand Up @@ -373,6 +428,24 @@ def detect_words_orientation(self) -> list[ImageOrientation]:

return word_orientations

def match(
self,
pattern: str,
preprocess_func: Optional[Callable[[str], str]] = None,
strip_characters: Optional[str] = None,
) -> list[list["Word"]]:
"""Find the words in the block that match the pattern words in the
same order.
See `Paragraph.match` for more details.
"""
return list(
itertools.chain.from_iterable(
paragraph.match(pattern, preprocess_func, strip_characters)
for paragraph in self.paragraphs
)
)


class Paragraph:
"""Structural unit of text representing a number of words in certain
Expand Down Expand Up @@ -411,11 +484,56 @@ def get_text(self) -> str:
"""Return the text of the paragraph, by concatenating the words."""
return "".join(w.text for w in self.words)

def match(
self,
pattern: str,
preprocess_func: Optional[Callable[[str], str]] = None,
strip_characters: Optional[str] = None,
) -> list[list["Word"]]:
"""Find the words in the paragraph that match the pattern words in
the same order.
The pattern is first splitted with a whitespace word delimiter, then
we iterate over both the words and the pattern words to find matches.
See `Word.match` for a description of `preprocess_func` and
`strip_characters` parameters or for more details about the matching
process.
:param pattern: the string pattern to look for
"""
pattern_words = pattern.split()
matches = []
# Iterate over the words
for word_idx in range(len(self.words)):
current_word = self.words[word_idx]
stack = list(pattern_words)
matched_words = []
while stack:
pattern_word = stack.pop(0)
# if there is no match or if there is no word left while the
# pattern stack is not empty, there is no match: break to
# continue to next word
if not current_word.match(
pattern_word, preprocess_func, strip_characters
) or word_idx + 1 >= len(self.words):
break
matched_words.append(current_word)
# there is a partial match, continue to next word to see if
# there is a full match
word_idx += 1
current_word = self.words[word_idx]
else:
# No break occured, so it's a full match
matches.append(matched_words)

return matches


class Word:
"""A word representation."""

__slots__ = ("bounding_poly", "symbols", "languages")
__slots__ = ("bounding_poly", "symbols", "languages", "_text")

def __init__(self, data: JSONType):
self.bounding_poly = BoundingPoly(data["boundingBox"])
Expand All @@ -429,6 +547,9 @@ def __init__(self, data: JSONType):
DetectedLanguage(lang) for lang in data["property"]["detectedLanguages"]
]

# Attribute to store text generated from symbols
self._text = None

@property
def text(self):
if not self._text:
Expand Down Expand Up @@ -474,6 +595,37 @@ def on_same_line(self, word: "Word"):
word_symbol_width,
)

def match(
self,
pattern: str,
preprocess_func: Optional[Callable[[str], str]] = None,
strip_characters: Optional[str] = None,
) -> bool:
"""Return True if the pattern is equal to the word string after
preprocessing, False otherwise.
A first preprocessing step is performed on the word and the
pattern: punctuation marks, spaces and line breaks are stripped.
:param pattern: a string to match
:param preprocess_func: a preprocessing function to apply to pattern
and word string, defaults to identity
:param strip_characters: word characters to strip before matching.
By default, the following character list is used: "\\n .,!?"
Pass an empty string to remove any character stripping.
"""
preprocess_func = preprocess_func or (lambda x: x)

if strip_characters is None:
strip_characters = "\n .,!?"

return preprocess_func(self.text.strip(strip_characters)) == preprocess_func(
pattern.strip(strip_characters)
)

def __repr__(self) -> str:
return f"<Word: {self.text}>"


class Symbol:
"""A single symbol representation."""
Expand Down
67 changes: 62 additions & 5 deletions tests/unit/prediction/ocr/test_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
import json
import pathlib
import re
from typing import Callable, Optional

import pytest

from robotoff.prediction.ocr.dataclass import OCRParsingException, OCRResult
from robotoff.utils.fold_to_ascii import fold

data_dir = pathlib.Path(__file__).parent / "data"


@pytest.mark.parametrize("ocr_name", ["3038350013804_11.json"])
def test_ocr_result_extraction_non_regression(ocr_name: str):
with (data_dir / ocr_name).open("r") as f:
@pytest.fixture(scope="session")
def example_ocr_result():
with (data_dir / "3038350013804_11.json").open("r") as f:
data = json.load(f)

result = OCRResult.from_json(data)
assert result
yield OCRResult.from_json(data)


def test_ocr_result_extraction_non_regression(example_ocr_result):
assert example_ocr_result


class TestOCRResult:
Expand Down Expand Up @@ -45,3 +50,55 @@ def test_from_json_error_response(self):
match=re.escape("Error in OCR response: [{'this is an error'}"),
):
OCRResult.from_json({"responses": [{"error": [{"this is an error"}]}]})


@pytest.mark.parametrize(
"pattern,expected_matches,preprocess_func",
[
(
"fromage de chèvre frais",
[
[
"fromage ",
"de ",
"chèvre ",
"frais ",
]
],
None,
),
# no preprocessing, in OCR it's "Mélangez bien les pâtes" (notice the upper letter)
("mélangez bien les pâtes", [], None),
(
"mélangez bien les pâtes",
[["Mélangez ", "bien ", "les ", "pâtes "]],
lambda x: x.lower(),
),
# Fold + lowercase should return a match
(
"MELANGEZ BIEN LES PATES",
[["Mélangez ", "bien ", "les ", "pâtes "]],
lambda x: fold(x.lower()),
),
# Test that ', ' is stripped after last word ('ciboulette, ')
("brins de ciboulette", [["brins ", "de ", "ciboulette, "]], None),
# Test multiple matches, we expect 2 matches
("rondelles", [["rondelles. "], ["rondelles "]], None),
],
)
def test_match(
pattern: str,
expected_matches: Optional[list[list[str]]],
preprocess_func: Optional[Callable[[str], str]],
example_ocr_result: OCRResult,
):
matches = example_ocr_result.match(pattern, preprocess_func)

if expected_matches is None:
assert matches is None
else:
assert matches is not None
assert len(matches) == len(expected_matches)

for match, expected_words in zip(matches, expected_matches):
assert [word.text for word in match] == expected_words

0 comments on commit 73ae0e6

Please sign in to comment.