diff --git a/.gitignore b/.gitignore index 6de8849a..a13d13c3 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ google_jetstream.egg-info/ data/ logs/ tmp/ -venv/ \ No newline at end of file +venv/ +.vscode/ diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 8b5268e5..34b45d8e 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -93,8 +93,9 @@ from jetstream.core.proto import jetstream_pb2_grpc from jetstream.core.utils import async_multifuture from jetstream.engine import engine_api -import numpy as np +import numpy as np +import prometheus_client root = logging.getLogger() root.setLevel(logging.DEBUG) @@ -209,6 +210,9 @@ class Driver: # todo: remove jax_padding after all then engine migrate to np padding _jax_padding = True + # Record metrics for prefill_backlog size + _prefill_backlog_size_metric: prometheus_client.Gauge + def __init__( self, prefill_engines: Optional[list[engine_api.Engine]] = None, @@ -242,6 +246,10 @@ def __init__( # Stage 1 # At first, a request is placed here in order to get prefilled. self._prefill_backlog = queue.Queue() + self._prefill_backlog_size_metric = prometheus_client.Gauge( + "jetstream_prefill_backlog_size", "Size of prefill queue" + ) + # Stage 2 # After prefilling, it is placed here in order to get transferred to # one of the generate backlogs. @@ -421,6 +429,7 @@ def place_request_on_prefill_queue(self, request: ActiveRequest): """Used to place new requests for prefilling and generation.""" # Don't block so we can fail and shed load when the queue is full. self._prefill_backlog.put(request, block=False) + self._prefill_backlog_size_metric.set(self._prefill_backlog.qsize()) def _load_cache_history(self, path: str) -> Union[None, Any]: """Loads previous kv cache for a longer conversation.""" @@ -442,6 +451,8 @@ def _prefill_thread(self, idx: int): my_transfer_backlog = self._transfer_backlogs[idx] # The prefill thread can just sleep until it has work to do. request = self._prefill_backlog.get(block=True) + self._prefill_backlog_size_metric.set(self._prefill_backlog.qsize()) + if request is None: break # Tokenize, and introduce a leading dimension diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index d66af518..cc983535 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -20,6 +20,7 @@ import asyncio from concurrent import futures import logging +import os import threading from typing import Any, Type @@ -29,8 +30,14 @@ from jetstream.core import orchestrator from jetstream.core.proto import jetstream_pb2_grpc +from prometheus_client import start_http_server _HOST = "[::]" +PROMETHEUS_ENABLED_ON_PORT = ( + int(os.getenv("PROMETHEUS_ENABLED_ON_PORT")) + if os.getenv("PROMETHEUS_ENABLED_ON_PORT") + else None +) class JetStreamServer: @@ -130,6 +137,17 @@ def run( logging.info("Starting server on port %d with %d threads", port, threads) jetstream_server.start() + + # Setup Prometheus server + if PROMETHEUS_ENABLED_ON_PORT is not None: + logging.info( + "Starting Prometheus server on port %d", PROMETHEUS_ENABLED_ON_PORT + ) + start_http_server(PROMETHEUS_ENABLED_ON_PORT) + else: + logging.info( + "Not starting Prometheus server: PROMETHEUS_ENABLED_ON_PORT not set" + ) return jetstream_server diff --git a/requirements.in b/requirements.in index bc5ba8fc..6f4ebb60 100644 --- a/requirements.in +++ b/requirements.in @@ -6,7 +6,8 @@ jax jaxlib numpy portpicker +prometheus-client pytest seqio tiktoken -blobfile \ No newline at end of file +blobfile diff --git a/requirements.txt b/requirements.txt index 20029e10..f2a23c7b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -177,6 +177,8 @@ pluggy==1.4.0 # via pytest portpicker==1.6.0 # via -r requirements.in +prometheus-client==0.20.0 + # via -r requirements.in promise==2.3 # via tfds-nightly protobuf==3.20.3