From 4f2f4fda7d5384e124bcdb8f90c1b39a1a5788ec Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 12 Nov 2024 08:18:22 -0800 Subject: [PATCH] Quicker results by limited workers via semaphore while still utilizing gpu --- pdelfin/beakerpipeline.py | 60 +++++++++++++++++++++++++++++++++------ 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/pdelfin/beakerpipeline.py b/pdelfin/beakerpipeline.py index bf3c562..ba1a39d 100644 --- a/pdelfin/beakerpipeline.py +++ b/pdelfin/beakerpipeline.py @@ -14,6 +14,7 @@ import aiohttp import datetime import tempfile +import re from tqdm import tqdm from io import BytesIO @@ -73,7 +74,7 @@ async def build_page_query(local_pdf_path: str, page: int, target_longest_image_ # Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread) image_base64 = asyncio.to_thread(render_pdf_to_base64png, local_pdf_path, page, target_longest_image_dim=target_longest_image_dim) - + # GET ANCHOR TEXT IS NOT THREAD SAFE!! Ahhhh..... don't try to do it # and it's also CPU bound, so it needs to run in a process pool loop = asyncio.get_running_loop() @@ -287,7 +288,7 @@ async def process_pdf(args, pdf_s3_path: str): logger.exception(f"Could not load page for {pdf_s3_path}, aborting document") return None - + # Build the document text and page spans document_text = "" pdf_page_spans = [] @@ -332,11 +333,14 @@ async def process_pdf(args, pdf_s3_path: str): return dolma_doc -async def worker(args, queue): +async def worker(args, queue, semaphore): while True: [work_hash, pdfs] = await queue.get() try: + # Wait until allowed to proceed + await semaphore.acquire() + dolma_docs = await asyncio.gather(*[process_pdf(args, pdf) for pdf in pdfs]) dolma_docs = [doc for doc in dolma_docs if doc is not None] @@ -372,7 +376,7 @@ async def worker(args, queue): queue.task_done() -async def sglang_server_task(args): +async def sglang_server_task(args, semaphore): model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', 'model') # TODO cache locally #download_directory(args.model, model_cache_dir) @@ -390,20 +394,53 @@ async def sglang_server_task(args): proc = await asyncio.create_subprocess_exec( "python3", - "-m", "sglang.launch_server", "--model-path", model_cache_dir, "--chat-template", args.model_chat_template, "--context-length", str(args.model_max_context), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, ) - # Make really sure we kill this subprocess on exit + # Make sure we kill this subprocess on exit def _kill_proc(): proc.terminate() atexit.register(_kill_proc) + last_queue_req = None # To track transitions + async def process_line(line): + # Parse the line and update semaphore if necessary + match = re.search(r'#running-req: (\d+), #queue-req: (\d+)', line) + if match: + logger.info(line) + running_req = int(match.group(1)) + queue_req = int(match.group(2)) + + nonlocal last_queue_req + if last_queue_req is not None and last_queue_req != 0 and queue_req == 0: + # Release the semaphore when queue_req transitions from non-zero to zero + if semaphore.locked(): + semaphore.release() + logger.info("Semaphore released, allowing a worker to proceed.") + + last_queue_req = queue_req + + async def read_stream(stream): + while True: + line = await stream.readline() + if not line: + break + line = line.decode('utf-8').rstrip() + await process_line(line) + + # Start tasks to read stdout and stderr + stdout_task = asyncio.create_task(read_stream(proc.stdout)) + stderr_task = asyncio.create_task(read_stream(proc.stderr)) + await proc.wait() + await stdout_task + await stderr_task async def sglang_server_ready(): @@ -463,7 +500,13 @@ async def main(): if args.pdfs: await populate_pdf_work_queue(args) - sglang_server = asyncio.create_task(sglang_server_task(args)) + # Create a semaphore to control worker access + # We only allow one worker to move forward with requests, until the server has no more requests in its queue + # This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible + # As soon as one worker is no longer saturating the gpu, the next one can start sending requests + semaphore = asyncio.Semaphore(1) + + sglang_server = asyncio.create_task(sglang_server_task(args, semaphore)) work_queue = await load_pdf_work_queue(args) logger.info(f"Work queue prepared with {work_queue.qsize()} items") @@ -473,7 +516,7 @@ async def main(): # Create worker tasks to process the queue concurrently. worker_tasks = [] for i in range(args.workers): - task = asyncio.create_task(worker(args, work_queue)) + task = asyncio.create_task(worker(args, work_queue, semaphore)) worker_tasks.append(task) # Wait for the queue to be fully processed @@ -501,4 +544,3 @@ async def main(): # TODO # Possible future addon, in beaker, discover other nodes on this same job # Send them a message when you take a work item off the queue -