Skip to content

Commit

Permalink
Code to get stats
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 11, 2024
1 parent 6b625b2 commit a9a94f2
Showing 1 changed file with 56 additions and 14 deletions.
70 changes: 56 additions & 14 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,29 @@
# Quiet logs from pypdf
logging.getLogger("pypdf").setLevel(logging.ERROR)

# Global s3 client for the whole script, feel free to adjust params if you need it
# Global s3 clients fo the whole script, we have two separate ones in case your workspace and your pdfs are in different accounts
workspace_s3 = boto3.client('s3')
pdf_s3 = boto3.client('s3')

MAX_TOKENS = 3000
# Global variables for token statistics
total_input_tokens = 0
total_output_tokens = 0
process_start_time = time.perf_counter()
last_batch_time = process_start_time


@dataclass(frozen=True)
class PageResult:
s3_path: str
page_num: int
response: PageResponse

total_input_tokens: int
total_output_tokens: int


async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict:
MAX_TOKENS = 3000
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"

# Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread)
Expand Down Expand Up @@ -216,7 +225,9 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p
model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json)

return PageResult(pdf_s3_path, page_num, page_response)
return PageResult(pdf_s3_path, page_num, page_response,
total_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
total_output_tokens=base_response_data["usage"].get("completion_tokens", 0))
except Exception as e:
logger.exception(f"Exception while processing page {page_num}: {e}")
raise
Expand Down Expand Up @@ -250,7 +261,7 @@ async def process_pdf(args, pdf_s3_path: str):


# Build the document text and page spans
document_text = ''
document_text = ""
pdf_page_spans = []
current_char_pos = 0

Expand All @@ -264,7 +275,7 @@ async def process_pdf(args, pdf_s3_path: str):
document_text += content
current_char_pos = len(document_text)
pdf_page_spans.append({
'pdf_page_number': page_num,
'pdf_page_number': page_result.page_num,
'start_char': start_pos,
'end_char': current_char_pos
})
Expand All @@ -276,6 +287,8 @@ async def process_pdf(args, pdf_s3_path: str):
metadata = {
"Source-File": pdf_s3_path,
"pdf-total-pages": num_pages,
"total-input-tokens": sum(page.total_input_tokens for page in page_results),
"total-output-tokens": sum(page.total_output_tokens for page in page_results)
}

id_ = hashlib.sha1(document_text.encode()).hexdigest()
Expand All @@ -297,16 +310,45 @@ async def process_pdf(args, pdf_s3_path: str):

async def worker(args, queue):
while True:
[work_hash, pdfs] = await queue.get()

completed_pdfs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])

# Take all the not None completed_pdfs and write them as a jsonl to the workspace output location
# under the proper work_hash location
for dolma_doc in completed_pdfs:
pass

queue.task_done()
[work_hash, pdfs] = await queue.get()

try:
dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs])
dolma_docs = [doc for doc in dolma_docs if doc is not None]

# Write the Dolma documents to a local temporary file in JSONL format
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as tf:
for doc in dolma_docs:
tf.write(json.dumps(doc))
tf.write('\n')
tf.flush()

# Define the output S3 path using the work_hash
output_s3_path = os.path.join(args.workspace, 'dolma_documents', f'output_{work_hash}.jsonl')

bucket, key = parse_s3_path(output_s3_path)
workspace_s3.upload_file(tf.name, bucket, key)

# Sum up stats and report them since the last batch finished
global total_input_tokens, total_output_tokens, last_batch_time
batch_input_tokens = sum(doc["metadata"]["total-input-tokens"] for doc in dolma_docs)
batch_output_tokens = sum(doc["metadata"]["total-output-tokens"] for doc in dolma_docs)
batch_time = time.perf_counter() - last_batch_time
logger.info(f"Tokens per second (since last batch): input {batch_input_tokens / batch_time:.1f}, output {batch_output_tokens / batch_time:.1f}, total {(batch_input_tokens + batch_output_tokens) / batch_time:.1f}")

# Print statistics since process start
total_input_tokens += batch_input_tokens
total_output_tokens += batch_output_tokens
total_time = time.perf_counter() - process_start_time
logger.info(f"Tokens per second (since process start): input {total_input_tokens / total_time:.1f}, output {total_output_tokens / total_time:.1f}, total {(total_input_tokens + total_output_tokens) / total_time:.1f}")

# Update last batch time
last_batch_time = current_time
except Exception as e:
logger.exception(f"Exception occurred while processing work_hash {work_hash}: {e}")
finally:
queue.task_done()


async def sglang_server_task(args):
Expand Down

0 comments on commit a9a94f2

Please sign in to comment.