Skip to content

Commit

Permalink
Merge pull request #565 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Fix llm layout missing text
  • Loading branch information
VikParuchuri authored Feb 19, 2025
2 parents 434c0ce + f0c9f22 commit 141da8c
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 14 deletions.
2 changes: 1 addition & 1 deletion marker/providers/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class PdfProvider(BaseProvider):
pdftext_workers: Annotated[
int,
"The number of workers to use for pdftext.",
] = 1
] = 4
flatten_pdf: Annotated[
bool,
"Whether to flatten the PDF structure.",
Expand Down
1 change: 1 addition & 0 deletions marker/schema/blocks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class Block(BaseModel):
metadata: BlockMetadata | None = None
lowres_image: Image.Image | None = None
highres_image: Image.Image | None = None
removed: bool = False # Has block been replaced by new block?

model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down
31 changes: 20 additions & 11 deletions marker/schema/groups/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from PIL import Image, ImageDraw

from pdftext.schema import Reference
from pydantic import computed_field

from marker.providers import ProviderOutput
from marker.schema import BlockTypes
from marker.schema.blocks import Block, BlockId, Text
Expand Down Expand Up @@ -45,14 +47,18 @@ def get_image(self, *args, highres: bool = False, remove_blocks: Sequence[BlockT
if remove_blocks:
image = image.copy()
draw = ImageDraw.Draw(image)
bad_blocks = [block for block in self.children if block.block_type in remove_blocks]
bad_blocks = [block for block in self.current_children if block.block_type in remove_blocks]
for bad_block in bad_blocks:
poly = bad_block.polygon.rescale(self.polygon.size, image.size).polygon
poly = [(int(p[0]), int(p[1])) for p in poly]
draw.polygon(poly, fill='white')

return image

@computed_field
@property
def current_children(self) -> List[Block]:
return [child for child in self.children if not child.removed]

def get_next_block(self, block: Optional[Block] = None, ignored_block_types: Optional[List[BlockTypes]] = None):
if ignored_block_types is None:
Expand Down Expand Up @@ -102,13 +108,9 @@ def assemble_html(self, document, child_blocks, parent_structure=None):
template += f"<content-ref src='{c.id}'></content-ref>"
return template

def compute_line_block_intersections(self, provider_outputs: List[ProviderOutput]):
def compute_line_block_intersections(self, blocks: List[Block], provider_outputs: List[ProviderOutput]):
max_intersections = {}

blocks = [
block for block in self.children
if block.block_type not in self.excluded_block_types
]
block_bboxes = [block.polygon.bbox for block in blocks]
line_bboxes = [provider_output.line.polygon.bbox for provider_output in provider_outputs]

Expand Down Expand Up @@ -137,6 +139,10 @@ def replace_block(self, block: Block, new_block: Block):
for child in self.children:
child.replace_block(block, new_block)

# Mark block as removed
block.removed = True


def identify_missing_blocks(
self,
provider_line_idxs: List[int],
Expand Down Expand Up @@ -224,7 +230,12 @@ def merge_blocks(
text_extraction_method: str
):
provider_line_idxs = list(range(len(provider_outputs)))
max_intersections = self.compute_line_block_intersections(provider_outputs)
valid_blocks = [
block for block in self.current_children # ensure we only look at children that haven't been replaced
if block.block_type not in self.excluded_block_types
]

max_intersections = self.compute_line_block_intersections(valid_blocks, provider_outputs)

# Try to assign lines by intersection
assigned_line_idxs = set()
Expand All @@ -241,9 +252,7 @@ def merge_blocks(
min_dist_idx = None
provider_output: ProviderOutput = provider_outputs[line_idx]
line = provider_output.line
for block in self.children:
if block.block_type in self.excluded_block_types:
continue
for block in valid_blocks:
# We want to assign to blocks closer in y than x
dist = line.polygon.center_distance(block.polygon, x_weight=5)
if min_dist_idx is None or dist < min_dist:
Expand All @@ -265,7 +274,7 @@ def aggregate_block_metadata(self) -> BlockMetadata:
if self.metadata is None:
self.metadata = BlockMetadata()

for block in self.children:
for block in self.current_children:
if block.metadata is not None:
self.metadata = self.metadata.merge(block.metadata)
return self.metadata
2 changes: 1 addition & 1 deletion marker/scripts/run_streamlit_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
def streamlit_app_cli():
cur_dir = os.path.dirname(os.path.abspath(__file__))
app_path = os.path.join(cur_dir, "streamlit_app.py")
cmd = ["streamlit", "run", app_path, "--server.fileWatcherType", "none"]
cmd = ["streamlit", "run", app_path, "--server.fileWatcherType", "none", "--server.headless", "true"]
subprocess.run(cmd, env={**os.environ, "IN_STREAMLIT": "true"})
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "marker-pdf"
version = "1.5.4"
version = "1.5.5"
description = "Convert PDF to markdown with high speed and accuracy."
authors = ["Vik Paruchuri <github@vikas.sh>"]
readme = "README.md"
Expand Down
42 changes: 42 additions & 0 deletions tests/builders/test_layout_replace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pytest

from marker.builders.document import DocumentBuilder
from marker.builders.layout import LayoutBuilder
from marker.builders.line import LineBuilder
from marker.renderers.markdown import MarkdownRenderer
from marker.schema import BlockTypes
from marker.schema.registry import get_block_class


@pytest.mark.config({"page_range": [0]})
def test_layout_replace(request, config, pdf_provider, layout_model, ocr_error_model, detection_model, inline_detection_model):
# The llm layout builder replaces blocks - this makes sure text is still merged properly
layout_builder = LayoutBuilder(layout_model, config)
line_builder = LineBuilder(detection_model, inline_detection_model, ocr_error_model, config)
builder = DocumentBuilder(config)
document = builder.build_document(pdf_provider)
layout_builder(document, pdf_provider)
page = document.pages[0]
new_blocks = []
for block in page.contained_blocks(document, (BlockTypes.Text,)):
generated_block_class = get_block_class(BlockTypes.TextInlineMath)
generated_block = generated_block_class(
polygon=block.polygon,
page_id=block.page_id,
structure=block.structure,
)
page.replace_block(block, generated_block)
new_blocks.append(generated_block)
line_builder(document, pdf_provider)

for block in new_blocks:
assert block.raw_text(document).strip()

renderer = MarkdownRenderer(config)
rendered = renderer(document)

assert "worst-case perturbations" in rendered.markdown
assert "projected gradient descent" in rendered.markdown



0 comments on commit 141da8c

Please sign in to comment.