Skip to content

Commit

Permalink
Merge pull request #429 from VikParuchuri/highquality-processors
Browse files Browse the repository at this point in the history
High Quality Layout Builder and Text Processors
  • Loading branch information
VikParuchuri authored Dec 20, 2024
2 parents a5de368 + d2c32af commit 26f68be
Show file tree
Hide file tree
Showing 18 changed files with 847 additions and 48 deletions.
2 changes: 1 addition & 1 deletion convert_single.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for a simple op, which is not supported on MPS

import time

Expand Down
169 changes: 169 additions & 0 deletions marker/builders/high_quality_layout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import json
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Optional

import google.generativeai as genai
import PIL
from google.ai.generativelanguage_v1beta.types import content
from google.api_core.exceptions import ResourceExhausted
from surya.model.layout.encoderdecoder import SuryaLayoutModel
from surya.model.ocr_error.model import DistilBertForSequenceClassification
from tqdm import tqdm

from marker.builders.layout import LayoutBuilder
from marker.providers.pdf import PdfProvider
from marker.schema import BlockTypes
from marker.schema.blocks import Block
from marker.schema.document import Document
from marker.schema.groups.page import PageGroup
from marker.schema.registry import get_block_class
from marker.settings import settings


class HighQualityLayoutBuilder(LayoutBuilder):
"""
A builder for relabelling blocks to improve the quality of the layout.
Attributes:
google_api_key (str):
The Google API key to use for the Gemini model.
Default is None.
confidence_threshold (float):
The confidence threshold to use for relabeling.
Default is 0.8.
model_name (str):
The name of the Gemini model to use.
Default is "gemini-1.5-flash".
max_retries (int):
The maximum number of retries to use for the Gemini model.
Default is 3.
max_concurrency (int):
The maximum number of concurrent requests to make to the Gemini model.
Default is 3.
timeout (int):
The timeout for requests to the Gemini model.
Default is 60 seconds.
gemini_relabelling_prompt (str):
The prompt to use for relabelling blocks.
Default is a string containing the Gemini relabelling prompt.
"""

google_api_key: Optional[str] = settings.GOOGLE_API_KEY
confidence_threshold: float = 0.7
model_name: str = "gemini-1.5-flash"
max_retries: int = 3
max_concurrency: int = 3
timeout: int = 60

gemini_relabelling_prompt = """You are a layout expert specializing in document analysis.
Your task is to relabel layout blocks in images to improve the accuracy of an existing layout model.
You will be provided with an image of a layout block and the top k predictions from the current model, along with their confidence scores.
Your job is to analyze the image and choose the single most appropriate label from the provided top k predictions.
Do not invent any new labels.
Carefully examine the image and consider the provided predictions.
Choose the label you believe is the most accurate representation of the layout block.
Here are the top k predictions from the model followed by the image:
"""

def __init__(self, layout_model: SuryaLayoutModel, ocr_error_model: DistilBertForSequenceClassification, config=None):
self.layout_model = layout_model
self.ocr_error_model = ocr_error_model

self.model = None
if self.google_api_key is None:
raise ValueError("Google API key is not set")

genai.configure(api_key=self.google_api_key)
self.model = genai.GenerativeModel(self.model_name)

def __call__(self, document: Document, provider: PdfProvider):
super().__call__(document, provider)

self.relabel_blocks(document)

def relabel_blocks(self, document: Document):
pbar = tqdm(desc="High quality layout relabelling")
with ThreadPoolExecutor(max_workers=self.max_concurrency) as executor:
futures = []
for page in document.pages:
for block_id in page.structure:
block = page.get_block(block_id)
if block.top_k:
confidence = block.top_k.get(block.block_type)
if confidence < self.confidence_threshold:
futures.append(executor.submit(self.process_block_relabelling, page, block))

for future in as_completed(futures):
future.result() # Raise exceptions if any occurred
pbar.update(1)

pbar.close()

def process_block_relabelling(self, page: PageGroup, block: Block):
topk = {str(k): round(v, 3) for k, v in block.top_k.items()}

prompt = self.gemini_relabelling_prompt + '```json' + json.dumps(topk) + '```\n'
image = self.extract_image(page, block)
response_schema = content.Schema(
type=content.Type.OBJECT,
enum=[],
required=["label"],
properties={
"label": content.Schema(
type=content.Type.STRING,
),
},
)

response = self.generate(prompt, image, response_schema)
generated_label = None
if response and "label" in response:
generated_label = response["label"]

if generated_label and generated_label != str(block.block_type):
generated_block_class = get_block_class(BlockTypes[generated_label])
generated_block = generated_block_class(
polygon=block.polygon,
page_id=block.page_id,
structure=block.structure,
)
page.replace_block(block, generated_block)

def extract_image(self, page: PageGroup, image_block: Block, expand: float = 0.01):
page_img = page.lowres_image
image_box = image_block.polygon\
.rescale(page.polygon.size, page_img.size)\
.expand(expand, expand)
cropped = page_img.crop(image_box.bbox)
return cropped

def generate(self, prompt: str, image: PIL.Image.Image, response_schema: content.Schema):
tries = 0
while tries < self.max_retries:
try:
responses = self.model.generate_content(
[prompt, image],
stream=False,
generation_config={
"temperature": 0,
"response_schema": response_schema,
"response_mime_type": "application/json",
},
request_options={'timeout': self.timeout}
)
output = responses.candidates[0].content.parts[0].text
return json.loads(output)

