Skip to content

Commit

Permalink
Pipeline stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
jakep-allenai committed Nov 12, 2024
1 parent 691cc5a commit 918e2f3
Showing 1 changed file with 34 additions and 11 deletions.
45 changes: 34 additions & 11 deletions pdelfin/beakerpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,23 @@
# Initialize logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.propagate = False

sglang_logger = logging.getLogger("sglang")
sglang_logger.propagate = False

file_handler = logging.FileHandler('beakerpipeline-debug.log', mode='a')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))

console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))

# Add handlers to the logger
logger.addHandler(file_handler)
logger.addHandler(console_handler)
sglang_logger.addHandler(file_handler)

# Quiet logs from pypdf
logging.getLogger("pypdf").setLevel(logging.ERROR)
Expand All @@ -50,8 +60,10 @@
pdf_s3 = boto3.client('s3')

# Global variables for token statistics
total_input_tokens = 0
total_output_tokens = 0
finished_input_tokens = 0
finished_output_tokens = 0
sglang_input_tokens = 0
sglang_output_tokens = 0
process_start_time = time.perf_counter()
last_batch_time = process_start_time

Expand Down Expand Up @@ -225,7 +237,6 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p
COMPLETION_URL = "http://localhost:30000/v1/chat/completions"
MAX_RETRIES = 3


for attempt in range(1, MAX_RETRIES + 1):
query = await build_page_query(
pdf_local_path,
Expand All @@ -239,6 +250,11 @@ async def process_page(args, session: aiohttp.ClientSession, pdf_s3_path: str, p
response.raise_for_status()

base_response_data = await response.json()

# Update global sglang token counts
global sglang_input_tokens, sglang_output_tokens
sglang_input_tokens += base_response_data["usage"].get("prompt_tokens", 0)
sglang_output_tokens += base_response_data["usage"].get("completion_tokens", 0)

model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json)
Expand Down Expand Up @@ -333,7 +349,7 @@ async def process_pdf(args, pdf_s3_path: str):
return dolma_doc


async def worker(args, queue, semaphore):
async def worker(args, queue, semaphore, worker_id):
while True:
[work_hash, pdfs] = await queue.get()

Expand All @@ -357,16 +373,21 @@ async def worker(args, queue, semaphore):
bucket, key = parse_s3_path(output_s3_path)
workspace_s3.upload_file(tf.name, bucket, key)

# Sum up stats and report them since the last batch finished
# Update finished token counts from successful documents
batch_input_tokens = sum(doc["metadata"]["total-input-tokens"] for doc in dolma_docs)
batch_output_tokens = sum(doc["metadata"]["total-output-tokens"] for doc in dolma_docs)


# Print statistics since process start
global total_input_tokens, total_output_tokens, last_batch_time
total_input_tokens += batch_input_tokens
total_output_tokens += batch_output_tokens
global finished_input_tokens, finished_output_tokens, last_batch_time
finished_input_tokens += batch_input_tokens
finished_output_tokens += batch_output_tokens
total_time = time.perf_counter() - process_start_time
logger.info(f"Processing speed: input {total_input_tokens / total_time:.1f} tok/sec, output {total_output_tokens / total_time:.1f} tok/sec, total {(total_input_tokens + total_output_tokens) / total_time:.1f} tok/sec")

# Log both finished and total sglang token statistics
logger.info(f"""Token Statistics:
Finished documents: input {finished_input_tokens / total_time:.1f} tok/sec, output {finished_output_tokens / total_time:.1f} tok/sec, total {(finished_input_tokens + finished_output_tokens) / total_time:.1f} tok/sec
All SGLang requests: input {sglang_input_tokens / total_time:.1f} tok/sec, output {sglang_output_tokens / total_time:.1f} tok/sec, total {(sglang_input_tokens + sglang_output_tokens) / total_time:.1f} tok/sec""")

# Update last batch time
last_batch_time = time.perf_counter()
Expand Down Expand Up @@ -410,11 +431,13 @@ def _kill_proc():

last_queue_req = None # To track transitions
async def process_line(line):
sglang_logger.info(line)

# Parse the line and update semaphore if necessary
match = re.search(r'#queue-req: (\d+)', line)
if match:
logger.info(line)
queue_req = int(match.group(1))
logger.info(f"sglang queue req: {queue_req}")

nonlocal last_queue_req
if last_queue_req is not None and last_queue_req != 0 and queue_req == 0:
Expand Down Expand Up @@ -515,7 +538,7 @@ async def main():
# Create worker tasks to process the queue concurrently.
worker_tasks = []
for i in range(args.workers):
task = asyncio.create_task(worker(args, work_queue, semaphore))
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
worker_tasks.append(task)

# Wait for the queue to be fully processed
Expand Down

0 comments on commit 918e2f3

Please sign in to comment.