generated from allenai/python-package-template
-
Notifications
You must be signed in to change notification settings - Fork 210
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generating parquets for hugging face
- Loading branch information
1 parent
84c0c71
commit 6ed6f85
Showing
1 changed file
with
186 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
# Script to generate parquet dataset files to upload to hugging face | ||
# Input is a dataset location /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json | ||
# Each json line has a custom id that looks like {"custom_id": "s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1", ... more data} | ||
|
||
# Fix this script so that it works, and that it will take a path to an input dataset, and sqllite database location | ||
# And then it will build a parquet file with rows that look like: "id", "url", "page_number", "response" | ||
# Where Id will be the output of parse_pdf_hash plus "-" plus the page number | ||
# The url will be the result of get_uri_from_db | ||
# Rresponse will be NormalizedEntry.text | ||
|
||
import argparse | ||
import glob | ||
import json | ||
import re | ||
import sqlite3 | ||
from dataclasses import dataclass | ||
from typing import Optional | ||
|
||
from tqdm import tqdm | ||
import pandas as pd | ||
|
||
|
||
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]: | ||
""" | ||
Extracts a hash from a pretty PDF S3 URL. | ||
For example, given: | ||
s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1 | ||
it will return "de80a57e6c57b45796d2e020173227f7eae44232". | ||
""" | ||
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf" | ||
match = re.match(pattern, pretty_pdf_path) | ||
if match: | ||
return match.group(1) + match.group(2) | ||
return None | ||
|
||
|
||
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]: | ||
""" | ||
Looks up the URL for the given pdf_hash in the sqlite database. | ||
Assumes there is a table called 'pdf_mapping' with a column 'uri'. | ||
""" | ||
conn = sqlite3.connect(db_path) | ||
cursor = conn.cursor() | ||
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,)) | ||
result = cursor.fetchone() | ||
conn.close() | ||
return result[0] if result else None | ||
|
||
|
||
@dataclass(frozen=True) | ||
class NormalizedEntry: | ||
s3_path: str | ||
pagenum: int | ||
text: Optional[str] | ||
finish_reason: Optional[str] | ||
error: Optional[str] = None | ||
|
||
@staticmethod | ||
def from_goldkey(goldkey: str, **kwargs): | ||
""" | ||
Constructs a NormalizedEntry from a goldkey string. | ||
The goldkey is expected to be of the format: | ||
<s3_path>-<page_number> | ||
""" | ||
s3_path = goldkey[: goldkey.rindex("-")] | ||
page_num = int(goldkey[goldkey.rindex("-") + 1 :]) | ||
return NormalizedEntry(s3_path, page_num, **kwargs) | ||
|
||
@property | ||
def goldkey(self): | ||
return f"{self.s3_path}-{self.pagenum}" | ||
|
||
|
||
def normalize_json_entry(data: dict) -> NormalizedEntry: | ||
""" | ||
Normalizes a JSON entry from any of the supported formats. | ||
It supports: | ||
- Birr: looks for an "outputs" field. | ||
- Already normalized entries: if they contain s3_path, pagenum, etc. | ||
- OpenAI: where the response is in data["response"]["body"]["choices"]. | ||
- SGLang: where the response is in data["response"]["choices"]. | ||
""" | ||
if "outputs" in data: | ||
# Birr case | ||
if data["outputs"] is None: | ||
text = None | ||
finish_reason = None | ||
else: | ||
text = data["outputs"][0]["text"] | ||
finish_reason = data["outputs"][0]["finish_reason"] | ||
|
||
return NormalizedEntry.from_goldkey( | ||
goldkey=data["custom_id"], | ||
text=text, | ||
finish_reason=finish_reason, | ||
error=data.get("completion_error", None), | ||
) | ||
elif all(field in data for field in ["s3_path", "pagenum", "text", "error", "finish_reason"]): | ||
# Already normalized | ||
return NormalizedEntry(**data) | ||
elif "response" in data and "body" in data["response"] and "choices" in data["response"]["body"]: | ||
return NormalizedEntry.from_goldkey( | ||
goldkey=data["custom_id"], | ||
text=data["response"]["body"]["choices"][0]["message"]["content"], | ||
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"], | ||
) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description="Generate a Parquet dataset file for HuggingFace upload." | ||
) | ||
parser.add_argument( | ||
"input_dataset", | ||
help="Input dataset file pattern (e.g., '/data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json')", | ||
) | ||
parser.add_argument( | ||
"db_path", help="Path to the SQLite database file." | ||
) | ||
parser.add_argument( | ||
"--output", default="output.parquet", help="Output Parquet file path." | ||
) | ||
args = parser.parse_args() | ||
|
||
rows = [] | ||
files = glob.glob(args.input_dataset) | ||
print(f"Found {len(files)} files matching pattern: {args.input_dataset}") | ||
|
||
for file_path in tqdm(files): | ||
print(f"Processing file: {file_path}") | ||
with open(file_path, "r", encoding="utf-8") as f: | ||
for line_num, line in enumerate(f, start=1): | ||
line = line.strip() | ||
if not line: | ||
continue | ||
try: | ||
data = json.loads(line) | ||
except json.JSONDecodeError as e: | ||
print(f"Skipping invalid JSON at {file_path}:{line_num} - {e}") | ||
continue | ||
|
||
try: | ||
normalized = normalize_json_entry(data) | ||
except Exception as e: | ||
print(f"Error normalizing entry at {file_path}:{line_num} - {e}") | ||
continue | ||
|
||
# Use the s3_path from the normalized entry to extract the pdf hash. | ||
pdf_hash = parse_pdf_hash(normalized.s3_path) | ||
if pdf_hash is None: | ||
print( | ||
f"Could not parse pdf hash from {normalized.s3_path} at {file_path}:{line_num}" | ||
) | ||
continue | ||
|
||
# The output id is the pdf hash plus '-' plus the page number. | ||
combined_id = f"{pdf_hash}-{normalized.pagenum}" | ||
|
||
# Look up the corresponding URL from the sqlite database. | ||
url = get_uri_from_db(args.db_path, pdf_hash) | ||
if url is None: | ||
print( | ||
f"No URL found in DB for pdf hash {pdf_hash} at {file_path}:{line_num}" | ||
) | ||
continue | ||
|
||
row = { | ||
"id": combined_id, | ||
"url": url, | ||
"page_number": normalized.pagenum, | ||
"response": normalized.text, | ||
} | ||
rows.append(row) | ||
|
||
break | ||
|
||
if rows: | ||
df = pd.DataFrame(rows) | ||
df.to_parquet(args.output, index=False) | ||
print(f"Successfully wrote {len(df)} rows to {args.output}") | ||
else: | ||
print("No rows to write. Exiting.") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |