Skip to content

Commit

Permalink
Applying rotation corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Oct 28, 2024
1 parent 08d51b7 commit ce2e4ba
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions pdelfin/birrpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,22 @@ def get_pdf_num_pages(s3_path: str) -> Optional[int]:

return None


def _get_page_data(page_index_entries: List[DatabaseManager.BatchInferenceRecord]) -> List[PageResponse]:
usable_page_data = [get_s3_bytes(workspace_s3, page.inference_s3_path,
start_index=page.start_index,
end_index=page.start_index + page.length - 1) for page in page_index_entries]

usable_page_final_results = []
for page_data in usable_page_data:
data = orjson.loads(page_data)
model_response_json = orjson.loads(data["outputs"][0]["text"])
page_response = PageResponse(**model_response_json)
usable_page_final_results.append(page_response)

return usable_page_final_results


def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_round: int, target_longest_image_dim: int, target_anchor_text_len: int) -> list[dict]:
db = DatabaseManager(s3_workspace, skip_init=True)

Expand All @@ -527,10 +543,13 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_rou
# Retry the page at least one more time regularly
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len), "round": cur_round})

# TODO: If the rotation was previously invalid, then apply a rotation
# If the rotation was previously invalid, then apply a rotation
rotated_page_data = _get_page_data([page for page in existing_pages if page.page_num == target_page_num and page.error == "rotation_invalid"])
rotation_corrections = set(page_data.rotation_correction for page_data in rotated_page_data)
for correction in rotation_corrections:
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len, image_rotation=correction), "round": cur_round})


# TODO: Try to provide a smaller prompt hint
# TODO: Try to provide a smaller prompt hint if that was the error
else:
new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len), "round": cur_round})
except Exception as ex:
Expand All @@ -554,17 +573,7 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option

for target_page_num in range(1, pdf.num_pages + 1):
usable_pages = [page for page in existing_pages if page.is_usable() and page.page_num == target_page_num]

usable_page_data = [get_s3_bytes(workspace_s3, page.inference_s3_path,
start_index=page.start_index,
end_index=page.start_index + page.length - 1) for page in usable_pages]

usable_page_final_results = []
for page_data in usable_page_data:
data = orjson.loads(page_data)
model_response_json = orjson.loads(data["outputs"][0]["text"])
page_response = PageResponse(**model_response_json)
usable_page_final_results.append(page_response)
usable_page_final_results = _get_page_data(usable_pages)

# Sort the pages:
# 1. Prefer pages with `is_rotation_valid` set to True.
Expand Down

0 comments on commit ce2e4ba

Please sign in to comment.