except ResourceExhausted as e:
tries += 1
wait_time = tries * 2
print(f"ResourceExhausted: {e}. Retrying in {wait_time} seconds... (Attempt {tries}/{self.max_retries})")
time.sleep(wait_time)
except Exception as e:
print(e)
break

return {}
4 changes: 2 additions & 2 deletions marker/builders/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,12 @@ def add_blocks_to_pages(self, pages: List[PageGroup], layout_results: List[Layou
for page, layout_result in zip(pages, layout_results):
layout_page_size = PolygonBox.from_bbox(layout_result.image_bbox).size
provider_page_size = page.polygon.size
page.layout_sliced = layout_result.sliced # This indicates if the page was sliced by the layout model
page.layout_sliced = layout_result.sliced # This indicates if the page was sliced by the layout model
for bbox in sorted(layout_result.bboxes, key=lambda x: x.position):
block_cls = get_block_class(BlockTypes[bbox.label])
layout_block = page.add_block(block_cls, PolygonBox(polygon=bbox.polygon))
layout_block.polygon = layout_block.polygon.rescale(layout_page_size, provider_page_size)
layout_block.top_k = {BlockTypes[label]: prob for (label, prob) in bbox.top_k.items()}
page.add_structure(layout_block)

# Ensure page has non-empty structure
Expand Down Expand Up @@ -177,4 +178,3 @@ def check_layout_coverage(
if not text_okay and (total_blocks == 1 and large_text_blocks == 1):
text_okay = True
return text_okay

4 changes: 4 additions & 0 deletions marker/config/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def common_options(fn):
fn = click.option("--disable_multiprocessing", is_flag=True, default=False, help="Disable multiprocessing.")(fn)
fn = click.option("--paginate_output", is_flag=True, default=False, help="Paginate output.")(fn)
fn = click.option("--disable_image_extraction", is_flag=True, default=False, help="Disable image extraction.")(fn)
fn = click.option("--high_quality", is_flag=True, default=False, help="Enable high quality processing with Gemini.")(fn)
return fn

def generate_config_dict(self) -> Dict[str, any]:
Expand Down Expand Up @@ -69,6 +70,9 @@ def generate_config_dict(self) -> Dict[str, any]:
case "disable_image_extraction":
if v:
config["extract_images"] = False
case "high_quality":
if v:
config["high_quality"] = True
return config

def get_renderer(self):
Expand Down
16 changes: 12 additions & 4 deletions marker/converters/pdf.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disables a tokenizers warning

import inspect
from collections import defaultdict
from typing import Any, Dict, List, Type

from marker.builders.document import DocumentBuilder
from marker.builders.high_quality_layout import HighQualityLayoutBuilder
from marker.builders.layout import LayoutBuilder
from marker.builders.ocr import OcrBuilder
from marker.builders.structure import StructureBuilder
Expand All @@ -16,6 +17,7 @@
from marker.processors.document_toc import DocumentTOCProcessor
from marker.processors.equation import EquationProcessor
from marker.processors.footnote import FootnoteProcessor
from marker.processors.high_quality_text import HighQualityTextProcessor
from marker.processors.ignoretext import IgnoreTextProcessor
from marker.processors.line_numbers import LineNumbersProcessor
from marker.processors.list import ListProcessor
Expand All @@ -36,17 +38,18 @@ class PdfConverter(BaseConverter):
A converter for processing and rendering PDF files into Markdown, JSON, HTML and other formats.
Attributes:
override_map (Dict[BlockTypes, Type[Block]]):
override_map (Dict[BlockTypes, Type[Block]]):
A mapping to override the default block classes for specific block types.
The keys are `BlockTypes` enum values, representing the types of blocks,
and the values are corresponding `Block` class implementations to use
instead of the defaults.
"""
override_map: Dict[BlockTypes, Type[Block]] = defaultdict()
high_quality: bool = False

def __init__(self, artifact_dict: Dict[str, Any], processor_list: List[str] | None = None, renderer: str | None = None, config=None):
super().__init__(config)

for block_type, override_block_type in self.override_map.items():
register_block_class(block_type, override_block_type)

Expand All @@ -66,6 +69,7 @@ def __init__(self, artifact_dict: Dict[str, Any], processor_list: List[str] | No
SectionHeaderProcessor,
TableProcessor,
TextProcessor,
HighQualityTextProcessor,
DebugProcessor,
]

Expand All @@ -78,6 +82,10 @@ def __init__(self, artifact_dict: Dict[str, Any], processor_list: List[str] | No
self.processor_list = processor_list
self.renderer = renderer

self.layout_builder_class = LayoutBuilder
if self.high_quality:
self.layout_builder_class = HighQualityLayoutBuilder

def resolve_dependencies(self, cls):
init_signature = inspect.signature(cls.__init__)
parameters = init_signature.parameters
Expand All @@ -99,7 +107,7 @@ def resolve_dependencies(self, cls):

def __call__(self, filepath: str):
pdf_provider = PdfProvider(filepath, self.config)
layout_builder = self.resolve_dependencies(LayoutBuilder)
layout_builder = self.resolve_dependencies(self.layout_builder_class)
ocr_builder = self.resolve_dependencies(OcrBuilder)
document = DocumentBuilder(self.config)(pdf_provider, layout_builder, ocr_builder)
StructureBuilder(self.config)(document)
Expand Down
Loading

0 comments on commit 26f68be

Please sign in to comment.